echoall2.cc revision 2c431f66
1#include "faketcp.h"
2
3#include <map>
4
5#include <stdio.h>
6#include <stdlib.h>
7#include <unistd.h>
8#include <netinet/ip.h>
9#include <netinet/tcp.h>
10#include <linux/if_ether.h>
11
12struct TcpState
13{
14  uint32_t rcv_nxt;
15  uint32_t snd_una;
16};
17
18std::map<SocketAddr, TcpState> expectedSeqs;
19
20void tcp_input(int fd, const void* input, const void* ippayload, int tot_len)
21{
22  const struct iphdr* iphdr = static_cast<const struct iphdr*>(input);
23  const struct tcphdr* tcphdr = static_cast<const struct tcphdr*>(ippayload);
24  const int iphdr_len = iphdr->ihl*4;
25  const int tcp_seg_len = tot_len - iphdr_len;
26  const int tcphdr_size = sizeof(*tcphdr);
27  if (tcp_seg_len >= tcphdr_size
28      && tcp_seg_len >= tcphdr->doff*4)
29  {
30    const int tcphdr_len = tcphdr->doff*4;
31    const void* payload = ippayload + tcphdr_len;
32    const int payload_len = tot_len - iphdr_len - tcphdr_len;
33
34    char source[INET_ADDRSTRLEN];
35    char dest[INET_ADDRSTRLEN];
36    inet_ntop(AF_INET, &iphdr->saddr, source, INET_ADDRSTRLEN);
37    inet_ntop(AF_INET, &iphdr->daddr, dest, INET_ADDRSTRLEN);
38    printf("IP %s.%d > %s.%d: ",
39           source, ntohs(tcphdr->source), dest, ntohs(tcphdr->dest));
40    printf("Flags [%c], seq %u, win %d, length %d\n",
41           tcphdr->syn ? 'S' : (tcphdr->fin ? 'F' : '.'),
42           ntohl(tcphdr->seq),
43           ntohs(tcphdr->window),
44           payload_len);
45
46    union
47    {
48      unsigned char output[ETH_FRAME_LEN];
49      struct
50      {
51        struct iphdr iphdr;
52        struct tcphdr tcphdr;
53      } out;
54    };
55
56    assert(sizeof(out) == sizeof(struct iphdr) + sizeof(struct tcphdr));
57    int output_len = sizeof(out);
58    bzero(&out, output_len + 4);
59    memcpy(output, input, sizeof(struct iphdr));
60
61    std::swap(out.iphdr.saddr, out.iphdr.daddr);
62    out.iphdr.check = 0;
63
64    out.tcphdr.source = tcphdr->dest;
65    out.tcphdr.dest = tcphdr->source;
66    out.tcphdr.doff = sizeof(struct tcphdr) / 4;
67    out.tcphdr.window = htons(5000);
68
69    SocketAddr addr = { out.iphdr.saddr, out.iphdr.daddr, out.tcphdr.source, out.tcphdr.dest };
70    bool response = false;
71    const uint32_t seq = ntohl(tcphdr->seq);
72    if (tcphdr->syn)
73    {
74      out.tcphdr.seq = htonl(seq);
75      out.tcphdr.ack_seq = htonl(seq+1);
76      out.tcphdr.syn = 1;
77      out.tcphdr.ack = 1;
78      TcpState s = { seq + 1, seq + 1 };
79      expectedSeqs[addr] = s;
80      response = true;
81    }
82    else if (tcphdr->fin)
83    {
84      out.tcphdr.seq = htonl(seq);
85      out.tcphdr.ack_seq = htonl(seq+1);
86      out.tcphdr.fin = 1;
87      out.tcphdr.ack = 1;
88      expectedSeqs.erase(addr);
89      response = true;
90    }
91    else
92    {
93      out.tcphdr.seq = htonl(seq);
94      out.tcphdr.ack_seq = htonl(seq+payload_len);
95      out.tcphdr.psh = 1;
96      out.tcphdr.ack = 1;
97      assert(output + output_len + payload_len < output + sizeof(output));
98      memcpy(output + output_len, payload, payload_len);
99      output_len += payload_len;
100      auto it = expectedSeqs.find(addr);
101      if (it != expectedSeqs.end())
102      {
103        const uint32_t ack = ntohl(tcphdr->ack_seq);
104        printf("seq = %u, ack = %u\n", seq, ack);
105        printf("rcv_nxt = %u, snd_una = %u\n", it->second.rcv_nxt, it->second.snd_una);
106        if (payload_len > 0 && it->second.rcv_nxt == seq)
107        {
108          it->second.snd_una = seq+payload_len;
109          response = true;
110        }
111        if (tcphdr->ack && ack == it->second.snd_una)
112        {
113          it->second.rcv_nxt = ack;
114        }
115        printf("rcv_nxt = %u, snd_una = %u\n", it->second.rcv_nxt, it->second.snd_una);
116      }
117      else
118      {
119        // RST
120      }
121    }
122
123    out.iphdr.tot_len = htons(output_len);
124    out.iphdr.check = in_checksum(output, sizeof(struct iphdr));
125
126    unsigned char* pseudo = output + output_len;
127    if (payload_len % 2 == 1)
128    {
129      *pseudo = 0;
130      ++pseudo;
131    }
132    unsigned int len = sizeof(struct tcphdr)+payload_len;
133    pseudo[0] = 0;
134    pseudo[1] = IPPROTO_TCP;
135    pseudo[2] = len / 256;
136    pseudo[3] = len % 256;
137    out.tcphdr.check = in_checksum(&out.iphdr.saddr, len + 12 + (payload_len % 2));
138    if (response)
139    {
140      write(fd, output, output_len);
141    }
142  }
143}
144
145int main()
146{
147  char ifname[IFNAMSIZ] = "tun%d";
148  int fd = tun_alloc(ifname);
149
150  if (fd < 0)
151  {
152    fprintf(stderr, "tunnel interface allocation failed\n");
153    exit(1);
154  }
155
156  printf("allocted tunnel interface %s\n", ifname);
157  sleep(1);
158
159  for (;;)
160  {
161    union
162    {
163      unsigned char buf[ETH_FRAME_LEN];
164      struct iphdr iphdr;
165    };
166
167    const int iphdr_size = sizeof iphdr;
168
169    int nread = read(fd, buf, sizeof(buf));
170    if (nread < 0)
171    {
172      perror("read");
173      close(fd);
174      exit(1);
175    }
176    printf("read %d bytes from tunnel interface %s.\n", nread, ifname);
177
178    const int iphdr_len = iphdr.ihl*4;
179    if (nread >= iphdr_size
180        && iphdr.version == 4
181        && iphdr_len >= iphdr_size
182        && iphdr_len <= nread
183        && iphdr.tot_len == htons(nread)
184        && in_checksum(buf, iphdr_len) == 0)
185    {
186      const void* payload = buf + iphdr_len;
187      if (iphdr.protocol == IPPROTO_ICMP)
188      {
189        icmp_input(fd, buf, payload, nread);
190      }
191      else if (iphdr.protocol == IPPROTO_TCP)
192      {
193        tcp_input(fd, buf, payload, nread);
194      }
195    }
196    else
197    {
198      printf("bad packet\n");
199      for (int i = 0; i < nread; ++i)
200      {
201        if (i % 4 == 0) printf("\n");
202        printf("%02x ", buf[i]);
203      }
204      printf("\n");
205    }
206  }
207
208  return 0;
209}
210