]> git.sesse.net Git - pistorm/blob - a314/a314.cc
Add Meson build files.
[pistorm] / a314 / a314.cc
1 /*
2  * Copyright 2020-2021 Niklas Ekström
3  * Based on a314d daemon for A314.
4  */
5
6 #include <arpa/inet.h>
7
8 #include <linux/spi/spidev.h>
9 #include <linux/types.h>
10
11 #include <netinet/in.h>
12 #include <netinet/tcp.h>
13
14 #include <sys/epoll.h>
15 #include <sys/ioctl.h>
16 #include <sys/socket.h>
17 #include <sys/stat.h>
18 #include <sys/types.h>
19
20 #include <ctype.h>
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <signal.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <time.h>
29 #include <unistd.h>
30 #include <pthread.h>
31
32 #include <algorithm>
33 #include <list>
34 #include <string>
35 #include <vector>
36
37 #include "a314.h"
38 // Silence stupid warning
39 #undef _GNU_SOURCE
40 #include "config_file/config_file.h"
41
42 extern "C" emulator_config *cfg;
43
44 #define LOGGER_TRACE    1
45 #define LOGGER_DEBUG    2
46 #define LOGGER_INFO     3
47 #define LOGGER_WARN     4
48 #define LOGGER_ERROR    5
49
50 #define LOGGER_SHOW LOGGER_INFO
51
52 #define logger_trace(...) do { if (LOGGER_TRACE >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
53 #define logger_debug(...) do { if (LOGGER_DEBUG >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
54 #define logger_info(...) do { if (LOGGER_INFO >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
55 #define logger_warn(...) do { if (LOGGER_WARN >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
56 #define logger_error(...) do { if (LOGGER_ERROR >= LOGGER_SHOW) fprintf(stderr, __VA_ARGS__); } while (0)
57
58 // Events that are communicated via IRQ from Amiga to Raspberry.
59 #define R_EVENT_A2R_TAIL        1
60 #define R_EVENT_R2A_HEAD        2
61 #define R_EVENT_STARTED         4
62
63 // Events that are communicated from Raspberry to Amiga.
64 #define A_EVENT_R2A_TAIL        1
65 #define A_EVENT_A2R_HEAD        2
66
67 // Offset relative to communication area for queue pointers.
68 #define A2R_TAIL_OFFSET         0
69 #define R2A_HEAD_OFFSET         1
70 #define R2A_TAIL_OFFSET         2
71 #define A2R_HEAD_OFFSET         3
72
73 // Packets that are communicated across physical channels (A2R and R2A).
74 #define PKT_CONNECT             4
75 #define PKT_CONNECT_RESPONSE    5
76 #define PKT_DATA                6
77 #define PKT_EOS                 7
78 #define PKT_RESET               8
79
80 // Valid responses for PKT_CONNECT_RESPONSE.
81 #define CONNECT_OK              0
82 #define CONNECT_UNKNOWN_SERVICE 3
83
84 // Messages that are communicated between driver and client.
85 #define MSG_REGISTER_REQ        1
86 #define MSG_REGISTER_RES        2
87 #define MSG_DEREGISTER_REQ      3
88 #define MSG_DEREGISTER_RES      4
89 #define MSG_READ_MEM_REQ        5
90 #define MSG_READ_MEM_RES        6
91 #define MSG_WRITE_MEM_REQ       7
92 #define MSG_WRITE_MEM_RES       8
93 #define MSG_CONNECT             9
94 #define MSG_CONNECT_RESPONSE    10
95 #define MSG_DATA                11
96 #define MSG_EOS                 12
97 #define MSG_RESET               13
98
99 #define MSG_SUCCESS             1
100 #define MSG_FAIL                0
101
102 static sigset_t original_sigset;
103
104 static pthread_t thread_id;
105 static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
106
107 static int server_socket = -1;
108
109 static int epfd = -1;
110 static int irq_fds[2];
111
112 extern "C" unsigned int ps_read_8(unsigned int address);
113 extern "C" void ps_write_8(unsigned int address, unsigned int value);
114 extern "C" void ps_write_16(unsigned int address, unsigned int value);
115
116 unsigned int a314_base;
117 int a314_base_configured;
118
119 struct ComArea
120 {
121     uint8_t a_events;
122     uint8_t a_enable;
123     uint8_t r_events;
124     uint8_t r_enable; // Unused.
125
126     uint32_t mem_base;
127     uint32_t mem_size;
128
129     uint8_t a2r_tail;
130     uint8_t r2a_head;
131     uint8_t r2a_tail;
132     uint8_t a2r_head;
133
134     uint8_t a2r_buffer[256];
135     uint8_t r2a_buffer[256];
136 };
137
138 static ComArea ca;
139
140 static bool a314_device_started = false;
141
142 static uint8_t channel_status[4];
143 static uint8_t channel_status_updated = 0;
144
145 static uint8_t recv_buf[256];
146 static uint8_t send_buf[256];
147
148 struct LogicalChannel;
149 struct ClientConnection;
150
151 #pragma pack(push, 1)
152 struct MessageHeader
153 {
154     uint32_t length;
155     uint32_t stream_id;
156     uint8_t type;
157 }; //} __attribute__((packed));
158 #pragma pack(pop)
159
160 struct MessageBuffer
161 {
162     int pos;
163     std::vector<uint8_t> data;
164 };
165
166 struct RegisteredService
167 {
168     std::string name;
169     ClientConnection *cc;
170 };
171
172 struct PacketBuffer
173 {
174     int type;
175     std::vector<uint8_t> data;
176 };
177
178 struct ClientConnection
179 {
180     int fd;
181
182     int next_stream_id;
183
184     int bytes_read;
185     MessageHeader header;
186     std::vector<uint8_t> payload;
187
188     std::list<MessageBuffer> message_queue;
189
190     std::list<LogicalChannel*> associations;
191 };
192
193 struct LogicalChannel
194 {
195     int channel_id;
196
197     ClientConnection *association;
198     int stream_id;
199
200     bool got_eos_from_ami;
201     bool got_eos_from_client;
202
203     std::list<PacketBuffer> packet_queue;
204 };
205
206 static void remove_association(LogicalChannel *ch);
207 static void clear_packet_queue(LogicalChannel *ch);
208 static void create_and_enqueue_packet(LogicalChannel *ch, uint8_t type, uint8_t *data, uint8_t length);
209
210 static std::list<ClientConnection> connections;
211 static std::list<RegisteredService> services;
212 static std::list<LogicalChannel> channels;
213 static std::list<LogicalChannel*> send_queue;
214
215 struct OnDemandStart
216 {
217     std::string service_name;
218     std::string program;
219     std::vector<std::string> arguments;
220 };
221
222 std::vector<OnDemandStart> on_demand_services;
223
224 std::string a314_config_file = "./a314/files_pi/a314d.conf";
225 std::string home_env = "HOME=./";
226
227 static void load_config_file(const char *filename)
228 {
229     FILE *f = fopen(filename, "rt");
230     if (f == nullptr) {
231         return;
232     }
233
234     char line[256];
235     std::vector<char *> parts;
236
237     while (fgets(line, 256, f) != nullptr)
238     {
239         char org_line[256];
240         strcpy(org_line, line);
241
242         bool in_quotes = false;
243
244         int start = 0;
245         for (int i = 0; i < 256; i++)
246         {
247             if (line[i] == 0)
248             {
249                 if (start < i)
250                     parts.push_back(&line[start]);
251                 break;
252             }
253             else if (line[i] == '"')
254             {
255                 line[i] = 0;
256                 if (in_quotes)
257                     parts.push_back(&line[start]);
258                 in_quotes = !in_quotes;
259                 start = i + 1;
260             }
261             else if (isspace(line[i]) && !in_quotes)
262             {
263                 line[i] = 0;
264                 if (start < i)
265                     parts.push_back(&line[start]);
266                 start = i + 1;
267             }
268         }
269
270         if (parts.size() >= 2)
271         {
272             on_demand_services.emplace_back();
273             auto &e = on_demand_services.back();
274             e.service_name = parts[0];
275             e.program = parts[1];
276             for (int i = 1; i < parts.size(); i++)
277                 e.arguments.push_back(std::string(parts[i]));
278         }
279         else if (parts.size() != 0)
280             logger_warn("Invalid number of columns in configuration file line: %s\n", org_line);
281
282         parts.clear();
283     }
284
285     fclose(f);
286
287     if (on_demand_services.empty())
288         logger_warn("No registered services\n");
289 }
290
291 static int init_server_socket()
292 {
293     server_socket = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0);
294     if (server_socket == -1)
295     {
296         logger_error("Failed to create server socket\n");
297         return -1;
298     }
299
300     struct sockaddr_in address;
301     address.sin_family = AF_INET;
302     address.sin_addr.s_addr = INADDR_ANY;
303     address.sin_port = htons(7110);
304
305     int res = bind(server_socket, (struct sockaddr *)&address, sizeof(address));
306     if (res < 0)
307     {
308         logger_error("Bind to localhost:7110 failed\n");
309         return -1;
310     }
311
312     listen(server_socket, 16);
313
314     return 0;
315 }
316
317 static void shutdown_server_socket()
318 {
319     if (server_socket != -1)
320         close(server_socket);
321     server_socket = -1;
322 }
323
324 void create_and_send_msg(ClientConnection *cc, int type, int stream_id, uint8_t *data, int length)
325 {
326     MessageBuffer mb;
327     mb.pos = 0;
328     mb.data.resize(sizeof(MessageHeader) + length);
329
330     MessageHeader *mh = (MessageHeader *)&mb.data[0];
331     mh->length = length;
332     mh->stream_id = stream_id;
333     mh->type = type;
334     if (length && data)
335         memcpy(&mb.data[sizeof(MessageHeader)], data, length);
336
337     if (!cc->message_queue.empty())
338     {
339         cc->message_queue.push_back(std::move(mb));
340         return;
341     }
342
343     while (1)
344     {
345         int left = mb.data.size() - mb.pos;
346         uint8_t *src = &mb.data[mb.pos];
347         ssize_t r = write(cc->fd, src, left);
348         if (r == -1)
349         {
350             if (errno == EAGAIN || errno == EWOULDBLOCK)
351             {
352                 cc->message_queue.push_back(std::move(mb));
353                 return;
354             }
355             else if (errno == ECONNRESET)
356             {
357                 // Do not close connection here; it will get done at some other place.
358                 return;
359             }
360             else
361             {
362                 logger_error("Write failed unexpectedly with errno = %d\n", errno);
363                 exit(-1);
364             }
365         }
366
367         mb.pos += r;
368         if (r == left)
369         {
370             return;
371         }
372     }
373 }
374
375 static void handle_msg_register_req(ClientConnection *cc)
376 {
377     uint8_t result = MSG_FAIL;
378
379     std::string service_name((char *)&cc->payload[0], cc->payload.size());
380
381     auto it = services.begin();
382     for (; it != services.end(); it++)
383         if (it->name == service_name)
384             break;
385
386     if (it == services.end())
387     {
388         services.emplace_back();
389
390         RegisteredService &srv = services.back();
391         srv.cc = cc;
392         srv.name = std::move(service_name);
393
394         result = MSG_SUCCESS;
395     }
396
397     create_and_send_msg(cc, MSG_REGISTER_RES, 0, &result, 1);
398 }
399
400 static void handle_msg_deregister_req(ClientConnection *cc)
401 {
402     uint8_t result = MSG_FAIL;
403
404     std::string service_name((char *)&cc->payload[0], cc->payload.size());
405
406     for (auto it = services.begin(); it != services.end(); it++)
407     {
408         if (it->name == service_name && it->cc == cc)
409         {
410             services.erase(it);
411             result = MSG_SUCCESS;
412             break;
413         }
414     }
415
416     create_and_send_msg(cc, MSG_DEREGISTER_RES, 0, &result, 1);
417 }
418
419 uint8_t manual_read_buf[64 * SIZE_KILO];
420
421 static void handle_msg_read_mem_req(ClientConnection *cc)
422 {
423     uint32_t address = *(uint32_t *)&(cc->payload[0]);
424     uint32_t length = *(uint32_t *)&(cc->payload[4]);
425
426     if (get_mapped_item_by_address(cfg, address) != -1) {
427         int32_t index = get_mapped_item_by_address(cfg, address);
428         uint8_t *map = &cfg->map_data[index][address - cfg->map_offset[index]];
429         create_and_send_msg(cc, MSG_READ_MEM_RES, 0, map, length);
430     } else {
431         // No idea if this actually works.
432         for (int i = 0; i < length; i++) {
433             manual_read_buf[i] = (unsigned char)ps_read_8(address + i);
434         }
435         create_and_send_msg(cc, MSG_READ_MEM_RES, 0, manual_read_buf, length);
436     }
437     
438 }
439
440 static void handle_msg_write_mem_req(ClientConnection *cc)
441 {
442     uint32_t address = *(uint32_t *)&(cc->payload[0]);
443     uint32_t length = cc->payload.size() - 4;
444
445     if (get_mapped_item_by_address(cfg, address) != -1) {
446         int32_t index = get_mapped_item_by_address(cfg, address);
447         uint8_t *map = &cfg->map_data[index][address - cfg->map_offset[index]];
448         memcpy(map, &(cc->payload[4]), length);
449     } else {
450         // No idea if this actually works.
451         for (int i = 0; i < length; i++) {
452             ps_write_8(address + i, cc->payload[4 + i]);
453         }
454     }
455
456     create_and_send_msg(cc, MSG_WRITE_MEM_RES, 0, nullptr, 0);
457 }
458
459 static LogicalChannel *get_associated_channel_by_stream_id(ClientConnection *cc, int stream_id)
460 {
461     for (auto ch : cc->associations)
462     {
463         if (ch->stream_id == stream_id)
464             return ch;
465     }
466     return nullptr;
467 }
468
469 static void handle_msg_connect(ClientConnection *cc)
470 {
471     // We currently don't handle that a client tries to connect to a service on the Amiga.
472 }
473
474 static void handle_msg_connect_response(ClientConnection *cc)
475 {
476     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
477     if (!ch)
478         return;
479
480     create_and_enqueue_packet(ch, PKT_CONNECT_RESPONSE, &cc->payload[0], cc->payload.size());
481
482     if (cc->payload[0] != CONNECT_OK)
483         remove_association(ch);
484 }
485
486 static void handle_msg_data(ClientConnection *cc)
487 {
488     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
489     if (!ch)
490         return;
491
492     create_and_enqueue_packet(ch, PKT_DATA, &cc->payload[0], cc->header.length);
493 }
494
495 static void handle_msg_eos(ClientConnection *cc)
496 {
497     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
498     if (!ch || ch->got_eos_from_client)
499         return;
500
501     ch->got_eos_from_client = true;
502
503     create_and_enqueue_packet(ch, PKT_EOS, nullptr, 0);
504
505     if (ch->got_eos_from_ami)
506         remove_association(ch);
507 }
508
509 static void handle_msg_reset(ClientConnection *cc)
510 {
511     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
512     if (!ch)
513         return;
514
515     remove_association(ch);
516
517     clear_packet_queue(ch);
518     create_and_enqueue_packet(ch, PKT_RESET, nullptr, 0);
519 }
520
521 static void handle_received_message(ClientConnection *cc)
522 {
523     switch (cc->header.type)
524     {
525     case MSG_REGISTER_REQ:
526         handle_msg_register_req(cc);
527         break;
528     case MSG_DEREGISTER_REQ:
529         handle_msg_deregister_req(cc);
530         break;
531     case MSG_READ_MEM_REQ:
532         handle_msg_read_mem_req(cc);
533         break;
534     case MSG_WRITE_MEM_REQ:
535         handle_msg_write_mem_req(cc);
536         break;
537     case MSG_CONNECT:
538         handle_msg_connect(cc);
539         break;
540     case MSG_CONNECT_RESPONSE:
541         handle_msg_connect_response(cc);
542         break;
543     case MSG_DATA:
544         handle_msg_data(cc);
545         break;
546     case MSG_EOS:
547         handle_msg_eos(cc);
548         break;
549     case MSG_RESET:
550         handle_msg_reset(cc);
551         break;
552     default:
553         // This is bad, probably should disconnect from client.
554         logger_warn("Received a message of unknown type from client\n");
555         break;
556     }
557 }
558
559 static void close_and_remove_connection(ClientConnection *cc)
560 {
561     shutdown(cc->fd, SHUT_WR);
562     close(cc->fd);
563
564     {
565         auto it = services.begin();
566         while (it != services.end())
567         {
568             if (it->cc == cc)
569                 it = services.erase(it);
570             else
571                 it++;
572         }
573     }
574
575     {
576         auto it = cc->associations.begin();
577         while (it != cc->associations.end())
578         {
579             auto ch = *it;
580
581             clear_packet_queue(ch);
582             create_and_enqueue_packet(ch, PKT_RESET, nullptr, 0);
583
584             ch->association = nullptr;
585             ch->stream_id = 0;
586
587             it = cc->associations.erase(it);
588         }
589     }
590
591     for (auto it = connections.begin(); it != connections.end(); it++)
592     {
593         if (&(*it) == cc)
594         {
595             connections.erase(it);
596             break;
597         }
598     }
599 }
600
601 static void remove_association(LogicalChannel *ch)
602 {
603     auto &ass = ch->association->associations;
604     ass.erase(std::find(ass.begin(), ass.end(), ch));
605
606     ch->association = nullptr;
607     ch->stream_id = 0;
608 }
609
610 static void clear_packet_queue(LogicalChannel *ch)
611 {
612     if (!ch->packet_queue.empty())
613     {
614         ch->packet_queue.clear();
615         send_queue.erase(std::find(send_queue.begin(), send_queue.end(), ch));
616     }
617 }
618
619 static void create_and_enqueue_packet(LogicalChannel *ch, uint8_t type, uint8_t *data, uint8_t length)
620 {
621     if (ch->packet_queue.empty())
622         send_queue.push_back(ch);
623
624     ch->packet_queue.emplace_back();
625
626     PacketBuffer &pb = ch->packet_queue.back();
627     pb.type = type;
628     pb.data.resize(length);
629     if (data && length)
630         memcpy(&pb.data[0], data, length);
631 }
632
633 static void handle_pkt_connect(int channel_id, uint8_t *data, int plen)
634 {
635     for (auto &ch : channels)
636     {
637         if (ch.channel_id == channel_id)
638         {
639             // We should handle this in some constructive way.
640             // This signals that should reset all logical channels.
641             logger_error("Received a CONNECT packet on a channel that was believed to be previously allocated\n");
642             exit(-1);
643         }
644     }
645
646     channels.emplace_back();
647
648     auto &ch = channels.back();
649
650     ch.channel_id = channel_id;
651     ch.association = nullptr;
652     ch.stream_id = 0;
653     ch.got_eos_from_ami = false;
654     ch.got_eos_from_client = false;
655
656     std::string service_name((char *)data, plen);
657
658     for (auto &srv : services)
659     {
660         if (srv.name == service_name)
661         {
662             ClientConnection *cc = srv.cc;
663
664             ch.association = cc;
665             ch.stream_id = cc->next_stream_id;
666
667             cc->next_stream_id += 2;
668             cc->associations.push_back(&ch);
669
670             create_and_send_msg(ch.association, MSG_CONNECT, ch.stream_id, data, plen);
671             return;
672         }
673     }
674
675     for (auto &on_demand : on_demand_services)
676     {
677         if (on_demand.service_name == service_name)
678         {
679             int fds[2];
680             int status = socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
681             if (status != 0)
682             {
683                 logger_error("Unexpectedly not able to create socket pair.\n");
684                 exit(-1);
685             }
686
687             pid_t child = fork();
688             if (child == -1)
689             {
690                 logger_error("Unexpectedly was not able to fork.\n");
691                 exit(-1);
692             }
693             else if (child == 0)
694             {
695                 close(fds[0]);
696                 int fd = fds[1];
697
698                 // FIXE: The user should be configurable.
699                 setgid(1000);
700                 setuid(1000);
701                 putenv((char *)home_env.c_str());
702
703                 std::vector<std::string> args(on_demand.arguments);
704                 args.push_back("-ondemand");
705                 args.push_back(std::to_string(fd));
706                 std::vector<const char *> args_arr;
707                 for (auto &arg : args)
708                     args_arr.push_back(arg.c_str());
709                 args_arr.push_back(nullptr);
710
711                 execvp(on_demand.program.c_str(), (char* const*) &args_arr[0]);
712             }
713             else
714             {
715                 close(fds[1]);
716                 int fd = fds[0];
717
718                 int status = fcntl(fd, F_SETFD, fcntl(fd, F_GETFD, 0) | FD_CLOEXEC);
719                 if (status == -1)
720                 {
721                     logger_error("Unexpectedly unable to set close-on-exec flag on client socket descriptor; errno = %d\n", errno);
722                     exit(-1);
723                 }
724
725                 status = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
726                 if (status == -1)
727                 {
728                     logger_error("Unexpectedly unable to set client socket to non blocking; errno = %d\n", errno);
729                     exit(-1);
730                 }
731
732                 int flag = 1;
733                 setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
734
735                 connections.emplace_back();
736
737                 ClientConnection &cc = connections.back();
738                 cc.fd = fd;
739                 cc.next_stream_id = 1;
740                 cc.bytes_read = 0;
741
742                 struct epoll_event ev;
743                 ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
744                 ev.data.fd = fd;
745                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev) != 0)
746                 {
747                     logger_error("epoll_ctl() failed unexpectedly with errno = %d\n", errno);
748                     exit(-1);
749                 }
750
751                 services.emplace_back();
752
753                 RegisteredService &srv = services.back();
754                 srv.cc = &cc;
755                 srv.name = std::move(service_name);
756
757                 ch.association = &cc;
758                 ch.stream_id = cc.next_stream_id;
759
760                 cc.next_stream_id += 2;
761                 cc.associations.push_back(&ch);
762
763                 create_and_send_msg(ch.association, MSG_CONNECT, ch.stream_id, data, plen);
764                 return;
765             }
766         }
767     }
768
769     uint8_t response = CONNECT_UNKNOWN_SERVICE;
770     create_and_enqueue_packet(&ch, PKT_CONNECT_RESPONSE, &response, 1);
771 }
772
773 static void handle_pkt_data(int channel_id, uint8_t *data, int plen)
774 {
775     for (auto &ch : channels)
776     {
777         if (ch.channel_id == channel_id)
778         {
779             if (ch.association != nullptr && !ch.got_eos_from_ami)
780                 create_and_send_msg(ch.association, MSG_DATA, ch.stream_id, data, plen);
781
782             break;
783         }
784     }
785 }
786
787 static void handle_pkt_eos(int channel_id)
788 {
789     for (auto &ch : channels)
790     {
791         if (ch.channel_id == channel_id)
792         {
793             if (ch.association != nullptr && !ch.got_eos_from_ami)
794             {
795                 ch.got_eos_from_ami = true;
796
797                 create_and_send_msg(ch.association, MSG_EOS, ch.stream_id, nullptr, 0);
798
799                 if (ch.got_eos_from_client)
800                     remove_association(&ch);
801             }
802             break;
803         }
804     }
805 }
806
807 static void handle_pkt_reset(int channel_id)
808 {
809     for (auto &ch : channels)
810     {
811         if (ch.channel_id == channel_id)
812         {
813             clear_packet_queue(&ch);
814
815             if (ch.association != nullptr)
816             {
817                 create_and_send_msg(ch.association, MSG_RESET, ch.stream_id, nullptr, 0);
818                 remove_association(&ch);
819             }
820
821             break;
822         }
823     }
824 }
825
826 static void remove_channel_if_not_associated_and_empty_pq(int channel_id)
827 {
828     for (auto it = channels.begin(); it != channels.end(); it++)
829     {
830         if (it->channel_id == channel_id)
831         {
832             if (it->association == nullptr && it->packet_queue.empty())
833                 channels.erase(it);
834
835             break;
836         }
837     }
838 }
839
840 static void handle_received_pkt(int ptype, int channel_id, uint8_t *data, int plen)
841 {
842     if (ptype == PKT_CONNECT)
843         handle_pkt_connect(channel_id, data, plen);
844     else if (ptype == PKT_DATA)
845         handle_pkt_data(channel_id, data, plen);
846     else if (ptype == PKT_EOS)
847         handle_pkt_eos(channel_id);
848     else if (ptype == PKT_RESET)
849         handle_pkt_reset(channel_id);
850
851     remove_channel_if_not_associated_and_empty_pq(channel_id);
852 }
853
854 static bool receive_from_a2r()
855 {
856     int head = channel_status[A2R_HEAD_OFFSET];
857     int tail = channel_status[A2R_TAIL_OFFSET];
858     int len = (tail - head) & 255;
859     if (len == 0)
860         return false;
861
862     if (head < tail)
863     {
864         memcpy(recv_buf, &ca.a2r_buffer[head], len);
865     }
866     else
867     {
868         memcpy(recv_buf, &ca.a2r_buffer[head], 256 - head);
869
870         if (tail != 0)
871         {
872             memcpy(&recv_buf[len - tail], &ca.a2r_buffer[0], tail);
873         }
874     }
875
876     uint8_t *p = recv_buf;
877     while (p < recv_buf + len)
878     {
879         uint8_t plen = *p++;
880         uint8_t ptype = *p++;
881         uint8_t channel_id = *p++;
882         handle_received_pkt(ptype, channel_id, p, plen);
883         p += plen;
884     }
885
886     channel_status[A2R_HEAD_OFFSET] = channel_status[A2R_TAIL_OFFSET];
887     channel_status_updated |= A_EVENT_A2R_HEAD;
888     return true;
889 }
890
891 static bool flush_send_queue()
892 {
893     int tail = channel_status[R2A_TAIL_OFFSET];
894     int head = channel_status[R2A_HEAD_OFFSET];
895     int len = (tail - head) & 255;
896     int left = 255 - len;
897
898     int pos = 0;
899
900     while (!send_queue.empty())
901     {
902         LogicalChannel *ch = send_queue.front();
903         PacketBuffer &pb = ch->packet_queue.front();
904
905         int ptype = pb.type;
906         int plen = 3 + pb.data.size();
907
908         if (left < plen)
909             break;
910
911         send_buf[pos++] = pb.data.size();
912         send_buf[pos++] = ptype;
913         send_buf[pos++] = ch->channel_id;
914         memcpy(&send_buf[pos], &pb.data[0], pb.data.size());
915         pos += pb.data.size();
916
917         ch->packet_queue.pop_front();
918
919         send_queue.pop_front();
920
921         if (!ch->packet_queue.empty())
922             send_queue.push_back(ch);
923         else
924             remove_channel_if_not_associated_and_empty_pq(ch->channel_id);
925
926         left -= plen;
927     }
928
929     int to_write = pos;
930     if (!to_write)
931         return false;
932
933     uint8_t *p = send_buf;
934     int at_end = 256 - tail;
935     if (at_end < to_write)
936     {
937         memcpy(&ca.r2a_buffer[tail], p, at_end);
938         p += at_end;
939         to_write -= at_end;
940         tail = 0;
941     }
942
943     memcpy(&ca.r2a_buffer[tail], p, to_write);
944     tail = (tail + to_write) & 255;
945
946     channel_status[R2A_TAIL_OFFSET] = tail;
947     channel_status_updated |= A_EVENT_R2A_TAIL;
948     return true;
949 }
950
951 static void read_channel_status()
952 {
953     channel_status[A2R_TAIL_OFFSET] = ca.a2r_tail;
954     channel_status[R2A_HEAD_OFFSET] = ca.r2a_head;
955     channel_status[R2A_TAIL_OFFSET] = ca.r2a_tail;
956     channel_status[A2R_HEAD_OFFSET] = ca.a2r_head;
957     channel_status_updated = 0;
958 }
959
960 static void write_channel_status()
961 {
962     if (channel_status_updated != 0)
963     {
964         ca.r2a_tail = channel_status[R2A_TAIL_OFFSET];
965         ca.a2r_head = channel_status[A2R_HEAD_OFFSET];
966
967         pthread_mutex_lock(&mutex);
968         ca.a_events |= channel_status_updated;
969         pthread_mutex_unlock(&mutex);
970
971         channel_status_updated = 0;
972     }
973 }
974
975 static void close_all_logical_channels()
976 {
977     send_queue.clear();
978
979     auto it = channels.begin();
980     while (it != channels.end())
981     {
982         LogicalChannel &ch = *it;
983
984         if (ch.association != nullptr)
985         {
986             create_and_send_msg(ch.association, MSG_RESET, ch.stream_id, nullptr, 0);
987             remove_association(&ch);
988         }
989
990         it = channels.erase(it);
991     }
992 }
993
994 static void handle_a314_irq(uint8_t events)
995 {
996     if (events == 0)
997         return;
998
999     if (events & R_EVENT_STARTED)
1000     {
1001         if (!channels.empty())
1002             logger_info("Received STARTED event while logical channels are open -- closing channels\n");
1003
1004         close_all_logical_channels();
1005         a314_device_started = true;
1006     }
1007
1008     if (!a314_device_started)
1009         return;
1010
1011     read_channel_status();
1012
1013     bool any_rcvd = receive_from_a2r();
1014     bool any_sent = flush_send_queue();
1015
1016     if (any_rcvd || any_sent)
1017         write_channel_status();
1018 }
1019
1020 static void handle_client_connection_event(ClientConnection *cc, struct epoll_event *ev)
1021 {
1022     if (ev->events & EPOLLERR)
1023     {
1024         logger_warn("Received EPOLLERR for client connection\n");
1025         close_and_remove_connection(cc);
1026         return;
1027     }
1028
1029     if (ev->events & EPOLLIN)
1030     {
1031         while (1)
1032         {
1033             int left;
1034             uint8_t *dst;
1035
1036             if (cc->payload.empty())
1037             {
1038                 left = sizeof(MessageHeader) - cc->bytes_read;
1039                 dst = (uint8_t *)&(cc->header) + cc->bytes_read;
1040             }
1041             else
1042             {
1043                 left = cc->header.length - cc->bytes_read;
1044                 dst = &cc->payload[cc->bytes_read];
1045             }
1046
1047             ssize_t r = read(cc->fd, dst, left);
1048             if (r == -1)
1049             {
1050                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1051                     break;
1052
1053                 logger_error("Read failed unexpectedly with errno = %d\n", errno);
1054                 exit(-1);
1055             }
1056
1057             if (r == 0)
1058             {
1059                 logger_info("Received End-of-File on client connection\n");
1060                 close_and_remove_connection(cc);
1061                 return;
1062             }
1063             else
1064             {
1065                 cc->bytes_read += r;
1066                 left -= r;
1067                 if (!left)
1068                 {
1069                     if (cc->payload.empty())
1070                     {
1071                         if (cc->header.length == 0)
1072                         {
1073                             logger_trace("header: length=%d, stream_id=%d, type=%d\n", cc->header.length, cc->header.stream_id, cc->header.type);
1074                             handle_received_message(cc);
1075                         }
1076                         else
1077                         {
1078                             cc->payload.resize(cc->header.length);
1079                         }
1080                     }
1081                     else
1082                     {
1083                         logger_trace("header: length=%d, stream_id=%d, type=%d\n", cc->header.length, cc->header.stream_id, cc->header.type);
1084                         handle_received_message(cc);
1085                         cc->payload.clear();
1086                     }
1087                     cc->bytes_read = 0;
1088                 }
1089             }
1090         }
1091     }
1092
1093     if (ev->events & EPOLLOUT)
1094     {
1095         while (!cc->message_queue.empty())
1096         {
1097             MessageBuffer &mb = cc->message_queue.front();
1098
1099             int left = mb.data.size() - mb.pos;
1100             uint8_t *src = &mb.data[mb.pos];
1101             ssize_t r = write(cc->fd, src, left);
1102             if (r == -1)
1103             {
1104                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1105                     break;
1106                 else if (errno == ECONNRESET)
1107                 {
1108                     close_and_remove_connection(cc);
1109                     return;
1110                 }
1111                 else
1112                 {
1113                     logger_error("Write failed unexpectedly with errno = %d\n", errno);
1114                     exit(-1);
1115                 }
1116             }
1117
1118             mb.pos += r;
1119             if (r == left)
1120                 cc->message_queue.pop_front();
1121         }
1122     }
1123 }
1124
1125 static void handle_server_socket_ready()
1126 {
1127     struct sockaddr_in address;
1128     int alen = sizeof(struct sockaddr_in);
1129
1130     int fd = accept(server_socket, (struct sockaddr *)&address, (socklen_t *)&alen);
1131     if (fd < 0)
1132     {
1133         logger_error("Accept failed unexpectedly with errno = %d\n", errno);
1134         exit(-1);
1135     }
1136
1137     int status = fcntl(fd, F_SETFD, fcntl(fd, F_GETFD, 0) | FD_CLOEXEC);
1138     if (status == -1)
1139     {
1140         logger_error("Unexpectedly unable to set close-on-exec flag on client socket descriptor; errno = %d\n", errno);
1141         exit(-1);
1142     }
1143
1144     status = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
1145     if (status == -1)
1146     {
1147         logger_error("Unexpectedly unable to set client socket to non blocking; errno = %d\n", errno);
1148         exit(-1);
1149     }
1150
1151     int flag = 1;
1152     setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
1153
1154     connections.emplace_back();
1155
1156     ClientConnection &cc = connections.back();
1157     cc.fd = fd;
1158     cc.next_stream_id = 1;
1159     cc.bytes_read = 0;
1160
1161     struct epoll_event ev;
1162     ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
1163     ev.data.fd = fd;
1164     if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev) != 0)
1165     {
1166         logger_error("epoll_ctl() failed unexpectedly with errno = %d\n", errno);
1167         exit(-1);
1168     }
1169 }
1170
1171 static void main_loop()
1172 {
1173     bool shutting_down = false;
1174     bool done = false;
1175
1176     while (!done)
1177     {
1178         struct epoll_event ev;
1179         int timeout = shutting_down ? 10000 : -1;
1180         int n = epoll_pwait(epfd, &ev, 1, timeout, &original_sigset);
1181         if (n == -1)
1182         {
1183             if (errno == EINTR)
1184             {
1185                 logger_info("Received SIGTERM\n");
1186
1187                 shutdown_server_socket();
1188
1189                 while (!connections.empty())
1190                     close_and_remove_connection(&connections.front());
1191
1192                 if (flush_send_queue())
1193                     write_channel_status();
1194
1195                 if (!channels.empty())
1196                     shutting_down = true;
1197                 else
1198                     done = true;
1199             }
1200             else
1201             {
1202                 logger_error("epoll_pwait failed with unexpected errno = %d\n", errno);
1203                 exit(-1);
1204             }
1205         }
1206         else if (n == 0)
1207         {
1208             if (shutting_down)
1209                 done = true;
1210             else
1211             {
1212                 logger_error("epoll_pwait returned 0 which is unexpected since no timeout was set\n");
1213                 exit(-1);
1214             }
1215         }
1216         else
1217         {
1218             if (ev.data.fd == irq_fds[1])
1219             {
1220                 uint8_t events;
1221                 if (read(irq_fds[1], &events, 1) != 1)
1222                 {
1223                     logger_error("Read from interrupt socket pair, and unexpectedly didn't return 1 byte\n");
1224                     exit(-1);
1225                 }
1226
1227                 handle_a314_irq(events);
1228             }
1229             else if (ev.data.fd == server_socket)
1230             {
1231                 logger_trace("Epoll event: server socket is ready, events = %d\n", ev.events);
1232                 handle_server_socket_ready();
1233             }
1234             else
1235             {
1236                 logger_trace("Epoll event: client socket is ready, events = %d\n", ev.events);
1237
1238                 auto it = connections.begin();
1239                 for (; it != connections.end(); it++)
1240                 {
1241                     if (it->fd == ev.data.fd)
1242                         break;
1243                 }
1244
1245                 if (it == connections.end())
1246                 {
1247                     logger_error("Got notified about an event on a client connection that supposedly isn't currently open\n");
1248                     exit(-1);
1249                 }
1250
1251                 ClientConnection *cc = &(*it);
1252                 handle_client_connection_event(cc, &ev);
1253
1254                 if (flush_send_queue())
1255                     write_channel_status();
1256             }
1257         }
1258     }
1259 }
1260
1261 static int init_driver()
1262 {
1263     if (init_server_socket() != 0)
1264         return -1;
1265
1266     int err = socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0, irq_fds);
1267     if (err != 0)
1268     {
1269         logger_error("Unable to create socket pair, errno = %d\n", errno);
1270         return -1;
1271     }
1272
1273     epfd = epoll_create1(EPOLL_CLOEXEC);
1274     if (epfd == -1)
1275         return -1;
1276
1277     struct epoll_event ev;
1278     ev.events = EPOLLIN;
1279     ev.data.fd = irq_fds[1];
1280     if (epoll_ctl(epfd, EPOLL_CTL_ADD, irq_fds[1], &ev) != 0)
1281         return -1;
1282
1283     ev.events = EPOLLIN;
1284     ev.data.fd = server_socket;
1285     if (epoll_ctl(epfd, EPOLL_CTL_ADD, server_socket, &ev) != 0)
1286         return -1;
1287
1288     return 0;
1289 }
1290
1291 static void shutdown_driver()
1292 {
1293     if (epfd != -1)
1294         close(epfd);
1295
1296     shutdown_server_socket();
1297 }
1298
1299 static void *thread_start(void *arg)
1300 {
1301     main_loop();
1302     shutdown_driver();
1303     return NULL;
1304 }
1305
1306 static void write_r_events(uint8_t events)
1307 {
1308     if (write(irq_fds[0], &events, 1) != 1)
1309         logger_error("Write to interrupt socket pair did not return 1\n");
1310 }
1311
1312 int a314_init()
1313 {
1314     load_config_file(a314_config_file.c_str());
1315
1316     int err = init_driver();
1317     if (err < 0)
1318     {
1319         shutdown_driver();
1320         return -1;
1321     }
1322
1323     err = pthread_create(&thread_id, NULL, thread_start, NULL);
1324     if (err < 0)
1325     {
1326         logger_error("pthread_create failed with err = %d\n", err);
1327         return -2;
1328     }
1329
1330     return 0;
1331 }
1332
1333 void a314_set_mem_base_size(unsigned int base, unsigned int size)
1334 {
1335     ca.mem_base = htobe32(base);
1336     ca.mem_size = htobe32(size);
1337 }
1338
1339 void a314_process_events()
1340 {
1341     if (ca.a_events & ca.a_enable)
1342     {
1343         ps_write_16(0xdff09c, 0x8008);
1344         m68k_set_irq(2);
1345     }
1346 }
1347
1348 unsigned int a314_read_memory_8(unsigned int address)
1349 {
1350     if (address >= sizeof(ca))
1351         return 0;
1352
1353     uint8_t val;
1354     if (address == offsetof(ComArea, a_events))
1355     {
1356         pthread_mutex_lock(&mutex);
1357         val = ca.a_events;
1358         ca.a_events = 0;
1359         pthread_mutex_unlock(&mutex);
1360     }
1361     else
1362     {
1363         uint8_t *p = (uint8_t *)&ca;
1364         val = p[address];
1365     }
1366
1367     return val;
1368 }
1369
1370 unsigned int a314_read_memory_16(unsigned int address)
1371 {
1372     if (address >= sizeof(ca))
1373         return 0;
1374
1375     uint16_t *p = (uint16_t *)&ca;
1376     return be16toh(p[address >> 1]);
1377 }
1378
1379 unsigned int a314_read_memory_32(unsigned int address)
1380 {
1381     if (address >= sizeof(ca))
1382         return 0;
1383
1384     uint32_t *p = (uint32_t *)&ca;
1385     return be32toh(p[address >> 2]);
1386 }
1387
1388 void a314_write_memory_8(unsigned int address, unsigned int value)
1389 {
1390     if (address >= sizeof(ca))
1391         return;
1392
1393     switch (address)
1394     {
1395         case offsetof(ComArea, a_events):
1396             // a_events is not writable.
1397             break;
1398
1399         case offsetof(ComArea, r_events):
1400             if (value != 0)
1401                 write_r_events((uint8_t)value);
1402             break;
1403
1404         default:
1405         {
1406             uint8_t *p = (uint8_t *)&ca;
1407             p[address] = (uint8_t)value;
1408             break;
1409         }
1410     }
1411 }
1412
1413 void a314_write_memory_16(unsigned int address, unsigned int value)
1414 {
1415     // Not implemented.
1416 }
1417
1418 void a314_write_memory_32(unsigned int address, unsigned int value)
1419 {
1420     // Not implemented.
1421 }
1422
1423 void a314_set_config_file(char *filename)
1424 {
1425     printf ("[A314] Set A314 config filename to %s.\n", filename);
1426     a314_config_file = std::string(filename);
1427 }