md5_server.c revision 9a690580
1/* Copyright (c) 2017 - 2020 LiteSpeed Technologies Inc.  See LICENSE. */
2/*
3 * md5_server.c -- Read one or more streams from the client and return
4 *                 MD5 sum of the payload.
5 */
6
7#include <assert.h>
8#include <signal.h>
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12#include <sys/queue.h>
13#include <time.h>
14#include <unistd.h>
15
16#include <openssl/md5.h>
17
18#include <event2/event.h>
19
20#include "lsquic.h"
21#include "test_common.h"
22#include "prog.h"
23
24#include "../src/liblsquic/lsquic_logger.h"
25
26
27static int g_really_calculate_md5 = 1;
28
29/* Turn on to test whether stream reset is being sent when stream is closed
30 * prematurely.
31 */
32static struct {
33    unsigned        stream_id;
34    unsigned long   limit;
35    unsigned long   n_read;
36} g_premature_close;
37
38struct lsquic_conn_ctx;
39
40struct server_ctx {
41    TAILQ_HEAD(, lsquic_conn_ctx)   conn_ctxs;
42    unsigned max_reqs;
43    int n_conn;
44    time_t expiry;
45    struct sport_head sports;
46    struct prog *prog;
47};
48
49struct lsquic_conn_ctx {
50    TAILQ_ENTRY(lsquic_conn_ctx)    next_connh;
51    lsquic_conn_t       *conn;
52    unsigned             n_reqs, n_closed;
53    struct server_ctx   *server_ctx;
54};
55
56
57static lsquic_conn_ctx_t *
58server_on_new_conn (void *stream_if_ctx, lsquic_conn_t *conn)
59{
60    struct server_ctx *server_ctx = stream_if_ctx;
61    lsquic_conn_ctx_t *conn_h = calloc(1, sizeof(*conn_h));
62    conn_h->conn = conn;
63    conn_h->server_ctx = server_ctx;
64    TAILQ_INSERT_TAIL(&server_ctx->conn_ctxs, conn_h, next_connh);
65    LSQ_NOTICE("New connection!");
66    print_conn_info(conn);
67    return conn_h;
68}
69
70
71static void
72server_on_conn_closed (lsquic_conn_t *conn)
73{
74    lsquic_conn_ctx_t *conn_h = lsquic_conn_get_ctx(conn);
75    int stopped;
76
77    if (conn_h->server_ctx->expiry && conn_h->server_ctx->expiry < time(NULL))
78    {
79        LSQ_NOTICE("reached engine expiration time, shut down");
80        prog_stop(conn_h->server_ctx->prog);
81        stopped = 1;
82    }
83    else
84        stopped = 0;
85
86    if (conn_h->server_ctx->n_conn)
87    {
88        --conn_h->server_ctx->n_conn;
89        LSQ_NOTICE("Connection closed, remaining: %d", conn_h->server_ctx->n_conn);
90        if (0 == conn_h->server_ctx->n_conn && !stopped)
91            prog_stop(conn_h->server_ctx->prog);
92    }
93    else
94        LSQ_NOTICE("Connection closed");
95    TAILQ_REMOVE(&conn_h->server_ctx->conn_ctxs, conn_h, next_connh);
96    free(conn_h);
97}
98
99
100struct lsquic_stream_ctx {
101    lsquic_stream_t     *stream;
102    struct server_ctx   *server_ctx;
103    MD5_CTX              md5ctx;
104    unsigned char        md5sum[MD5_DIGEST_LENGTH];
105    char                 md5str[MD5_DIGEST_LENGTH * 2 + 1];
106};
107
108
109static struct lsquic_conn_ctx *
110find_conn_h (const struct server_ctx *server_ctx, lsquic_stream_t *stream)
111{
112    struct lsquic_conn_ctx *conn_h;
113    lsquic_conn_t *conn;
114
115    conn = lsquic_stream_conn(stream);
116    TAILQ_FOREACH(conn_h, &server_ctx->conn_ctxs, next_connh)
117        if (conn_h->conn == conn)
118            return conn_h;
119    return NULL;
120}
121
122
123static lsquic_stream_ctx_t *
124server_md5_on_new_stream (void *stream_if_ctx, lsquic_stream_t *stream)
125{
126    struct lsquic_conn_ctx *conn_h;
127    lsquic_stream_ctx_t *st_h = malloc(sizeof(*st_h));
128    st_h->stream = stream;
129    st_h->server_ctx = stream_if_ctx;
130    lsquic_stream_wantread(stream, 1);
131    if (g_really_calculate_md5)
132        MD5_Init(&st_h->md5ctx);
133    conn_h = find_conn_h(st_h->server_ctx, stream);
134    assert(conn_h);
135    conn_h->n_reqs++;
136    LSQ_NOTICE("request #%u", conn_h->n_reqs);
137    if (st_h->server_ctx->max_reqs &&
138        conn_h->n_reqs >= st_h->server_ctx->max_reqs)
139    {
140        /* The assert guards the assumption that after the we mark the
141         * connection as going away, no new streams are opened and thus
142         * this callback is not called.
143         */
144        assert(conn_h->n_reqs == st_h->server_ctx->max_reqs);
145        LSQ_NOTICE("reached maximum requests: %u, going away",
146            st_h->server_ctx->max_reqs);
147        lsquic_conn_going_away(conn_h->conn);
148    }
149    return st_h;
150}
151
152
153static void
154server_md5_on_read (lsquic_stream_t *stream, lsquic_stream_ctx_t *st_h)
155{
156    char buf[0x1000];
157    ssize_t nr;
158
159    nr = lsquic_stream_read(stream, buf, sizeof(buf));
160    if (-1 == nr)
161    {
162        /* This should never return an error if we only call read() once
163         * per callback.
164         */
165        perror("lsquic_stream_read");
166        lsquic_stream_shutdown(stream, 0);
167        return;
168    }
169
170    if (g_premature_close.limit &&
171                g_premature_close.stream_id == lsquic_stream_id(stream))
172    {
173        g_premature_close.n_read += nr;
174        if (g_premature_close.n_read > g_premature_close.limit)
175        {
176            LSQ_WARN("Done after reading %lu bytes", g_premature_close.n_read);
177            lsquic_stream_shutdown(stream, 0);
178            return;
179        }
180    }
181
182    if (nr)
183    {
184        if (g_really_calculate_md5)
185            MD5_Update(&st_h->md5ctx, buf, nr);
186    }
187    else
188    {
189        lsquic_stream_wantread(stream, 0);
190        if (g_really_calculate_md5)
191        {
192            MD5_Final(st_h->md5sum, &st_h->md5ctx);
193            snprintf(st_h->md5str, sizeof(st_h->md5str),
194                "%02x%02x%02x%02x%02x%02x%02x%02x"
195                "%02x%02x%02x%02x%02x%02x%02x%02x"
196                , st_h->md5sum[0]
197                , st_h->md5sum[1]
198                , st_h->md5sum[2]
199                , st_h->md5sum[3]
200                , st_h->md5sum[4]
201                , st_h->md5sum[5]
202                , st_h->md5sum[6]
203                , st_h->md5sum[7]
204                , st_h->md5sum[8]
205                , st_h->md5sum[9]
206                , st_h->md5sum[10]
207                , st_h->md5sum[11]
208                , st_h->md5sum[12]
209                , st_h->md5sum[13]
210                , st_h->md5sum[14]
211                , st_h->md5sum[15]
212            );
213        }
214        else
215        {
216            memset(st_h->md5str, '0', sizeof(st_h->md5str) - 1);
217            st_h->md5str[sizeof(st_h->md5str) - 1] = '\0';
218        }
219        lsquic_stream_wantwrite(stream, 1);
220        lsquic_stream_shutdown(stream, 0);
221    }
222}
223
224
225static void
226server_md5_on_write (lsquic_stream_t *stream, lsquic_stream_ctx_t *st_h)
227{
228    ssize_t nw;
229    nw = lsquic_stream_write(stream, st_h->md5str, sizeof(st_h->md5str) - 1);
230    if (-1 == nw)
231    {
232        perror("lsquic_stream_write");
233        return;
234    }
235    lsquic_stream_wantwrite(stream, 0);
236    lsquic_stream_shutdown(stream, 1);
237}
238
239
240static void
241server_on_close (lsquic_stream_t *stream, lsquic_stream_ctx_t *st_h)
242{
243    struct lsquic_conn_ctx *conn_h;
244    LSQ_NOTICE("%s called", __func__);
245    conn_h = find_conn_h(st_h->server_ctx, stream);
246    conn_h->n_closed++;
247    if (st_h->server_ctx->max_reqs &&
248        conn_h->n_closed >= st_h->server_ctx->max_reqs)
249    {
250        assert(conn_h->n_closed == st_h->server_ctx->max_reqs);
251        LSQ_NOTICE("closing connection after completing %u requests",
252            conn_h->n_closed);
253        lsquic_conn_close(conn_h->conn);
254    }
255    free(st_h);
256}
257
258
259const struct lsquic_stream_if server_md5_stream_if = {
260    .on_new_conn            = server_on_new_conn,
261    .on_conn_closed         = server_on_conn_closed,
262    .on_new_stream          = server_md5_on_new_stream,
263    .on_read                = server_md5_on_read,
264    .on_write               = server_md5_on_write,
265    .on_close               = server_on_close,
266};
267
268
269static void
270usage (const char *prog)
271{
272    const char *const slash = strrchr(prog, '/');
273    if (slash)
274        prog = slash + 1;
275    printf(
276"Usage: %s [opts]\n"
277"\n"
278"Options:\n"
279"   -e EXPIRY   Stop engine after this many seconds.  The expiration is\n"
280"                 checked when connections are closed.\n"
281                , prog);
282}
283
284
285int
286main (int argc, char **argv)
287{
288    int opt, s;
289    struct prog prog;
290    struct server_ctx server_ctx;
291
292    memset(&server_ctx, 0, sizeof(server_ctx));
293    TAILQ_INIT(&server_ctx.conn_ctxs);
294    server_ctx.prog = &prog;
295    TAILQ_INIT(&server_ctx.sports);
296    prog_init(&prog, LSENG_SERVER, &server_ctx.sports,
297                                    &server_md5_stream_if, &server_ctx);
298
299    while (-1 != (opt = getopt(argc, argv, PROG_OPTS "hr:Fn:e:p:")))
300    {
301        switch (opt) {
302        case 'F':
303            g_really_calculate_md5 = 0;
304            break;
305        case 'p':
306            g_premature_close.stream_id = atoi(optarg);
307            g_premature_close.limit = atoi(strchr(optarg, ':') + 1);
308            break;
309        case 'r':
310            server_ctx.max_reqs = atoi(optarg);
311            break;
312        case 'e':
313            server_ctx.expiry = time(NULL) + atoi(optarg);
314            break;
315        case 'n':
316            server_ctx.n_conn = atoi(optarg);
317            break;
318        case 'h':
319            usage(argv[0]);
320            prog_print_common_options(&prog, stdout);
321            exit(0);
322        default:
323            if (0 != prog_set_opt(&prog, opt, optarg))
324                exit(1);
325        }
326    }
327
328    if (0 != prog_prep(&prog))
329    {
330        LSQ_ERROR("could not prep");
331        exit(EXIT_FAILURE);
332    }
333
334    LSQ_DEBUG("entering event loop");
335
336    s = prog_run(&prog);
337    prog_cleanup(&prog);
338
339    exit(0 == s ? EXIT_SUCCESS : EXIT_FAILURE);
340}
341