18d51ab70SShuo Chen#pragma once
28d51ab70SShuo Chen
38d51ab70SShuo Chen#include "Common.h"
48d51ab70SShuo Chen#include "TlsConfig.h"
58d51ab70SShuo Chen
68d51ab70SShuo Chen// Internal class
78d51ab70SShuo Chenclass TlsContext : noncopyable
88d51ab70SShuo Chen{
98d51ab70SShuo Chen public:
108d51ab70SShuo Chen  enum Endpoint { kClient, kServer };
118d51ab70SShuo Chen
128d51ab70SShuo Chen  TlsContext(Endpoint type, TlsConfig* config)
138d51ab70SShuo Chen    : context_(type == kServer ? tls_server() : tls_client())
148d51ab70SShuo Chen  {
158d51ab70SShuo Chen    check(tls_configure(context_, config->get()));
168d51ab70SShuo Chen  }
178d51ab70SShuo Chen
188d51ab70SShuo Chen  TlsContext(TlsContext&& rhs)
198d51ab70SShuo Chen  {
208d51ab70SShuo Chen    swap(rhs);
218d51ab70SShuo Chen  }
228d51ab70SShuo Chen
238d51ab70SShuo Chen  ~TlsContext()
248d51ab70SShuo Chen  {
258d51ab70SShuo Chen    tls_free(context_);
268d51ab70SShuo Chen  }
278d51ab70SShuo Chen
288d51ab70SShuo Chen  TlsContext& operator=(TlsContext rhs)  // ???
298d51ab70SShuo Chen  {
308d51ab70SShuo Chen    swap(rhs);
318d51ab70SShuo Chen    return *this;
328d51ab70SShuo Chen  }
338d51ab70SShuo Chen
348d51ab70SShuo Chen  void swap(TlsContext& rhs)
358d51ab70SShuo Chen  {
368d51ab70SShuo Chen    std::swap(context_, rhs.context_);
378d51ab70SShuo Chen  }
388d51ab70SShuo Chen
397db0aea6SShuo Chen  // void reset(struct tls* ctx) { context_ = ctx; }
407db0aea6SShuo Chen
417db0aea6SShuo Chen  // struct tls* get() { return context_; }
428d51ab70SShuo Chen
437db0aea6SShuo Chen  const char* cipher() { return tls_conn_cipher(context_); }
448d51ab70SShuo Chen
458d51ab70SShuo Chen  // if there is no error, this will segfault.
468d51ab70SShuo Chen  const char* error() { return tls_error(context_); }
478d51ab70SShuo Chen
488d51ab70SShuo Chen  int connect(const char* hostport, const char* servername = nullptr)
498d51ab70SShuo Chen  {
508d51ab70SShuo Chen    return tls_connect_servername(context_, hostport, nullptr, servername);
518d51ab70SShuo Chen  }
528d51ab70SShuo Chen
537db0aea6SShuo Chen  TlsContext accept(int sockfd)
547db0aea6SShuo Chen  {
557db0aea6SShuo Chen    struct tls* conn_ctx = nullptr;
567db0aea6SShuo Chen    check(tls_accept_socket(context_, &conn_ctx, sockfd));
577db0aea6SShuo Chen    return TlsContext(conn_ctx);
587db0aea6SShuo Chen  }
597db0aea6SShuo Chen
608d51ab70SShuo Chen  int handshake()
618d51ab70SShuo Chen  {
628d51ab70SShuo Chen    int ret = -1;
638d51ab70SShuo Chen    do {
648d51ab70SShuo Chen      ret = tls_handshake(context_);
658d51ab70SShuo Chen    } while(ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT);
668d51ab70SShuo Chen    return ret;
678d51ab70SShuo Chen  }
688d51ab70SShuo Chen
6902cc483dSShuo Chen  int read(void* buf, int len)
7002cc483dSShuo Chen  {
7102cc483dSShuo Chen    return tls_read(context_, buf, len);
7202cc483dSShuo Chen  }
7302cc483dSShuo Chen
7402cc483dSShuo Chen  int write(const void* buf, int len)
7502cc483dSShuo Chen  {
7602cc483dSShuo Chen    return tls_write(context_, buf, len);
7702cc483dSShuo Chen  }
7802cc483dSShuo Chen
798d51ab70SShuo Chen private:
807db0aea6SShuo Chen  explicit TlsContext(struct tls* context) : context_(context) {}
817db0aea6SShuo Chen
828d51ab70SShuo Chen  void check(int ret)
838d51ab70SShuo Chen  {
848d51ab70SShuo Chen    if (ret != 0)
858d51ab70SShuo Chen    {
868d51ab70SShuo Chen      LOG_FATAL << tls_error(context_);
878d51ab70SShuo Chen    }
888d51ab70SShuo Chen  }
898d51ab70SShuo Chen
908d51ab70SShuo Chen  struct tls* context_ = nullptr;
918d51ab70SShuo Chen};
92