word_freq_shards_basic.cc revision 2cf09315
10ab2e892SShuo Chen/* sort word by frequency, sharding version.
20ab2e892SShuo Chen
30ab2e892SShuo Chen  1. read input file, shard to N files:
40ab2e892SShuo Chen       word
50ab2e892SShuo Chen  2. assume each shard file fits in memory, read each shard file, count words and sort by count, then write to N count files:
60ab2e892SShuo Chen       count \t word
70ab2e892SShuo Chen  3. merge N count files using heap.
80ab2e892SShuo Chen
90ab2e892SShuo ChenLimits: each shard must fit in memory.
100ab2e892SShuo Chen*/
110ab2e892SShuo Chen
120ab2e892SShuo Chen#include <assert.h>
130ab2e892SShuo Chen
1485147189SShuo Chen#include "file.h"
1585147189SShuo Chen#include "timer.h"
1685147189SShuo Chen
174136e585SShuo Chen#include "absl/container/flat_hash_map.h"
182cf09315SShuo Chen#include "absl/hash/hash.h"
192a129a12SShuo Chen#include "absl/strings/str_format.h"
2085147189SShuo Chen#include "muduo/base/BoundedBlockingQueue.h"
212a129a12SShuo Chen#include "muduo/base/Logging.h"
22a251380aSShuo Chen#include "muduo/base/ThreadPool.h"
234136e585SShuo Chen
240ab2e892SShuo Chen#include <algorithm>
250ab2e892SShuo Chen#include <memory>
260ab2e892SShuo Chen#include <string>
270ab2e892SShuo Chen#include <unordered_map>
280ab2e892SShuo Chen#include <vector>
290ab2e892SShuo Chen
302a129a12SShuo Chen#include <boost/program_options.hpp>
312a129a12SShuo Chen
320ab2e892SShuo Chen#include <fcntl.h>
330ab2e892SShuo Chen#include <string.h>
340ab2e892SShuo Chen#include <sys/mman.h>
354136e585SShuo Chen#include <sys/stat.h>
360ab2e892SShuo Chen#include <unistd.h>
370ab2e892SShuo Chen
382cf09315SShuo Chenusing absl::string_view;
390ab2e892SShuo Chenusing std::string;
400ab2e892SShuo Chenusing std::vector;
410ab2e892SShuo Chenusing std::unique_ptr;
420ab2e892SShuo Chen
4385147189SShuo Chenint kShards = 10, kThreads = 4;
4485147189SShuo Chenbool g_verbose = false, g_keep = false;
45a6693141SShuo Chenconst char* shard_dir = ".";
4685147189SShuo Chenconst char* g_output = "output";
47270b6cceSShuo Chen
480ab2e892SShuo Chenclass Sharder // : boost::noncopyable
490ab2e892SShuo Chen{
500ab2e892SShuo Chen public:
510ab2e892SShuo Chen  Sharder()
520ab2e892SShuo Chen    : files_(kShards)
530ab2e892SShuo Chen  {
540ab2e892SShuo Chen    for (int i = 0; i < kShards; ++i)
550ab2e892SShuo Chen    {
560ab2e892SShuo Chen      char name[256];
57a6693141SShuo Chen      snprintf(name, sizeof name, "%s/shard-%05d-of-%05d", shard_dir, i, kShards);
580ab2e892SShuo Chen      files_[i].reset(new OutputFile(name));
590ab2e892SShuo Chen    }
600ab2e892SShuo Chen    assert(files_.size() == static_cast<size_t>(kShards));
610ab2e892SShuo Chen  }
620ab2e892SShuo Chen
630ab2e892SShuo Chen  void output(string_view word)
640ab2e892SShuo Chen  {
650ab2e892SShuo Chen    size_t shard = hash(word) % files_.size();
66270b6cceSShuo Chen    files_[shard]->appendRecord(word);
670ab2e892SShuo Chen  }
680ab2e892SShuo Chen
690ab2e892SShuo Chen  void finish()
700ab2e892SShuo Chen  {
714136e585SShuo Chen    int shard = 0;
724136e585SShuo Chen    for (const auto& file : files_)
730ab2e892SShuo Chen    {
7485147189SShuo Chen      // if (g_verbose)
754136e585SShuo Chen      printf("  shard %d: %ld bytes, %ld items\n", shard, file->tell(), file->items());
764136e585SShuo Chen      ++shard;
774136e585SShuo Chen      file->close();
780ab2e892SShuo Chen    }
790ab2e892SShuo Chen  }
800ab2e892SShuo Chen
810ab2e892SShuo Chen private:
822cf09315SShuo Chen  absl::Hash<string_view> hash;
830ab2e892SShuo Chen  vector<unique_ptr<OutputFile>> files_;
840ab2e892SShuo Chen};
850ab2e892SShuo Chen
864136e585SShuo Chenint64_t shard_(int argc, char* argv[])
870ab2e892SShuo Chen{
880ab2e892SShuo Chen  Sharder sharder;
894136e585SShuo Chen  Timer timer;
904136e585SShuo Chen  int64_t total = 0;
912a129a12SShuo Chen  for (int i = optind; i < argc; ++i)
920ab2e892SShuo Chen  {
932a129a12SShuo Chen    LOG_INFO << "Processing input file " << argv[i];
9485147189SShuo Chen    double t = Timer::now();
9585147189SShuo Chen    string line;
9685147189SShuo Chen    InputFile input(argv[i]);
9785147189SShuo Chen    while (input.getline(&line))
980ab2e892SShuo Chen    {
9985147189SShuo Chen      sharder.output(line);
1000ab2e892SShuo Chen    }
10185147189SShuo Chen    size_t len = input.tell();
1024136e585SShuo Chen    total += len;
10385147189SShuo Chen    double sec = Timer::now() - t;
1042a129a12SShuo Chen    LOG_INFO << "Done file " << argv[i] << absl::StrFormat(" %.3f sec %.2f MiB/s", sec, len / sec / 1024 / 1024);
1050ab2e892SShuo Chen  }
1060ab2e892SShuo Chen  sharder.finish();
1072a129a12SShuo Chen  LOG_INFO << "Sharding done " << timer.report(total);
1084136e585SShuo Chen  return total;
1090ab2e892SShuo Chen}
1100ab2e892SShuo Chen
1110ab2e892SShuo Chen// ======= count_shards =======
1120ab2e892SShuo Chen
113ecd7048bSShuo Chenvoid count_shard(int shard, int fd, size_t len)
1140ab2e892SShuo Chen{
115ecd7048bSShuo Chen  Timer timer;
116ecd7048bSShuo Chen
11785147189SShuo Chen  double t = Timer::now();
1182a129a12SShuo Chen  LOG_INFO << absl::StrFormat("counting shard %d: input file size %ld", shard, len);
1194136e585SShuo Chen  {
1200ab2e892SShuo Chen  void* mapped = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
1210ab2e892SShuo Chen  assert(mapped != MAP_FAILED);
1220ab2e892SShuo Chen  const uint8_t* const start = static_cast<const uint8_t*>(mapped);
1230ab2e892SShuo Chen  const uint8_t* const end = start + len;
1240ab2e892SShuo Chen
1254136e585SShuo Chen  // std::unordered_map<string_view, uint64_t> items;
1264136e585SShuo Chen  absl::flat_hash_map<string_view, uint64_t> items;
1272a129a12SShuo Chen  int64_t count = 0;
1280ab2e892SShuo Chen  for (const uint8_t* p = start; p < end;)
1290ab2e892SShuo Chen  {
1300ab2e892SShuo Chen    string_view s((const char*)p+1, *p);
1310ab2e892SShuo Chen    items[s]++;
1320ab2e892SShuo Chen    p += 1 + *p;
1332a129a12SShuo Chen    ++count;
1340ab2e892SShuo Chen  }
135270b6cceSShuo Chen  LOG_INFO << "items " << count << " unique " << items.size();
13685147189SShuo Chen  if (g_verbose)
13785147189SShuo Chen  printf("  count %.3f sec %ld items\n", Timer::now() - t, items.size());
1380ab2e892SShuo Chen
13985147189SShuo Chen  t = Timer::now();
1400ab2e892SShuo Chen  vector<std::pair<size_t, string_view>> counts;
1410ab2e892SShuo Chen  for (const auto& it : items)
1420ab2e892SShuo Chen  {
1430ab2e892SShuo Chen    if (it.second > 1)
1442cf09315SShuo Chen      counts.push_back(std::make_pair(it.second, it.first));
1450ab2e892SShuo Chen  }
14685147189SShuo Chen  if (g_verbose)
14785147189SShuo Chen  printf("  select %.3f sec %ld\n", Timer::now() - t, counts.size());
1480ab2e892SShuo Chen
14985147189SShuo Chen  t = Timer::now();
1500ab2e892SShuo Chen  std::sort(counts.begin(), counts.end());
15185147189SShuo Chen  if (g_verbose)
15285147189SShuo Chen  printf("  sort %.3f sec\n", Timer::now() - t);
1530ab2e892SShuo Chen
15485147189SShuo Chen  t = Timer::now();
155c377920eSShuo Chen  int64_t out_len = 0;
1560ab2e892SShuo Chen  {
157ecd7048bSShuo Chen    char buf[256];
158ecd7048bSShuo Chen    snprintf(buf, sizeof buf, "count-%05d-of-%05d", shard, kShards);
159ecd7048bSShuo Chen    OutputFile output(buf);
160ecd7048bSShuo Chen
1614136e585SShuo Chen    for (auto it = counts.rbegin(); it != counts.rend(); ++it)
1620ab2e892SShuo Chen    {
163c377920eSShuo Chen      output.write(absl::StrFormat("%d\t%s\n", it->first, it->second));
1644136e585SShuo Chen    }
165270b6cceSShuo Chen
1664136e585SShuo Chen    for (const auto& it : items)
1674136e585SShuo Chen    {
1684136e585SShuo Chen      if (it.second == 1)
1694136e585SShuo Chen      {
170c377920eSShuo Chen        output.write(absl::StrFormat("1\t%s\n", it.first));
1714136e585SShuo Chen      }
1720ab2e892SShuo Chen    }
173c377920eSShuo Chen    out_len = output.tell();
1740ab2e892SShuo Chen  }
175c377920eSShuo Chen  if (g_verbose)
176c377920eSShuo Chen  printf("  output %.3f sec %lu\n", Timer::now() - t, out_len);
1770ab2e892SShuo Chen
1780ab2e892SShuo Chen  if (munmap(mapped, len))
1790ab2e892SShuo Chen    perror("munmap");
1804136e585SShuo Chen  }
181ecd7048bSShuo Chen  ::close(fd);
182ecd7048bSShuo Chen  LOG_INFO << "shard " << shard << " done " << timer.report(len);
1830ab2e892SShuo Chen}
1840ab2e892SShuo Chen
18585147189SShuo Chenvoid count_shards(int shards)
1860ab2e892SShuo Chen{
18785147189SShuo Chen  assert(shards <= kShards);
1884136e585SShuo Chen  Timer timer;
1894136e585SShuo Chen  int64_t total = 0;
190a251380aSShuo Chen  muduo::ThreadPool threadPool;
19185147189SShuo Chen  threadPool.setMaxQueueSize(2*kThreads);
19285147189SShuo Chen  threadPool.start(kThreads);
19385147189SShuo Chen
19485147189SShuo Chen  for (int shard = 0; shard < shards; ++shard)
1950ab2e892SShuo Chen  {
1960ab2e892SShuo Chen    char buf[256];
197a6693141SShuo Chen    snprintf(buf, sizeof buf, "%s/shard-%05d-of-%05d", shard_dir, shard, kShards);
1980ab2e892SShuo Chen    int fd = open(buf, O_RDONLY);
199ecd7048bSShuo Chen    assert(fd >= 0);
20085147189SShuo Chen    if (!g_keep)
201ecd7048bSShuo Chen      ::unlink(buf);
2022a129a12SShuo Chen
203ecd7048bSShuo Chen    struct stat st;
204ecd7048bSShuo Chen    if (::fstat(fd, &st) == 0)
205ecd7048bSShuo Chen    {
206ecd7048bSShuo Chen      size_t len = st.st_size;
207ecd7048bSShuo Chen      total += len;
208ecd7048bSShuo Chen      threadPool.run([shard, fd, len]{ count_shard(shard, fd, len); });
209ecd7048bSShuo Chen    }
210a251380aSShuo Chen  }
211a251380aSShuo Chen  while (threadPool.queueSize() > 0)
212a251380aSShuo Chen  {
21385147189SShuo Chen    LOG_DEBUG << "waiting for ThreadPool " << threadPool.queueSize();
214ecd7048bSShuo Chen    muduo::CurrentThread::sleepUsec(1000*1000);
2150ab2e892SShuo Chen  }
216a251380aSShuo Chen  threadPool.stop();
217270b6cceSShuo Chen  LOG_INFO << "Counting done "<< timer.report(total);
2180ab2e892SShuo Chen}
2190ab2e892SShuo Chen
2200ab2e892SShuo Chen// ======= merge =======
2210ab2e892SShuo Chen
2220ab2e892SShuo Chenclass Source  // copyable
2230ab2e892SShuo Chen{
2240ab2e892SShuo Chen public:
225270b6cceSShuo Chen  explicit Source(InputFile* in)
2260ab2e892SShuo Chen    : in_(in),
2270ab2e892SShuo Chen      count_(0),
2280ab2e892SShuo Chen      word_()
2290ab2e892SShuo Chen  {
2300ab2e892SShuo Chen  }
2310ab2e892SShuo Chen
2320ab2e892SShuo Chen  bool next()
2330ab2e892SShuo Chen  {
2340ab2e892SShuo Chen    string line;
235270b6cceSShuo Chen    if (in_->getline(&line))
2360ab2e892SShuo Chen    {
2370ab2e892SShuo Chen      size_t tab = line.find('\t');
2380ab2e892SShuo Chen      if (tab != string::npos)
2390ab2e892SShuo Chen      {
2400ab2e892SShuo Chen        count_ = strtol(line.c_str(), NULL, 10);
2410ab2e892SShuo Chen        if (count_ > 0)
2420ab2e892SShuo Chen        {
2430ab2e892SShuo Chen          word_ = line.substr(tab+1);
2440ab2e892SShuo Chen          return true;
2450ab2e892SShuo Chen        }
2460ab2e892SShuo Chen      }
2470ab2e892SShuo Chen    }
2480ab2e892SShuo Chen    return false;
2490ab2e892SShuo Chen  }
2500ab2e892SShuo Chen
2510ab2e892SShuo Chen  bool operator<(const Source& rhs) const
2520ab2e892SShuo Chen  {
2530ab2e892SShuo Chen    return count_ < rhs.count_;
2540ab2e892SShuo Chen  }
2550ab2e892SShuo Chen
256270b6cceSShuo Chen  void outputTo(OutputFile* out) const
2570ab2e892SShuo Chen  {
25885147189SShuo Chen    //char buf[1024];
25985147189SShuo Chen    //snprintf(buf, sizeof buf, "%ld\t%s\n", count_, word_.c_str());
26085147189SShuo Chen    //out->write(buf);
261270b6cceSShuo Chen    out->write(absl::StrFormat("%d\t%s\n", count_, word_));
2620ab2e892SShuo Chen  }
2630ab2e892SShuo Chen
26485147189SShuo Chen  std::pair<int64_t, string> item()
26585147189SShuo Chen  {
26685147189SShuo Chen    return make_pair(count_, std::move(word_));
26785147189SShuo Chen  }
26885147189SShuo Chen
2690ab2e892SShuo Chen private:
270270b6cceSShuo Chen  InputFile* in_;  // not owned
2710ab2e892SShuo Chen  int64_t count_;
2720ab2e892SShuo Chen  string word_;
2730ab2e892SShuo Chen};
2740ab2e892SShuo Chen
27585147189SShuo Chenint64_t merge()
2760ab2e892SShuo Chen{
2774136e585SShuo Chen  Timer timer;
278270b6cceSShuo Chen  vector<unique_ptr<InputFile>> inputs;
2790ab2e892SShuo Chen  vector<Source> keys;
2800ab2e892SShuo Chen
2814136e585SShuo Chen  int64_t total = 0;
2820ab2e892SShuo Chen  for (int i = 0; i < kShards; ++i)
2830ab2e892SShuo Chen  {
2840ab2e892SShuo Chen    char buf[256];
2850ab2e892SShuo Chen    snprintf(buf, sizeof buf, "count-%05d-of-%05d", i, kShards);
2864136e585SShuo Chen    struct stat st;
287a6693141SShuo Chen    if (::stat(buf, &st) == 0)
2880ab2e892SShuo Chen    {
289a6693141SShuo Chen      total += st.st_size;
29085147189SShuo Chen      // TODO: select buffer size based on kShards.
29185147189SShuo Chen      inputs.push_back(std::make_unique<InputFile>(buf, 32 * 1024 * 1024));
292a6693141SShuo Chen      Source rec(inputs.back().get());
293a6693141SShuo Chen      if (rec.next())
294a6693141SShuo Chen      {
295a6693141SShuo Chen        keys.push_back(rec);
296a6693141SShuo Chen      }
29785147189SShuo Chen      if (!g_keep)
298a6693141SShuo Chen        ::unlink(buf);
299a6693141SShuo Chen    }
300a6693141SShuo Chen    else
301a6693141SShuo Chen    {
302a6693141SShuo Chen      perror("Unable to stat file:");
3030ab2e892SShuo Chen    }
3040ab2e892SShuo Chen  }
3052a129a12SShuo Chen  LOG_INFO << "merging " << inputs.size() << " files of " << total << " bytes in total";
3060ab2e892SShuo Chen
3074136e585SShuo Chen  {
30885147189SShuo Chen  OutputFile out(g_output);
30985147189SShuo Chen  /*
31085147189SShuo Chen  muduo::BoundedBlockingQueue<vector<std::pair<int64_t, string>>> queue(1024);
31185147189SShuo Chen  muduo::Thread thr([&queue] {
31285147189SShuo Chen    OutputFile out(g_output);
31385147189SShuo Chen    while (true) {
31485147189SShuo Chen      auto vec = queue.take();
31585147189SShuo Chen      if (vec.size() == 0)
31685147189SShuo Chen        break;
31785147189SShuo Chen      for (const auto& x : vec)
31885147189SShuo Chen        out.write(absl::StrFormat("%d\t%s\n", x.first, x.second));
31985147189SShuo Chen    }
32085147189SShuo Chen  });
32185147189SShuo Chen  thr.start();
32285147189SShuo Chen
32385147189SShuo Chen  vector<std::pair<int64_t, string>> batch;
32485147189SShuo Chen  */
3250ab2e892SShuo Chen  std::make_heap(keys.begin(), keys.end());
3260ab2e892SShuo Chen  while (!keys.empty())
3270ab2e892SShuo Chen  {
3280ab2e892SShuo Chen    std::pop_heap(keys.begin(), keys.end());
329270b6cceSShuo Chen    keys.back().outputTo(&out);
33085147189SShuo Chen    /*
33185147189SShuo Chen    batch.push_back(std::move(keys.back().item()));
33285147189SShuo Chen    if (batch.size() >= 10*1024*1024)
33385147189SShuo Chen    {
33485147189SShuo Chen      queue.put(std::move(batch));
33585147189SShuo Chen      batch.clear();
33685147189SShuo Chen    }
33785147189SShuo Chen    */
3380ab2e892SShuo Chen
3390ab2e892SShuo Chen    if (keys.back().next())
3400ab2e892SShuo Chen    {
3410ab2e892SShuo Chen      std::push_heap(keys.begin(), keys.end());
3420ab2e892SShuo Chen    }
3430ab2e892SShuo Chen    else
3440ab2e892SShuo Chen    {
3450ab2e892SShuo Chen      keys.pop_back();
3460ab2e892SShuo Chen    }
3470ab2e892SShuo Chen  }
34885147189SShuo Chen  /*
34985147189SShuo Chen  queue.put(batch);
35085147189SShuo Chen  batch.clear();
35185147189SShuo Chen  queue.put(batch);
35285147189SShuo Chen  thr.join();
35385147189SShuo Chen  */
3544136e585SShuo Chen  }
355a251380aSShuo Chen  LOG_INFO << "Merging done " << timer.report(total);
3562a129a12SShuo Chen  return total;
3570ab2e892SShuo Chen}
3580ab2e892SShuo Chen
3590ab2e892SShuo Chenint main(int argc, char* argv[])
3600ab2e892SShuo Chen{
3610ab2e892SShuo Chen  /*
3620ab2e892SShuo Chen  int fd = open("shard-00000-of-00010", O_RDONLY);
36385147189SShuo Chen  double t = Timer::now();
3644136e585SShuo Chen  int64_t len = count_shard(0, fd);
36585147189SShuo Chen  double sec = Timer::now() - t;
3664136e585SShuo Chen  printf("count_shard %.3f sec %.2f MB/s\n", sec, len / sec / 1e6);
3674136e585SShuo Chen  */
3683e607da5SShuo Chen  setlocale(LC_NUMERIC, "");
3690ab2e892SShuo Chen
3702a129a12SShuo Chen  int opt;
37185147189SShuo Chen  int count_only = 0;
37285147189SShuo Chen  bool merge_only = false;
37385147189SShuo Chen  while ((opt = getopt(argc, argv, "c:kmo:p:s:t:v")) != -1)
3742a129a12SShuo Chen  {
3752a129a12SShuo Chen    switch (opt)
3762a129a12SShuo Chen    {
37785147189SShuo Chen      case 'c':
37885147189SShuo Chen        count_only = atoi(optarg);
37985147189SShuo Chen        break;
3802a129a12SShuo Chen      case 'k':
38185147189SShuo Chen        g_keep = true;
38285147189SShuo Chen        break;
38385147189SShuo Chen      case 'm':
38485147189SShuo Chen        merge_only = true;
3852a129a12SShuo Chen        break;
386a6693141SShuo Chen      case 'o':
38785147189SShuo Chen        g_output = optarg;
38885147189SShuo Chen        break;
38985147189SShuo Chen      case 'p':  // Path for temp shard files
39085147189SShuo Chen        shard_dir = optarg;
391a6693141SShuo Chen        break;
3922a129a12SShuo Chen      case 's':
3932a129a12SShuo Chen        kShards = atoi(optarg);
3942a129a12SShuo Chen        break;
395a6693141SShuo Chen      case 't':
39685147189SShuo Chen        kThreads = atoi(optarg);
397a6693141SShuo Chen        break;
3982a129a12SShuo Chen      case 'v':
39985147189SShuo Chen        g_verbose = true;
4002a129a12SShuo Chen        break;
4012a129a12SShuo Chen    }
4022a129a12SShuo Chen  }
4032a129a12SShuo Chen
40485147189SShuo Chen  if (count_only > 0 || merge_only)
40585147189SShuo Chen  {
40685147189SShuo Chen    g_keep = true;
40785147189SShuo Chen    g_verbose = true;
40885147189SShuo Chen    count_only = std::min(count_only, kShards);
40985147189SShuo Chen
41085147189SShuo Chen    if (count_only > 0)
41185147189SShuo Chen    {
41285147189SShuo Chen      count_shards(count_only);
41385147189SShuo Chen    }
41485147189SShuo Chen
41585147189SShuo Chen    if (merge_only)
41685147189SShuo Chen    {
41785147189SShuo Chen      merge();
41885147189SShuo Chen    }
41985147189SShuo Chen  }
42085147189SShuo Chen  else
42185147189SShuo Chen  {
42285147189SShuo Chen    // Run all three steps
42385147189SShuo Chen    Timer timer;
42485147189SShuo Chen    LOG_INFO << argc - optind << " input files, " << kShards << " shards, "
42585147189SShuo Chen             << "output " << g_output <<" , temp " << shard_dir;
42685147189SShuo Chen    int64_t input = 0;
42785147189SShuo Chen    input = shard_(argc, argv);
42885147189SShuo Chen    count_shards(kShards);
42985147189SShuo Chen    int64_t output_size = merge();
43085147189SShuo Chen    LOG_INFO << "All done " << timer.report(input) << " output " << output_size;
43185147189SShuo Chen  }
4320ab2e892SShuo Chen}
433