1#include "timer.h"
2#include "thread/Thread.h"
3#include <boost/bind.hpp>
4
5#include <assert.h>
6#include <fcntl.h>
7#include <stdio.h>
8#include <sys/types.h>
9#include <sys/socket.h>
10
11#include <tls.h>
12
13struct tls* client(int sockfd)
14{
15  struct tls_config* cfg = tls_config_new();
16  assert(cfg != NULL);
17
18  tls_config_set_ca_file(cfg, "ca.pem");
19  // tls_config_insecure_noverifycert(cfg);
20  // tls_config_insecure_noverifyname(cfg);
21
22  struct tls* ctx = tls_client();
23  assert(ctx != NULL);
24
25  int ret = tls_configure(ctx, cfg);
26  assert(ret == 0);
27
28  ret = tls_connect_socket(ctx, sockfd, "Test Server Cert");
29  assert(ret == 0);
30
31  return ctx;
32}
33
34struct tls* server(int sockfd)
35{
36  struct tls_config* cfg = tls_config_new();
37  assert(cfg != NULL);
38
39  int ret = tls_config_set_cert_file(cfg, "server.pem");
40  assert(ret == 0);
41
42  ret = tls_config_set_key_file(cfg, "server.pem");
43  assert(ret == 0);
44
45  ret = tls_config_set_ecdhecurve(cfg, "prime256v1");
46  assert(ret == 0);
47
48  // tls_config_verify_client_optional(cfg);
49  struct tls* ctx = tls_server();
50  assert(ctx != NULL);
51
52  ret = tls_configure(ctx, cfg);
53  assert(ret == 0);
54
55  struct tls* sctx = NULL;
56  ret = tls_accept_socket(ctx, &sctx, sockfd);
57  assert(ret == 0 && sctx != NULL);
58
59  return sctx;
60}
61
62// only works for non-blocking sockets
63bool handshake(struct tls* cctx, struct tls* sctx)
64{
65  int client_done = false, server_done = false;
66
67  while (!(client_done && server_done))
68  {
69    if (!client_done)
70    {
71      int ret = tls_handshake(cctx);
72      // printf("c %d\n", ret);
73      if (ret == 0)
74        client_done = true;
75      else if (ret == -1)
76      {
77        printf("client handshake failed: %s\n", tls_error(cctx));
78        break;
79      }
80    }
81
82    if (!server_done)
83    {
84      int ret = tls_handshake(sctx);
85      // printf("s %d\n", ret);
86      if (ret == 0)
87        server_done = true;
88      else if (ret == -1)
89      {
90        printf("server handshake failed: %s\n", tls_error(sctx));
91        break;
92      }
93    }
94  }
95
96  return client_done && server_done;
97}
98
99void setBlockingIO(int fd)
100{
101  int flags = fcntl(fd, F_GETFL, 0);
102  if (flags > 0)
103  {
104    printf("set blocking IO for %d\n", fd);
105    fcntl(fd, F_SETFL, flags & ~O_NONBLOCK);
106  }
107}
108
109const int N = 500;
110
111struct Trial
112{
113  int blocks, block_size;
114};
115
116void client_thread(struct tls* ctx)
117{
118  Timer t;
119  t.start();
120  for (int i = 0; i < N; ++i)
121  {
122    int ret = tls_handshake(ctx);
123    if (ret != 0)
124      printf("client err = %d\n", ret);
125  }
126  t.stop();
127  printf("client %f secs, %f handshakes/sec\n", t.seconds(), N / t.seconds());
128  while (true)
129  {
130    Trial trial = { 0, 0 };
131
132    int nr = tls_read(ctx, &trial, sizeof trial);
133    if (nr == 0)
134      break;
135    assert(nr == sizeof trial);
136    // printf("client read bs %d nb %d\n", trial.block_size, trial.blocks);
137    if (trial.block_size == 0)
138      break;
139    char* buf = new char[trial.block_size];
140    for (int i = 0; i < trial.blocks; ++i)
141    {
142      nr = tls_read(ctx, buf, trial.block_size);
143      assert(nr == trial.block_size);
144    }
145    int64_t ack = static_cast<int64_t>(trial.blocks) * trial.block_size;
146    int nw = tls_write(ctx, &ack, sizeof ack);
147    assert(nw == sizeof ack);
148    delete[] buf;
149  }
150  printf("client done\n");
151  tls_close(ctx);
152  tls_free(ctx);
153}
154
155void send(int block_size, struct tls* ctx)
156{
157  double start = now();
158  int total = 0;
159  int blocks = 1024;
160  char* message = new char[block_size];
161  bzero(message, block_size);
162  Timer t;
163  while (now() - start < 10)
164  {
165    Trial trial = { blocks, block_size };
166    int nw = tls_write(ctx, &trial, sizeof trial);
167    assert(nw == sizeof trial);
168    t.start();
169    for (int i = 0; i < blocks; ++i)
170    {
171      nw = tls_write(ctx, message, block_size);
172      if (nw != block_size)
173        printf("bs %d nw %d\n", block_size, nw);
174      assert(nw == block_size);
175    }
176    t.stop();
177    int64_t ack = 0;
178    int nr = tls_read(ctx, &ack, sizeof ack);
179    assert(nr == sizeof ack);
180    assert(ack == static_cast<int64_t>(blocks) * block_size);
181    total += blocks;
182    blocks *= 2;
183  }
184  double secs = now() - start;
185  printf("bs %5d sec %.3f tot %d thr %.1fKB/s wr cpu %.3f\n", block_size, secs, total,
186         block_size / secs * total / 1024, t.seconds());
187  delete[] message;
188}
189
190int main(int argc, char* argv[])
191{
192  int ret = tls_init();
193  assert(ret == 0);
194
195  int fds[2];
196  socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, fds);
197
198  struct tls* cctx = client(fds[0]);
199  struct tls* sctx = server(fds[1]);
200
201  if (handshake(cctx, sctx))
202    printf("cipher %s\n", tls_conn_cipher(cctx));
203  else
204    return -1;
205
206  setBlockingIO(fds[0]);
207  setBlockingIO(fds[1]);
208  muduo::Thread thr(boost::bind(client_thread, cctx), "clientThread");
209  thr.start();
210
211  {
212  Timer t;
213  t.start();
214  for (int i = 0; i < N; ++i)
215  {
216    int ret = tls_handshake(sctx);
217    if (ret != 0)
218      printf("server err = %d\n", ret);
219  }
220  t.stop();
221  printf("server %f secs, %f handshakes/sec\n", t.seconds(), N / t.seconds());
222  }
223
224  for (int i = 1024 * 16; i >= 1; i /= 4)
225  {
226    send(i, sctx);
227  }
228  tls_close(sctx);
229  shutdown(fds[1], SHUT_RDWR);
230  tls_free(sctx);
231
232  thr.join();
233}
234