1/* sort word by frequency, sharding version.
2
3  1. read input file, shard to N files:
4       word
5  2. assume each shard file fits in memory, read each shard file, count words and sort by count, then write to N count files:
6       count \t word
7  3. merge N count files using heap.
8
9Limits: each shard must fit in memory.
10*/
11
12#include <assert.h>
13
14#include "file.h"
15#include "merge.h"
16#include "timer.h"
17
18#include "absl/container/flat_hash_map.h"
19#include "absl/hash/hash.h"
20#include "absl/strings/str_format.h"
21#include "muduo/base/Logging.h"
22#include "muduo/base/ThreadPool.h"
23
24#include <algorithm>
25#include <memory>
26#include <string>
27#include <unordered_map>
28#include <vector>
29
30#include <fcntl.h>
31#include <string.h>
32#include <sys/mman.h>
33#include <sys/stat.h>
34#include <unistd.h>
35
36using absl::string_view;
37using std::string;
38using std::vector;
39using std::unique_ptr;
40
41int kShards = 10, kThreads = 4;
42bool g_verbose = false, g_keep = false;
43const char* shard_dir = ".";
44const char* g_output = "output";
45
46class Sharder // : boost::noncopyable
47{
48 public:
49  Sharder()
50    : files_(kShards)
51  {
52    for (int i = 0; i < kShards; ++i)
53    {
54      char name[256];
55      snprintf(name, sizeof name, "%s/shard-%05d-of-%05d", shard_dir, i, kShards);
56      files_[i].reset(new OutputFile(name));
57    }
58    assert(files_.size() == static_cast<size_t>(kShards));
59  }
60
61  void output(string_view word)
62  {
63    size_t shard = hash(word) % files_.size();
64    files_[shard]->appendRecord(word);
65  }
66
67  void finish()
68  {
69    int shard = 0;
70    for (const auto& file : files_)
71    {
72      // if (g_verbose)
73      printf("  shard %d: %ld bytes, %ld items\n", shard, file->tell(), file->items());
74      ++shard;
75      file->close();
76    }
77  }
78
79 private:
80  absl::Hash<string_view> hash;
81  vector<unique_ptr<OutputFile>> files_;
82};
83
84int64_t shard_(int argc, char* argv[])
85{
86  Sharder sharder;
87  Timer timer;
88  int64_t total = 0;
89  for (int i = optind; i < argc; ++i)
90  {
91    LOG_INFO << "Processing input file " << argv[i];
92    double t = Timer::now();
93    string line;
94    InputFile input(argv[i]);
95    while (input.getline(&line))
96    {
97      sharder.output(line);
98    }
99    size_t len = input.tell();
100    total += len;
101    double sec = Timer::now() - t;
102    LOG_INFO << "Done file " << argv[i] << absl::StrFormat(" %.3f sec %.2f MiB/s", sec, len / sec / 1024 / 1024);
103  }
104  sharder.finish();
105  LOG_INFO << "Sharding done " << timer.report(total);
106  return total;
107}
108
109// ======= count_shards =======
110
111void count_shard(int shard, int fd, size_t len)
112{
113  Timer timer;
114
115  double t = Timer::now();
116  LOG_INFO << absl::StrFormat("counting shard %d: input file size %ld", shard, len);
117  {
118  void* mapped = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
119  assert(mapped != MAP_FAILED);
120  const uint8_t* const start = static_cast<const uint8_t*>(mapped);
121  const uint8_t* const end = start + len;
122
123  // std::unordered_map<string_view, uint64_t> items;
124  absl::flat_hash_map<string_view, uint64_t> items;
125  int64_t count = 0;
126  for (const uint8_t* p = start; p < end;)
127  {
128    string_view s((const char*)p+1, *p);
129    items[s]++;
130    p += 1 + *p;
131    ++count;
132  }
133  LOG_INFO << "items " << count << " unique " << items.size();
134  if (g_verbose)
135  printf("  count %.3f sec %ld items\n", Timer::now() - t, items.size());
136
137  t = Timer::now();
138  vector<std::pair<size_t, string_view>> counts;
139  for (const auto& it : items)
140  {
141    if (it.second > 1)
142      counts.push_back(std::make_pair(it.second, it.first));
143  }
144  if (g_verbose)
145  printf("  select %.3f sec %ld\n", Timer::now() - t, counts.size());
146
147  t = Timer::now();
148  std::sort(counts.begin(), counts.end());
149  if (g_verbose)
150  printf("  sort %.3f sec\n", Timer::now() - t);
151
152  t = Timer::now();
153  int64_t out_len = 0;
154  {
155    char buf[256];
156    snprintf(buf, sizeof buf, "count-%05d", shard);
157    OutputFile output(buf);
158
159    for (auto it = counts.rbegin(); it != counts.rend(); ++it)
160    {
161      output.write(absl::StrFormat("%d\t%s\n", it->first, it->second));
162    }
163
164    for (const auto& it : items)
165    {
166      if (it.second == 1)
167      {
168        output.write(absl::StrFormat("1\t%s\n", it.first));
169      }
170    }
171    out_len = output.tell();
172  }
173  if (g_verbose)
174  printf("  output %.3f sec %lu\n", Timer::now() - t, out_len);
175
176  if (munmap(mapped, len))
177    perror("munmap");
178  }
179  ::close(fd);
180  LOG_INFO << "shard " << shard << " done " << timer.report(len);
181}
182
183void count_shards(int shards)
184{
185  assert(shards <= kShards);
186  Timer timer;
187  int64_t total = 0;
188  muduo::ThreadPool threadPool;
189  threadPool.setMaxQueueSize(2*kThreads);
190  threadPool.start(kThreads);
191
192  for (int shard = 0; shard < shards; ++shard)
193  {
194    char buf[256];
195    snprintf(buf, sizeof buf, "%s/shard-%05d-of-%05d", shard_dir, shard, kShards);
196    int fd = open(buf, O_RDONLY);
197    assert(fd >= 0);
198    if (!g_keep)
199      ::unlink(buf);
200
201    struct stat st;
202    if (::fstat(fd, &st) == 0)
203    {
204      size_t len = st.st_size;
205      total += len;
206      threadPool.run([shard, fd, len]{ count_shard(shard, fd, len); });
207    }
208  }
209  while (threadPool.queueSize() > 0)
210  {
211    LOG_DEBUG << "waiting for ThreadPool " << threadPool.queueSize();
212    muduo::CurrentThread::sleepUsec(1000*1000);
213  }
214  threadPool.stop();
215  LOG_INFO << "Counting done "<< timer.report(total);
216}
217
218// ======= merge =======
219
220int main(int argc, char* argv[])
221{
222  /*
223  int fd = open("shard-00000-of-00010", O_RDONLY);
224  double t = Timer::now();
225  int64_t len = count_shard(0, fd);
226  double sec = Timer::now() - t;
227  printf("count_shard %.3f sec %.2f MB/s\n", sec, len / sec / 1e6);
228  */
229  setlocale(LC_NUMERIC, "");
230
231  int opt;
232  int count_only = 0;
233  int merge_only = 0;
234  while ((opt = getopt(argc, argv, "c:km:o:p:s:t:v")) != -1)
235  {
236    switch (opt)
237    {
238      case 'c':
239        count_only = atoi(optarg);
240        break;
241      case 'k':
242        g_keep = true;
243        break;
244      case 'm':
245        merge_only = atoi(optarg);
246        break;
247      case 'o':
248        g_output = optarg;
249        break;
250      case 'p':  // Path for temp shard files
251        shard_dir = optarg;
252        break;
253      case 's':
254        kShards = atoi(optarg);
255        break;
256      case 't':
257        kThreads = atoi(optarg);
258        break;
259      case 'v':
260        g_verbose = true;
261        break;
262    }
263  }
264
265  if (count_only > 0 || merge_only)
266  {
267    g_keep = true;
268    //g_verbose = true;
269    count_only = std::min(count_only, kShards);
270
271    if (count_only > 0)
272    {
273      count_shards(count_only);
274    }
275
276    if (merge_only > 0)
277    {
278      merge(merge_only);
279    }
280  }
281  else
282  {
283    // Run all three steps
284    Timer timer;
285    LOG_INFO << argc - optind << " input files, " << kShards << " shards, "
286             << "output " << g_output <<" , temp " << shard_dir;
287    int64_t input = 0;
288    input = shard_(argc, argv);
289    count_shards(kShards);
290    int64_t output_size = merge(kShards);
291    LOG_INFO << "All done " << timer.report(input) << " output " << output_size;
292  }
293}
294