word_freq_shards.cc revision cc454125
1144e8e4eSShuo Chen#include <boost/noncopyable.hpp>
2144e8e4eSShuo Chen#include <boost/ptr_container/ptr_vector.hpp>
3144e8e4eSShuo Chen
4144e8e4eSShuo Chen#include <fstream>
5144e8e4eSShuo Chen#include <iostream>
6144e8e4eSShuo Chen#include <unordered_map>
7144e8e4eSShuo Chen
8144e8e4eSShuo Chen#ifdef STD_STRING
9144e8e4eSShuo Chen#warning "STD STRING"
10144e8e4eSShuo Chen#include <string>
11144e8e4eSShuo Chenusing std::string;
12144e8e4eSShuo Chen#else
13144e8e4eSShuo Chen#include <ext/vstring.h>
14144e8e4eSShuo Chentypedef __gnu_cxx::__sso_string string;
15144e8e4eSShuo Chen#endif
16144e8e4eSShuo Chen
17144e8e4eSShuo Chenconst size_t kMaxSize = 10 * 1000 * 1000;
18144e8e4eSShuo Chen
19144e8e4eSShuo Chenclass Sharder : boost::noncopyable
20144e8e4eSShuo Chen{
21144e8e4eSShuo Chen public:
22144e8e4eSShuo Chen  explicit Sharder(int nbuckets)
23144e8e4eSShuo Chen    : buckets_(nbuckets)
24144e8e4eSShuo Chen  {
25144e8e4eSShuo Chen    for (int i = 0; i < nbuckets; ++i)
26144e8e4eSShuo Chen    {
27144e8e4eSShuo Chen      char buf[256];
28144e8e4eSShuo Chen      snprintf(buf, sizeof buf, "shard-%05d-of-%05d", i, nbuckets);
29144e8e4eSShuo Chen      buckets_.push_back(new std::ofstream(buf));
30144e8e4eSShuo Chen    }
31144e8e4eSShuo Chen    assert(buckets_.size() == static_cast<size_t>(nbuckets));
32144e8e4eSShuo Chen  }
33144e8e4eSShuo Chen
34cc454125SShuo Chen  void output(const string& word, int64_t count)
35144e8e4eSShuo Chen  {
36cc454125SShuo Chen    size_t idx = std::hash<string>()(word) % buckets_.size();
37cc454125SShuo Chen    buckets_[idx] << word << '\t' << count << '\n';
38144e8e4eSShuo Chen  }
39144e8e4eSShuo Chen
40144e8e4eSShuo Chen protected:
41144e8e4eSShuo Chen  boost::ptr_vector<std::ofstream> buckets_;
42144e8e4eSShuo Chen};
43144e8e4eSShuo Chen
44144e8e4eSShuo Chenvoid shard(int nbuckets, int argc, char* argv[])
45144e8e4eSShuo Chen{
46144e8e4eSShuo Chen  Sharder sharder(nbuckets);
47144e8e4eSShuo Chen  for (int i = 1; i < argc; ++i)
48144e8e4eSShuo Chen  {
49144e8e4eSShuo Chen    std::cout << "  processing input file " << argv[i] << std::endl;
50cc454125SShuo Chen    std::unordered_map<string, int64_t> counts;
51144e8e4eSShuo Chen    std::ifstream in(argv[i]);
52144e8e4eSShuo Chen    while (in && !in.eof())
53144e8e4eSShuo Chen    {
54cc454125SShuo Chen      counts.clear();
55cc454125SShuo Chen      string word;
56cc454125SShuo Chen      while (in >> word)
57144e8e4eSShuo Chen      {
58cc454125SShuo Chen        counts[word] += 1;
59cc454125SShuo Chen        if (counts.size() > kMaxSize)
60144e8e4eSShuo Chen        {
61144e8e4eSShuo Chen          std::cout << "    split" << std::endl;
62144e8e4eSShuo Chen          break;
63144e8e4eSShuo Chen        }
64144e8e4eSShuo Chen      }
65144e8e4eSShuo Chen
66cc454125SShuo Chen      for (auto kv : counts)
67144e8e4eSShuo Chen      {
68144e8e4eSShuo Chen        sharder.output(kv.first, kv.second);
69144e8e4eSShuo Chen      }
70144e8e4eSShuo Chen    }
71144e8e4eSShuo Chen  }
72144e8e4eSShuo Chen  std::cout << "shuffling done" << std::endl;
73144e8e4eSShuo Chen}
74144e8e4eSShuo Chen
75144e8e4eSShuo Chen// ======= combine =======
76144e8e4eSShuo Chen
77144e8e4eSShuo Chenstd::unordered_map<string, int64_t> read_shard(int idx, int nbuckets)
78144e8e4eSShuo Chen{
79cc454125SShuo Chen  std::unordered_map<string, int64_t> counts;
80144e8e4eSShuo Chen
81144e8e4eSShuo Chen  char buf[256];
82144e8e4eSShuo Chen  snprintf(buf, sizeof buf, "shard-%05d-of-%05d", idx, nbuckets);
83144e8e4eSShuo Chen  std::cout << "  reading " << buf << std::endl;
84144e8e4eSShuo Chen  {
85144e8e4eSShuo Chen    std::ifstream in(buf);
86144e8e4eSShuo Chen    string line;
87144e8e4eSShuo Chen
88144e8e4eSShuo Chen    while (getline(in, line))
89144e8e4eSShuo Chen    {
90144e8e4eSShuo Chen      size_t tab = line.find('\t');
91144e8e4eSShuo Chen      if (tab != string::npos)
92144e8e4eSShuo Chen      {
93144e8e4eSShuo Chen        int64_t count = strtol(line.c_str() + tab, NULL, 10);
94144e8e4eSShuo Chen        if (count > 0)
95144e8e4eSShuo Chen        {
96cc454125SShuo Chen          counts[line.substr(0, tab)] += count;
97144e8e4eSShuo Chen        }
98144e8e4eSShuo Chen      }
99144e8e4eSShuo Chen    }
100144e8e4eSShuo Chen  }
101144e8e4eSShuo Chen
102144e8e4eSShuo Chen  ::unlink(buf);
103cc454125SShuo Chen  return counts;
104144e8e4eSShuo Chen}
105144e8e4eSShuo Chen
106144e8e4eSShuo Chenvoid combine(const int nbuckets)
107144e8e4eSShuo Chen{
108144e8e4eSShuo Chen  for (int i = 0; i < nbuckets; ++i)
109144e8e4eSShuo Chen  {
110cc454125SShuo Chen    std::unordered_map<string, int64_t> counts(read_shard(i, nbuckets));
111144e8e4eSShuo Chen
112144e8e4eSShuo Chen    // std::cout << "  sorting " << std::endl;
113144e8e4eSShuo Chen    std::vector<std::pair<int64_t, string>> counts;
114cc454125SShuo Chen    for (const auto& entry : counts)
115144e8e4eSShuo Chen    {
116144e8e4eSShuo Chen      counts.push_back(make_pair(entry.second, entry.first));
117144e8e4eSShuo Chen    }
118144e8e4eSShuo Chen    std::sort(counts.begin(), counts.end());
119144e8e4eSShuo Chen
120144e8e4eSShuo Chen    char buf[256];
121144e8e4eSShuo Chen    snprintf(buf, sizeof buf, "count-%05d-of-%05d", i, nbuckets);
122144e8e4eSShuo Chen    std::ofstream out(buf);
123144e8e4eSShuo Chen    std::cout << "  writing " << buf << std::endl;
124144e8e4eSShuo Chen    for (auto it = counts.rbegin(); it != counts.rend(); ++it)
125144e8e4eSShuo Chen    {
126144e8e4eSShuo Chen      out << it->first << '\t' << it->second << '\n';
127144e8e4eSShuo Chen    }
128144e8e4eSShuo Chen  }
129144e8e4eSShuo Chen
130144e8e4eSShuo Chen  std::cout << "reducing done" << std::endl;
131144e8e4eSShuo Chen}
132144e8e4eSShuo Chen
133144e8e4eSShuo Chen// ======= merge =======
134144e8e4eSShuo Chen
135cc454125SShuo Chenclass Source  // copyable
136144e8e4eSShuo Chen{
137144e8e4eSShuo Chen public:
138144e8e4eSShuo Chen  explicit Source(std::ifstream* in)
139144e8e4eSShuo Chen    : in_(in),
140144e8e4eSShuo Chen      count_(0),
141cc454125SShuo Chen      word_()
142144e8e4eSShuo Chen  {
143144e8e4eSShuo Chen  }
144144e8e4eSShuo Chen
145144e8e4eSShuo Chen  bool next()
146144e8e4eSShuo Chen  {
147144e8e4eSShuo Chen    string line;
148144e8e4eSShuo Chen    if (getline(*in_, line))
149144e8e4eSShuo Chen    {
150144e8e4eSShuo Chen      size_t tab = line.find('\t');
151144e8e4eSShuo Chen      if (tab != string::npos)
152144e8e4eSShuo Chen      {
153144e8e4eSShuo Chen        count_ = strtol(line.c_str(), NULL, 10);
154144e8e4eSShuo Chen        if (count_ > 0)
155144e8e4eSShuo Chen        {
156cc454125SShuo Chen          word_ = line.substr(tab+1);
157144e8e4eSShuo Chen          return true;
158144e8e4eSShuo Chen        }
159144e8e4eSShuo Chen      }
160144e8e4eSShuo Chen    }
161144e8e4eSShuo Chen    return false;
162144e8e4eSShuo Chen  }
163144e8e4eSShuo Chen
164144e8e4eSShuo Chen  bool operator<(const Source& rhs) const
165144e8e4eSShuo Chen  {
166144e8e4eSShuo Chen    return count_ < rhs.count_;
167144e8e4eSShuo Chen  }
168144e8e4eSShuo Chen
169cc454125SShuo Chen  void outputTo(std::ostream& out) const
170144e8e4eSShuo Chen  {
171cc454125SShuo Chen    out << count_ << '\t' << word_ << '\n';
172144e8e4eSShuo Chen  }
173144e8e4eSShuo Chen
174144e8e4eSShuo Chen private:
175144e8e4eSShuo Chen  std::ifstream* in_;
176144e8e4eSShuo Chen  int64_t count_;
177cc454125SShuo Chen  string word_;
178144e8e4eSShuo Chen};
179144e8e4eSShuo Chen
180144e8e4eSShuo Chenvoid merge(const int nbuckets)
181144e8e4eSShuo Chen{
182144e8e4eSShuo Chen  boost::ptr_vector<std::ifstream> inputs;
183144e8e4eSShuo Chen  std::vector<Source> keys;
184144e8e4eSShuo Chen
185144e8e4eSShuo Chen  for (int i = 0; i < nbuckets; ++i)
186144e8e4eSShuo Chen  {
187144e8e4eSShuo Chen    char buf[256];
188144e8e4eSShuo Chen    snprintf(buf, sizeof buf, "count-%05d-of-%05d", i, nbuckets);
189144e8e4eSShuo Chen    inputs.push_back(new std::ifstream(buf));
190144e8e4eSShuo Chen    Source rec(&inputs.back());
191144e8e4eSShuo Chen    if (rec.next())
192144e8e4eSShuo Chen    {
193144e8e4eSShuo Chen      keys.push_back(rec);
194144e8e4eSShuo Chen    }
195144e8e4eSShuo Chen    ::unlink(buf);
196144e8e4eSShuo Chen  }
197144e8e4eSShuo Chen
198144e8e4eSShuo Chen  std::ofstream out("output");
199144e8e4eSShuo Chen  std::make_heap(keys.begin(), keys.end());
200144e8e4eSShuo Chen  while (!keys.empty())
201144e8e4eSShuo Chen  {
202144e8e4eSShuo Chen    std::pop_heap(keys.begin(), keys.end());
203cc454125SShuo Chen    keys.back().outputTo(out);
204144e8e4eSShuo Chen
205144e8e4eSShuo Chen    if (keys.back().next())
206144e8e4eSShuo Chen    {
207144e8e4eSShuo Chen      std::push_heap(keys.begin(), keys.end());
208144e8e4eSShuo Chen    }
209144e8e4eSShuo Chen    else
210144e8e4eSShuo Chen    {
211144e8e4eSShuo Chen      keys.pop_back();
212144e8e4eSShuo Chen    }
213144e8e4eSShuo Chen  }
214144e8e4eSShuo Chen  std::cout << "merging done\n";
215144e8e4eSShuo Chen}
216144e8e4eSShuo Chen
217144e8e4eSShuo Chenint main(int argc, char* argv[])
218144e8e4eSShuo Chen{
219144e8e4eSShuo Chen  int nbuckets = 10;
220144e8e4eSShuo Chen  shard(nbuckets, argc, argv);
221144e8e4eSShuo Chen  combine(nbuckets);
222144e8e4eSShuo Chen  merge(nbuckets);
223144e8e4eSShuo Chen}
224