+#include <stdio.h>
+#include <string.h>
+#include <arpa/inet.h>
+
+#include <map>
+#include "rsdecoder.h"
+#include "rs_parm.h"
+
+extern "C" {
+#include <fec.h>
+}
+
+#define RS_GROUP_HISTORY 3
+
+using namespace std;
+
+RSDecoder::RSDecoder(Sender *sender)
+ : sender(sender) {}
+
+void RSDecoder::send_packet(uint16_t proto, const std::string &data, int incoming_seq)
+{
+ int rs_group;
+ if (proto == 0xffff) {
+ // RS packet
+ rs_group = (incoming_seq + RS_PAYLOAD_SIZE - 1) / RS_PAYLOAD_SIZE;
+ } else {
+ // Regular packet
+ rs_group = incoming_seq / RS_PAYLOAD_SIZE;
+ }
+
+ if (rs_groups.size() >= RS_GROUP_HISTORY &&
+ rs_group < rs_groups.begin()->first) {
+ // Older than the oldest group.
+ return;
+ }
+
+ auto group_it = rs_groups.find(rs_group);
+ if (group_it == rs_groups.end()) {
+ RSGroup group;
+ group.done = false;
+ group_it = rs_groups.insert(make_pair(rs_group, group)).first;
+ }
+
+ RSGroup &group = group_it->second;
+ if (group.done) {
+ // This RS group was already sent.
+ return;
+ }
+ if (group.packets.count(incoming_seq)) {
+ // Already seen this packet.
+ return;
+ }
+
+ if (proto != 0xffff) {
+ sender->send_packet(proto, data, incoming_seq);
+ }
+
+ GREPacket packet;
+ packet.seq = incoming_seq;
+ packet.proto = proto;
+ packet.data = data;
+ // Don't care about ts.
+
+ group.packets.insert(make_pair(incoming_seq, packet));
+ if (group.packets.size() >= RS_PAYLOAD_SIZE) {
+ // Enough to reconstruct all missing packets.
+
+ // Reconstruction always happens on the longest packet;
+ // we will truncate them later.
+ int max_length = 0;
+ int num_regular = 0;
+ for (const auto &it : group.packets) {
+ if (it.first >= rs_group * RS_PAYLOAD_SIZE) {
+ // Regular packet.
+ max_length = max<int>(max_length, it.second.data.size() + 4);
+ ++num_regular;
+ } else {
+ // RS packet.
+ max_length = max<int>(max_length, it.second.data.size());
+ }
+ }
+
+ if (num_regular < RS_PAYLOAD_SIZE) {
+ // Piece the data back into the different RS groups.
+ vector<string> padded_packets;
+ vector<int> missing_packets;
+ for (int i = 0; i < RS_GROUP_SIZE; ++i) {
+ int packet_num = (i < RS_PAYLOAD_SIZE) ? rs_group * RS_PAYLOAD_SIZE + i :
+ rs_group * RS_PAYLOAD_SIZE - 1 - (i - RS_PAYLOAD_SIZE);
+ string p;
+ p.resize(max_length);
+ const auto it = group.packets.find(packet_num);
+ if (it == group.packets.end()) {
+ missing_packets.push_back(i);
+ } else {
+ const GREPacket &packet = it->second;
+ uint16_t proto_be = htons(packet.proto);
+ memcpy(&p[0], &proto_be, sizeof(uint16_t));
+ uint16_t len_be = htons(packet.data.size());
+ memcpy(&p[2], &len_be, sizeof(uint16_t));
+ memcpy(&p[4], packet.data.data(), packet.data.size());
+ }
+ padded_packets.push_back(p);
+ }
+
+ // Now reconstruct the missing pieces.
+ unsigned char ch[RS_GROUP_SIZE];
+ for (int i = 0; i < max_length; ++i) {
+ for (int j = 0; j < RS_GROUP_SIZE; ++j) {
+ ch[j] = padded_packets[j][i];
+ }
+ int ret = decode_rs_8(ch, &missing_packets[0], missing_packets.size(),
+ RS_PAD);
+ if (ret == -1) {
+ printf("Failed reconstruction!\n");
+ // We might get more data later, so don't remove it.
+ return;
+ }
+ for (int j = 0; j < RS_GROUP_SIZE; ++j) {
+ padded_packets[j][i] = ch[j];
+ }
+ }
+
+ // Output all packets we didn't have before. They will come
+ // out-of-order, which will be the job of the Reorderer to fix.
+ for (int i = 0; i < RS_PAYLOAD_SIZE; ++i) {
+ int packet_num = rs_group * RS_PAYLOAD_SIZE + i;
+ if (group.packets.count(packet_num) != 0) {
+ // Already had this packet.
+ continue;
+ }
+ const string &p = padded_packets[i];
+ uint16_t proto_be, len_be;
+ memcpy(&proto_be, &p[0], sizeof(uint16_t));
+ memcpy(&len_be, &p[2], sizeof(uint16_t));
+ string s(&p[4], &p[4 + ntohs(len_be)]); // TODO: security
+ sender->send_packet(ntohs(proto_be), s, packet_num);
+ printf("Reconstructed packet %d\n", packet_num);
+ }
+ }
+
+ group.done = true;
+ }
+
+ if (rs_groups.size() > RS_GROUP_HISTORY) {
+ const auto &it = rs_groups.begin();
+ if (!it->second.done) {
+ printf("Giving up reconstruction within group [%d,%d> (only got %d/%d packets, needed %d)\n",
+ it->first * RS_PAYLOAD_SIZE,
+ (it->first + 1) * RS_PAYLOAD_SIZE,
+ int(it->second.packets.size()),
+ RS_GROUP_SIZE,
+ RS_PAYLOAD_SIZE);
+ }
+ rs_groups.erase(it);
+ }
+}
+