]> git.sesse.net Git - greproxy/blob - tungre.cpp
Split Protocol out into a header.
[greproxy] / tungre.cpp
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <fcntl.h>
5 #include <unistd.h>
6 #include <sys/ioctl.h>
7 #include <netinet/in.h>
8 #include <arpa/inet.h>
9 #include <linux/if.h>
10 #include <linux/if_tun.h>
11
12 #include <map>
13 #include <string>
14 #include <queue>
15
16 #include "protocol.h"
17
18 using namespace std;
19
20 struct gre_header {
21         uint8_t reserved0_hi : 4;
22         uint8_t has_seq : 1;
23         uint8_t has_key : 1;
24         uint8_t unused : 1;
25         uint8_t has_checksum : 1;
26
27         uint8_t version : 3;
28         uint8_t reserved0_lo: 5;
29
30         uint16_t protocol_type;
31 };
32
33 int tun_open(const char *name) {
34         struct ifreq ifr;
35
36         int fd = open("/dev/net/tun", O_RDWR);
37         if (fd == -1) {
38                 perror("/dev/net/tun");
39                 exit(1);
40         }
41
42         memset(&ifr, 0, sizeof(ifr));
43         ifr.ifr_flags = IFF_TUN;
44         strncpy(ifr.ifr_name, name, IFNAMSIZ);
45
46         int err = ioctl(fd, TUNSETIFF, &ifr);
47         if (err == -1) {
48                 perror("ioctl(TUNSETIFF)");
49                 exit(-1);
50         }
51
52         return fd;
53 }
54
55 in6_addr get_addr(const char *str) {
56         in6_addr ret;
57         if (inet_pton(AF_INET6, str, &ret) != 1) {
58                 fprintf(stderr, "Could not parse %s\n", str);
59                 exit(1);
60         }
61         return ret;
62 }
63
64 struct GREPacket {
65         int seq;
66         uint16_t proto;
67         string data;
68
69         bool operator> (const GREPacket &other) const {
70                 return seq > other.seq;
71         }
72 };
73
74 class Reorderer;
75
76 class GREProtocol : public Protocol {
77 public:
78         GREProtocol(const in6_addr &myaddr, const in6_addr &dst);
79         virtual void send_packet(uint16_t proto, const string &data);
80         virtual int fd() const;
81         void read_packet(Reorderer* sender);
82
83 private:
84         int seq;
85         int sock;
86         sockaddr_in6 dstaddr;
87 };
88
89 class TUNProtocol : public Protocol {
90 public:
91         TUNProtocol(const char *devname);
92         virtual void send_packet(uint16_t proto, const string &data);
93         virtual int fd() const;
94         void read_packet(Protocol* sender);
95
96 private:
97         int tunfd;
98 };
99
100 class Reorderer {
101 public:
102         Reorderer(Protocol* sender);
103         void handle_packet(uint16_t proto, const string& data, int seq);
104
105 private:
106         void send_packet(uint16_t proto, const string &data, bool silence);
107
108         Protocol* sender;
109         int last_seq;
110
111         priority_queue<GREPacket, vector<GREPacket>, greater<GREPacket>> packet_buffer;
112         map<int, int> ccs;
113 };
114
115 GREProtocol::GREProtocol(const in6_addr &src, const in6_addr &dst)
116         : seq(0)
117 {
118         memset(&dstaddr, 0, sizeof(dstaddr));
119         dstaddr.sin6_family = AF_INET6;
120         dstaddr.sin6_addr = dst;
121
122         sock = socket(AF_INET6, SOCK_RAW, IPPROTO_GRE);
123         if (sock == -1) {
124                 perror("socket");
125                 exit(1);
126         }
127
128         sockaddr_in6 my_addr;
129         memset(&my_addr, 0, sizeof(my_addr));
130         my_addr.sin6_family = AF_INET6;
131         my_addr.sin6_addr = src;
132         if (bind(sock, (sockaddr *)&my_addr, sizeof(my_addr)) == -1) {
133                 perror("bind");
134                 exit(1);
135         }
136 }
137
138 void GREProtocol::send_packet(uint16_t proto, const string &data)
139 {
140         char buf[4096];
141         gre_header *gre = (gre_header *)buf;
142
143         memset(gre, 0, sizeof(*gre));
144         gre->has_seq = 1;
145         gre->version = 0;
146         gre->protocol_type = htons(proto);
147
148         char *ptr = buf + sizeof(*gre);
149         int seq_be = htonl(seq++);
150         memcpy(ptr, &seq_be, sizeof(seq_be));
151         ptr += sizeof(seq_be);
152
153         memcpy(ptr, data.data(), data.size());
154         
155         if (sendto(sock, buf, data.size() + sizeof(seq_be) + sizeof(*gre), 0, (sockaddr *)&dstaddr, sizeof(dstaddr)) == -1) {
156                 perror("sendto");
157                 return;
158         }
159 }
160
161 int GREProtocol::fd() const
162 {
163         return sock;
164 }
165         
166 TUNProtocol::TUNProtocol(const char *devname)
167         : tunfd(tun_open(devname)) {
168 }
169
170 void TUNProtocol::send_packet(uint16_t proto, const string &data)
171 {
172         char buf[4096];
173
174         char *ptr = buf;
175         uint16_t flags = 0;
176         memcpy(ptr, &flags, sizeof(flags));
177         ptr += sizeof(flags);
178
179         proto = htons(proto);
180         memcpy(ptr, &proto, sizeof(proto));
181         ptr += sizeof(proto);
182
183         memcpy(ptr, data.data(), data.size());
184
185         int len = sizeof(flags) + sizeof(proto) + data.size();
186         if (write(tunfd, buf, len) != len) {
187                 perror("write");
188                 return;
189         }
190 }
191
192 int TUNProtocol::fd() const
193 {
194         return tunfd;
195 }
196
197 Reorderer::Reorderer(Protocol* sender)
198         : sender(sender), last_seq(-1)
199 {
200 }
201
202 #define PACKET_BUFFER_SIZE 100
203
204 void Reorderer::handle_packet(uint16_t proto, const string& data, int seq)
205 {
206         bool silence = false;
207         if (packet_buffer.size() >= PACKET_BUFFER_SIZE) {
208                 printf("Gave up waiting for packets [%d,%d>\n",
209                         last_seq + 1, packet_buffer.top().seq);
210                 silence = true;
211                 last_seq = packet_buffer.top().seq - 1;
212         }
213
214         GREPacket packet;
215         packet.seq = seq;
216         packet.proto = proto;
217         packet.data = data;
218         packet_buffer.push(packet);
219
220         while (!packet_buffer.empty() &&
221                (last_seq == -1 || packet_buffer.top().seq <= last_seq + 1)) {
222                 int front_seq = packet_buffer.top().seq;
223                 if (front_seq < last_seq + 1) {
224                         printf("Duplicate packet or way out-of-order: seq=%d front_seq=%d\n",
225                                 front_seq, last_seq + 1);
226                         packet_buffer.pop();
227                         continue;
228                 }
229                 //if (packet_buffer.size() > 1) {
230                 //      printf("seq=%d (REORDER %d)\n", front_seq, int(packet_buffer.size()));
231                 //} else {
232                 //      printf("seq=%d\n", front_seq);
233                 //}
234                 const string &data = packet_buffer.top().data;
235                 send_packet(packet_buffer.top().proto, data, silence);
236                 packet_buffer.pop();
237                 last_seq = front_seq;
238                 if (!silence && !packet_buffer.empty()) {
239                         printf("Reordering with packet buffer size %d: seq=%d new_front_seq=%d\n", int(packet_buffer.size()), front_seq, packet_buffer.top().seq);
240                         silence = true;
241                 }
242         }
243 }
244
245 void Reorderer::send_packet(uint16_t proto, const string &data, bool silence)
246 {
247         if (data.size() == 1344) {
248                 for (int i = 0; i < 7; ++i) {
249                         const char *pkt = &data[i * 188 + 28];
250                         int pid = (ntohl(*(uint32_t *)(pkt)) & 0x1fff00) >> 8;
251                         if (pid == 8191) {
252                                 // stuffing, ignore
253                                 continue;
254                         }
255                         int has_payload = pkt[3] & 0x10;
256                         int cc = pkt[3] & 0xf;
257                         if (has_payload) {
258                                 int last_cc = ccs[pid];
259                                 if (!silence && cc != ((last_cc + 1) & 0xf)) {
260                                         printf("Pid %d discontinuity (expected %d, got %d)\n", pid, (last_cc + 1) & 0xf, cc);
261                                 }
262                                 ccs[pid] = cc;
263                         }
264                 }
265         }
266         sender->send_packet(proto, data);
267 }
268
269 void GREProtocol::read_packet(Reorderer *sender)
270 {
271         struct sockaddr_storage addr;
272         socklen_t addrlen = sizeof(addr);
273         char buf[4096];
274         int ret = recvfrom(sock, buf, sizeof(buf), 0, (struct sockaddr *)&addr, &addrlen);
275         if (ret == -1) {
276                 perror("recvfrom");
277                 exit(1);
278         }
279         if (addr.ss_family != AF_INET6) {
280                 return;
281         }
282         struct in6_addr *addr6 = &((struct sockaddr_in6 *)&addr)->sin6_addr;
283         if (memcmp(addr6, &dstaddr.sin6_addr, sizeof(*addr6)) != 0) {
284                 // ignore
285                 return;
286         }
287         gre_header* gre = (gre_header *)buf;
288
289         char* ptr = buf + sizeof(gre_header);
290         if (gre->has_checksum) {
291                 ptr += 4;
292         }
293         if (gre->has_key) {
294                 ptr += 4;
295         }
296         uint32_t seq;
297         if (gre->has_seq) {
298                 seq = ntohl(*(uint32_t *)ptr);
299                 ptr += 4;
300         }
301
302         //printf("gre packet: proto=%x\n", ntohs(gre->protocol_type));
303
304         sender->handle_packet(ntohs(gre->protocol_type), string(ptr, buf + ret), seq);
305 }
306
307 void TUNProtocol::read_packet(Protocol *sender)
308 {
309         char buf[4096];
310         int ret = read(tunfd, buf, sizeof(buf));
311         if (ret == -1) {
312                 perror("read");
313                 exit(1);
314         }
315         if (ret == 0) {
316                 fprintf(stderr, "tunfd EOF\n");
317                 exit(1);
318         }
319         
320         char *ptr = buf;
321         uint16_t flags = *(uint16_t *)ptr;
322         ptr += 2;
323         uint16_t proto = ntohs(*(uint16_t *)ptr);
324         ptr += 2;
325         //fprintf(stderr, "tun packet: flags=%x proto=%x len=%d\n",
326         //      flags, proto, ret - 4);
327         sender->send_packet(proto, string(ptr, buf + ret));
328 }
329
330 int main(int argc, char **argv)
331 {
332         in6_addr myaddr = get_addr(argv[1]);
333         in6_addr remoteaddr = get_addr(argv[2]);
334         GREProtocol gre(myaddr, remoteaddr);
335         TUNProtocol tun("tungre");
336
337         Reorderer tun_reorderer(&tun);
338
339         fd_set fds;
340         FD_ZERO(&fds);
341         for ( ;; ) {
342                 FD_SET(gre.fd(), &fds);
343                 FD_SET(tun.fd(), &fds);
344                 int ret = select(1024, &fds, NULL, NULL, NULL);
345                 if (ret == -1) {
346                         perror("select");
347                         continue;
348                 }
349
350                 if (FD_ISSET(gre.fd(), &fds)) {
351                         gre.read_packet(&tun_reorderer);
352                 }
353                 if (FD_ISSET(tun.fd(), &fds)) {
354                         tun.read_packet(&gre);
355                 }
356         }
357 }