1#include "faketcp.h"
2
3#include <stdio.h>
4#include <stdlib.h>
5#include <string.h>
6#include <unistd.h>
7#include <netinet/ip.h>
8#include <netinet/tcp.h>
9#include <linux/if_ether.h>
10
11void tcp_input(int fd, const void* input, const void* ippayload, int tot_len)
12{
13  const struct iphdr* iphdr = static_cast<const struct iphdr*>(input);
14  const struct tcphdr* tcphdr = static_cast<const struct tcphdr*>(ippayload);
15  const int iphdr_len = iphdr->ihl*4;
16  const int tcp_seg_len = tot_len - iphdr_len;
17  const int tcphdr_size = sizeof(*tcphdr);
18  if (tcp_seg_len >= tcphdr_size
19      && tcp_seg_len >= tcphdr->doff*4)
20  {
21    const int tcphdr_len = tcphdr->doff*4;
22    const void* payload = ippayload + tcphdr_len;
23    const int payload_len = tot_len - iphdr_len - tcphdr_len;
24
25    char source[INET_ADDRSTRLEN];
26    char dest[INET_ADDRSTRLEN];
27    inet_ntop(AF_INET, &iphdr->saddr, source, INET_ADDRSTRLEN);
28    inet_ntop(AF_INET, &iphdr->daddr, dest, INET_ADDRSTRLEN);
29    printf("IP %s.%d > %s.%d: ",
30           source, ntohs(tcphdr->source), dest, ntohs(tcphdr->dest));
31    printf("Flags [%c], seq %u, win %d, length %d\n",
32           tcphdr->syn ? 'S' : (tcphdr->fin ? 'F' : '.'),
33           ntohl(tcphdr->seq),
34           ntohs(tcphdr->window),
35           payload_len);
36
37    union
38    {
39      unsigned char output[ETH_FRAME_LEN];
40      struct
41      {
42        struct iphdr iphdr;
43        struct tcphdr tcphdr;
44      } out;
45    };
46
47    assert(sizeof(out) == sizeof(struct iphdr) + sizeof(struct tcphdr));
48    int output_len = sizeof(out);
49    bzero(&out, output_len + 4);
50    memcpy(output, input, sizeof(struct iphdr));
51
52    std::swap(out.iphdr.saddr, out.iphdr.daddr);
53    out.iphdr.check = 0;
54
55    out.tcphdr.source = tcphdr->dest;
56    out.tcphdr.dest = tcphdr->source;
57    out.tcphdr.doff = sizeof(struct tcphdr) / 4;
58    out.tcphdr.window = htons(5000);
59
60    bool response = false;
61    const uint32_t seq = ntohl(tcphdr->seq);
62    if (tcphdr->syn)
63    {
64      out.tcphdr.seq = htonl(seq);
65      out.tcphdr.ack_seq = htonl(seq+1);
66      out.tcphdr.syn = 1;
67      out.tcphdr.ack = 1;
68      response = true;
69    }
70    else if (tcphdr->fin)
71    {
72      out.tcphdr.seq = htonl(seq);
73      out.tcphdr.ack_seq = htonl(seq+1);
74      out.tcphdr.fin = 1;
75      out.tcphdr.ack = 1;
76      response = true;
77    }
78    else if (payload_len > 0)
79    {
80      out.tcphdr.seq = htonl(seq);
81      out.tcphdr.ack_seq = htonl(seq+payload_len);
82      out.tcphdr.psh = 1;
83      out.tcphdr.ack = 1;
84      assert(output + output_len + payload_len < output + sizeof(output));
85      memcpy(output + output_len, payload, payload_len);
86      output_len += payload_len;
87      response = true;
88    }
89
90    out.iphdr.tot_len = htons(output_len);
91    out.iphdr.check = in_checksum(output, sizeof(struct iphdr));
92
93    unsigned char* pseudo = output + output_len;
94    if (payload_len % 2 == 1)
95    {
96      *pseudo = 0;
97      ++pseudo;
98    }
99    unsigned int len = sizeof(struct tcphdr)+payload_len;
100    pseudo[0] = 0;
101    pseudo[1] = IPPROTO_TCP;
102    pseudo[2] = len / 256;
103    pseudo[3] = len % 256;
104    out.tcphdr.check = in_checksum(&out.iphdr.saddr, len + 12 + (payload_len % 2));
105    if (response)
106    {
107      write(fd, output, output_len);
108    }
109  }
110}
111
112int main()
113{
114  char ifname[IFNAMSIZ] = "tun%d";
115  int fd = tun_alloc(ifname);
116
117  if (fd < 0)
118  {
119    fprintf(stderr, "tunnel interface allocation failed\n");
120    exit(1);
121  }
122
123  printf("allocted tunnel interface %s\n", ifname);
124  sleep(1);
125
126  for (;;)
127  {
128    union
129    {
130      unsigned char buf[ETH_FRAME_LEN];
131      struct iphdr iphdr;
132    };
133
134    const int iphdr_size = sizeof iphdr;
135
136    int nread = read(fd, buf, sizeof(buf));
137    if (nread < 0)
138    {
139      perror("read");
140      close(fd);
141      exit(1);
142    }
143    printf("read %d bytes from tunnel interface %s.\n", nread, ifname);
144
145    const int iphdr_len = iphdr.ihl*4;
146    if (nread >= iphdr_size
147        && iphdr.version == 4
148        && iphdr_len >= iphdr_size
149        && iphdr_len <= nread
150        && iphdr.tot_len == htons(nread)
151        && in_checksum(buf, iphdr_len) == 0)
152    {
153      const void* payload = buf + iphdr_len;
154      if (iphdr.protocol == IPPROTO_ICMP)
155      {
156        icmp_input(fd, buf, payload, nread);
157      }
158      else if (iphdr.protocol == IPPROTO_TCP)
159      {
160        tcp_input(fd, buf, payload, nread);
161      }
162    }
163    else
164    {
165      printf("bad packet\n");
166      for (int i = 0; i < nread; ++i)
167      {
168        if (i % 4 == 0) printf("\n");
169        printf("%02x ", buf[i]);
170      }
171      printf("\n");
172    }
173  }
174
175  return 0;
176}
177