1// reproduce race condition of Factory.cc if compiled with -DREPRODUCE_BUG
2
3#include "../Mutex.h"
4
5#include <boost/noncopyable.hpp>
6
7#include <memory>
8#include <unordered_map>
9
10#include <assert.h>
11#include <stdio.h>
12#include <unistd.h>
13
14using std::string;
15
16void sleepMs(int ms)
17{
18  usleep(ms * 1000);
19}
20
21class Stock : boost::noncopyable
22{
23 public:
24  Stock(const string& name)
25    : name_(name)
26  {
27    printf("%s: Stock[%p] %s\n", muduo::CurrentThread::name(), this, name_.c_str());
28  }
29
30  ~Stock()
31  {
32    printf("%s: ~Stock[%p] %s\n", muduo::CurrentThread::name(), this, name_.c_str());
33  }
34
35  const string& key() const { return name_; }
36
37 private:
38  string name_;
39};
40
41
42class StockFactory : boost::noncopyable
43{
44 public:
45
46  std::shared_ptr<Stock> get(const string& key)
47  {
48    std::shared_ptr<Stock> pStock;
49    muduo::MutexLockGuard lock(mutex_);
50    std::weak_ptr<Stock>& wkStock = stocks_[key];
51    pStock = wkStock.lock();
52    if (!pStock)
53    {
54      pStock.reset(new Stock(key),
55                   [this] (Stock* stock) { deleteStock(stock); });
56      wkStock = pStock;
57    }
58    return pStock;
59  }
60
61 private:
62
63  void deleteStock(Stock* stock)
64  {
65    printf("%s: deleteStock[%p]\n", muduo::CurrentThread::name(), stock);
66    if (stock)
67    {
68      sleepMs(500);
69      muduo::MutexLockGuard lock(mutex_);
70#ifdef REPRODUCE_BUG
71      printf("%s: erase %zd\n", muduo::CurrentThread::name(), stocks_.erase(stock->key()));
72#else
73      auto it = stocks_.find(stock->key());
74      assert(it != stocks_.end());
75      if (it->second.expired())
76      {
77        stocks_.erase(it);
78      }
79      else
80      {
81        printf("%s: %s is not expired\n", muduo::CurrentThread::name(), stock->key().c_str());
82      }
83#endif
84    }
85    delete stock;  // sorry, I lied
86  }
87
88  mutable muduo::MutexLock mutex_;
89  std::unordered_map<string, std::weak_ptr<Stock> > stocks_;
90};
91
92void threadB(StockFactory* factory)
93{
94  sleepMs(250);
95  auto stock = factory->get("MS");
96  printf("%s: stock %p\n", muduo::CurrentThread::name(), stock.get());
97
98  sleepMs(500);
99  auto stock2 = factory->get("MS");
100  printf("%s: stock2 %p\n", muduo::CurrentThread::name(), stock2.get());
101  if (stock != stock2)
102  {
103    printf("WARNING: stock != stock2\n");
104  }
105}
106
107int main()
108{
109  StockFactory factory;
110  muduo::Thread thr([&factory] { threadB(&factory); }, "thrB");
111  thr.start();
112  {
113  auto stock = factory.get("MS");
114  printf("%s: stock %p\n", muduo::CurrentThread::name(), stock.get());
115  }
116  thr.join();
117}
118