1#include "faketcp.h"
2
3#include <map>
4
5#include <stdio.h>
6#include <stdlib.h>
7#include <unistd.h>
8#include <netinet/ip.h>
9#include <netinet/tcp.h>
10#include <linux/if_ether.h>
11
12struct TcpState
13{
14  uint32_t rcv_nxt;
15//  uint32_t snd_una;
16};
17
18std::map<SocketAddr, TcpState> expectedSeqs;
19
20void tcp_input(int fd, const void* input, const void* payload, int tot_len)
21{
22  const struct iphdr* iphdr = static_cast<const struct iphdr*>(input);
23  const struct tcphdr* tcphdr = static_cast<const struct tcphdr*>(payload);
24  const int iphdr_len = iphdr->ihl*4;
25  const int tcp_seg_len = tot_len - iphdr_len;
26  const int tcphdr_size = sizeof(*tcphdr);
27  if (tcp_seg_len >= tcphdr_size
28      && tcp_seg_len >= tcphdr->doff*4)
29  {
30    const int tcphdr_len = tcphdr->doff*4;
31    const int payload_len = tot_len - iphdr_len - tcphdr_len;
32
33    char source[INET_ADDRSTRLEN];
34    char dest[INET_ADDRSTRLEN];
35    inet_ntop(AF_INET, &iphdr->saddr, source, INET_ADDRSTRLEN);
36    inet_ntop(AF_INET, &iphdr->daddr, dest, INET_ADDRSTRLEN);
37    printf("IP %s.%d > %s.%d: ",
38           source, ntohs(tcphdr->source), dest, ntohs(tcphdr->dest));
39    printf("Flags [%c], seq %u, win %d, length %d\n",
40           tcphdr->syn ? 'S' : (tcphdr->fin ? 'F' : '.'),
41           ntohl(tcphdr->seq),
42           ntohs(tcphdr->window),
43           payload_len);
44
45    union
46    {
47      unsigned char output[ETH_FRAME_LEN];
48      struct
49      {
50        struct iphdr iphdr;
51        struct tcphdr tcphdr;
52      } out;
53    };
54
55    static_assert(sizeof(out) == sizeof(struct iphdr) + sizeof(struct tcphdr), "");
56    int output_len = sizeof(out);
57    bzero(&out, output_len + 4);
58    memcpy(output, input, sizeof(struct iphdr));
59
60    out.iphdr.tot_len = htons(output_len);
61    std::swap(out.iphdr.saddr, out.iphdr.daddr);
62    out.iphdr.check = 0;
63    out.iphdr.check = in_checksum(output, sizeof(struct iphdr));
64
65    out.tcphdr.source = tcphdr->dest;
66    out.tcphdr.dest = tcphdr->source;
67    out.tcphdr.doff = sizeof(struct tcphdr) / 4;
68    out.tcphdr.window = htons(5000);
69
70    SocketAddr addr = { out.iphdr.saddr, out.iphdr.daddr, out.tcphdr.source, out.tcphdr.dest };
71    bool response = false;
72    const uint32_t seq = ntohl(tcphdr->seq);
73    if (tcphdr->syn)
74    {
75      out.tcphdr.seq = htonl(123456);
76      out.tcphdr.ack_seq = htonl(seq+1);
77      out.tcphdr.syn = 1;
78      out.tcphdr.ack = 1;
79      TcpState s = { seq + 1 };
80      expectedSeqs[addr] = s;
81      response = true;
82    }
83    else if (tcphdr->fin)
84    {
85      out.tcphdr.seq = htonl(123457);
86      out.tcphdr.ack_seq = htonl(seq+1);
87      out.tcphdr.fin = 1;
88      out.tcphdr.ack = 1;
89      expectedSeqs.erase(addr);
90      response = true;
91    }
92    else if (payload_len > 0)
93    {
94      out.tcphdr.seq = htonl(123457);
95      out.tcphdr.ack_seq = htonl(seq+payload_len);
96      out.tcphdr.ack = 1;
97      auto it = expectedSeqs.find(addr);
98      if (it != expectedSeqs.end())
99      {
100        if (it->second.rcv_nxt >= seq)  // FIXME: wrap!
101        {
102          if (it->second.rcv_nxt == seq)
103          {
104            printf("received, seq = %u:%u, len = %u\n", seq, seq+payload_len, payload_len);
105          }
106          else
107          {
108            printf("retransmit, rcv_nxt = %u, seq = %u\n", it->second.rcv_nxt, seq);
109          }
110
111          const uint32_t ack = ntohl(out.tcphdr.ack_seq);
112          it->second.rcv_nxt = ack;
113          response = true;
114        }
115        else
116        {
117          printf("packet loss, rcv_nxt = %u, seq = %u\n", it->second.rcv_nxt, seq);
118        }
119      }
120      else
121      {
122        // RST
123      }
124    }
125
126    unsigned char* pseudo = output + output_len;
127    pseudo[0] = 0;
128    pseudo[1] = IPPROTO_TCP;
129    pseudo[2] = 0;
130    pseudo[3] = sizeof(struct tcphdr);
131    out.tcphdr.check = in_checksum(&out.iphdr.saddr, sizeof(struct tcphdr)+12);
132    if (response)
133    {
134      write(fd, output, output_len);
135    }
136  }
137}
138
139int main()
140{
141  char ifname[IFNAMSIZ] = "tun%d";
142  int fd = tun_alloc(ifname);
143
144  if (fd < 0)
145  {
146    fprintf(stderr, "tunnel interface allocation failed\n");
147    exit(1);
148  }
149
150  printf("allocted tunnel interface %s\n", ifname);
151  sleep(1);
152
153  for (;;)
154  {
155    union
156    {
157      unsigned char buf[IP_MAXPACKET];
158      struct iphdr iphdr;
159    };
160
161    const int iphdr_size = sizeof iphdr;
162
163    int nread = read(fd, buf, sizeof(buf));
164    if (nread < 0)
165    {
166      perror("read");
167      close(fd);
168      exit(1);
169    }
170    else if (nread == sizeof(buf))
171    {
172      printf("possible message truncated.\n");
173    }
174    printf("read %d bytes from tunnel interface %s.\n", nread, ifname);
175
176    const int iphdr_len = iphdr.ihl*4;  // FIXME: check nread >= sizeof iphdr before accessing iphdr.ihl.
177    if (nread >= iphdr_size
178        && iphdr.version == 4
179        && iphdr_len >= iphdr_size
180        && iphdr_len <= nread
181        && iphdr.tot_len == htons(nread)
182        && in_checksum(buf, iphdr_len) == 0)
183    {
184      const void* payload = buf + iphdr_len;
185      if (iphdr.protocol == IPPROTO_ICMP)
186      {
187        icmp_input(fd, buf, payload, nread);
188      }
189      else if (iphdr.protocol == IPPROTO_TCP)
190      {
191        tcp_input(fd, buf, payload, nread);
192      }
193    }
194    else if (iphdr.version == 4)
195    {
196      printf("bad packet\n");
197      for (int i = 0; i < nread; ++i)
198      {
199        if (i % 4 == 0) printf("\n");
200        printf("%02x ", buf[i]);
201      }
202      printf("\n");
203    }
204  }
205
206  return 0;
207}
208