1/* Copyright (c) 2017 - 2022 LiteSpeed Technologies Inc.  See LICENSE. */
2/*
3 * md5_server.c -- Read one or more streams from the client and return
4 *                 MD5 sum of the payload.
5 */
6
7#include <assert.h>
8#include <signal.h>
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12#include <sys/queue.h>
13#include <time.h>
14#include <unistd.h>
15
16#include <openssl/md5.h>
17
18#include <event2/event.h>
19
20#include "lsquic.h"
21#include "test_common.h"
22#include "../src/liblsquic/lsquic_hash.h"
23#include "test_cert.h"
24#include "prog.h"
25
26#include "../src/liblsquic/lsquic_logger.h"
27
28
29static int g_really_calculate_md5 = 1;
30
31/* Turn on to test whether stream reset is being sent when stream is closed
32 * prematurely.
33 */
34static struct {
35    unsigned        stream_id;
36    unsigned long   limit;
37    unsigned long   n_read;
38} g_premature_close;
39
40struct lsquic_conn_ctx;
41
42struct server_ctx {
43    TAILQ_HEAD(, lsquic_conn_ctx)   conn_ctxs;
44    unsigned max_reqs;
45    int n_conn;
46    time_t expiry;
47    struct sport_head sports;
48    struct prog *prog;
49};
50
51struct lsquic_conn_ctx {
52    TAILQ_ENTRY(lsquic_conn_ctx)    next_connh;
53    lsquic_conn_t       *conn;
54    unsigned             n_reqs, n_closed;
55    struct server_ctx   *server_ctx;
56};
57
58
59static lsquic_conn_ctx_t *
60server_on_new_conn (void *stream_if_ctx, lsquic_conn_t *conn)
61{
62    struct server_ctx *server_ctx = stream_if_ctx;
63    lsquic_conn_ctx_t *conn_h = calloc(1, sizeof(*conn_h));
64    conn_h->conn = conn;
65    conn_h->server_ctx = server_ctx;
66    TAILQ_INSERT_TAIL(&server_ctx->conn_ctxs, conn_h, next_connh);
67    LSQ_NOTICE("New connection!");
68    print_conn_info(conn);
69    return conn_h;
70}
71
72
73static void
74server_on_conn_closed (lsquic_conn_t *conn)
75{
76    lsquic_conn_ctx_t *conn_h = lsquic_conn_get_ctx(conn);
77    int stopped;
78
79    if (conn_h->server_ctx->expiry && conn_h->server_ctx->expiry < time(NULL))
80    {
81        LSQ_NOTICE("reached engine expiration time, shut down");
82        prog_stop(conn_h->server_ctx->prog);
83        stopped = 1;
84    }
85    else
86        stopped = 0;
87
88    if (conn_h->server_ctx->n_conn)
89    {
90        --conn_h->server_ctx->n_conn;
91        LSQ_NOTICE("Connection closed, remaining: %d", conn_h->server_ctx->n_conn);
92        if (0 == conn_h->server_ctx->n_conn && !stopped)
93            prog_stop(conn_h->server_ctx->prog);
94    }
95    else
96        LSQ_NOTICE("Connection closed");
97    TAILQ_REMOVE(&conn_h->server_ctx->conn_ctxs, conn_h, next_connh);
98    free(conn_h);
99}
100
101
102struct lsquic_stream_ctx {
103    lsquic_stream_t     *stream;
104    struct server_ctx   *server_ctx;
105    MD5_CTX              md5ctx;
106    unsigned char        md5sum[MD5_DIGEST_LENGTH];
107    char                 md5str[MD5_DIGEST_LENGTH * 2 + 1];
108};
109
110
111static struct lsquic_conn_ctx *
112find_conn_h (const struct server_ctx *server_ctx, lsquic_stream_t *stream)
113{
114    struct lsquic_conn_ctx *conn_h;
115    lsquic_conn_t *conn;
116
117    conn = lsquic_stream_conn(stream);
118    TAILQ_FOREACH(conn_h, &server_ctx->conn_ctxs, next_connh)
119        if (conn_h->conn == conn)
120            return conn_h;
121    return NULL;
122}
123
124
125static lsquic_stream_ctx_t *
126server_md5_on_new_stream (void *stream_if_ctx, lsquic_stream_t *stream)
127{
128    struct lsquic_conn_ctx *conn_h;
129    lsquic_stream_ctx_t *st_h = malloc(sizeof(*st_h));
130    st_h->stream = stream;
131    st_h->server_ctx = stream_if_ctx;
132    lsquic_stream_wantread(stream, 1);
133    if (g_really_calculate_md5)
134        MD5_Init(&st_h->md5ctx);
135    conn_h = find_conn_h(st_h->server_ctx, stream);
136    assert(conn_h);
137    conn_h->n_reqs++;
138    LSQ_NOTICE("request #%u", conn_h->n_reqs);
139    if (st_h->server_ctx->max_reqs &&
140        conn_h->n_reqs >= st_h->server_ctx->max_reqs)
141    {
142        /* The assert guards the assumption that after the we mark the
143         * connection as going away, no new streams are opened and thus
144         * this callback is not called.
145         */
146        assert(conn_h->n_reqs == st_h->server_ctx->max_reqs);
147        LSQ_NOTICE("reached maximum requests: %u, going away",
148            st_h->server_ctx->max_reqs);
149        lsquic_conn_going_away(conn_h->conn);
150    }
151    return st_h;
152}
153
154
155static void
156server_md5_on_read (lsquic_stream_t *stream, lsquic_stream_ctx_t *st_h)
157{
158    char buf[0x1000];
159    ssize_t nr;
160
161    nr = lsquic_stream_read(stream, buf, sizeof(buf));
162    if (-1 == nr)
163    {
164        /* This should never return an error if we only call read() once
165         * per callback.
166         */
167        perror("lsquic_stream_read");
168        lsquic_stream_shutdown(stream, 0);
169        return;
170    }
171
172    if (g_premature_close.limit &&
173                g_premature_close.stream_id == lsquic_stream_id(stream))
174    {
175        g_premature_close.n_read += nr;
176        if (g_premature_close.n_read > g_premature_close.limit)
177        {
178            LSQ_WARN("Done after reading %lu bytes", g_premature_close.n_read);
179            lsquic_stream_shutdown(stream, 0);
180            return;
181        }
182    }
183
184    if (nr)
185    {
186        if (g_really_calculate_md5)
187            MD5_Update(&st_h->md5ctx, buf, nr);
188    }
189    else
190    {
191        lsquic_stream_wantread(stream, 0);
192        if (g_really_calculate_md5)
193        {
194            MD5_Final(st_h->md5sum, &st_h->md5ctx);
195            snprintf(st_h->md5str, sizeof(st_h->md5str),
196                "%02x%02x%02x%02x%02x%02x%02x%02x"
197                "%02x%02x%02x%02x%02x%02x%02x%02x"
198                , st_h->md5sum[0]
199                , st_h->md5sum[1]
200                , st_h->md5sum[2]
201                , st_h->md5sum[3]
202                , st_h->md5sum[4]
203                , st_h->md5sum[5]
204                , st_h->md5sum[6]
205                , st_h->md5sum[7]
206                , st_h->md5sum[8]
207                , st_h->md5sum[9]
208                , st_h->md5sum[10]
209                , st_h->md5sum[11]
210                , st_h->md5sum[12]
211                , st_h->md5sum[13]
212                , st_h->md5sum[14]
213                , st_h->md5sum[15]
214            );
215        }
216        else
217        {
218            memset(st_h->md5str, '0', sizeof(st_h->md5str) - 1);
219            st_h->md5str[sizeof(st_h->md5str) - 1] = '\0';
220        }
221        lsquic_stream_wantwrite(stream, 1);
222        lsquic_stream_shutdown(stream, 0);
223    }
224}
225
226
227static void
228server_md5_on_write (lsquic_stream_t *stream, lsquic_stream_ctx_t *st_h)
229{
230    ssize_t nw;
231    nw = lsquic_stream_write(stream, st_h->md5str, sizeof(st_h->md5str) - 1);
232    if (-1 == nw)
233    {
234        perror("lsquic_stream_write");
235        return;
236    }
237    lsquic_stream_wantwrite(stream, 0);
238    lsquic_stream_shutdown(stream, 1);
239}
240
241
242static void
243server_on_close (lsquic_stream_t *stream, lsquic_stream_ctx_t *st_h)
244{
245    struct lsquic_conn_ctx *conn_h;
246    LSQ_NOTICE("%s called", __func__);
247    conn_h = find_conn_h(st_h->server_ctx, stream);
248    conn_h->n_closed++;
249    if (st_h->server_ctx->max_reqs &&
250        conn_h->n_closed >= st_h->server_ctx->max_reqs)
251    {
252        assert(conn_h->n_closed == st_h->server_ctx->max_reqs);
253        LSQ_NOTICE("closing connection after completing %u requests",
254            conn_h->n_closed);
255        lsquic_conn_close(conn_h->conn);
256    }
257    free(st_h);
258}
259
260
261const struct lsquic_stream_if server_md5_stream_if = {
262    .on_new_conn            = server_on_new_conn,
263    .on_conn_closed         = server_on_conn_closed,
264    .on_new_stream          = server_md5_on_new_stream,
265    .on_read                = server_md5_on_read,
266    .on_write               = server_md5_on_write,
267    .on_close               = server_on_close,
268};
269
270
271static void
272usage (const char *prog)
273{
274    const char *const slash = strrchr(prog, '/');
275    if (slash)
276        prog = slash + 1;
277    printf(
278"Usage: %s [opts]\n"
279"\n"
280"Options:\n"
281"   -e EXPIRY   Stop engine after this many seconds.  The expiration is\n"
282"                 checked when connections are closed.\n"
283                , prog);
284}
285
286
287int
288main (int argc, char **argv)
289{
290    int opt, s;
291    struct prog prog;
292    struct server_ctx server_ctx;
293
294    memset(&server_ctx, 0, sizeof(server_ctx));
295    TAILQ_INIT(&server_ctx.conn_ctxs);
296    server_ctx.prog = &prog;
297    TAILQ_INIT(&server_ctx.sports);
298    prog_init(&prog, LSENG_SERVER, &server_ctx.sports,
299                                    &server_md5_stream_if, &server_ctx);
300
301    while (-1 != (opt = getopt(argc, argv, PROG_OPTS "hr:Fn:e:p:")))
302    {
303        switch (opt) {
304        case 'F':
305            g_really_calculate_md5 = 0;
306            break;
307        case 'p':
308            g_premature_close.stream_id = atoi(optarg);
309            g_premature_close.limit = atoi(strchr(optarg, ':') + 1);
310            break;
311        case 'r':
312            server_ctx.max_reqs = atoi(optarg);
313            break;
314        case 'e':
315            server_ctx.expiry = time(NULL) + atoi(optarg);
316            break;
317        case 'n':
318            server_ctx.n_conn = atoi(optarg);
319            break;
320        case 'h':
321            usage(argv[0]);
322            prog_print_common_options(&prog, stdout);
323            exit(0);
324        default:
325            if (0 != prog_set_opt(&prog, opt, optarg))
326                exit(1);
327        }
328    }
329
330    add_alpn("md5");
331    if (0 != prog_prep(&prog))
332    {
333        LSQ_ERROR("could not prep");
334        exit(EXIT_FAILURE);
335    }
336
337    LSQ_DEBUG("entering event loop");
338
339    s = prog_run(&prog);
340    prog_cleanup(&prog);
341
342    exit(0 == s ? EXIT_SUCCESS : EXIT_FAILURE);
343}
344