1a98d478eSShuo Chen#include <polarssl/ctr_drbg.h>
2a98d478eSShuo Chen#include <polarssl/error.h>
3a98d478eSShuo Chen#include <polarssl/entropy.h>
4a98d478eSShuo Chen#include <polarssl/ssl.h>
5a98d478eSShuo Chen
6a98d478eSShuo Chen#include <polarssl/certs.h>
7a98d478eSShuo Chen
8a98d478eSShuo Chen#include <muduo/net/Buffer.h>
99acb42f4SShuo Chen#include <string>
10a98d478eSShuo Chen#include <stdio.h>
11a98d478eSShuo Chen#include <sys/time.h>
129acb42f4SShuo Chen#include "timer.h"
13a98d478eSShuo Chen
14a98d478eSShuo Chenmuduo::net::Buffer clientOut, serverOut;
15a98d478eSShuo Chen
16a98d478eSShuo Chenint net_recv(void* ctx, unsigned char* buf, size_t len)
17a98d478eSShuo Chen{
18a98d478eSShuo Chen  muduo::net::Buffer* in = static_cast<muduo::net::Buffer*>(ctx);
19a98d478eSShuo Chen  //printf("%s recv %zd\n", in == &clientOut ? "server" : "client", len);
20a98d478eSShuo Chen  if (in->readableBytes() > 0)
21a98d478eSShuo Chen  {
22a98d478eSShuo Chen    size_t n = std::min(in->readableBytes(), len);
23a98d478eSShuo Chen    memcpy(buf, in->peek(), n);
24a98d478eSShuo Chen    in->retrieve(n);
25a98d478eSShuo Chen
26a98d478eSShuo Chen    /*
27a98d478eSShuo Chen    if (n < len)
28a98d478eSShuo Chen      printf("got %zd\n", n);
29a98d478eSShuo Chen    else
30a98d478eSShuo Chen      printf("\n");
31a98d478eSShuo Chen      */
32a98d478eSShuo Chen    return n;
33a98d478eSShuo Chen  }
34a98d478eSShuo Chen  else
35a98d478eSShuo Chen  {
36a98d478eSShuo Chen    //printf("got 0\n");
37a98d478eSShuo Chen    return POLARSSL_ERR_NET_WANT_READ;
38a98d478eSShuo Chen  }
39a98d478eSShuo Chen}
40a98d478eSShuo Chen
41a98d478eSShuo Chenint net_send(void* ctx, const unsigned char* buf, size_t len)
42a98d478eSShuo Chen{
43a98d478eSShuo Chen  muduo::net::Buffer* out = static_cast<muduo::net::Buffer*>(ctx);
44a98d478eSShuo Chen  // printf("%s send %zd\n", out == &clientOut ? "client" : "server", len);
45a98d478eSShuo Chen  out->append(buf, len);
46a98d478eSShuo Chen  return len;
47a98d478eSShuo Chen}
48a98d478eSShuo Chen
49a98d478eSShuo Chenint main(int argc, char* argv[])
50a98d478eSShuo Chen{
51a98d478eSShuo Chen  entropy_context entropy;
52a98d478eSShuo Chen  entropy_init(&entropy);
53a98d478eSShuo Chen  ctr_drbg_context ctr_drbg;
54a98d478eSShuo Chen  ctr_drbg_init(&ctr_drbg, entropy_func, &entropy, NULL, 0);
55a98d478eSShuo Chen
56a98d478eSShuo Chen  ssl_context ssl;
57a98d478eSShuo Chen  bzero(&ssl, sizeof ssl);
58a98d478eSShuo Chen  ssl_init(&ssl);
59a98d478eSShuo Chen  ssl_set_rng(&ssl, ctr_drbg_random, &ctr_drbg);
60a98d478eSShuo Chen  ssl_set_bio(&ssl, &net_recv, &serverOut, &net_send, &clientOut);
61a98d478eSShuo Chen  ssl_set_endpoint(&ssl, SSL_IS_CLIENT);
62a98d478eSShuo Chen  ssl_set_authmode(&ssl, SSL_VERIFY_NONE);
63a98d478eSShuo Chen
64a98d478eSShuo Chen  const char* srv_cert = test_srv_crt_ec;
65a98d478eSShuo Chen  const char* srv_key = test_srv_key_ec;
669acb42f4SShuo Chen  std::string arg = argc > 1 ? argv[1] : "r";
679acb42f4SShuo Chen  bool useRSA = arg == "r" || arg == "er";
689acb42f4SShuo Chen  bool useECDHE = arg == "er" || arg == "ee";
69a98d478eSShuo Chen  if (useRSA)
70a98d478eSShuo Chen  {
71a98d478eSShuo Chen    srv_cert = test_srv_crt;
72a98d478eSShuo Chen    srv_key = test_srv_key;
73a98d478eSShuo Chen  }
74a98d478eSShuo Chen  x509_crt cert;
75a98d478eSShuo Chen  x509_crt_init(&cert);
76a98d478eSShuo Chen  // int ret = x509_crt_parse_file(&cert, argv[1]);
77a98d478eSShuo Chen  // printf("cert parse %d\n", ret);
78a98d478eSShuo Chen  x509_crt_parse(&cert, reinterpret_cast<const unsigned char*>(srv_cert), strlen(srv_cert));
79a98d478eSShuo Chen  x509_crt_parse(&cert, reinterpret_cast<const unsigned char*>(test_ca_list), strlen(test_ca_list));
80a98d478eSShuo Chen
81a98d478eSShuo Chen  pk_context pkey;
82a98d478eSShuo Chen  pk_init(&pkey);
83a98d478eSShuo Chen  pk_parse_key(&pkey, reinterpret_cast<const unsigned char*>(srv_key), strlen(srv_key), NULL, 0);
84a98d478eSShuo Chen  // ret = pk_parse_keyfile(&pkey, argv[2], NULL);
85a98d478eSShuo Chen  // printf("key parse %d\n", ret);
86a98d478eSShuo Chen
87a98d478eSShuo Chen  ssl_context ssl_server;
88a98d478eSShuo Chen  bzero(&ssl_server, sizeof ssl_server);
89a98d478eSShuo Chen  ssl_init(&ssl_server);
90a98d478eSShuo Chen  ssl_set_rng(&ssl_server, ctr_drbg_random, &ctr_drbg);
91a98d478eSShuo Chen  ssl_set_bio(&ssl_server, &net_recv, &clientOut, &net_send, &serverOut);
92a98d478eSShuo Chen  ssl_set_endpoint(&ssl_server, SSL_IS_SERVER);
93a98d478eSShuo Chen  ssl_set_authmode(&ssl_server, SSL_VERIFY_NONE);
94a98d478eSShuo Chen  //ssl_set_ca_chain(&ssl_server, cert.next, NULL, NULL);
95a98d478eSShuo Chen  ssl_set_own_cert(&ssl_server, &cert, &pkey);
96a98d478eSShuo Chen  ecp_group_id curves[] = { POLARSSL_ECP_DP_SECP256R1, POLARSSL_ECP_DP_SECP224K1, POLARSSL_ECP_DP_NONE };
97a98d478eSShuo Chen  ssl_set_curves(&ssl_server, curves);
989acb42f4SShuo Chen  if (useECDHE)
999acb42f4SShuo Chen  {
1009acb42f4SShuo Chen    int ciphersuites[] = { TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 0 };
1019acb42f4SShuo Chen    ssl_set_ciphersuites(&ssl_server, ciphersuites);
1029acb42f4SShuo Chen  }
1039acb42f4SShuo Chen  else
1049acb42f4SShuo Chen  {
1059acb42f4SShuo Chen    int ciphersuites[] = { TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA, 0 };
1069acb42f4SShuo Chen    ssl_set_ciphersuites(&ssl_server, ciphersuites);
1079acb42f4SShuo Chen  }
108a98d478eSShuo Chen
109a98d478eSShuo Chen  double start = now();
1109acb42f4SShuo Chen  Timer tc, ts;
1119acb42f4SShuo Chen  const int N = 500;
112a98d478eSShuo Chen  for (int i = 0; i < N; ++i)
113a98d478eSShuo Chen  {
114a98d478eSShuo Chen    ssl_session_reset(&ssl);
115a98d478eSShuo Chen    ssl_session_reset(&ssl_server);
116a98d478eSShuo Chen    while (true)
117a98d478eSShuo Chen    {
1189acb42f4SShuo Chen      tc.start();
119a98d478eSShuo Chen      int ret = ssl_handshake(&ssl);
1209acb42f4SShuo Chen      tc.stop();
121a98d478eSShuo Chen      //printf("ssl %d\n", ret);
122a98d478eSShuo Chen      if (ret < 0)
123a98d478eSShuo Chen      {
124a98d478eSShuo Chen        if (ret != POLARSSL_ERR_NET_WANT_READ)
125a98d478eSShuo Chen        {
126a98d478eSShuo Chen          char errbuf[512];
127a98d478eSShuo Chen          polarssl_strerror(ret, errbuf, sizeof errbuf);
128a98d478eSShuo Chen          printf("client error %d %s\n", ret, errbuf);
129a98d478eSShuo Chen          break;
130a98d478eSShuo Chen        }
131a98d478eSShuo Chen      }
132a98d478eSShuo Chen      else if (ret == 0 && i == 0)
133a98d478eSShuo Chen      {
134a98d478eSShuo Chen        printf("client done %s %s\n", ssl_get_version(&ssl), ssl_get_ciphersuite(&ssl));
135a98d478eSShuo Chen      }
136a98d478eSShuo Chen
1379acb42f4SShuo Chen      ts.start();
138a98d478eSShuo Chen      int ret2 = ssl_handshake(&ssl_server);
1399acb42f4SShuo Chen      ts.stop();
140a98d478eSShuo Chen      // printf("srv %d\n", ret2);
141a98d478eSShuo Chen      if (ret2 < 0)
142a98d478eSShuo Chen      {
143a98d478eSShuo Chen        if (ret != POLARSSL_ERR_NET_WANT_READ)
144a98d478eSShuo Chen        {
145a98d478eSShuo Chen          char errbuf[512];
146a98d478eSShuo Chen          polarssl_strerror(ret2, errbuf, sizeof errbuf);
147a98d478eSShuo Chen          printf("server error %d %s\n", ret2, errbuf);
148a98d478eSShuo Chen          break;
149a98d478eSShuo Chen        }
150a98d478eSShuo Chen      }
151a98d478eSShuo Chen      else if (ret2 == 0)
152a98d478eSShuo Chen      {
153a98d478eSShuo Chen        // printf("server done %s %s\n", ssl_get_version(&ssl_server), ssl_get_ciphersuite(&ssl_server));
154a98d478eSShuo Chen      }
155a98d478eSShuo Chen
156a98d478eSShuo Chen      if (ret == 0 && ret2 == 0)
157a98d478eSShuo Chen        break;
158a98d478eSShuo Chen    }
159a98d478eSShuo Chen  }
160a98d478eSShuo Chen  double elapsed = now() - start;
161a98d478eSShuo Chen  printf("%.2fs %.1f handshakes/s\n", elapsed, N / elapsed);
1629acb42f4SShuo Chen  printf("client %.3f %.1f\n", tc.seconds(), N / tc.seconds());
1639acb42f4SShuo Chen  printf("server %.3f %.1f\n", ts.seconds(), N / ts.seconds());
1649acb42f4SShuo Chen  printf("server/client %.2f\n", ts.seconds() / tc.seconds());
165a98d478eSShuo Chen
166a98d478eSShuo Chen  double start2 = now();
167a98d478eSShuo Chen  const int M = 200;
1689acb42f4SShuo Chen  unsigned char buf[16384] = { 0 };
169a98d478eSShuo Chen  for (int i = 0; i < M*1024; ++i)
170a98d478eSShuo Chen  {
171a98d478eSShuo Chen    int n = ssl_write(&ssl, buf, 1024);
172a98d478eSShuo Chen    if (n < 0)
173a98d478eSShuo Chen    {
174a98d478eSShuo Chen      char errbuf[512];
175a98d478eSShuo Chen      polarssl_strerror(n, errbuf, sizeof errbuf);
176a98d478eSShuo Chen      printf("%s\n", errbuf);
177a98d478eSShuo Chen    }
178a98d478eSShuo Chen    /*
179a98d478eSShuo Chen    n = ssl_read(&ssl_server, buf, 8192);
180a98d478eSShuo Chen    if (n != 1024)
181a98d478eSShuo Chen      break;
182a98d478eSShuo Chen    if (n < 0)
183a98d478eSShuo Chen    {
184a98d478eSShuo Chen      char errbuf[512];
185a98d478eSShuo Chen      polarssl_strerror(n, errbuf, sizeof errbuf);
186a98d478eSShuo Chen      printf("%s\n", errbuf);
187a98d478eSShuo Chen    }
188a98d478eSShuo Chen    */
189a98d478eSShuo Chen    clientOut.retrieveAll();
190a98d478eSShuo Chen  }
191a98d478eSShuo Chen  elapsed = now() - start2;
192a98d478eSShuo Chen  printf("%.2f %.1f MiB/s\n", elapsed, M / elapsed);
193a98d478eSShuo Chen
194a98d478eSShuo Chen  ssl_free(&ssl);
195a98d478eSShuo Chen  ssl_free(&ssl_server);
196a98d478eSShuo Chen  pk_free(&pkey);
197a98d478eSShuo Chen  x509_crt_free(&cert);
198a98d478eSShuo Chen  entropy_free(&entropy);
199a98d478eSShuo Chen}
200