]> git.sesse.net Git - pistorm/blob - a314/a314.cc
Maybe make A314 emulation launch Python scripts automatically
[pistorm] / a314 / a314.cc
1 /*
2  * Copyright 2020-2021 Niklas Ekström
3  * Based on a314d daemon for A314.
4  */
5
6 #include <arpa/inet.h>
7
8 #include <linux/spi/spidev.h>
9 #include <linux/types.h>
10
11 #include <netinet/in.h>
12 #include <netinet/tcp.h>
13
14 #include <sys/epoll.h>
15 #include <sys/ioctl.h>
16 #include <sys/socket.h>
17 #include <sys/stat.h>
18 #include <sys/types.h>
19
20 #include <ctype.h>
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <signal.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <time.h>
29 #include <unistd.h>
30 #include <pthread.h>
31
32 #include <algorithm>
33 #include <list>
34 #include <string>
35 #include <vector>
36
37 #include "a314.h"
38 // Silence stupid warning
39 #undef _GNU_SOURCE
40 #include "config_file/config_file.h"
41
42 extern "C" emulator_config *cfg;
43
44 #define LOGGER_TRACE    1
45 #define LOGGER_DEBUG    2
46 #define LOGGER_INFO     3
47 #define LOGGER_WARN     4
48 #define LOGGER_ERROR    5
49
50 #define LOGGER_SHOW LOGGER_INFO
51
52 #define logger_trace(...) do { if (LOGGER_TRACE >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
53 #define logger_debug(...) do { if (LOGGER_DEBUG >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
54 #define logger_info(...) do { if (LOGGER_INFO >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
55 #define logger_warn(...) do { if (LOGGER_WARN >= LOGGER_SHOW) fprintf(stdout, __VA_ARGS__); } while (0)
56 #define logger_error(...) do { if (LOGGER_ERROR >= LOGGER_SHOW) fprintf(stderr, __VA_ARGS__); } while (0)
57
58 // Events that are communicated via IRQ from Amiga to Raspberry.
59 #define R_EVENT_A2R_TAIL        1
60 #define R_EVENT_R2A_HEAD        2
61 #define R_EVENT_STARTED         4
62
63 // Events that are communicated from Raspberry to Amiga.
64 #define A_EVENT_R2A_TAIL        1
65 #define A_EVENT_A2R_HEAD        2
66
67 // Offset relative to communication area for queue pointers.
68 #define A2R_TAIL_OFFSET         0
69 #define R2A_HEAD_OFFSET         1
70 #define R2A_TAIL_OFFSET         2
71 #define A2R_HEAD_OFFSET         3
72
73 // Packets that are communicated across physical channels (A2R and R2A).
74 #define PKT_CONNECT             4
75 #define PKT_CONNECT_RESPONSE    5
76 #define PKT_DATA                6
77 #define PKT_EOS                 7
78 #define PKT_RESET               8
79
80 // Valid responses for PKT_CONNECT_RESPONSE.
81 #define CONNECT_OK              0
82 #define CONNECT_UNKNOWN_SERVICE 3
83
84 // Messages that are communicated between driver and client.
85 #define MSG_REGISTER_REQ        1
86 #define MSG_REGISTER_RES        2
87 #define MSG_DEREGISTER_REQ      3
88 #define MSG_DEREGISTER_RES      4
89 #define MSG_READ_MEM_REQ        5
90 #define MSG_READ_MEM_RES        6
91 #define MSG_WRITE_MEM_REQ       7
92 #define MSG_WRITE_MEM_RES       8
93 #define MSG_CONNECT             9
94 #define MSG_CONNECT_RESPONSE    10
95 #define MSG_DATA                11
96 #define MSG_EOS                 12
97 #define MSG_RESET               13
98
99 #define MSG_SUCCESS             1
100 #define MSG_FAIL                0
101
102 static sigset_t original_sigset;
103
104 static pthread_t thread_id;
105 static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
106
107 static int server_socket = -1;
108
109 static int epfd = -1;
110 static int irq_fds[2];
111
112 extern "C" unsigned int ps_read_8(unsigned int address);
113 extern "C" void ps_write_8(unsigned int address, unsigned int value);
114 extern "C" void ps_write_16(unsigned int address, unsigned int value);
115
116 unsigned int a314_base;
117 int a314_base_configured;
118
119 struct ComArea
120 {
121     uint8_t a_events;
122     uint8_t a_enable;
123     uint8_t r_events;
124     uint8_t r_enable; // Unused.
125
126     uint32_t mem_base;
127     uint32_t mem_size;
128
129     uint8_t a2r_tail;
130     uint8_t r2a_head;
131     uint8_t r2a_tail;
132     uint8_t a2r_head;
133
134     uint8_t a2r_buffer[256];
135     uint8_t r2a_buffer[256];
136 };
137
138 static ComArea ca;
139
140 static bool a314_device_started = false;
141
142 static uint8_t channel_status[4];
143 static uint8_t channel_status_updated = 0;
144
145 static uint8_t recv_buf[256];
146 static uint8_t send_buf[256];
147
148 struct LogicalChannel;
149 struct ClientConnection;
150
151 #pragma pack(push, 1)
152 struct MessageHeader
153 {
154     uint32_t length;
155     uint32_t stream_id;
156     uint8_t type;
157 }; //} __attribute__((packed));
158 #pragma pack(pop)
159
160 struct MessageBuffer
161 {
162     int pos;
163     std::vector<uint8_t> data;
164 };
165
166 struct RegisteredService
167 {
168     std::string name;
169     ClientConnection *cc;
170 };
171
172 struct PacketBuffer
173 {
174     int type;
175     std::vector<uint8_t> data;
176 };
177
178 struct ClientConnection
179 {
180     int fd;
181
182     int next_stream_id;
183
184     int bytes_read;
185     MessageHeader header;
186     std::vector<uint8_t> payload;
187
188     std::list<MessageBuffer> message_queue;
189
190     std::list<LogicalChannel*> associations;
191 };
192
193 struct LogicalChannel
194 {
195     int channel_id;
196
197     ClientConnection *association;
198     int stream_id;
199
200     bool got_eos_from_ami;
201     bool got_eos_from_client;
202
203     std::list<PacketBuffer> packet_queue;
204 };
205
206 static void remove_association(LogicalChannel *ch);
207 static void clear_packet_queue(LogicalChannel *ch);
208 static void create_and_enqueue_packet(LogicalChannel *ch, uint8_t type, uint8_t *data, uint8_t length);
209
210 static std::list<ClientConnection> connections;
211 static std::list<RegisteredService> services;
212 static std::list<LogicalChannel> channels;
213 static std::list<LogicalChannel*> send_queue;
214
215 struct OnDemandStart
216 {
217     std::string service_name;
218     std::string program;
219     std::vector<std::string> arguments;
220 };
221
222 std::vector<OnDemandStart> on_demand_services;
223
224 std::string a314_config_file = "./a314/files_pi/a314d.conf";
225 std::string home_env = "HOME=./";
226
227 static void load_config_file(const char *filename)
228 {
229     FILE *f = fopen(filename, "rt");
230     if (f == nullptr) {
231         return;
232     }
233
234     char line[256];
235     std::vector<char *> parts;
236
237     while (fgets(line, 256, f) != nullptr)
238     {
239         char org_line[256];
240         strcpy(org_line, line);
241
242         bool in_quotes = false;
243
244         int start = 0;
245         for (int i = 0; i < 256; i++)
246         {
247             if (line[i] == 0)
248             {
249                 if (start < i)
250                     parts.push_back(&line[start]);
251                 break;
252             }
253             else if (line[i] == '"')
254             {
255                 line[i] = 0;
256                 if (in_quotes)
257                     parts.push_back(&line[start]);
258                 in_quotes = !in_quotes;
259                 start = i + 1;
260             }
261             else if (isspace(line[i]) && !in_quotes)
262             {
263                 line[i] = 0;
264                 if (start < i)
265                     parts.push_back(&line[start]);
266                 start = i + 1;
267             }
268         }
269
270         if (parts.size() >= 2)
271         {
272             on_demand_services.emplace_back();
273             auto &e = on_demand_services.back();
274             e.service_name = parts[0];
275             e.program = parts[1];
276             for (int i = 1; i < parts.size(); i++)
277                 e.arguments.push_back(std::string(parts[i]));
278         }
279         else if (parts.size() != 0)
280             logger_warn("Invalid number of columns in configuration file line: %s\n", org_line);
281
282         parts.clear();
283     }
284
285     fclose(f);
286
287     if (on_demand_services.empty())
288         logger_warn("No registered services\n");
289 }
290
291 static int init_server_socket()
292 {
293     server_socket = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0);
294     if (server_socket == -1)
295     {
296         logger_error("Failed to create server socket\n");
297         return -1;
298     }
299
300     struct sockaddr_in address;
301     address.sin_family = AF_INET;
302     address.sin_addr.s_addr = INADDR_ANY;
303     address.sin_port = htons(7110);
304
305     int res = bind(server_socket, (struct sockaddr *)&address, sizeof(address));
306     if (res < 0)
307     {
308         logger_error("Bind to localhost:7110 failed\n");
309         return -1;
310     }
311
312     listen(server_socket, 16);
313
314     return 0;
315 }
316
317 static void shutdown_server_socket()
318 {
319     if (server_socket != -1)
320         close(server_socket);
321     server_socket = -1;
322 }
323
324 void create_and_send_msg(ClientConnection *cc, int type, int stream_id, uint8_t *data, int length)
325 {
326     MessageBuffer mb;
327     mb.pos = 0;
328     mb.data.resize(sizeof(MessageHeader) + length);
329
330     MessageHeader *mh = (MessageHeader *)&mb.data[0];
331     mh->length = length;
332     mh->stream_id = stream_id;
333     mh->type = type;
334     if (length && data)
335         memcpy(&mb.data[sizeof(MessageHeader)], data, length);
336
337     if (!cc->message_queue.empty())
338     {
339         cc->message_queue.push_back(std::move(mb));
340         return;
341     }
342
343     while (1)
344     {
345         int left = mb.data.size() - mb.pos;
346         uint8_t *src = &mb.data[mb.pos];
347         ssize_t r = write(cc->fd, src, left);
348         if (r == -1)
349         {
350             if (errno == EAGAIN || errno == EWOULDBLOCK)
351             {
352                 cc->message_queue.push_back(std::move(mb));
353                 return;
354             }
355             else if (errno == ECONNRESET)
356             {
357                 // Do not close connection here; it will get done at some other place.
358                 return;
359             }
360             else
361             {
362                 logger_error("Write failed unexpectedly with errno = %d\n", errno);
363                 exit(-1);
364             }
365         }
366
367         mb.pos += r;
368         if (r == left)
369         {
370             return;
371         }
372     }
373 }
374
375 static void handle_msg_register_req(ClientConnection *cc)
376 {
377     uint8_t result = MSG_FAIL;
378
379     std::string service_name((char *)&cc->payload[0], cc->payload.size());
380
381     auto it = services.begin();
382     for (; it != services.end(); it++)
383         if (it->name == service_name)
384             break;
385
386     if (it == services.end())
387     {
388         services.emplace_back();
389
390         RegisteredService &srv = services.back();
391         srv.cc = cc;
392         srv.name = std::move(service_name);
393
394         result = MSG_SUCCESS;
395     }
396
397     create_and_send_msg(cc, MSG_REGISTER_RES, 0, &result, 1);
398 }
399
400 static void handle_msg_deregister_req(ClientConnection *cc)
401 {
402     uint8_t result = MSG_FAIL;
403
404     std::string service_name((char *)&cc->payload[0], cc->payload.size());
405
406     for (auto it = services.begin(); it != services.end(); it++)
407     {
408         if (it->name == service_name && it->cc == cc)
409         {
410             services.erase(it);
411             result = MSG_SUCCESS;
412             break;
413         }
414     }
415
416     create_and_send_msg(cc, MSG_DEREGISTER_RES, 0, &result, 1);
417 }
418
419 uint8_t manual_read_buf[64 * SIZE_KILO];
420
421 static void handle_msg_read_mem_req(ClientConnection *cc)
422 {
423     uint32_t address = *(uint32_t *)&(cc->payload[0]);
424     uint32_t length = *(uint32_t *)&(cc->payload[4]);
425
426     if (get_mapped_item_by_address(cfg, address) != -1) {
427         int32_t index = get_mapped_item_by_address(cfg, address);
428         uint8_t *map = &cfg->map_data[index][address - cfg->map_offset[index]];
429         create_and_send_msg(cc, MSG_READ_MEM_RES, 0, map, length);
430     } else {
431         // No idea if this actually works.
432         for (int i = 0; i < length; i++) {
433             manual_read_buf[i] = (unsigned char)ps_read_8(address + i);
434         }
435         create_and_send_msg(cc, MSG_READ_MEM_RES, 0, manual_read_buf, length);
436     }
437     
438 }
439
440 static void handle_msg_write_mem_req(ClientConnection *cc)
441 {
442     uint32_t address = *(uint32_t *)&(cc->payload[0]);
443     uint32_t length = cc->payload.size() - 4;
444
445
446     if (get_mapped_item_by_address(cfg, address) != -1) {
447         int32_t index = get_mapped_item_by_address(cfg, address);
448         uint8_t *map = &cfg->map_data[index][address - cfg->map_offset[index]];
449         memcpy(map, &(cc->payload[4]), length);
450     } else {
451         // No idea if this actually works.
452         for (int i = 0; i < length; i++) {
453             ps_write_8(address + i, cc->payload[4 + i]);
454         }
455     }
456
457     create_and_send_msg(cc, MSG_WRITE_MEM_RES, 0, nullptr, 0);
458 }
459
460 static LogicalChannel *get_associated_channel_by_stream_id(ClientConnection *cc, int stream_id)
461 {
462     for (auto ch : cc->associations)
463     {
464         if (ch->stream_id == stream_id)
465             return ch;
466     }
467     return nullptr;
468 }
469
470 static void handle_msg_connect(ClientConnection *cc)
471 {
472     // We currently don't handle that a client tries to connect to a service on the Amiga.
473 }
474
475 static void handle_msg_connect_response(ClientConnection *cc)
476 {
477     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
478     if (!ch)
479         return;
480
481     create_and_enqueue_packet(ch, PKT_CONNECT_RESPONSE, &cc->payload[0], cc->payload.size());
482
483     if (cc->payload[0] != CONNECT_OK)
484         remove_association(ch);
485 }
486
487 static void handle_msg_data(ClientConnection *cc)
488 {
489     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
490     if (!ch)
491         return;
492
493     create_and_enqueue_packet(ch, PKT_DATA, &cc->payload[0], cc->header.length);
494 }
495
496 static void handle_msg_eos(ClientConnection *cc)
497 {
498     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
499     if (!ch || ch->got_eos_from_client)
500         return;
501
502     ch->got_eos_from_client = true;
503
504     create_and_enqueue_packet(ch, PKT_EOS, nullptr, 0);
505
506     if (ch->got_eos_from_ami)
507         remove_association(ch);
508 }
509
510 static void handle_msg_reset(ClientConnection *cc)
511 {
512     LogicalChannel *ch = get_associated_channel_by_stream_id(cc, cc->header.stream_id);
513     if (!ch)
514         return;
515
516     remove_association(ch);
517
518     clear_packet_queue(ch);
519     create_and_enqueue_packet(ch, PKT_RESET, nullptr, 0);
520 }
521
522 static void handle_received_message(ClientConnection *cc)
523 {
524     switch (cc->header.type)
525     {
526     case MSG_REGISTER_REQ:
527         handle_msg_register_req(cc);
528         break;
529     case MSG_DEREGISTER_REQ:
530         handle_msg_deregister_req(cc);
531         break;
532     case MSG_READ_MEM_REQ:
533         handle_msg_read_mem_req(cc);
534         break;
535     case MSG_WRITE_MEM_REQ:
536         handle_msg_write_mem_req(cc);
537         break;
538     case MSG_CONNECT:
539         handle_msg_connect(cc);
540         break;
541     case MSG_CONNECT_RESPONSE:
542         handle_msg_connect_response(cc);
543         break;
544     case MSG_DATA:
545         handle_msg_data(cc);
546         break;
547     case MSG_EOS:
548         handle_msg_eos(cc);
549         break;
550     case MSG_RESET:
551         handle_msg_reset(cc);
552         break;
553     default:
554         // This is bad, probably should disconnect from client.
555         logger_warn("Received a message of unknown type from client\n");
556         break;
557     }
558 }
559
560 static void close_and_remove_connection(ClientConnection *cc)
561 {
562     shutdown(cc->fd, SHUT_WR);
563     close(cc->fd);
564
565     {
566         auto it = services.begin();
567         while (it != services.end())
568         {
569             if (it->cc == cc)
570                 it = services.erase(it);
571             else
572                 it++;
573         }
574     }
575
576     {
577         auto it = cc->associations.begin();
578         while (it != cc->associations.end())
579         {
580             auto ch = *it;
581
582             clear_packet_queue(ch);
583             create_and_enqueue_packet(ch, PKT_RESET, nullptr, 0);
584
585             ch->association = nullptr;
586             ch->stream_id = 0;
587
588             it = cc->associations.erase(it);
589         }
590     }
591
592     for (auto it = connections.begin(); it != connections.end(); it++)
593     {
594         if (&(*it) == cc)
595         {
596             connections.erase(it);
597             break;
598         }
599     }
600 }
601
602 static void remove_association(LogicalChannel *ch)
603 {
604     auto &ass = ch->association->associations;
605     ass.erase(std::find(ass.begin(), ass.end(), ch));
606
607     ch->association = nullptr;
608     ch->stream_id = 0;
609 }
610
611 static void clear_packet_queue(LogicalChannel *ch)
612 {
613     if (!ch->packet_queue.empty())
614     {
615         ch->packet_queue.clear();
616         send_queue.erase(std::find(send_queue.begin(), send_queue.end(), ch));
617     }
618 }
619
620 static void create_and_enqueue_packet(LogicalChannel *ch, uint8_t type, uint8_t *data, uint8_t length)
621 {
622     if (ch->packet_queue.empty())
623         send_queue.push_back(ch);
624
625     ch->packet_queue.emplace_back();
626
627     PacketBuffer &pb = ch->packet_queue.back();
628     pb.type = type;
629     pb.data.resize(length);
630     if (data && length)
631         memcpy(&pb.data[0], data, length);
632 }
633
634 static void handle_pkt_connect(int channel_id, uint8_t *data, int plen)
635 {
636     for (auto &ch : channels)
637     {
638         if (ch.channel_id == channel_id)
639         {
640             // We should handle this in some constructive way.
641             // This signals that should reset all logical channels.
642             logger_error("Received a CONNECT packet on a channel that was believed to be previously allocated\n");
643             exit(-1);
644         }
645     }
646
647     channels.emplace_back();
648
649     auto &ch = channels.back();
650
651     ch.channel_id = channel_id;
652     ch.association = nullptr;
653     ch.stream_id = 0;
654     ch.got_eos_from_ami = false;
655     ch.got_eos_from_client = false;
656
657     std::string service_name((char *)data, plen);
658
659     for (auto &srv : services)
660     {
661         if (srv.name == service_name)
662         {
663             ClientConnection *cc = srv.cc;
664
665             ch.association = cc;
666             ch.stream_id = cc->next_stream_id;
667
668             cc->next_stream_id += 2;
669             cc->associations.push_back(&ch);
670
671             create_and_send_msg(ch.association, MSG_CONNECT, ch.stream_id, data, plen);
672             return;
673         }
674     }
675
676     for (auto &on_demand : on_demand_services)
677     {
678         if (on_demand.service_name == service_name)
679         {
680             int fds[2];
681             int status = socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
682             if (status != 0)
683             {
684                 logger_error("Unexpectedly not able to create socket pair.\n");
685                 exit(-1);
686             }
687
688             pid_t child = fork();
689             if (child == -1)
690             {
691                 logger_error("Unexpectedly was not able to fork.\n");
692                 exit(-1);
693             }
694             else if (child == 0)
695             {
696                 close(fds[0]);
697                 int fd = fds[1];
698
699                 // FIXE: The user should be configurable.
700                 setgid(1000);
701                 setuid(1000);
702                 putenv((char *)home_env.c_str());
703
704                 std::vector<std::string> args(on_demand.arguments);
705                 args.push_back("-ondemand");
706                 args.push_back(std::to_string(fd));
707                 std::vector<const char *> args_arr;
708                 for (auto &arg : args)
709                     args_arr.push_back(arg.c_str());
710                 args_arr.push_back(nullptr);
711
712                 execvp(on_demand.program.c_str(), (char* const*) &args_arr[0]);
713             }
714             else
715             {
716                 close(fds[1]);
717                 int fd = fds[0];
718
719                 int status = fcntl(fd, F_SETFD, fcntl(fd, F_GETFD, 0) | FD_CLOEXEC);
720                 if (status == -1)
721                 {
722                     logger_error("Unexpectedly unable to set close-on-exec flag on client socket descriptor; errno = %d\n", errno);
723                     exit(-1);
724                 }
725
726                 status = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
727                 if (status == -1)
728                 {
729                     logger_error("Unexpectedly unable to set client socket to non blocking; errno = %d\n", errno);
730                     exit(-1);
731                 }
732
733                 int flag = 1;
734                 setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
735
736                 connections.emplace_back();
737
738                 ClientConnection &cc = connections.back();
739                 cc.fd = fd;
740                 cc.next_stream_id = 1;
741                 cc.bytes_read = 0;
742
743                 struct epoll_event ev;
744                 ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
745                 ev.data.fd = fd;
746                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev) != 0)
747                 {
748                     logger_error("epoll_ctl() failed unexpectedly with errno = %d\n", errno);
749                     exit(-1);
750                 }
751
752                 services.emplace_back();
753
754                 RegisteredService &srv = services.back();
755                 srv.cc = &cc;
756                 srv.name = std::move(service_name);
757
758                 ch.association = &cc;
759                 ch.stream_id = cc.next_stream_id;
760
761                 cc.next_stream_id += 2;
762                 cc.associations.push_back(&ch);
763
764                 create_and_send_msg(ch.association, MSG_CONNECT, ch.stream_id, data, plen);
765                 return;
766             }
767         }
768     }
769
770     uint8_t response = CONNECT_UNKNOWN_SERVICE;
771     create_and_enqueue_packet(&ch, PKT_CONNECT_RESPONSE, &response, 1);
772 }
773
774 static void handle_pkt_data(int channel_id, uint8_t *data, int plen)
775 {
776     for (auto &ch : channels)
777     {
778         if (ch.channel_id == channel_id)
779         {
780             if (ch.association != nullptr && !ch.got_eos_from_ami)
781                 create_and_send_msg(ch.association, MSG_DATA, ch.stream_id, data, plen);
782
783             break;
784         }
785     }
786 }
787
788 static void handle_pkt_eos(int channel_id)
789 {
790     for (auto &ch : channels)
791     {
792         if (ch.channel_id == channel_id)
793         {
794             if (ch.association != nullptr && !ch.got_eos_from_ami)
795             {
796                 ch.got_eos_from_ami = true;
797
798                 create_and_send_msg(ch.association, MSG_EOS, ch.stream_id, nullptr, 0);
799
800                 if (ch.got_eos_from_client)
801                     remove_association(&ch);
802             }
803             break;
804         }
805     }
806 }
807
808 static void handle_pkt_reset(int channel_id)
809 {
810     for (auto &ch : channels)
811     {
812         if (ch.channel_id == channel_id)
813         {
814             clear_packet_queue(&ch);
815
816             if (ch.association != nullptr)
817             {
818                 create_and_send_msg(ch.association, MSG_RESET, ch.stream_id, nullptr, 0);
819                 remove_association(&ch);
820             }
821
822             break;
823         }
824     }
825 }
826
827 static void remove_channel_if_not_associated_and_empty_pq(int channel_id)
828 {
829     for (auto it = channels.begin(); it != channels.end(); it++)
830     {
831         if (it->channel_id == channel_id)
832         {
833             if (it->association == nullptr && it->packet_queue.empty())
834                 channels.erase(it);
835
836             break;
837         }
838     }
839 }
840
841 static void handle_received_pkt(int ptype, int channel_id, uint8_t *data, int plen)
842 {
843     if (ptype == PKT_CONNECT)
844         handle_pkt_connect(channel_id, data, plen);
845     else if (ptype == PKT_DATA)
846         handle_pkt_data(channel_id, data, plen);
847     else if (ptype == PKT_EOS)
848         handle_pkt_eos(channel_id);
849     else if (ptype == PKT_RESET)
850         handle_pkt_reset(channel_id);
851
852     remove_channel_if_not_associated_and_empty_pq(channel_id);
853 }
854
855 static bool receive_from_a2r()
856 {
857     int head = channel_status[A2R_HEAD_OFFSET];
858     int tail = channel_status[A2R_TAIL_OFFSET];
859     int len = (tail - head) & 255;
860     if (len == 0)
861         return false;
862
863     if (head < tail)
864     {
865         memcpy(recv_buf, &ca.a2r_buffer[head], len);
866     }
867     else
868     {
869         memcpy(recv_buf, &ca.a2r_buffer[head], 256 - head);
870
871         if (tail != 0)
872         {
873             memcpy(&recv_buf[len - tail], &ca.a2r_buffer[0], tail);
874         }
875     }
876
877     uint8_t *p = recv_buf;
878     while (p < recv_buf + len)
879     {
880         uint8_t plen = *p++;
881         uint8_t ptype = *p++;
882         uint8_t channel_id = *p++;
883         handle_received_pkt(ptype, channel_id, p, plen);
884         p += plen;
885     }
886
887     channel_status[A2R_HEAD_OFFSET] = channel_status[A2R_TAIL_OFFSET];
888     channel_status_updated |= A_EVENT_A2R_HEAD;
889     return true;
890 }
891
892 static bool flush_send_queue()
893 {
894     int tail = channel_status[R2A_TAIL_OFFSET];
895     int head = channel_status[R2A_HEAD_OFFSET];
896     int len = (tail - head) & 255;
897     int left = 255 - len;
898
899     int pos = 0;
900
901     while (!send_queue.empty())
902     {
903         LogicalChannel *ch = send_queue.front();
904         PacketBuffer &pb = ch->packet_queue.front();
905
906         int ptype = pb.type;
907         int plen = 3 + pb.data.size();
908
909         if (left < plen)
910             break;
911
912         send_buf[pos++] = pb.data.size();
913         send_buf[pos++] = ptype;
914         send_buf[pos++] = ch->channel_id;
915         memcpy(&send_buf[pos], &pb.data[0], pb.data.size());
916         pos += pb.data.size();
917
918         ch->packet_queue.pop_front();
919
920         send_queue.pop_front();
921
922         if (!ch->packet_queue.empty())
923             send_queue.push_back(ch);
924         else
925             remove_channel_if_not_associated_and_empty_pq(ch->channel_id);
926
927         left -= plen;
928     }
929
930     int to_write = pos;
931     if (!to_write)
932         return false;
933
934     uint8_t *p = send_buf;
935     int at_end = 256 - tail;
936     if (at_end < to_write)
937     {
938         memcpy(&ca.r2a_buffer[tail], p, at_end);
939         p += at_end;
940         to_write -= at_end;
941         tail = 0;
942     }
943
944     memcpy(&ca.r2a_buffer[tail], p, to_write);
945     tail = (tail + to_write) & 255;
946
947     channel_status[R2A_TAIL_OFFSET] = tail;
948     channel_status_updated |= A_EVENT_R2A_TAIL;
949     return true;
950 }
951
952 static void read_channel_status()
953 {
954     channel_status[A2R_TAIL_OFFSET] = ca.a2r_tail;
955     channel_status[R2A_HEAD_OFFSET] = ca.r2a_head;
956     channel_status[R2A_TAIL_OFFSET] = ca.r2a_tail;
957     channel_status[A2R_HEAD_OFFSET] = ca.a2r_head;
958     channel_status_updated = 0;
959 }
960
961 static void write_channel_status()
962 {
963     if (channel_status_updated != 0)
964     {
965         ca.r2a_tail = channel_status[R2A_TAIL_OFFSET];
966         ca.a2r_head = channel_status[A2R_HEAD_OFFSET];
967
968         pthread_mutex_lock(&mutex);
969         ca.a_events |= channel_status_updated;
970         pthread_mutex_unlock(&mutex);
971
972         channel_status_updated = 0;
973     }
974 }
975
976 static void close_all_logical_channels()
977 {
978     send_queue.clear();
979
980     auto it = channels.begin();
981     while (it != channels.end())
982     {
983         LogicalChannel &ch = *it;
984
985         if (ch.association != nullptr)
986         {
987             create_and_send_msg(ch.association, MSG_RESET, ch.stream_id, nullptr, 0);
988             remove_association(&ch);
989         }
990
991         it = channels.erase(it);
992     }
993 }
994
995 static void handle_a314_irq(uint8_t events)
996 {
997     if (events == 0)
998         return;
999
1000     if (events & R_EVENT_STARTED)
1001     {
1002         if (!channels.empty())
1003             logger_info("Received STARTED event while logical channels are open -- closing channels\n");
1004
1005         close_all_logical_channels();
1006         a314_device_started = true;
1007     }
1008
1009     if (!a314_device_started)
1010         return;
1011
1012     read_channel_status();
1013
1014     bool any_rcvd = receive_from_a2r();
1015     bool any_sent = flush_send_queue();
1016
1017     if (any_rcvd || any_sent)
1018         write_channel_status();
1019 }
1020
1021 static void handle_client_connection_event(ClientConnection *cc, struct epoll_event *ev)
1022 {
1023     if (ev->events & EPOLLERR)
1024     {
1025         logger_warn("Received EPOLLERR for client connection\n");
1026         close_and_remove_connection(cc);
1027         return;
1028     }
1029
1030     if (ev->events & EPOLLIN)
1031     {
1032         while (1)
1033         {
1034             int left;
1035             uint8_t *dst;
1036
1037             if (cc->payload.empty())
1038             {
1039                 left = sizeof(MessageHeader) - cc->bytes_read;
1040                 dst = (uint8_t *)&(cc->header) + cc->bytes_read;
1041             }
1042             else
1043             {
1044                 left = cc->header.length - cc->bytes_read;
1045                 dst = &cc->payload[cc->bytes_read];
1046             }
1047
1048             ssize_t r = read(cc->fd, dst, left);
1049             if (r == -1)
1050             {
1051                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1052                     break;
1053
1054                 logger_error("Read failed unexpectedly with errno = %d\n", errno);
1055                 exit(-1);
1056             }
1057
1058             if (r == 0)
1059             {
1060                 logger_info("Received End-of-File on client connection\n");
1061                 close_and_remove_connection(cc);
1062                 return;
1063             }
1064             else
1065             {
1066                 cc->bytes_read += r;
1067                 left -= r;
1068                 if (!left)
1069                 {
1070                     if (cc->payload.empty())
1071                     {
1072                         if (cc->header.length == 0)
1073                         {
1074                             logger_trace("header: length=%d, stream_id=%d, type=%d\n", cc->header.length, cc->header.stream_id, cc->header.type);
1075                             handle_received_message(cc);
1076                         }
1077                         else
1078                         {
1079                             cc->payload.resize(cc->header.length);
1080                         }
1081                     }
1082                     else
1083                     {
1084                         logger_trace("header: length=%d, stream_id=%d, type=%d\n", cc->header.length, cc->header.stream_id, cc->header.type);
1085                         handle_received_message(cc);
1086                         cc->payload.clear();
1087                     }
1088                     cc->bytes_read = 0;
1089                 }
1090             }
1091         }
1092     }
1093
1094     if (ev->events & EPOLLOUT)
1095     {
1096         while (!cc->message_queue.empty())
1097         {
1098             MessageBuffer &mb = cc->message_queue.front();
1099
1100             int left = mb.data.size() - mb.pos;
1101             uint8_t *src = &mb.data[mb.pos];
1102             ssize_t r = write(cc->fd, src, left);
1103             if (r == -1)
1104             {
1105                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1106                     break;
1107                 else if (errno == ECONNRESET)
1108                 {
1109                     close_and_remove_connection(cc);
1110                     return;
1111                 }
1112                 else
1113                 {
1114                     logger_error("Write failed unexpectedly with errno = %d\n", errno);
1115                     exit(-1);
1116                 }
1117             }
1118
1119             mb.pos += r;
1120             if (r == left)
1121                 cc->message_queue.pop_front();
1122         }
1123     }
1124 }
1125
1126 static void handle_server_socket_ready()
1127 {
1128     struct sockaddr_in address;
1129     int alen = sizeof(struct sockaddr_in);
1130
1131     int fd = accept(server_socket, (struct sockaddr *)&address, (socklen_t *)&alen);
1132     if (fd < 0)
1133     {
1134         logger_error("Accept failed unexpectedly with errno = %d\n", errno);
1135         exit(-1);
1136     }
1137
1138     int status = fcntl(fd, F_SETFD, fcntl(fd, F_GETFD, 0) | FD_CLOEXEC);
1139     if (status == -1)
1140     {
1141         logger_error("Unexpectedly unable to set close-on-exec flag on client socket descriptor; errno = %d\n", errno);
1142         exit(-1);
1143     }
1144
1145     status = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
1146     if (status == -1)
1147     {
1148         logger_error("Unexpectedly unable to set client socket to non blocking; errno = %d\n", errno);
1149         exit(-1);
1150     }
1151
1152     int flag = 1;
1153     setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
1154
1155     connections.emplace_back();
1156
1157     ClientConnection &cc = connections.back();
1158     cc.fd = fd;
1159     cc.next_stream_id = 1;
1160     cc.bytes_read = 0;
1161
1162     struct epoll_event ev;
1163     ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
1164     ev.data.fd = fd;
1165     if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev) != 0)
1166     {
1167         logger_error("epoll_ctl() failed unexpectedly with errno = %d\n", errno);
1168         exit(-1);
1169     }
1170 }
1171
1172 static void main_loop()
1173 {
1174     bool shutting_down = false;
1175     bool done = false;
1176
1177     while (!done)
1178     {
1179         struct epoll_event ev;
1180         int timeout = shutting_down ? 10000 : -1;
1181         int n = epoll_pwait(epfd, &ev, 1, timeout, &original_sigset);
1182         if (n == -1)
1183         {
1184             if (errno == EINTR)
1185             {
1186                 logger_info("Received SIGTERM\n");
1187
1188                 shutdown_server_socket();
1189
1190                 while (!connections.empty())
1191                     close_and_remove_connection(&connections.front());
1192
1193                 if (flush_send_queue())
1194                     write_channel_status();
1195
1196                 if (!channels.empty())
1197                     shutting_down = true;
1198                 else
1199                     done = true;
1200             }
1201             else
1202             {
1203                 logger_error("epoll_pwait failed with unexpected errno = %d\n", errno);
1204                 exit(-1);
1205             }
1206         }
1207         else if (n == 0)
1208         {
1209             if (shutting_down)
1210                 done = true;
1211             else
1212             {
1213                 logger_error("epoll_pwait returned 0 which is unexpected since no timeout was set\n");
1214                 exit(-1);
1215             }
1216         }
1217         else
1218         {
1219             if (ev.data.fd == irq_fds[1])
1220             {
1221                 uint8_t events;
1222                 if (read(irq_fds[1], &events, 1) != 1)
1223                 {
1224                     logger_error("Read from interrupt socket pair, and unexpectedly didn't return 1 byte\n");
1225                     exit(-1);
1226                 }
1227
1228                 handle_a314_irq(events);
1229             }
1230             else if (ev.data.fd == server_socket)
1231             {
1232                 logger_trace("Epoll event: server socket is ready, events = %d\n", ev.events);
1233                 handle_server_socket_ready();
1234             }
1235             else
1236             {
1237                 logger_trace("Epoll event: client socket is ready, events = %d\n", ev.events);
1238
1239                 auto it = connections.begin();
1240                 for (; it != connections.end(); it++)
1241                 {
1242                     if (it->fd == ev.data.fd)
1243                         break;
1244                 }
1245
1246                 if (it == connections.end())
1247                 {
1248                     logger_error("Got notified about an event on a client connection that supposedly isn't currently open\n");
1249                     exit(-1);
1250                 }
1251
1252                 ClientConnection *cc = &(*it);
1253                 handle_client_connection_event(cc, &ev);
1254
1255                 if (flush_send_queue())
1256                     write_channel_status();
1257             }
1258         }
1259     }
1260 }
1261
1262 static int init_driver()
1263 {
1264     if (init_server_socket() != 0)
1265         return -1;
1266
1267     int err = socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0, irq_fds);
1268     if (err != 0)
1269     {
1270         logger_error("Unable to create socket pair, errno = %d\n", errno);
1271         return -1;
1272     }
1273
1274     epfd = epoll_create1(EPOLL_CLOEXEC);
1275     if (epfd == -1)
1276         return -1;
1277
1278     struct epoll_event ev;
1279     ev.events = EPOLLIN;
1280     ev.data.fd = irq_fds[1];
1281     if (epoll_ctl(epfd, EPOLL_CTL_ADD, irq_fds[1], &ev) != 0)
1282         return -1;
1283
1284     ev.events = EPOLLIN;
1285     ev.data.fd = server_socket;
1286     if (epoll_ctl(epfd, EPOLL_CTL_ADD, server_socket, &ev) != 0)
1287         return -1;
1288
1289     return 0;
1290 }
1291
1292 static void shutdown_driver()
1293 {
1294     if (epfd != -1)
1295         close(epfd);
1296
1297     shutdown_server_socket();
1298 }
1299
1300 static void *thread_start(void *arg)
1301 {
1302     main_loop();
1303     shutdown_driver();
1304     return NULL;
1305 }
1306
1307 static void write_r_events(uint8_t events)
1308 {
1309     if (write(irq_fds[0], &events, 1) != 1)
1310         logger_error("Write to interrupt socket pair did not return 1\n");
1311 }
1312
1313 int a314_init()
1314 {
1315     load_config_file(a314_config_file.c_str());
1316
1317     int err = init_driver();
1318     if (err < 0)
1319     {
1320         shutdown_driver();
1321         return -1;
1322     }
1323
1324     err = pthread_create(&thread_id, NULL, thread_start, NULL);
1325     if (err < 0)
1326     {
1327         logger_error("pthread_create failed with err = %d\n", err);
1328         return -2;
1329     }
1330
1331     return 0;
1332 }
1333
1334 void a314_set_mem_base_size(unsigned int base, unsigned int size)
1335 {
1336     ca.mem_base = htobe32(base);
1337     ca.mem_size = htobe32(size);
1338 }
1339
1340 void a314_process_events()
1341 {
1342     if (ca.a_events & ca.a_enable)
1343     {
1344         ps_write_16(0xdff09c, 0x8008);
1345         m68k_set_irq(2);
1346     }
1347 }
1348
1349 unsigned int a314_read_memory_8(unsigned int address)
1350 {
1351     if (address >= sizeof(ca))
1352         return 0;
1353
1354     uint8_t val;
1355     if (address == offsetof(ComArea, a_events))
1356     {
1357         pthread_mutex_lock(&mutex);
1358         val = ca.a_events;
1359         ca.a_events = 0;
1360         pthread_mutex_unlock(&mutex);
1361     }
1362     else
1363     {
1364         uint8_t *p = (uint8_t *)&ca;
1365         val = p[address];
1366     }
1367
1368     return val;
1369 }
1370
1371 unsigned int a314_read_memory_16(unsigned int address)
1372 {
1373     if (address >= sizeof(ca))
1374         return 0;
1375
1376     uint16_t *p = (uint16_t *)&ca;
1377     return be16toh(p[address >> 1]);
1378 }
1379
1380 unsigned int a314_read_memory_32(unsigned int address)
1381 {
1382     if (address >= sizeof(ca))
1383         return 0;
1384
1385     uint32_t *p = (uint32_t *)&ca;
1386     return be32toh(p[address >> 2]);
1387 }
1388
1389 void a314_write_memory_8(unsigned int address, unsigned int value)
1390 {
1391     if (address >= sizeof(ca))
1392         return;
1393
1394     switch (address)
1395     {
1396         case offsetof(ComArea, a_events):
1397             // a_events is not writable.
1398             break;
1399
1400         case offsetof(ComArea, r_events):
1401             if (value != 0)
1402                 write_r_events((uint8_t)value);
1403             break;
1404
1405         default:
1406         {
1407             uint8_t *p = (uint8_t *)&ca;
1408             p[address] = (uint8_t)value;
1409             break;
1410         }
1411     }
1412 }
1413
1414 void a314_write_memory_16(unsigned int address, unsigned int value)
1415 {
1416     // Not implemented.
1417 }
1418
1419 void a314_write_memory_32(unsigned int address, unsigned int value)
1420 {
1421     // Not implemented.
1422 }
1423
1424 void a314_set_config_file(char *filename)
1425 {
1426     printf ("[A314] Set A314 config filename to %s.\n", filename);
1427     a314_config_file = std::string(filename);
1428 }