1#include "faketcp.h" 2 3#include <stdio.h> 4#include <stdlib.h> 5#include <string.h> 6#include <unistd.h> 7#include <netinet/ip.h> 8#include <netinet/tcp.h> 9#include <linux/if_ether.h> 10 11void tcp_input(int fd, const void* input, const void* payload, int tot_len) 12{ 13 const struct iphdr* iphdr = static_cast<const struct iphdr*>(input); 14 const struct tcphdr* tcphdr = static_cast<const struct tcphdr*>(payload); 15 const int iphdr_len = iphdr->ihl*4; 16 const int tcp_seg_len = tot_len - iphdr_len; 17 const int tcphdr_size = sizeof(*tcphdr); 18 if (tcp_seg_len >= tcphdr_size 19 && tcp_seg_len >= tcphdr->doff*4) 20 { 21 const int tcphdr_len = tcphdr->doff*4; 22 const int payload_len = tot_len - iphdr_len - tcphdr_len; 23 24 char source[INET_ADDRSTRLEN]; 25 char dest[INET_ADDRSTRLEN]; 26 inet_ntop(AF_INET, &iphdr->saddr, source, INET_ADDRSTRLEN); 27 inet_ntop(AF_INET, &iphdr->daddr, dest, INET_ADDRSTRLEN); 28 printf("IP %s.%d > %s.%d: ", 29 source, ntohs(tcphdr->source), dest, ntohs(tcphdr->dest)); 30 printf("Flags [%c], seq %u, win %d, length %d%s\n", 31 tcphdr->syn ? 'S' : (tcphdr->fin ? 'F' : '.'), 32 ntohl(tcphdr->seq), 33 ntohs(tcphdr->window), 34 payload_len, 35 tcphdr_len > sizeof(struct tcphdr) ? " <>" : ""); 36 37 union 38 { 39 unsigned char output[ETH_FRAME_LEN]; 40 struct 41 { 42 struct iphdr iphdr; 43 struct tcphdr tcphdr; 44 } out; 45 }; 46 47 static_assert(sizeof(out) == sizeof(struct iphdr) + sizeof(struct tcphdr), ""); 48 int output_len = sizeof(out); 49 bzero(&out, output_len + 4); 50 memcpy(output, input, sizeof(struct iphdr)); 51 52 out.tcphdr.source = tcphdr->dest; 53 out.tcphdr.dest = tcphdr->source; 54 out.tcphdr.doff = sizeof(struct tcphdr) / 4; 55 out.tcphdr.window = htons(65000); 56 57 bool response = false; 58 const uint32_t seq = ntohl(tcphdr->seq); 59 const uint32_t isn = 123456; 60 if (tcphdr->syn) 61 { 62 out.tcphdr.seq = htonl(isn); 63 out.tcphdr.ack_seq = htonl(seq+1); 64 out.tcphdr.syn = 1; 65 out.tcphdr.ack = 1; 66 67 // set mss=1000 68 unsigned char* mss = output + output_len; 69 *mss++ = 2; 70 *mss++ = 4; 71 *mss++ = 0x03; 72 *mss++ = 0xe8; // 1000 == 0x03e8 73 out.tcphdr.doff += 1; 74 output_len += 4; 75 76 response = true; 77 } 78 else if (tcphdr->fin) 79 { 80 out.tcphdr.seq = htonl(isn+1); 81 out.tcphdr.ack_seq = htonl(seq+1); 82 out.tcphdr.fin = 1; 83 out.tcphdr.ack = 1; 84 response = true; 85 } 86 else if (payload_len > 0) 87 { 88 out.tcphdr.seq = htonl(isn+1); 89 out.tcphdr.ack_seq = htonl(seq+payload_len); 90 out.tcphdr.ack = 1; 91 response = true; 92 } 93 94 // build IP header 95 out.iphdr.tot_len = htons(output_len); 96 std::swap(out.iphdr.saddr, out.iphdr.daddr); 97 out.iphdr.check = 0; 98 out.iphdr.check = in_checksum(output, sizeof(struct iphdr)); 99 100 unsigned char* pseudo = output + output_len; 101 pseudo[0] = 0; 102 pseudo[1] = IPPROTO_TCP; 103 pseudo[2] = 0; 104 pseudo[3] = output_len - sizeof(struct iphdr); 105 out.tcphdr.check = in_checksum(&out.iphdr.saddr, output_len - 8); 106 if (response) 107 { 108 write(fd, output, output_len); 109 } 110 } 111} 112 113int main(int argc, char* argv[]) 114{ 115 char ifname[IFNAMSIZ] = "tun%d"; 116 bool offload = argc > 1 && strcmp(argv[1], "-K") == 0; 117 int fd = tun_alloc(ifname, offload); 118 119 if (fd < 0) 120 { 121 fprintf(stderr, "tunnel interface allocation failed\n"); 122 exit(1); 123 } 124 125 printf("allocted tunnel interface %s\n", ifname); 126 sleep(1); 127 128 for (;;) 129 { 130 union 131 { 132 unsigned char buf[IP_MAXPACKET]; 133 struct iphdr iphdr; 134 }; 135 136 const int iphdr_size = sizeof iphdr; 137 138 int nread = read(fd, buf, sizeof(buf)); 139 if (nread < 0) 140 { 141 perror("read"); 142 close(fd); 143 exit(1); 144 } 145 else if (nread == sizeof(buf)) 146 { 147 printf("possible message truncated.\n"); 148 } 149 printf("read %d bytes from tunnel interface %s.\n", nread, ifname); 150 151 const int iphdr_len = iphdr.ihl*4; // FIXME: check nread >= sizeof iphdr before accessing iphdr.ihl. 152 if (nread >= iphdr_size 153 && iphdr.version == 4 154 && iphdr_len >= iphdr_size 155 && iphdr_len <= nread 156 && iphdr.tot_len == htons(nread) 157 && in_checksum(buf, iphdr_len) == 0) 158 { 159 const void* payload = buf + iphdr_len; 160 if (iphdr.protocol == IPPROTO_ICMP) 161 { 162 icmp_input(fd, buf, payload, nread); 163 } 164 else if (iphdr.protocol == IPPROTO_TCP) 165 { 166 tcp_input(fd, buf, payload, nread); 167 } 168 } 169 else if (iphdr.version == 4) 170 { 171 printf("bad packet\n"); 172 for (int i = 0; i < nread; ++i) 173 { 174 if (i % 4 == 0) printf("\n"); 175 printf("%02x ", buf[i]); 176 } 177 printf("\n"); 178 } 179 } 180 181 return 0; 182} 183