1116e48deSShuo Chen#include <openssl/aes.h>
2116e48deSShuo Chen#include <openssl/conf.h>
3116e48deSShuo Chen#include <openssl/err.h>
4116e48deSShuo Chen#include <openssl/ssl.h>
5116e48deSShuo Chen
6116e48deSShuo Chen#include <malloc.h>
7116e48deSShuo Chen#include <mcheck.h>
8116e48deSShuo Chen#include <stdio.h>
9116e48deSShuo Chen
10116e48deSShuo Chen#include "timer.h"
11116e48deSShuo Chen
123af4c543SShuo Chen#include <string>
133af4c543SShuo Chen#include <vector>
143af4c543SShuo Chen
153af4c543SShuo Chenvoid (*old_free_hook) (void *__ptr, const void *);
163af4c543SShuo Chenvoid *(*old_malloc_hook)(size_t __size, const void *);
17116e48deSShuo Chen
18116e48deSShuo Chenvoid my_free_hook (void*, const void *);
19116e48deSShuo Chen
20116e48deSShuo Chenvoid* my_malloc_hook(size_t size, const void* caller)
21116e48deSShuo Chen{
22116e48deSShuo Chen  void *result;
23116e48deSShuo Chen  /* Restore all old hooks */
24116e48deSShuo Chen  __malloc_hook = old_malloc_hook;
25116e48deSShuo Chen  __free_hook = old_free_hook;
26116e48deSShuo Chen  /* Call recursively */
27116e48deSShuo Chen  result = malloc (size);
28116e48deSShuo Chen  /* Save underlying hooks */
29116e48deSShuo Chen  old_malloc_hook = __malloc_hook;
30116e48deSShuo Chen  old_free_hook = __free_hook;
31116e48deSShuo Chen  /* printf might call malloc, so protect it too. */
32116e48deSShuo Chen  printf ("%p malloc (%u) returns %p\n", caller, (unsigned int) size, result);
33116e48deSShuo Chen  /* Restore our own hooks */
34116e48deSShuo Chen  __malloc_hook = my_malloc_hook;
35116e48deSShuo Chen  __free_hook = my_free_hook;
36116e48deSShuo Chen  return result;
37116e48deSShuo Chen}
38116e48deSShuo Chen
39116e48deSShuo Chenvoid my_free_hook (void *ptr, const void *caller)
40116e48deSShuo Chen{
41116e48deSShuo Chen  if (!ptr) return;
42116e48deSShuo Chen  /* Restore all old hooks */
43116e48deSShuo Chen  __malloc_hook = old_malloc_hook;
44116e48deSShuo Chen  __free_hook = old_free_hook;
45116e48deSShuo Chen  /* Call recursively */
46116e48deSShuo Chen  free (ptr);
47116e48deSShuo Chen  /* Save underlying hooks */
48116e48deSShuo Chen  old_malloc_hook = __malloc_hook;
49116e48deSShuo Chen  old_free_hook = __free_hook;
50116e48deSShuo Chen  /* printf might call free, so protect it too. */
51116e48deSShuo Chen  printf ("freed %p\n", ptr);
52116e48deSShuo Chen  /* Restore our own hooks */
53116e48deSShuo Chen  __malloc_hook = my_malloc_hook;
54116e48deSShuo Chen  __free_hook = my_free_hook;
55116e48deSShuo Chen}
56116e48deSShuo Chen
57116e48deSShuo Chenvoid init_hook()
58116e48deSShuo Chen{
59116e48deSShuo Chen  old_malloc_hook = __malloc_hook;
60116e48deSShuo Chen  old_free_hook = __free_hook;
61116e48deSShuo Chen  __malloc_hook = my_malloc_hook;
62116e48deSShuo Chen  __free_hook = my_free_hook;
63116e48deSShuo Chen}
64116e48deSShuo Chen
65116e48deSShuo Chenint main(int argc, char* argv[])
66116e48deSShuo Chen{
67116e48deSShuo Chen  SSL_load_error_strings();
68116e48deSShuo Chen  ERR_load_BIO_strings();
69116e48deSShuo Chen  SSL_library_init();
70116e48deSShuo Chen  OPENSSL_config(NULL);
71116e48deSShuo Chen
72116e48deSShuo Chen  SSL_CTX* ctx = SSL_CTX_new(TLSv1_2_server_method());
733af4c543SShuo Chen  SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION);
74116e48deSShuo Chen
75116e48deSShuo Chen  EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
76116e48deSShuo Chen  SSL_CTX_set_options(ctx, SSL_OP_SINGLE_ECDH_USE);
77116e48deSShuo Chen  SSL_CTX_set_tmp_ecdh(ctx, ecdh);
78116e48deSShuo Chen  EC_KEY_free(ecdh);
79116e48deSShuo Chen
80116e48deSShuo Chen  const char* CertFile = "server.pem";  // argv[1];
81116e48deSShuo Chen  const char* KeyFile = "server.pem";  // argv[2];
82116e48deSShuo Chen  SSL_CTX_use_certificate_file(ctx, CertFile, SSL_FILETYPE_PEM);
83116e48deSShuo Chen  SSL_CTX_use_PrivateKey_file(ctx, KeyFile, SSL_FILETYPE_PEM);
84116e48deSShuo Chen  if (!SSL_CTX_check_private_key(ctx))
85116e48deSShuo Chen    abort();
86116e48deSShuo Chen
87116e48deSShuo Chen  SSL_CTX* ctx_client = SSL_CTX_new(TLSv1_2_client_method());
88116e48deSShuo Chen
89116e48deSShuo Chen  init_hook();
90116e48deSShuo Chen
913af4c543SShuo Chen  const int N = 10;
92116e48deSShuo Chen  SSL *ssl, *ssl_client;
933af4c543SShuo Chen  std::vector<SSL*> ssls;
94116e48deSShuo Chen  for (int i = 0; i < N; ++i)
95116e48deSShuo Chen  {
96116e48deSShuo Chen    printf("=============================================== BIO_new_bio_pair %d\n", i);
97116e48deSShuo Chen    BIO *client, *server;
98116e48deSShuo Chen    BIO_new_bio_pair(&client, 0, &server, 0);
99116e48deSShuo Chen
100116e48deSShuo Chen    printf("=============================================== SSL_new server %d\n", i);
101116e48deSShuo Chen    ssl = SSL_new (ctx);
102116e48deSShuo Chen    printf("=============================================== SSL_new client %d\n", i);
103116e48deSShuo Chen    ssl_client = SSL_new (ctx_client);
104116e48deSShuo Chen    SSL_set_bio(ssl, server, server);
105116e48deSShuo Chen    SSL_set_bio(ssl_client, client, client);
106116e48deSShuo Chen
107116e48deSShuo Chen    printf("=============================================== SSL_connect client %d\n", i);
108116e48deSShuo Chen    int ret = SSL_connect(ssl_client);
109116e48deSShuo Chen    printf("=============================================== SSL_accept server %d\n", i);
110116e48deSShuo Chen    int ret2 = SSL_accept(ssl);
111116e48deSShuo Chen
112116e48deSShuo Chen    while (true)
113116e48deSShuo Chen    {
114116e48deSShuo Chen      printf("=============================================== SSL_handshake client %d\n", i);
115116e48deSShuo Chen      ret = SSL_do_handshake(ssl_client);
116116e48deSShuo Chen      printf("=============================================== SSL_handshake server %d\n", i);
117116e48deSShuo Chen      ret2 = SSL_do_handshake(ssl);
118116e48deSShuo Chen      if (ret == 1 && ret2 == 1)
119116e48deSShuo Chen        break;
120116e48deSShuo Chen    }
121116e48deSShuo Chen
122116e48deSShuo Chen    if (i == 0)
123116e48deSShuo Chen      printf ("SSL connection using %s %s\n", SSL_get_version(ssl_client), SSL_get_cipher (ssl_client));
1243af4c543SShuo Chen    /*
125116e48deSShuo Chen    if (i != N-1)
126116e48deSShuo Chen    {
127116e48deSShuo Chen      printf("=============================================== SSL_free server %d\n", i);
128116e48deSShuo Chen      SSL_free (ssl);
129116e48deSShuo Chen      printf("=============================================== SSL_free client %d\n", i);
130116e48deSShuo Chen      SSL_free (ssl_client);
131116e48deSShuo Chen    }
1323af4c543SShuo Chen    else
1333af4c543SShuo Chen    */
1343af4c543SShuo Chen    {
1353af4c543SShuo Chen      ssls.push_back(ssl);
1363af4c543SShuo Chen      ssls.push_back(ssl_client);
1373af4c543SShuo Chen    }
138116e48deSShuo Chen  }
139116e48deSShuo Chen
140116e48deSShuo Chen  printf("=============================================== data \n");
141116e48deSShuo Chen
142116e48deSShuo Chen  double start2 = now();
143116e48deSShuo Chen  const int M = 300;
144116e48deSShuo Chen  char buf[1024] = { 0 };
145116e48deSShuo Chen  for (int i = 0; i < M*1024; ++i)
146116e48deSShuo Chen  {
147116e48deSShuo Chen    int nw = SSL_write(ssl_client, buf, sizeof buf);
148116e48deSShuo Chen    if (nw != sizeof buf)
149116e48deSShuo Chen    {
150116e48deSShuo Chen      printf("nw = %d\n", nw);
151116e48deSShuo Chen    }
152116e48deSShuo Chen    int nr = SSL_read(ssl, buf, sizeof buf);
153116e48deSShuo Chen    if (nr != sizeof buf)
154116e48deSShuo Chen    {
155116e48deSShuo Chen      printf("nr = %d\n", nr);
156116e48deSShuo Chen    }
157116e48deSShuo Chen  }
158116e48deSShuo Chen  double elapsed = now() - start2;
159116e48deSShuo Chen  printf("%.2f %.1f MiB/s\n", elapsed, M / elapsed);
1603af4c543SShuo Chen  printf("=============================================== SSL_free\n");
1613af4c543SShuo Chen  for (int i = 0; i < ssls.size(); ++i)
1623af4c543SShuo Chen    SSL_free(ssls[i]);
163116e48deSShuo Chen
1643af4c543SShuo Chen  printf("=============================================== SSL_CTX_free\n");
165116e48deSShuo Chen  SSL_CTX_free (ctx);
166116e48deSShuo Chen  SSL_CTX_free (ctx_client);
167116e48deSShuo Chen  // OPENSSL_cleanup();  // only in 1.1.0
1683af4c543SShuo Chen  printf("=============================================== end\n");
169116e48deSShuo Chen}
170