word_freq_shards_basic.cc revision c377920e
10ab2e892SShuo Chen/* sort word by frequency, sharding version.
20ab2e892SShuo Chen
30ab2e892SShuo Chen  1. read input file, shard to N files:
40ab2e892SShuo Chen       word
50ab2e892SShuo Chen  2. assume each shard file fits in memory, read each shard file, count words and sort by count, then write to N count files:
60ab2e892SShuo Chen       count \t word
70ab2e892SShuo Chen  3. merge N count files using heap.
80ab2e892SShuo Chen
90ab2e892SShuo ChenLimits: each shard must fit in memory.
100ab2e892SShuo Chen*/
110ab2e892SShuo Chen
120ab2e892SShuo Chen#include <assert.h>
130ab2e892SShuo Chen
1485147189SShuo Chen#include "file.h"
1585147189SShuo Chen#include "timer.h"
1685147189SShuo Chen
174136e585SShuo Chen#include "absl/container/flat_hash_map.h"
182a129a12SShuo Chen#include "absl/strings/str_format.h"
1985147189SShuo Chen#include "muduo/base/BoundedBlockingQueue.h"
202a129a12SShuo Chen#include "muduo/base/Logging.h"
21a251380aSShuo Chen#include "muduo/base/ThreadPool.h"
224136e585SShuo Chen
230ab2e892SShuo Chen#include <algorithm>
240ab2e892SShuo Chen#include <memory>
250ab2e892SShuo Chen#include <string>
260ab2e892SShuo Chen#include <unordered_map>
270ab2e892SShuo Chen#include <vector>
280ab2e892SShuo Chen
292a129a12SShuo Chen#include <boost/program_options.hpp>
302a129a12SShuo Chen
310ab2e892SShuo Chen#include <fcntl.h>
320ab2e892SShuo Chen#include <string.h>
330ab2e892SShuo Chen#include <sys/mman.h>
344136e585SShuo Chen#include <sys/stat.h>
350ab2e892SShuo Chen#include <unistd.h>
360ab2e892SShuo Chen
370ab2e892SShuo Chenusing std::string;
380ab2e892SShuo Chenusing std::string_view;
390ab2e892SShuo Chenusing std::vector;
400ab2e892SShuo Chenusing std::unique_ptr;
410ab2e892SShuo Chen
4285147189SShuo Chenint kShards = 10, kThreads = 4;
4385147189SShuo Chenbool g_verbose = false, g_keep = false;
44a6693141SShuo Chenconst char* shard_dir = ".";
4585147189SShuo Chenconst char* g_output = "output";
46270b6cceSShuo Chen
470ab2e892SShuo Chenclass Sharder // : boost::noncopyable
480ab2e892SShuo Chen{
490ab2e892SShuo Chen public:
500ab2e892SShuo Chen  Sharder()
510ab2e892SShuo Chen    : files_(kShards)
520ab2e892SShuo Chen  {
530ab2e892SShuo Chen    for (int i = 0; i < kShards; ++i)
540ab2e892SShuo Chen    {
550ab2e892SShuo Chen      char name[256];
56a6693141SShuo Chen      snprintf(name, sizeof name, "%s/shard-%05d-of-%05d", shard_dir, i, kShards);
570ab2e892SShuo Chen      files_[i].reset(new OutputFile(name));
580ab2e892SShuo Chen    }
590ab2e892SShuo Chen    assert(files_.size() == static_cast<size_t>(kShards));
600ab2e892SShuo Chen  }
610ab2e892SShuo Chen
620ab2e892SShuo Chen  void output(string_view word)
630ab2e892SShuo Chen  {
640ab2e892SShuo Chen    size_t shard = hash(word) % files_.size();
65270b6cceSShuo Chen    files_[shard]->appendRecord(word);
660ab2e892SShuo Chen  }
670ab2e892SShuo Chen
680ab2e892SShuo Chen  void finish()
690ab2e892SShuo Chen  {
704136e585SShuo Chen    int shard = 0;
714136e585SShuo Chen    for (const auto& file : files_)
720ab2e892SShuo Chen    {
7385147189SShuo Chen      // if (g_verbose)
744136e585SShuo Chen      printf("  shard %d: %ld bytes, %ld items\n", shard, file->tell(), file->items());
754136e585SShuo Chen      ++shard;
764136e585SShuo Chen      file->close();
770ab2e892SShuo Chen    }
780ab2e892SShuo Chen  }
790ab2e892SShuo Chen
800ab2e892SShuo Chen private:
810ab2e892SShuo Chen  std::hash<string_view> hash;
820ab2e892SShuo Chen  vector<unique_ptr<OutputFile>> files_;
830ab2e892SShuo Chen};
840ab2e892SShuo Chen
854136e585SShuo Chenint64_t shard_(int argc, char* argv[])
860ab2e892SShuo Chen{
870ab2e892SShuo Chen  Sharder sharder;
884136e585SShuo Chen  Timer timer;
894136e585SShuo Chen  int64_t total = 0;
902a129a12SShuo Chen  for (int i = optind; i < argc; ++i)
910ab2e892SShuo Chen  {
922a129a12SShuo Chen    LOG_INFO << "Processing input file " << argv[i];
9385147189SShuo Chen    double t = Timer::now();
9485147189SShuo Chen    string line;
9585147189SShuo Chen    InputFile input(argv[i]);
9685147189SShuo Chen    while (input.getline(&line))
970ab2e892SShuo Chen    {
9885147189SShuo Chen      sharder.output(line);
990ab2e892SShuo Chen    }
10085147189SShuo Chen    size_t len = input.tell();
1014136e585SShuo Chen    total += len;
10285147189SShuo Chen    double sec = Timer::now() - t;
1032a129a12SShuo Chen    LOG_INFO << "Done file " << argv[i] << absl::StrFormat(" %.3f sec %.2f MiB/s", sec, len / sec / 1024 / 1024);
1040ab2e892SShuo Chen  }
1050ab2e892SShuo Chen  sharder.finish();
1062a129a12SShuo Chen  LOG_INFO << "Sharding done " << timer.report(total);
1074136e585SShuo Chen  return total;
1080ab2e892SShuo Chen}
1090ab2e892SShuo Chen
1100ab2e892SShuo Chen// ======= count_shards =======
1110ab2e892SShuo Chen
112ecd7048bSShuo Chenvoid count_shard(int shard, int fd, size_t len)
1130ab2e892SShuo Chen{
114ecd7048bSShuo Chen  Timer timer;
115ecd7048bSShuo Chen
11685147189SShuo Chen  double t = Timer::now();
1172a129a12SShuo Chen  LOG_INFO << absl::StrFormat("counting shard %d: input file size %ld", shard, len);
1184136e585SShuo Chen  {
1190ab2e892SShuo Chen  void* mapped = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
1200ab2e892SShuo Chen  assert(mapped != MAP_FAILED);
1210ab2e892SShuo Chen  const uint8_t* const start = static_cast<const uint8_t*>(mapped);
1220ab2e892SShuo Chen  const uint8_t* const end = start + len;
1230ab2e892SShuo Chen
1244136e585SShuo Chen  // std::unordered_map<string_view, uint64_t> items;
1254136e585SShuo Chen  absl::flat_hash_map<string_view, uint64_t> items;
1262a129a12SShuo Chen  int64_t count = 0;
1270ab2e892SShuo Chen  for (const uint8_t* p = start; p < end;)
1280ab2e892SShuo Chen  {
1290ab2e892SShuo Chen    string_view s((const char*)p+1, *p);
1300ab2e892SShuo Chen    items[s]++;
1310ab2e892SShuo Chen    p += 1 + *p;
1322a129a12SShuo Chen    ++count;
1330ab2e892SShuo Chen  }
134270b6cceSShuo Chen  LOG_INFO << "items " << count << " unique " << items.size();
13585147189SShuo Chen  if (g_verbose)
13685147189SShuo Chen  printf("  count %.3f sec %ld items\n", Timer::now() - t, items.size());
1370ab2e892SShuo Chen
13885147189SShuo Chen  t = Timer::now();
1390ab2e892SShuo Chen  vector<std::pair<size_t, string_view>> counts;
1400ab2e892SShuo Chen  for (const auto& it : items)
1410ab2e892SShuo Chen  {
1420ab2e892SShuo Chen    if (it.second > 1)
1430ab2e892SShuo Chen      counts.push_back(make_pair(it.second, it.first));
1440ab2e892SShuo Chen  }
14585147189SShuo Chen  if (g_verbose)
14685147189SShuo Chen  printf("  select %.3f sec %ld\n", Timer::now() - t, counts.size());
1470ab2e892SShuo Chen
14885147189SShuo Chen  t = Timer::now();
1490ab2e892SShuo Chen  std::sort(counts.begin(), counts.end());
15085147189SShuo Chen  if (g_verbose)
15185147189SShuo Chen  printf("  sort %.3f sec\n", Timer::now() - t);
1520ab2e892SShuo Chen
15385147189SShuo Chen  t = Timer::now();
154c377920eSShuo Chen  int64_t out_len = 0;
1550ab2e892SShuo Chen  {
156ecd7048bSShuo Chen    char buf[256];
157ecd7048bSShuo Chen    snprintf(buf, sizeof buf, "count-%05d-of-%05d", shard, kShards);
158ecd7048bSShuo Chen    OutputFile output(buf);
159ecd7048bSShuo Chen
1604136e585SShuo Chen    for (auto it = counts.rbegin(); it != counts.rend(); ++it)
1610ab2e892SShuo Chen    {
162c377920eSShuo Chen      output.write(absl::StrFormat("%d\t%s\n", it->first, it->second));
1634136e585SShuo Chen    }
164270b6cceSShuo Chen
1654136e585SShuo Chen    for (const auto& it : items)
1664136e585SShuo Chen    {
1674136e585SShuo Chen      if (it.second == 1)
1684136e585SShuo Chen      {
169c377920eSShuo Chen        output.write(absl::StrFormat("1\t%s\n", it.first));
1704136e585SShuo Chen      }
1710ab2e892SShuo Chen    }
172c377920eSShuo Chen    out_len = output.tell();
1730ab2e892SShuo Chen  }
174c377920eSShuo Chen  if (g_verbose)
175c377920eSShuo Chen  printf("  output %.3f sec %lu\n", Timer::now() - t, out_len);
1760ab2e892SShuo Chen
1770ab2e892SShuo Chen  if (munmap(mapped, len))
1780ab2e892SShuo Chen    perror("munmap");
1794136e585SShuo Chen  }
180ecd7048bSShuo Chen  ::close(fd);
181ecd7048bSShuo Chen  LOG_INFO << "shard " << shard << " done " << timer.report(len);
1820ab2e892SShuo Chen}
1830ab2e892SShuo Chen
18485147189SShuo Chenvoid count_shards(int shards)
1850ab2e892SShuo Chen{
18685147189SShuo Chen  assert(shards <= kShards);
1874136e585SShuo Chen  Timer timer;
1884136e585SShuo Chen  int64_t total = 0;
189a251380aSShuo Chen  muduo::ThreadPool threadPool;
19085147189SShuo Chen  threadPool.setMaxQueueSize(2*kThreads);
19185147189SShuo Chen  threadPool.start(kThreads);
19285147189SShuo Chen
19385147189SShuo Chen  for (int shard = 0; shard < shards; ++shard)
1940ab2e892SShuo Chen  {
1950ab2e892SShuo Chen    char buf[256];
196a6693141SShuo Chen    snprintf(buf, sizeof buf, "%s/shard-%05d-of-%05d", shard_dir, shard, kShards);
1970ab2e892SShuo Chen    int fd = open(buf, O_RDONLY);
198ecd7048bSShuo Chen    assert(fd >= 0);
19985147189SShuo Chen    if (!g_keep)
200ecd7048bSShuo Chen      ::unlink(buf);
2012a129a12SShuo Chen
202ecd7048bSShuo Chen    struct stat st;
203ecd7048bSShuo Chen    if (::fstat(fd, &st) == 0)
204ecd7048bSShuo Chen    {
205ecd7048bSShuo Chen      size_t len = st.st_size;
206ecd7048bSShuo Chen      total += len;
207ecd7048bSShuo Chen      threadPool.run([shard, fd, len]{ count_shard(shard, fd, len); });
208ecd7048bSShuo Chen    }
209a251380aSShuo Chen  }
210a251380aSShuo Chen  while (threadPool.queueSize() > 0)
211a251380aSShuo Chen  {
21285147189SShuo Chen    LOG_DEBUG << "waiting for ThreadPool " << threadPool.queueSize();
213ecd7048bSShuo Chen    muduo::CurrentThread::sleepUsec(1000*1000);
2140ab2e892SShuo Chen  }
215a251380aSShuo Chen  threadPool.stop();
216270b6cceSShuo Chen  LOG_INFO << "Counting done "<< timer.report(total);
2170ab2e892SShuo Chen}
2180ab2e892SShuo Chen
2190ab2e892SShuo Chen// ======= merge =======
2200ab2e892SShuo Chen
2210ab2e892SShuo Chenclass Source  // copyable
2220ab2e892SShuo Chen{
2230ab2e892SShuo Chen public:
224270b6cceSShuo Chen  explicit Source(InputFile* in)
2250ab2e892SShuo Chen    : in_(in),
2260ab2e892SShuo Chen      count_(0),
2270ab2e892SShuo Chen      word_()
2280ab2e892SShuo Chen  {
2290ab2e892SShuo Chen  }
2300ab2e892SShuo Chen
2310ab2e892SShuo Chen  bool next()
2320ab2e892SShuo Chen  {
2330ab2e892SShuo Chen    string line;
234270b6cceSShuo Chen    if (in_->getline(&line))
2350ab2e892SShuo Chen    {
2360ab2e892SShuo Chen      size_t tab = line.find('\t');
2370ab2e892SShuo Chen      if (tab != string::npos)
2380ab2e892SShuo Chen      {
2390ab2e892SShuo Chen        count_ = strtol(line.c_str(), NULL, 10);
2400ab2e892SShuo Chen        if (count_ > 0)
2410ab2e892SShuo Chen        {
2420ab2e892SShuo Chen          word_ = line.substr(tab+1);
2430ab2e892SShuo Chen          return true;
2440ab2e892SShuo Chen        }
2450ab2e892SShuo Chen      }
2460ab2e892SShuo Chen    }
2470ab2e892SShuo Chen    return false;
2480ab2e892SShuo Chen  }
2490ab2e892SShuo Chen
2500ab2e892SShuo Chen  bool operator<(const Source& rhs) const
2510ab2e892SShuo Chen  {
2520ab2e892SShuo Chen    return count_ < rhs.count_;
2530ab2e892SShuo Chen  }
2540ab2e892SShuo Chen
255270b6cceSShuo Chen  void outputTo(OutputFile* out) const
2560ab2e892SShuo Chen  {
25785147189SShuo Chen    //char buf[1024];
25885147189SShuo Chen    //snprintf(buf, sizeof buf, "%ld\t%s\n", count_, word_.c_str());
25985147189SShuo Chen    //out->write(buf);
260270b6cceSShuo Chen    out->write(absl::StrFormat("%d\t%s\n", count_, word_));
2610ab2e892SShuo Chen  }
2620ab2e892SShuo Chen
26385147189SShuo Chen  std::pair<int64_t, string> item()
26485147189SShuo Chen  {
26585147189SShuo Chen    return make_pair(count_, std::move(word_));
26685147189SShuo Chen  }
26785147189SShuo Chen
2680ab2e892SShuo Chen private:
269270b6cceSShuo Chen  InputFile* in_;  // not owned
2700ab2e892SShuo Chen  int64_t count_;
2710ab2e892SShuo Chen  string word_;
2720ab2e892SShuo Chen};
2730ab2e892SShuo Chen
27485147189SShuo Chenint64_t merge()
2750ab2e892SShuo Chen{
2764136e585SShuo Chen  Timer timer;
277270b6cceSShuo Chen  vector<unique_ptr<InputFile>> inputs;
2780ab2e892SShuo Chen  vector<Source> keys;
2790ab2e892SShuo Chen
2804136e585SShuo Chen  int64_t total = 0;
2810ab2e892SShuo Chen  for (int i = 0; i < kShards; ++i)
2820ab2e892SShuo Chen  {
2830ab2e892SShuo Chen    char buf[256];
2840ab2e892SShuo Chen    snprintf(buf, sizeof buf, "count-%05d-of-%05d", i, kShards);
2854136e585SShuo Chen    struct stat st;
286a6693141SShuo Chen    if (::stat(buf, &st) == 0)
2870ab2e892SShuo Chen    {
288a6693141SShuo Chen      total += st.st_size;
28985147189SShuo Chen      // TODO: select buffer size based on kShards.
29085147189SShuo Chen      inputs.push_back(std::make_unique<InputFile>(buf, 32 * 1024 * 1024));
291a6693141SShuo Chen      Source rec(inputs.back().get());
292a6693141SShuo Chen      if (rec.next())
293a6693141SShuo Chen      {
294a6693141SShuo Chen        keys.push_back(rec);
295a6693141SShuo Chen      }
29685147189SShuo Chen      if (!g_keep)
297a6693141SShuo Chen        ::unlink(buf);
298a6693141SShuo Chen    }
299a6693141SShuo Chen    else
300a6693141SShuo Chen    {
301a6693141SShuo Chen      perror("Unable to stat file:");
3020ab2e892SShuo Chen    }
3030ab2e892SShuo Chen  }
3042a129a12SShuo Chen  LOG_INFO << "merging " << inputs.size() << " files of " << total << " bytes in total";
3050ab2e892SShuo Chen
3064136e585SShuo Chen  {
30785147189SShuo Chen  OutputFile out(g_output);
30885147189SShuo Chen  /*
30985147189SShuo Chen  muduo::BoundedBlockingQueue<vector<std::pair<int64_t, string>>> queue(1024);
31085147189SShuo Chen  muduo::Thread thr([&queue] {
31185147189SShuo Chen    OutputFile out(g_output);
31285147189SShuo Chen    while (true) {
31385147189SShuo Chen      auto vec = queue.take();
31485147189SShuo Chen      if (vec.size() == 0)
31585147189SShuo Chen        break;
31685147189SShuo Chen      for (const auto& x : vec)
31785147189SShuo Chen        out.write(absl::StrFormat("%d\t%s\n", x.first, x.second));
31885147189SShuo Chen    }
31985147189SShuo Chen  });
32085147189SShuo Chen  thr.start();
32185147189SShuo Chen
32285147189SShuo Chen  vector<std::pair<int64_t, string>> batch;
32385147189SShuo Chen  */
3240ab2e892SShuo Chen  std::make_heap(keys.begin(), keys.end());
3250ab2e892SShuo Chen  while (!keys.empty())
3260ab2e892SShuo Chen  {
3270ab2e892SShuo Chen    std::pop_heap(keys.begin(), keys.end());
328270b6cceSShuo Chen    keys.back().outputTo(&out);
32985147189SShuo Chen    /*
33085147189SShuo Chen    batch.push_back(std::move(keys.back().item()));
33185147189SShuo Chen    if (batch.size() >= 10*1024*1024)
33285147189SShuo Chen    {
33385147189SShuo Chen      queue.put(std::move(batch));
33485147189SShuo Chen      batch.clear();
33585147189SShuo Chen    }
33685147189SShuo Chen    */
3370ab2e892SShuo Chen
3380ab2e892SShuo Chen    if (keys.back().next())
3390ab2e892SShuo Chen    {
3400ab2e892SShuo Chen      std::push_heap(keys.begin(), keys.end());
3410ab2e892SShuo Chen    }
3420ab2e892SShuo Chen    else
3430ab2e892SShuo Chen    {
3440ab2e892SShuo Chen      keys.pop_back();
3450ab2e892SShuo Chen    }
3460ab2e892SShuo Chen  }
34785147189SShuo Chen  /*
34885147189SShuo Chen  queue.put(batch);
34985147189SShuo Chen  batch.clear();
35085147189SShuo Chen  queue.put(batch);
35185147189SShuo Chen  thr.join();
35285147189SShuo Chen  */
3534136e585SShuo Chen  }
354a251380aSShuo Chen  LOG_INFO << "Merging done " << timer.report(total);
3552a129a12SShuo Chen  return total;
3560ab2e892SShuo Chen}
3570ab2e892SShuo Chen
3580ab2e892SShuo Chenint main(int argc, char* argv[])
3590ab2e892SShuo Chen{
3600ab2e892SShuo Chen  /*
3610ab2e892SShuo Chen  int fd = open("shard-00000-of-00010", O_RDONLY);
36285147189SShuo Chen  double t = Timer::now();
3634136e585SShuo Chen  int64_t len = count_shard(0, fd);
36485147189SShuo Chen  double sec = Timer::now() - t;
3654136e585SShuo Chen  printf("count_shard %.3f sec %.2f MB/s\n", sec, len / sec / 1e6);
3664136e585SShuo Chen  */
3673e607da5SShuo Chen  setlocale(LC_NUMERIC, "");
3680ab2e892SShuo Chen
3692a129a12SShuo Chen  int opt;
37085147189SShuo Chen  int count_only = 0;
37185147189SShuo Chen  bool merge_only = false;
37285147189SShuo Chen  while ((opt = getopt(argc, argv, "c:kmo:p:s:t:v")) != -1)
3732a129a12SShuo Chen  {
3742a129a12SShuo Chen    switch (opt)
3752a129a12SShuo Chen    {
37685147189SShuo Chen      case 'c':
37785147189SShuo Chen        count_only = atoi(optarg);
37885147189SShuo Chen        break;
3792a129a12SShuo Chen      case 'k':
38085147189SShuo Chen        g_keep = true;
38185147189SShuo Chen        break;
38285147189SShuo Chen      case 'm':
38385147189SShuo Chen        merge_only = true;
3842a129a12SShuo Chen        break;
385a6693141SShuo Chen      case 'o':
38685147189SShuo Chen        g_output = optarg;
38785147189SShuo Chen        break;
38885147189SShuo Chen      case 'p':  // Path for temp shard files
38985147189SShuo Chen        shard_dir = optarg;
390a6693141SShuo Chen        break;
3912a129a12SShuo Chen      case 's':
3922a129a12SShuo Chen        kShards = atoi(optarg);
3932a129a12SShuo Chen        break;
394a6693141SShuo Chen      case 't':
39585147189SShuo Chen        kThreads = atoi(optarg);
396a6693141SShuo Chen        break;
3972a129a12SShuo Chen      case 'v':
39885147189SShuo Chen        g_verbose = true;
3992a129a12SShuo Chen        break;
4002a129a12SShuo Chen    }
4012a129a12SShuo Chen  }
4022a129a12SShuo Chen
40385147189SShuo Chen  if (count_only > 0 || merge_only)
40485147189SShuo Chen  {
40585147189SShuo Chen    g_keep = true;
40685147189SShuo Chen    g_verbose = true;
40785147189SShuo Chen    count_only = std::min(count_only, kShards);
40885147189SShuo Chen
40985147189SShuo Chen    if (count_only > 0)
41085147189SShuo Chen    {
41185147189SShuo Chen      count_shards(count_only);
41285147189SShuo Chen    }
41385147189SShuo Chen
41485147189SShuo Chen    if (merge_only)
41585147189SShuo Chen    {
41685147189SShuo Chen      merge();
41785147189SShuo Chen    }
41885147189SShuo Chen  }
41985147189SShuo Chen  else
42085147189SShuo Chen  {
42185147189SShuo Chen    // Run all three steps
42285147189SShuo Chen    Timer timer;
42385147189SShuo Chen    LOG_INFO << argc - optind << " input files, " << kShards << " shards, "
42485147189SShuo Chen             << "output " << g_output <<" , temp " << shard_dir;
42585147189SShuo Chen    int64_t input = 0;
42685147189SShuo Chen    input = shard_(argc, argv);
42785147189SShuo Chen    count_shards(kShards);
42885147189SShuo Chen    int64_t output_size = merge();
42985147189SShuo Chen    LOG_INFO << "All done " << timer.report(input) << " output " << output_size;
43085147189SShuo Chen  }
4310ab2e892SShuo Chen}
432