1#pragma once 2 3#include "Common.h" 4#include "TlsConfig.h" 5 6// Internal class 7class TlsContext : noncopyable 8{ 9 public: 10 enum Endpoint { kClient, kServer }; 11 12 TlsContext(Endpoint type, TlsConfig* config) 13 : context_(type == kServer ? tls_server() : tls_client()) 14 { 15 check(tls_configure(context_, config->get())); 16 } 17 18 TlsContext(TlsContext&& rhs) 19 { 20 swap(rhs); 21 } 22 23 ~TlsContext() 24 { 25 tls_free(context_); 26 } 27 28 TlsContext& operator=(TlsContext rhs) // ??? 29 { 30 swap(rhs); 31 return *this; 32 } 33 34 void swap(TlsContext& rhs) 35 { 36 std::swap(context_, rhs.context_); 37 } 38 39 // void reset(struct tls* ctx) { context_ = ctx; } 40 41 // struct tls* get() { return context_; } 42 43 const char* cipher() { return tls_conn_cipher(context_); } 44 45 // if there is no error, this will segfault. 46 const char* error() { return tls_error(context_); } 47 48 int connect(const char* hostport, const char* servername = nullptr) 49 { 50 return tls_connect_servername(context_, hostport, nullptr, servername); 51 } 52 53 TlsContext accept(int sockfd) 54 { 55 struct tls* conn_ctx = nullptr; 56 check(tls_accept_socket(context_, &conn_ctx, sockfd)); 57 return TlsContext(conn_ctx); 58 } 59 60 int handshake() 61 { 62 int ret = -1; 63 do { 64 ret = tls_handshake(context_); 65 } while(ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT); 66 return ret; 67 } 68 69 int read(void* buf, int len) 70 { 71 return tls_read(context_, buf, len); 72 } 73 74 int write(const void* buf, int len) 75 { 76 return tls_write(context_, buf, len); 77 } 78 79 private: 80 explicit TlsContext(struct tls* context) : context_(context) {} 81 82 void check(int ret) 83 { 84 if (ret != 0) 85 { 86 LOG_FATAL << tls_error(context_); 87 } 88 } 89 90 struct tls* context_ = nullptr; 91}; 92