]> git.sesse.net Git - greproxy/blob - tungre.cpp
Move more stuff into the classes. Not fully orthogonal yet.
[greproxy] / tungre.cpp
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <fcntl.h>
5 #include <unistd.h>
6 #include <sys/ioctl.h>
7 #include <netinet/in.h>
8 #include <arpa/inet.h>
9 #include <linux/if.h>
10 #include <linux/if_tun.h>
11
12 #include <map>
13 #include <string>
14 #include <queue>
15
16 using namespace std;
17
18 struct gre_header {
19         uint8_t reserved0_hi : 4;
20         uint8_t has_seq : 1;
21         uint8_t has_key : 1;
22         uint8_t unused : 1;
23         uint8_t has_checksum : 1;
24
25         uint8_t version : 3;
26         uint8_t reserved0_lo: 5;
27
28         uint16_t protocol_type;
29 };
30
31 int tun_open(const char *name) {
32         struct ifreq ifr;
33
34         int fd = open("/dev/net/tun", O_RDWR);
35         if (fd == -1) {
36                 perror("/dev/net/tun");
37                 exit(1);
38         }
39
40         memset(&ifr, 0, sizeof(ifr));
41         ifr.ifr_flags = IFF_TUN;
42         strncpy(ifr.ifr_name, name, IFNAMSIZ);
43
44         int err = ioctl(fd, TUNSETIFF, &ifr);
45         if (err == -1) {
46                 perror("ioctl(TUNSETIFF)");
47                 exit(-1);
48         }
49
50         return fd;
51 }
52
53 in6_addr get_addr(const char *str) {
54         in6_addr ret;
55         if (inet_pton(AF_INET6, str, &ret) != 1) {
56                 fprintf(stderr, "Could not parse %s\n", str);
57                 exit(1);
58         }
59         return ret;
60 }
61
62 struct GREPacket {
63         int seq;
64         uint16_t proto;
65         string data;
66
67         bool operator> (const GREPacket &other) const {
68                 return seq > other.seq;
69         }
70 };
71
72 class Protocol {
73 public:
74         virtual void send_packet(uint16_t proto, const string &data) = 0;
75         virtual int fd() const = 0;
76 };
77
78 class Reorderer;
79 class Protocol;
80
81 class GREProtocol : public Protocol {
82 public:
83         GREProtocol(const in6_addr &myaddr, const in6_addr &dst);
84         virtual void send_packet(uint16_t proto, const string &data);
85         virtual int fd() const;
86         void read_packet(Reorderer* sender);
87
88 private:
89         int seq;
90         int sock;
91         sockaddr_in6 dstaddr;
92 };
93
94 class TUNProtocol : public Protocol {
95 public:
96         TUNProtocol(const char *devname);
97         virtual void send_packet(uint16_t proto, const string &data);
98         virtual int fd() const;
99         void read_packet(Protocol* sender);
100
101 private:
102         int tunfd;
103 };
104
105 class Reorderer {
106 public:
107         Reorderer(Protocol* sender);
108         void handle_packet(uint16_t proto, const string& data, int seq);
109
110 private:
111         void send_packet(uint16_t proto, const string &data, bool silence);
112
113         Protocol* sender;
114         int last_seq;
115
116         priority_queue<GREPacket, vector<GREPacket>, greater<GREPacket>> packet_buffer;
117         map<int, int> ccs;
118 };
119
120 GREProtocol::GREProtocol(const in6_addr &src, const in6_addr &dst)
121         : seq(0)
122 {
123         memset(&dstaddr, 0, sizeof(dstaddr));
124         dstaddr.sin6_family = AF_INET6;
125         dstaddr.sin6_addr = dst;
126
127         sock = socket(AF_INET6, SOCK_RAW, IPPROTO_GRE);
128         if (sock == -1) {
129                 perror("socket");
130                 exit(1);
131         }
132
133         sockaddr_in6 my_addr;
134         memset(&my_addr, 0, sizeof(my_addr));
135         my_addr.sin6_family = AF_INET6;
136         my_addr.sin6_addr = src;
137         if (bind(sock, (sockaddr *)&my_addr, sizeof(my_addr)) == -1) {
138                 perror("bind");
139                 exit(1);
140         }
141 }
142
143 void GREProtocol::send_packet(uint16_t proto, const string &data)
144 {
145         char buf[4096];
146         gre_header *gre = (gre_header *)buf;
147
148         memset(gre, 0, sizeof(*gre));
149         gre->has_seq = 1;
150         gre->version = 0;
151         gre->protocol_type = htons(proto);
152
153         char *ptr = buf + sizeof(*gre);
154         int seq_be = htonl(seq++);
155         memcpy(ptr, &seq_be, sizeof(seq_be));
156         ptr += sizeof(seq_be);
157
158         memcpy(ptr, data.data(), data.size());
159         
160         if (sendto(sock, buf, data.size() + sizeof(seq_be) + sizeof(*gre), 0, (sockaddr *)&dstaddr, sizeof(dstaddr)) == -1) {
161                 perror("sendto");
162                 return;
163         }
164 }
165
166 int GREProtocol::fd() const
167 {
168         return sock;
169 }
170         
171 TUNProtocol::TUNProtocol(const char *devname)
172         : tunfd(tun_open(devname)) {
173 }
174
175 void TUNProtocol::send_packet(uint16_t proto, const string &data)
176 {
177         char buf[4096];
178
179         char *ptr = buf;
180         uint16_t flags = 0;
181         memcpy(ptr, &flags, sizeof(flags));
182         ptr += sizeof(flags);
183
184         proto = htons(proto);
185         memcpy(ptr, &proto, sizeof(proto));
186         ptr += sizeof(proto);
187
188         memcpy(ptr, data.data(), data.size());
189
190         int len = sizeof(flags) + sizeof(proto) + data.size();
191         if (write(tunfd, buf, len) != len) {
192                 perror("write");
193                 return;
194         }
195 }
196
197 int TUNProtocol::fd() const
198 {
199         return tunfd;
200 }
201
202 Reorderer::Reorderer(Protocol* sender)
203         : sender(sender), last_seq(-1)
204 {
205 }
206
207 #define PACKET_BUFFER_SIZE 100
208
209 void Reorderer::handle_packet(uint16_t proto, const string& data, int seq)
210 {
211         bool silence = false;
212         if (packet_buffer.size() >= PACKET_BUFFER_SIZE) {
213                 printf("Gave up waiting for packets [%d,%d>\n",
214                         last_seq + 1, packet_buffer.top().seq);
215                 silence = true;
216                 last_seq = packet_buffer.top().seq - 1;
217         }
218
219         GREPacket packet;
220         packet.seq = seq;
221         packet.proto = proto;
222         packet.data = data;
223         packet_buffer.push(packet);
224
225         while (!packet_buffer.empty() &&
226                (last_seq == -1 || packet_buffer.top().seq <= last_seq + 1)) {
227                 int front_seq = packet_buffer.top().seq;
228                 if (front_seq < last_seq + 1) {
229                         printf("Duplicate packet or way out-of-order: seq=%d front_seq=%d\n",
230                                 front_seq, last_seq + 1);
231                         packet_buffer.pop();
232                         continue;
233                 }
234                 //if (packet_buffer.size() > 1) {
235                 //      printf("seq=%d (REORDER %d)\n", front_seq, int(packet_buffer.size()));
236                 //} else {
237                 //      printf("seq=%d\n", front_seq);
238                 //}
239                 const string &data = packet_buffer.top().data;
240                 send_packet(packet_buffer.top().proto, data, silence);
241                 packet_buffer.pop();
242                 last_seq = front_seq;
243                 if (!silence && !packet_buffer.empty()) {
244                         printf("Reordering with packet buffer size %d: seq=%d new_front_seq=%d\n", int(packet_buffer.size()), front_seq, packet_buffer.top().seq);
245                         silence = true;
246                 }
247         }
248 }
249
250 void Reorderer::send_packet(uint16_t proto, const string &data, bool silence)
251 {
252         if (data.size() == 1344) {
253                 for (int i = 0; i < 7; ++i) {
254                         const char *pkt = &data[i * 188 + 28];
255                         int pid = (ntohl(*(uint32_t *)(pkt)) & 0x1fff00) >> 8;
256                         if (pid == 8191) {
257                                 // stuffing, ignore
258                                 continue;
259                         }
260                         int has_payload = pkt[3] & 0x10;
261                         int cc = pkt[3] & 0xf;
262                         if (has_payload) {
263                                 int last_cc = ccs[pid];
264                                 if (!silence && cc != ((last_cc + 1) & 0xf)) {
265                                         printf("Pid %d discontinuity (expected %d, got %d)\n", pid, (last_cc + 1) & 0xf, cc);
266                                 }
267                                 ccs[pid] = cc;
268                         }
269                 }
270         }
271         sender->send_packet(proto, data);
272 }
273
274 void GREProtocol::read_packet(Reorderer *sender)
275 {
276         struct sockaddr_storage addr;
277         socklen_t addrlen = sizeof(addr);
278         char buf[4096];
279         int ret = recvfrom(sock, buf, sizeof(buf), 0, (struct sockaddr *)&addr, &addrlen);
280         if (ret == -1) {
281                 perror("recvfrom");
282                 exit(1);
283         }
284         if (addr.ss_family != AF_INET6) {
285                 return;
286         }
287         struct in6_addr *addr6 = &((struct sockaddr_in6 *)&addr)->sin6_addr;
288         if (memcmp(addr6, &dstaddr.sin6_addr, sizeof(*addr6)) != 0) {
289                 // ignore
290                 return;
291         }
292         gre_header* gre = (gre_header *)buf;
293
294         char* ptr = buf + sizeof(gre_header);
295         if (gre->has_checksum) {
296                 ptr += 4;
297         }
298         if (gre->has_key) {
299                 ptr += 4;
300         }
301         uint32_t seq;
302         if (gre->has_seq) {
303                 seq = ntohl(*(uint32_t *)ptr);
304                 ptr += 4;
305         }
306
307         //printf("gre packet: proto=%x\n", ntohs(gre->protocol_type));
308
309         sender->handle_packet(ntohs(gre->protocol_type), string(ptr, buf + ret), seq);
310 }
311
312 void TUNProtocol::read_packet(Protocol *sender)
313 {
314         char buf[4096];
315         int ret = read(tunfd, buf, sizeof(buf));
316         if (ret == -1) {
317                 perror("read");
318                 exit(1);
319         }
320         if (ret == 0) {
321                 fprintf(stderr, "tunfd EOF\n");
322                 exit(1);
323         }
324         
325         char *ptr = buf;
326         uint16_t flags = *(uint16_t *)ptr;
327         ptr += 2;
328         uint16_t proto = ntohs(*(uint16_t *)ptr);
329         ptr += 2;
330         //fprintf(stderr, "tun packet: flags=%x proto=%x len=%d\n",
331         //      flags, proto, ret - 4);
332         sender->send_packet(proto, string(ptr, buf + ret));
333 }
334
335 int main(int argc, char **argv)
336 {
337         in6_addr myaddr = get_addr(argv[1]);
338         in6_addr remoteaddr = get_addr(argv[2]);
339         GREProtocol gre(myaddr, remoteaddr);
340         TUNProtocol tun("tungre");
341
342         Reorderer tun_reorderer(&tun);
343
344         fd_set fds;
345         FD_ZERO(&fds);
346         for ( ;; ) {
347                 FD_SET(gre.fd(), &fds);
348                 FD_SET(tun.fd(), &fds);
349                 int ret = select(1024, &fds, NULL, NULL, NULL);
350                 if (ret == -1) {
351                         perror("select");
352                         continue;
353                 }
354
355                 if (FD_ISSET(gre.fd(), &fds)) {
356                         gre.read_packet(&tun_reorderer);
357                 }
358                 if (FD_ISSET(tun.fd(), &fds)) {
359                         tun.read_packet(&gre);
360                 }
361         }
362 }