1#include "datetime/Timestamp.h"
2#include "Acceptor.h"
3#include "InetAddress.h"
4#include "TcpStream.h"
5
6#ifdef __linux
7#include <linux/tcp.h>
8#else
9#include <netinet/tcp.h>
10#endif
11#include <stdio.h>
12#include <stdlib.h>
13#include <unistd.h>
14
15using muduo::Timestamp;
16
17class BandwidthReporter
18{
19 public:
20  BandwidthReporter(int fd, bool sender)
21      : fd_(fd), sender_(sender)
22  {
23  }
24
25  void reportDelta(double now, int64_t total_bytes)
26  {
27    report(now, total_bytes - last_bytes_, now - last_time_);
28    last_time_ = now;
29    last_bytes_ = total_bytes;
30  }
31
32  void reportAll(double now, int64_t total_bytes, int64_t syscalls)
33  {
34    printf("Transferred %.3fMB %.3fMiB in %.3fs, %lld syscalls, %.1f Bytes/syscall\n",
35           total_bytes / 1e6, total_bytes / (1024.0 * 1024), now, (long long)syscalls,
36           total_bytes * 1.0 / syscalls);
37    report(now, total_bytes, now);
38  }
39
40 private:
41  void report(double now, int64_t bytes, double elapsed)
42  {
43    double mbps = elapsed > 0 ? bytes / 1e6 / elapsed : 0.0;
44    printf("%6.3f  %6.2fMB/s  %6.1fMbits/s ", now, mbps, mbps*8);
45    if (sender_)
46      printSender();
47    else
48      printReceiver();
49  }
50
51  void printSender()
52  {
53    int sndbuf = 0;
54    socklen_t optlen = sizeof sndbuf;
55    if (::getsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &sndbuf, &optlen) < 0)
56      perror("getsockopt(SNDBUF)");
57
58    struct tcp_info tcpi = {0};
59    socklen_t len = sizeof(tcpi);
60    if (getsockopt(fd_, IPPROTO_TCP, TCP_INFO, &tcpi, &len) < 0)
61      perror("getsockopt(TCP_INFO)");
62
63    // bytes_in_flight = tcpi.tcpi_bytes_sent - tcpi.tcpi_bytes_acked;
64    // tcpi.tcpi_notsent_bytes;
65    int snd_cwnd = tcpi.tcpi_snd_cwnd;
66    int ssthresh = tcpi.tcpi_snd_ssthresh;
67#ifdef __linux
68    snd_cwnd *= tcpi.tcpi_snd_mss;  // Linux's cwnd is # of mss.
69    if (ssthresh < INT32_MAX)
70      ssthresh *= tcpi.tcpi_snd_mss;
71#endif
72
73#ifdef __linux
74    int retrans = tcpi.tcpi_total_retrans;
75#elif __FreeBSD__
76    int retrans = tcpi.tcpi_snd_rexmitpack;
77#endif
78
79    printf(" sndbuf=%.1fK snd_cwnd=%.1fK ssthresh=%.1fK snd_wnd=%.1fK rtt=%d/%d",
80           sndbuf / 1024.0, snd_cwnd / 1024.0, ssthresh / 1024.0,
81           tcpi.tcpi_snd_wnd / 1024.0, tcpi.tcpi_rtt, tcpi.tcpi_rttvar);
82    if (retrans - last_retrans_ > 0) {
83      printf(" retrans=%d", retrans - last_retrans_);
84    }
85    printf("\n");
86    last_retrans_ = retrans;
87  }
88
89  void printReceiver() const
90  {
91    int rcvbuf = 0;
92    socklen_t optlen = sizeof rcvbuf;
93    if (::getsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &rcvbuf, &optlen) < 0)
94      perror("getsockopt(RCVBUF)");
95
96    printf(" rcvbuf=%.1fK\n", rcvbuf / 1024.0);
97  }
98
99  const int fd_ = 0;
100  const bool sender_ = false;
101  double last_time_ = 0;
102  int64_t last_bytes_ = 0;
103  int last_retrans_ = 0;
104};
105
106void runClient(const InetAddress& serverAddr, int64_t bytes_limit, double duration)
107{
108  TcpStreamPtr stream(TcpStream::connect(serverAddr));
109  if (!stream) {
110    printf("Unable to connect %s\n", serverAddr.toIpPort().c_str());
111    perror("");
112    return;
113  }
114  char cong[64] = "";
115  socklen_t optlen = sizeof cong;
116  if (::getsockopt(stream->fd(), IPPROTO_TCP, TCP_CONGESTION, cong, &optlen) < 0)
117      perror("getsockopt(TCP_CONGESTION)");
118  printf("Connected %s -> %s, congestion control: %s\n",
119         stream->getLocalAddr().toIpPort().c_str(),
120         stream->getPeerAddr().toIpPort().c_str(), cong);
121
122  const Timestamp start = Timestamp::now();
123  const int block_size = 64 * 1024;
124  std::string message(block_size, 'S');
125  int seconds = 1;
126  int64_t total_bytes = 0;
127  int64_t syscalls = 0;
128  double elapsed = 0;
129  BandwidthReporter rpt(stream->fd(), true);
130  rpt.reportDelta(0, 0);
131
132  while (total_bytes < bytes_limit) {
133    int bytes = std::min<int64_t>(message.size(), bytes_limit - total_bytes);
134    int nw = stream->sendSome(message.data(), bytes);
135    if (nw <= 0)
136      break;
137    total_bytes += nw;
138    syscalls++;
139    elapsed = timeDifference(Timestamp::now(), start);
140
141    if (elapsed >= duration)
142      break;
143
144    if (elapsed >= seconds) {
145      rpt.reportDelta(elapsed, total_bytes);
146      while (elapsed >= seconds)
147        ++seconds;
148    }
149  }
150
151  stream->shutdownWrite();
152  Timestamp shutdown = Timestamp::now();
153  elapsed = timeDifference(shutdown, start);
154  rpt.reportDelta(elapsed, total_bytes);
155
156  char buf[1024];
157  int nr = stream->receiveSome(buf, sizeof buf);
158  if (nr != 0)
159    printf("nr = %d\n", nr);
160  Timestamp end = Timestamp::now();
161  elapsed = timeDifference(end, start);
162  rpt.reportAll(elapsed, total_bytes, syscalls);
163}
164
165void runServer(int port)
166{
167  InetAddress listenAddr(port);
168  Acceptor acceptor(listenAddr);
169  int count = 0;
170  while (true) {
171    printf("Accepting on port %d ... Ctrl-C to exit\n", port);
172    TcpStreamPtr stream = acceptor.accept();
173    ++count;
174    printf("accepted no. %d client %s <- %s\n", count,
175           stream->getLocalAddr().toIpPort().c_str(),
176           stream->getPeerAddr().toIpPort().c_str());
177
178    const Timestamp start = Timestamp::now();
179    int seconds = 1;
180    int64_t bytes = 0;
181    int64_t syscalls = 0;
182    double elapsed = 0;
183    BandwidthReporter rpt(stream->fd(), false);
184    rpt.reportDelta(elapsed, bytes);
185
186    char buf[65536];
187    while (true) {
188      int nr = stream->receiveSome(buf, sizeof buf);
189      if (nr <= 0)
190        break;
191      bytes += nr;
192      syscalls++;
193
194      elapsed = timeDifference(Timestamp::now(), start);
195      if (elapsed >= seconds) {
196        rpt.reportDelta(elapsed, bytes);
197        while (elapsed >= seconds)
198          ++seconds;
199      }
200    }
201    elapsed = timeDifference(Timestamp::now(), start);
202    rpt.reportAll(elapsed, bytes, syscalls);
203    printf("Client no. %d done\n", count);
204  }
205}
206
207int64_t parseBytes(const char* arg)
208{
209  char* end = NULL;
210  int64_t bytes = strtoll(arg, &end, 10);
211  switch (*end) {
212    case '\0':
213      return bytes;
214    case 'k':
215      return bytes * 1000;
216    case 'K':
217      return bytes * 1024;
218    case 'm':
219      return bytes * 1000 * 1000;
220    case 'M':
221      return bytes * 1024 * 1024;
222    case 'g':
223      return bytes * 1000 * 1000 * 1000;
224    case 'G':
225      return bytes * 1024 * 1024 * 1024;
226    default:
227      return 0;
228  }
229}
230
231int main(int argc, char* argv[])
232{
233  int opt;
234  bool client = false, server = false;
235  std::string serverAddr;
236  int port = 2009;
237  const int64_t kGigaBytes = 1024 * 1024 * 1024;
238  int64_t bytes_limit = 10 * kGigaBytes;
239  double duration = 10;
240
241  while ((opt = getopt(argc, argv, "sc:t:b:p:")) != -1) {
242    switch (opt) {
243      case 's':
244        server = true;
245        break;
246      case 'c':
247        client = true;
248        serverAddr = optarg;
249        break;
250      case 't':
251        duration = strtod(optarg, NULL);
252        break;
253      case 'b':
254        bytes_limit = parseBytes(optarg);
255        break;
256      case 'p':
257        port = strtol(optarg, NULL, 10);
258        break;
259      default:
260        fprintf(stderr, "Usage: %s FIXME\n", argv[0]);
261        break;
262    }
263  }
264
265  if (client)
266    runClient(InetAddress(serverAddr, port), bytes_limit, duration);
267  else if (server)
268    runServer(port);
269}
270