1/* sort word by frequency, sorting version.
2
3   1. read input files, sort every 1GB to segment files
4      word \t count  -- sorted by word
5   2. read all segment files, do merging & counting, when count map > 10M keys, output to count files, each word goes to one count file only.
6      count \t word  -- sorted by count
7   3. read all count files, do merging and output
8*/
9
10#include "file.h"
11#include "input.h"
12#include "merge.h"
13#include "timer.h"
14
15#include "muduo/base/Logging.h"
16
17#include <assert.h>
18
19#include <algorithm>
20#include <fstream>
21#include <iostream>
22#include <map>
23#include <memory>
24#include <string>
25#include <unordered_map>
26#include <vector>
27
28#include <fcntl.h>
29#include <string.h>
30#include <sys/stat.h>
31#include <sys/time.h>
32#include <unistd.h>
33
34using std::pair;
35using std::string;
36using std::string_view;
37using std::vector;
38
39const size_t kMaxSize = 10 * 1000 * 1000;
40bool g_verbose = false, g_keep = false;
41const char* segment_dir = ".";
42const char* g_output = "output";
43
44inline double now()
45{
46  struct timeval tv = { 0, 0 };
47  gettimeofday(&tv, nullptr);
48  return tv.tv_sec + tv.tv_usec / 1000000.0;
49}
50
51int64_t sort_segments(int* count, int fd)
52{
53  Timer timer;
54  const int64_t file_size = lseek(fd, 0, SEEK_END);
55  lseek(fd, 0, SEEK_SET);
56  if (g_verbose)
57    printf("  file size %ld\n", file_size);
58  int64_t offset = 0;
59  while (offset < file_size)
60  {
61    double t = now();
62    const int64_t len = std::min(file_size - offset, 1024 * 1000 * 1000L);
63    if (g_verbose)
64      printf("    reading segment %d: offset %ld len %ld", *count, offset, len);
65    std::unique_ptr<char[]> buf(new char[len]);
66    const ssize_t nr = ::pread(fd, buf.get(), len, offset);
67    double sec = now() - t;
68    if (g_verbose)
69    printf(" %.3f sec %.3f MB/s\n", sec, nr / sec / 1000 / 1000);
70
71    // TODO: move to another thread
72    t = now();
73    const char* const start = buf.get();
74    const char* const end = start + nr;
75    vector<string_view> items;
76    const char* p = start;
77    while (p < end)
78    {
79      const char* nl = static_cast<const char*>(memchr(p, '\n', end - p));
80      if (nl)
81      {
82        string_view s(p, nl - p);
83        items.push_back(s);
84        p = nl + 1;
85      }
86      else
87      {
88        break;
89      }
90    }
91    offset += p - start;
92    if (g_verbose)
93    printf("    parse %.3f sec %ld items %ld bytes\n", now() - t, items.size(), p - start);
94
95    t = now();
96    std::sort(items.begin(), items.end());
97    if (g_verbose)
98    printf("    sort %.3f sec\n", now() - t);
99
100    t = now();
101    char name[256];
102    snprintf(name, sizeof name, "%s/segment-%05d", segment_dir, *count);
103    ++*count;
104    int unique = 0;
105    {
106    // TODO: replace with OutputFile
107    std::ofstream out(name);
108    string_view curr;
109    int cnt = 0;
110    for (auto it = items.begin(); it != items.end(); ++it)
111    {
112      if (*it != curr)
113      {
114        if (cnt)
115        {
116          out << curr << '\t' << cnt << '\n';
117          ++unique;
118        }
119        curr = *it;
120        cnt = 1;
121      }
122      else
123        ++cnt;
124    }
125    if (cnt)
126    {
127      out << curr << '\t' << cnt << '\n';
128      ++unique;
129    }
130    }
131    if (g_verbose)
132    printf("    unique %.3f sec %d\n", now() - t, unique);
133    LOG_INFO << "  wrote " << name;
134  }
135  LOG_INFO << "  file done " << timer.report(file_size);
136  return file_size;
137}
138
139int input(int argc, char* argv[], int64_t* total_in = nullptr)
140{
141  int count = 0;
142  int64_t total = 0;
143  Timer timer;
144  for (int i = optind; i < argc; ++i)
145  {
146    LOG_INFO << "Reading input file " << argv[i];
147
148    int fd = open(argv[i], O_RDONLY);
149    if (fd >= 0)
150    {
151      total += sort_segments(&count, fd);
152      ::close(fd);
153    }
154    else
155      perror("open");
156  }
157  LOG_INFO << "Reading done " << count << " segments " << timer.report(total);
158  if (total_in)
159    *total_in = total;
160  return count;
161}
162
163// ======= combine =======
164
165class Segment  // copyable
166{
167 public:
168  string_view word;
169
170  explicit Segment(SegmentInput* in)
171    : in_(in)
172  {
173  }
174
175  bool next()
176  {
177    if (in_->next())
178    {
179      word = in_->current_word();
180      return true;
181    }
182    else
183      return false;
184  }
185
186  bool operator<(const Segment& rhs) const
187  {
188    return word > rhs.word;
189  }
190
191  int64_t count() const { return in_->current_count(); }
192
193 private:
194  SegmentInput* in_;
195};
196
197class CountOutput
198{
199 public:
200  CountOutput();
201
202  void add(string_view word, int64_t count)
203  {
204    if (block_->add(word, count))
205    {
206    }
207    else
208    {
209      // TODO: Move to another thread.
210      block_->output(merge_count_);
211      ++merge_count_;
212      block_.reset(new Block);
213      if (!block_->add(word, count))
214      {
215        abort();
216      }
217    }
218  }
219
220  int finish()
221  {
222    block_->output(merge_count_);
223    ++merge_count_;
224    return merge_count_;
225  }
226
227 private:
228  struct Count
229  {
230    int64_t count = 0;
231    int32_t offset = 0, len = 0;
232  };
233
234  struct Block
235  {
236    std::unique_ptr<char[]> data { new char[kSize] };
237    vector<Count> counts;
238    int start = 0;
239    static const int kSize = 512 * 1000 * 1000;
240
241
242    bool add(string_view word, int64_t count)
243    {
244      if (static_cast<size_t>(kSize - start) >= word.size())
245      {
246        memcpy(data.get() + start, word.data(), word.size());
247        Count c;
248        c.count = count;
249        c.offset = start;
250        c.len = word.size();
251        counts.push_back(c);
252        start += word.size();
253        return true;
254      }
255      else
256        return false;
257    }
258
259    void output(int n)
260    {
261      Timer t;
262      char buf[256];
263      snprintf(buf, sizeof buf, "count-%05d", n);
264      LOG_INFO << "  writing " << buf << " of " << counts.size() << " words";
265      std::sort(counts.begin(), counts.end(), [](const auto& lhs, const auto& rhs) {
266        return lhs.count > rhs.count;
267      });
268      int64_t file_size = 0;
269      {
270      OutputFile out(buf);
271
272      for (const auto& c : counts)
273      {
274        out.writeWord(c.count, string_view(data.get() + c.offset, c.len));
275      }
276      file_size = out.tell();
277      }
278      LOG_DEBUG << "  done " << t.report(file_size);
279    }
280  };
281
282  std::unique_ptr<Block> block_;
283  int merge_count_ = 0;
284};
285
286CountOutput::CountOutput()
287  : block_(new Block)
288{
289}
290
291int combine(int count)
292{
293  Timer timer;
294  std::vector<std::unique_ptr<SegmentInput>> inputs;
295  std::vector<Segment> keys;
296
297  int64_t total = 0;
298  for (int i = 0; i < count; ++i)
299  {
300    char buf[256];
301    snprintf(buf, sizeof buf, "%s/segment-%05d", segment_dir, i);
302    struct stat st;
303    if (::stat(buf, &st) == 0)
304    {
305      total += st.st_size;
306      inputs.emplace_back(new SegmentInput(buf));
307      Segment rec(inputs.back().get());
308      if (rec.next())
309      {
310        keys.push_back(rec);
311      }
312      if (!g_keep)
313        ::unlink(buf);
314    }
315    else
316    {
317      perror("Cannot open segment");
318    }
319  }
320  LOG_INFO << "Combining " << count << " files " << total << " bytes";
321
322  // std::cout << keys.size() << '\n';
323  string last = "Chen Shuo";
324  int64_t last_count = 0, total_count = 0;
325  int64_t lines_in = 0, lines_out = 0;
326  CountOutput out;
327  std::make_heap(keys.begin(), keys.end());
328
329  while (!keys.empty())
330  {
331    std::pop_heap(keys.begin(), keys.end());
332    lines_in++;
333    total_count += keys.back().count();
334
335    if (keys.back().word != last)
336    {
337      if (last_count > 0)
338      {
339        assert(last > keys.back().word);
340        lines_out++;
341        out.add(last, last_count);
342      }
343
344      last = keys.back().word;
345      last_count = keys.back().count();
346    }
347    else
348    {
349      last_count += keys.back().count();
350    }
351
352    if (keys.back().next())
353    {
354      std::push_heap(keys.begin(), keys.end());
355    }
356    else
357    {
358      keys.pop_back();
359    }
360  }
361
362  if (last_count > 0)
363  {
364    lines_out++;
365    out.add(last, last_count);
366  }
367  int m = out.finish();
368
369  LOG_INFO << "total count " << total_count << ", lines in " << lines_in << " out " << lines_out;
370  LOG_INFO << "Combine done " << timer.report(total);
371  return m;
372}
373
374int main(int argc, char* argv[])
375{
376  setlocale(LC_NUMERIC, "");
377  int opt;
378  bool sort_only = false;
379  int count_only = 0;
380  int merge_only = 0;
381  while ((opt = getopt(argc, argv, "c:d:km:o:sv")) != -1)
382  {
383    switch (opt)
384    {
385      case 'c':
386        count_only = atoi(optarg);
387        break;
388      case 'd':
389        segment_dir = optarg;
390        break;
391      case 'k':
392        g_keep = true;
393        break;
394      case 'm':
395        merge_only = atoi(optarg);
396        break;
397      case 'o':
398        g_output = optarg;
399        break;
400      case 's':
401        sort_only = true;
402        break;
403      case 'v':
404        g_verbose = true;
405        break;
406    }
407  }
408
409  if (sort_only || count_only > 0 || merge_only > 0)
410  {
411    g_keep = true;
412    if (sort_only)
413    {
414      int count = input(argc, argv);
415      LOG_INFO << "wrote " << count << " segments";
416    }
417    if (count_only)
418    {
419      int m = combine(count_only);
420      LOG_INFO << "wrote " << m << " counts";
421    }
422    if (merge_only)
423    {
424      merge(merge_only);
425    }
426  }
427  else
428  {
429    int64_t total = 0;
430    Timer timer;
431    int count = input(argc, argv, &total);
432    int m = combine(count);
433    merge(m);
434    LOG_INFO << "All done " << timer.report(total);
435  }
436}
437