1#pragma once
2
3#include "Common.h"
4#include "TlsConfig.h"
5
6// Internal class
7class TlsContext : noncopyable
8{
9 public:
10  enum Endpoint { kClient, kServer };
11
12  TlsContext(Endpoint type, TlsConfig* config)
13    : context_(type == kServer ? tls_server() : tls_client())
14  {
15    check(tls_configure(context_, config->get()));
16  }
17
18  TlsContext(TlsContext&& rhs)
19  {
20    swap(rhs);
21  }
22
23  ~TlsContext()
24  {
25    tls_free(context_);
26  }
27
28  TlsContext& operator=(TlsContext rhs)  // ???
29  {
30    swap(rhs);
31    return *this;
32  }
33
34  void swap(TlsContext& rhs)
35  {
36    std::swap(context_, rhs.context_);
37  }
38
39  // void reset(struct tls* ctx) { context_ = ctx; }
40
41  // struct tls* get() { return context_; }
42
43  const char* cipher() { return tls_conn_cipher(context_); }
44
45  // if there is no error, this will segfault.
46  const char* error() { return tls_error(context_); }
47
48  int connect(const char* hostport, const char* servername = nullptr)
49  {
50    return tls_connect_servername(context_, hostport, nullptr, servername);
51  }
52
53  TlsContext accept(int sockfd)
54  {
55    struct tls* conn_ctx = nullptr;
56    check(tls_accept_socket(context_, &conn_ctx, sockfd));
57    return TlsContext(conn_ctx);
58  }
59
60  int handshake()
61  {
62    int ret = -1;
63    do {
64      ret = tls_handshake(context_);
65    } while(ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT);
66    return ret;
67  }
68
69  int read(void* buf, int len)
70  {
71    return tls_read(context_, buf, len);
72  }
73
74  int write(const void* buf, int len)
75  {
76    return tls_write(context_, buf, len);
77  }
78
79 private:
80  explicit TlsContext(struct tls* context) : context_(context) {}
81
82  void check(int ret)
83  {
84    if (ret != 0)
85    {
86      LOG_FATAL << tls_error(context_);
87    }
88  }
89
90  struct tls* context_ = nullptr;
91};
92