1/* sort word by frequency, sharding version. 2 3 1. read input file, shard to N files: 4 word 5 2. assume each shard file fits in memory, read each shard file, count words and sort by count, then write to N count files: 6 count \t word 7 3. merge N count files using heap. 8 9Limits: each shard must fit in memory. 10*/ 11 12#include <assert.h> 13 14#include "file.h" 15#include "merge.h" 16#include "timer.h" 17 18#include "absl/container/flat_hash_map.h" 19#include "absl/hash/hash.h" 20#include "absl/strings/str_format.h" 21#include "muduo/base/Logging.h" 22#include "muduo/base/ThreadPool.h" 23 24#include <algorithm> 25#include <memory> 26#include <string> 27#include <unordered_map> 28#include <vector> 29 30#include <fcntl.h> 31#include <string.h> 32#include <sys/mman.h> 33#include <sys/stat.h> 34#include <unistd.h> 35 36using absl::string_view; 37using std::string; 38using std::vector; 39using std::unique_ptr; 40 41int kShards = 10, kThreads = 4; 42bool g_verbose = false, g_keep = false; 43const char* shard_dir = "."; 44const char* g_output = "output"; 45 46class Sharder // : boost::noncopyable 47{ 48 public: 49 Sharder() 50 : files_(kShards) 51 { 52 for (int i = 0; i < kShards; ++i) 53 { 54 char name[256]; 55 snprintf(name, sizeof name, "%s/shard-%05d-of-%05d", shard_dir, i, kShards); 56 files_[i].reset(new OutputFile(name)); 57 } 58 assert(files_.size() == static_cast<size_t>(kShards)); 59 } 60 61 void output(string_view word) 62 { 63 size_t shard = hash(word) % files_.size(); 64 files_[shard]->appendRecord(word); 65 } 66 67 void finish() 68 { 69 int shard = 0; 70 for (const auto& file : files_) 71 { 72 // if (g_verbose) 73 printf(" shard %d: %ld bytes, %ld items\n", shard, file->tell(), file->items()); 74 ++shard; 75 file->close(); 76 } 77 } 78 79 private: 80 absl::Hash<string_view> hash; 81 vector<unique_ptr<OutputFile>> files_; 82}; 83 84int64_t shard_(int argc, char* argv[]) 85{ 86 Sharder sharder; 87 Timer timer; 88 int64_t total = 0; 89 for (int i = optind; i < argc; ++i) 90 { 91 LOG_INFO << "Processing input file " << argv[i]; 92 double t = Timer::now(); 93 string line; 94 InputFile input(argv[i]); 95 while (input.getline(&line)) 96 { 97 sharder.output(line); 98 } 99 size_t len = input.tell(); 100 total += len; 101 double sec = Timer::now() - t; 102 LOG_INFO << "Done file " << argv[i] << absl::StrFormat(" %.3f sec %.2f MiB/s", sec, len / sec / 1024 / 1024); 103 } 104 sharder.finish(); 105 LOG_INFO << "Sharding done " << timer.report(total); 106 return total; 107} 108 109// ======= count_shards ======= 110 111void count_shard(int shard, int fd, size_t len) 112{ 113 Timer timer; 114 115 double t = Timer::now(); 116 LOG_INFO << absl::StrFormat("counting shard %d: input file size %ld", shard, len); 117 { 118 void* mapped = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0); 119 assert(mapped != MAP_FAILED); 120 const uint8_t* const start = static_cast<const uint8_t*>(mapped); 121 const uint8_t* const end = start + len; 122 123 // std::unordered_map<string_view, uint64_t> items; 124 absl::flat_hash_map<string_view, uint64_t> items; 125 int64_t count = 0; 126 for (const uint8_t* p = start; p < end;) 127 { 128 string_view s((const char*)p+1, *p); 129 items[s]++; 130 p += 1 + *p; 131 ++count; 132 } 133 LOG_INFO << "items " << count << " unique " << items.size(); 134 if (g_verbose) 135 printf(" count %.3f sec %ld items\n", Timer::now() - t, items.size()); 136 137 t = Timer::now(); 138 vector<std::pair<size_t, string_view>> counts; 139 for (const auto& it : items) 140 { 141 if (it.second > 1) 142 counts.push_back(std::make_pair(it.second, it.first)); 143 } 144 if (g_verbose) 145 printf(" select %.3f sec %ld\n", Timer::now() - t, counts.size()); 146 147 t = Timer::now(); 148 std::sort(counts.begin(), counts.end()); 149 if (g_verbose) 150 printf(" sort %.3f sec\n", Timer::now() - t); 151 152 t = Timer::now(); 153 int64_t out_len = 0; 154 { 155 char buf[256]; 156 snprintf(buf, sizeof buf, "count-%05d", shard); 157 OutputFile output(buf); 158 159 for (auto it = counts.rbegin(); it != counts.rend(); ++it) 160 { 161 output.write(absl::StrFormat("%d\t%s\n", it->first, it->second)); 162 } 163 164 for (const auto& it : items) 165 { 166 if (it.second == 1) 167 { 168 output.write(absl::StrFormat("1\t%s\n", it.first)); 169 } 170 } 171 out_len = output.tell(); 172 } 173 if (g_verbose) 174 printf(" output %.3f sec %lu\n", Timer::now() - t, out_len); 175 176 if (munmap(mapped, len)) 177 perror("munmap"); 178 } 179 ::close(fd); 180 LOG_INFO << "shard " << shard << " done " << timer.report(len); 181} 182 183void count_shards(int shards) 184{ 185 assert(shards <= kShards); 186 Timer timer; 187 int64_t total = 0; 188 muduo::ThreadPool threadPool; 189 threadPool.setMaxQueueSize(2*kThreads); 190 threadPool.start(kThreads); 191 192 for (int shard = 0; shard < shards; ++shard) 193 { 194 char buf[256]; 195 snprintf(buf, sizeof buf, "%s/shard-%05d-of-%05d", shard_dir, shard, kShards); 196 int fd = open(buf, O_RDONLY); 197 assert(fd >= 0); 198 if (!g_keep) 199 ::unlink(buf); 200 201 struct stat st; 202 if (::fstat(fd, &st) == 0) 203 { 204 size_t len = st.st_size; 205 total += len; 206 threadPool.run([shard, fd, len]{ count_shard(shard, fd, len); }); 207 } 208 } 209 while (threadPool.queueSize() > 0) 210 { 211 LOG_DEBUG << "waiting for ThreadPool " << threadPool.queueSize(); 212 muduo::CurrentThread::sleepUsec(1000*1000); 213 } 214 threadPool.stop(); 215 LOG_INFO << "Counting done "<< timer.report(total); 216} 217 218// ======= merge ======= 219 220int main(int argc, char* argv[]) 221{ 222 /* 223 int fd = open("shard-00000-of-00010", O_RDONLY); 224 double t = Timer::now(); 225 int64_t len = count_shard(0, fd); 226 double sec = Timer::now() - t; 227 printf("count_shard %.3f sec %.2f MB/s\n", sec, len / sec / 1e6); 228 */ 229 setlocale(LC_NUMERIC, ""); 230 231 int opt; 232 int count_only = 0; 233 int merge_only = 0; 234 while ((opt = getopt(argc, argv, "c:km:o:p:s:t:v")) != -1) 235 { 236 switch (opt) 237 { 238 case 'c': 239 count_only = atoi(optarg); 240 break; 241 case 'k': 242 g_keep = true; 243 break; 244 case 'm': 245 merge_only = atoi(optarg); 246 break; 247 case 'o': 248 g_output = optarg; 249 break; 250 case 'p': // Path for temp shard files 251 shard_dir = optarg; 252 break; 253 case 's': 254 kShards = atoi(optarg); 255 break; 256 case 't': 257 kThreads = atoi(optarg); 258 break; 259 case 'v': 260 g_verbose = true; 261 break; 262 } 263 } 264 265 if (count_only > 0 || merge_only) 266 { 267 g_keep = true; 268 //g_verbose = true; 269 count_only = std::min(count_only, kShards); 270 271 if (count_only > 0) 272 { 273 count_shards(count_only); 274 } 275 276 if (merge_only > 0) 277 { 278 merge(merge_only); 279 } 280 } 281 else 282 { 283 // Run all three steps 284 Timer timer; 285 LOG_INFO << argc - optind << " input files, " << kShards << " shards, " 286 << "output " << g_output <<" , temp " << shard_dir; 287 int64_t input = 0; 288 input = shard_(argc, argv); 289 count_shards(kShards); 290 int64_t output_size = merge(kShards); 291 LOG_INFO << "All done " << timer.report(input) << " output " << output_size; 292 } 293} 294