1/* sort word by frequency, sharding while counting version.
2
3  1. read input file, do counting, if counts > 10M keys, write counts to 10 shard files:
4       word \t count
5  2. assume each shard file fits in memory, read each shard file, accumulate counts, and write to 10 count files:
6       count \t word
7  3. merge 10 count files using heap.
8
9Limits: each shard must fit in memory.
10*/
11#include <boost/noncopyable.hpp>
12#include <boost/ptr_container/ptr_vector.hpp>
13
14#include <fstream>
15#include <iostream>
16#include <unordered_map>
17
18#ifdef STD_STRING
19#warning "STD STRING"
20#include <string>
21using std::string;
22#else
23#include <ext/vstring.h>
24typedef __gnu_cxx::__sso_string string;
25#endif
26
27const size_t kMaxSize = 10 * 1000 * 1000;
28
29class Sharder : boost::noncopyable
30{
31 public:
32  explicit Sharder(int nbuckets)
33    : buckets_(nbuckets)
34  {
35    for (int i = 0; i < nbuckets; ++i)
36    {
37      char buf[256];
38      snprintf(buf, sizeof buf, "shard-%05d-of-%05d", i, nbuckets);
39      buckets_.push_back(new std::ofstream(buf));
40    }
41    assert(buckets_.size() == static_cast<size_t>(nbuckets));
42  }
43
44  void output(const string& word, int64_t count)
45  {
46    size_t idx = std::hash<string>()(word) % buckets_.size();
47    buckets_[idx] << word << '\t' << count << '\n';
48  }
49
50 protected:
51  boost::ptr_vector<std::ofstream> buckets_;
52};
53
54void shard(int nbuckets, int argc, char* argv[])
55{
56  Sharder sharder(nbuckets);
57  for (int i = 1; i < argc; ++i)
58  {
59    std::cout << "  processing input file " << argv[i] << std::endl;
60    std::unordered_map<string, int64_t> counts;
61    std::ifstream in(argv[i]);
62    while (in && !in.eof())
63    {
64      counts.clear();
65      string word;
66      while (in >> word)
67      {
68        counts[word]++;
69        if (counts.size() > kMaxSize)
70        {
71          std::cout << "    split" << std::endl;
72          break;
73        }
74      }
75
76      for (const auto& kv : counts)
77      {
78        sharder.output(kv.first, kv.second);
79      }
80    }
81  }
82  std::cout << "shuffling done" << std::endl;
83}
84
85// ======= sort_shards =======
86
87std::unordered_map<string, int64_t> read_shard(int idx, int nbuckets)
88{
89  std::unordered_map<string, int64_t> counts;
90
91  char buf[256];
92  snprintf(buf, sizeof buf, "shard-%05d-of-%05d", idx, nbuckets);
93  std::cout << "  reading " << buf << std::endl;
94  {
95    std::ifstream in(buf);
96    string line;
97
98    while (getline(in, line))
99    {
100      size_t tab = line.find('\t');
101      if (tab != string::npos)
102      {
103        int64_t count = strtol(line.c_str() + tab, NULL, 10);
104        if (count > 0)
105        {
106          counts[line.substr(0, tab)] += count;
107        }
108      }
109    }
110  }
111
112  ::unlink(buf);
113  return counts;
114}
115
116void sort_shards(const int nbuckets)
117{
118  for (int i = 0; i < nbuckets; ++i)
119  {
120    // std::cout << "  sorting " << std::endl;
121    std::vector<std::pair<int64_t, string>> counts;
122    for (const auto& entry : read_shard(i, nbuckets))
123    {
124      counts.push_back(make_pair(entry.second, entry.first));
125    }
126    std::sort(counts.begin(), counts.end());
127
128    char buf[256];
129    snprintf(buf, sizeof buf, "count-%05d-of-%05d", i, nbuckets);
130    std::ofstream out(buf);
131    std::cout << "  writing " << buf << std::endl;
132    for (auto it = counts.rbegin(); it != counts.rend(); ++it)
133    {
134      out << it->first << '\t' << it->second << '\n';
135    }
136  }
137
138  std::cout << "reducing done" << std::endl;
139}
140
141// ======= merge =======
142
143class Source  // copyable
144{
145 public:
146  explicit Source(std::istream* in)
147    : in_(in),
148      count_(0),
149      word_()
150  {
151  }
152
153  bool next()
154  {
155    string line;
156    if (getline(*in_, line))
157    {
158      size_t tab = line.find('\t');
159      if (tab != string::npos)
160      {
161        count_ = strtol(line.c_str(), NULL, 10);
162        if (count_ > 0)
163        {
164          word_ = line.substr(tab+1);
165          return true;
166        }
167      }
168    }
169    return false;
170  }
171
172  bool operator<(const Source& rhs) const
173  {
174    return count_ < rhs.count_;
175  }
176
177  void outputTo(std::ostream& out) const
178  {
179    out << count_ << '\t' << word_ << '\n';
180  }
181
182 private:
183  std::istream* in_;
184  int64_t count_;
185  string word_;
186};
187
188void merge(const int nbuckets)
189{
190  boost::ptr_vector<std::ifstream> inputs;
191  std::vector<Source> keys;
192
193  for (int i = 0; i < nbuckets; ++i)
194  {
195    char buf[256];
196    snprintf(buf, sizeof buf, "count-%05d-of-%05d", i, nbuckets);
197    inputs.push_back(new std::ifstream(buf));
198    Source rec(&inputs.back());
199    if (rec.next())
200    {
201      keys.push_back(rec);
202    }
203    ::unlink(buf);
204  }
205
206  std::ofstream out("output");
207  std::make_heap(keys.begin(), keys.end());
208  while (!keys.empty())
209  {
210    std::pop_heap(keys.begin(), keys.end());
211    keys.back().outputTo(out);
212
213    if (keys.back().next())
214    {
215      std::push_heap(keys.begin(), keys.end());
216    }
217    else
218    {
219      keys.pop_back();
220    }
221  }
222  std::cout << "merging done\n";
223}
224
225int main(int argc, char* argv[])
226{
227  int nbuckets = 10;
228  shard(nbuckets, argc, argv);
229  sort_shards(nbuckets);
230  merge(nbuckets);
231}
232