1#include "faketcp.h"
2
3#include <stdio.h>
4#include <stdlib.h>
5#include <string.h>
6#include <unistd.h>
7#include <netinet/ip.h>
8#include <netinet/tcp.h>
9#include <linux/if_ether.h>
10
11void tcp_input(int fd, const void* input, const void* payload, int tot_len)
12{
13  const struct iphdr* iphdr = static_cast<const struct iphdr*>(input);
14  const struct tcphdr* tcphdr = static_cast<const struct tcphdr*>(payload);
15  const int iphdr_len = iphdr->ihl*4;
16  const int tcp_seg_len = tot_len - iphdr_len;
17  const int tcphdr_size = sizeof(*tcphdr);
18  if (tcp_seg_len >= tcphdr_size
19      && tcp_seg_len >= tcphdr->doff*4)
20  {
21    const int tcphdr_len = tcphdr->doff*4;
22    const int payload_len = tot_len - iphdr_len - tcphdr_len;
23
24    char source[INET_ADDRSTRLEN];
25    char dest[INET_ADDRSTRLEN];
26    inet_ntop(AF_INET, &iphdr->saddr, source, INET_ADDRSTRLEN);
27    inet_ntop(AF_INET, &iphdr->daddr, dest, INET_ADDRSTRLEN);
28    printf("IP %s.%d > %s.%d: ",
29           source, ntohs(tcphdr->source), dest, ntohs(tcphdr->dest));
30    printf("Flags [%c], seq %u, win %d, length %d%s\n",
31           tcphdr->syn ? 'S' : (tcphdr->fin ? 'F' : '.'),
32           ntohl(tcphdr->seq),
33           ntohs(tcphdr->window),
34           payload_len,
35           tcphdr_len > sizeof(struct tcphdr) ? " <>" : "");
36
37    union
38    {
39      unsigned char output[ETH_FRAME_LEN];
40      struct
41      {
42        struct iphdr iphdr;
43        struct tcphdr tcphdr;
44      } out;
45    };
46
47    static_assert(sizeof(out) == sizeof(struct iphdr) + sizeof(struct tcphdr), "");
48    int output_len = sizeof(out);
49    bzero(&out, output_len + 4);
50    memcpy(output, input, sizeof(struct iphdr));
51
52    out.tcphdr.source = tcphdr->dest;
53    out.tcphdr.dest = tcphdr->source;
54    out.tcphdr.doff = sizeof(struct tcphdr) / 4;
55    out.tcphdr.window = htons(65000);
56
57    bool response = false;
58    const uint32_t seq = ntohl(tcphdr->seq);
59    const uint32_t isn = 123456;
60    if (tcphdr->syn)
61    {
62      out.tcphdr.seq = htonl(isn);
63      out.tcphdr.ack_seq = htonl(seq+1);
64      out.tcphdr.syn = 1;
65      out.tcphdr.ack = 1;
66
67      // set mss=1000
68      unsigned char* mss = output + output_len;
69      *mss++ = 2;
70      *mss++ = 4;
71      *mss++ = 0x03;
72      *mss++ = 0xe8;  // 1000 == 0x03e8
73      out.tcphdr.doff += 1;
74      output_len += 4;
75
76      response = true;
77    }
78    else if (tcphdr->fin)
79    {
80      out.tcphdr.seq = htonl(isn+1);
81      out.tcphdr.ack_seq = htonl(seq+1);
82      out.tcphdr.fin = 1;
83      out.tcphdr.ack = 1;
84      response = true;
85    }
86    else if (payload_len > 0)
87    {
88      out.tcphdr.seq = htonl(isn+1);
89      out.tcphdr.ack_seq = htonl(seq+payload_len);
90      out.tcphdr.ack = 1;
91      response = true;
92    }
93
94    // build IP header
95    out.iphdr.tot_len = htons(output_len);
96    std::swap(out.iphdr.saddr, out.iphdr.daddr);
97    out.iphdr.check = 0;
98    out.iphdr.check = in_checksum(output, sizeof(struct iphdr));
99
100    unsigned char* pseudo = output + output_len;
101    pseudo[0] = 0;
102    pseudo[1] = IPPROTO_TCP;
103    pseudo[2] = 0;
104    pseudo[3] = output_len - sizeof(struct iphdr);
105    out.tcphdr.check = in_checksum(&out.iphdr.saddr, output_len - 8);
106    if (response)
107    {
108      write(fd, output, output_len);
109    }
110  }
111}
112
113int main(int argc, char* argv[])
114{
115  char ifname[IFNAMSIZ] = "tun%d";
116  bool offload = argc > 1 && strcmp(argv[1], "-K") == 0;
117  int fd = tun_alloc(ifname, offload);
118
119  if (fd < 0)
120  {
121    fprintf(stderr, "tunnel interface allocation failed\n");
122    exit(1);
123  }
124
125  printf("allocted tunnel interface %s\n", ifname);
126  sleep(1);
127
128  for (;;)
129  {
130    union
131    {
132      unsigned char buf[IP_MAXPACKET];
133      struct iphdr iphdr;
134    };
135
136    const int iphdr_size = sizeof iphdr;
137
138    int nread = read(fd, buf, sizeof(buf));
139    if (nread < 0)
140    {
141      perror("read");
142      close(fd);
143      exit(1);
144    }
145    else if (nread == sizeof(buf))
146    {
147      printf("possible message truncated.\n");
148    }
149    printf("read %d bytes from tunnel interface %s.\n", nread, ifname);
150
151    const int iphdr_len = iphdr.ihl*4;  // FIXME: check nread >= sizeof iphdr before accessing iphdr.ihl.
152    if (nread >= iphdr_size
153        && iphdr.version == 4
154        && iphdr_len >= iphdr_size
155        && iphdr_len <= nread
156        && iphdr.tot_len == htons(nread)
157        && in_checksum(buf, iphdr_len) == 0)
158    {
159      const void* payload = buf + iphdr_len;
160      if (iphdr.protocol == IPPROTO_ICMP)
161      {
162        icmp_input(fd, buf, payload, nread);
163      }
164      else if (iphdr.protocol == IPPROTO_TCP)
165      {
166        tcp_input(fd, buf, payload, nread);
167      }
168    }
169    else if (iphdr.version == 4)
170    {
171      printf("bad packet\n");
172      for (int i = 0; i < nread; ++i)
173      {
174        if (i % 4 == 0) printf("\n");
175        printf("%02x ", buf[i]);
176      }
177      printf("\n");
178    }
179  }
180
181  return 0;
182}
183