]> git.sesse.net Git - greproxy/blob - rsdecoder.cpp
Fix another crippling RS decoding bug.
[greproxy] / rsdecoder.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <arpa/inet.h>
4
5 #include <map>
6 #include "rsdecoder.h"
7 #include "rs_parm.h"
8
9 extern "C" {
10 #include <fec.h>
11 }
12
13 #define RS_GROUP_HISTORY 3
14
15 using namespace std;
16
17 RSDecoder::RSDecoder(Sender *sender)
18         : sender(sender)
19 {
20         rs = init_rs_char(RS_SYM_SIZE, RS_GF_POLY, 1, 1, RS_PARITY_SIZE, RS_PAD);
21 }
22
23 void RSDecoder::send_packet(uint16_t proto, const std::string &data, int incoming_seq)
24 {
25         int rs_group;
26         if (proto == 0xffff) {
27                 // RS packet
28                 rs_group = (incoming_seq + RS_PAYLOAD_SIZE - 1) / RS_PAYLOAD_SIZE;
29         } else {
30                 // Regular packet
31                 rs_group = incoming_seq / RS_PAYLOAD_SIZE;
32         }
33
34         if (rs_groups.size() >= RS_GROUP_HISTORY &&
35             rs_group < rs_groups.begin()->first) {
36                 // Older than the oldest group.
37                 return;
38         }
39
40         auto group_it = rs_groups.find(rs_group);
41         if (group_it == rs_groups.end()) {
42                 RSGroup group;
43                 group.done = false;
44                 group_it = rs_groups.insert(make_pair(rs_group, group)).first;
45         }
46
47         RSGroup &group = group_it->second;
48         if (group.done) {
49                 // This RS group was already sent.
50                 return;
51         }
52         if (group.packets.count(incoming_seq)) {
53                 // Already seen this packet.
54                 return;
55         }
56
57         if (proto != 0xffff) {
58                 sender->send_packet(proto, data, incoming_seq);
59         }
60
61         GREPacket packet;
62         packet.seq = incoming_seq;
63         packet.proto = proto;
64         packet.data = data;
65         // Don't care about ts.
66
67         group.packets.insert(make_pair(incoming_seq, packet));
68         if (group.packets.size() >= RS_PAYLOAD_SIZE) {
69                 // Enough to reconstruct all missing packets.
70
71                 // Reconstruction always happens on the longest packet;
72                 // we will truncate them later.
73                 int max_length = 0;
74                 int num_regular = 0;
75                 for (const auto &it : group.packets) {
76                         if (it.first >= rs_group * RS_PAYLOAD_SIZE) {
77                                 // Regular packet.
78                                 max_length = max<int>(max_length, it.second.data.size() + 4);
79                                 ++num_regular;
80                         } else {
81                                 // RS packet.
82                                 max_length = max<int>(max_length, it.second.data.size());
83                         }
84                 }
85
86                 if (num_regular < RS_PAYLOAD_SIZE) {
87                         // Piece the data back into the different RS groups.
88                         vector<string> padded_packets;
89                         vector<int> missing_packets;
90                         for (int i = 0; i < RS_GROUP_SIZE; ++i) {
91                                 int packet_num = (i < RS_PAYLOAD_SIZE) ? rs_group * RS_PAYLOAD_SIZE + i :
92                                         rs_group * RS_PAYLOAD_SIZE - 1 - (i - RS_PAYLOAD_SIZE);
93                                 string p;
94                                 p.resize(max_length);
95                                 const auto it = group.packets.find(packet_num);
96                                 if (it == group.packets.end()) {
97                                         missing_packets.push_back(i);
98                                 } else if (i < RS_PAYLOAD_SIZE) {
99                                         // Regular packet.
100                                         const GREPacket &packet = it->second;
101                                         uint16_t proto_be = htons(packet.proto);
102                                         memcpy(&p[0], &proto_be, sizeof(uint16_t));
103                                         uint16_t len_be = htons(packet.data.size());
104                                         memcpy(&p[2], &len_be, sizeof(uint16_t));
105                                         memcpy(&p[4], packet.data.data(), packet.data.size());
106                                 } else {
107                                         // RS packet.
108                                         const GREPacket &packet = it->second;
109                                         memcpy(&p[0], packet.data.data(), packet.data.size());
110                                 }
111                                 padded_packets.push_back(p);
112                         }
113
114                         // Now reconstruct the missing pieces.
115                         unsigned char ch[RS_GROUP_SIZE];
116                         for (int i = 0; i < max_length; ++i) {
117                                 for (int j = 0; j < RS_GROUP_SIZE; ++j) {
118                                         ch[j] = padded_packets[j][i];
119                                 }
120                                 int ret = decode_rs_char(rs, ch, &missing_packets[0], missing_packets.size());
121                                 if (ret == -1) {
122                                         printf("Failed reconstruction!\n");
123                                         // We might get more data later, so don't remove it.
124                                         return;
125                                 }
126                                 for (int j = 0; j < RS_GROUP_SIZE; ++j) {
127                                         padded_packets[j][i] = ch[j];
128                                 }
129                         }
130
131                         // Output all packets we didn't have before. They will come
132                         // out-of-order, which will be the job of the Reorderer to fix.
133                         for (int i = 0; i < RS_PAYLOAD_SIZE; ++i) {
134                                 int packet_num = rs_group * RS_PAYLOAD_SIZE + i;
135                                 if (group.packets.count(packet_num) != 0) {
136                                         // Already had this packet.
137                                         continue;
138                                 }
139                                 const string &p = padded_packets[i];
140                                 uint16_t proto_be, len_be;
141                                 memcpy(&proto_be, &p[0], sizeof(uint16_t));
142                                 memcpy(&len_be, &p[2], sizeof(uint16_t));
143                                 string s(&p[4], &p[4 + ntohs(len_be)]);  // TODO: security
144                                 sender->send_packet(ntohs(proto_be), s, packet_num);
145                                 printf("Reconstructed packet %d\n", packet_num);
146                         }
147                 }
148                 
149                 group.done = true;
150         }
151
152         if (rs_groups.size() > RS_GROUP_HISTORY) {
153                 const auto &it = rs_groups.begin();
154                 if (!it->second.done) {
155                         printf("Giving up reconstruction within group [%d,%d> (only got %d/%d packets, needed %d)\n",
156                                it->first * RS_PAYLOAD_SIZE,
157                                (it->first + 1) * RS_PAYLOAD_SIZE,
158                                int(it->second.packets.size()),
159                                RS_GROUP_SIZE,
160                                RS_PAYLOAD_SIZE);
161                 }
162                 rs_groups.erase(it);
163         }
164 }
165