]> git.sesse.net Git - nageru/blob - midi_mapper.cpp
Make sure MIDI mapping extrapolation doesn't give out-of-range values.
[nageru] / midi_mapper.cpp
1 #include "midi_mapper.h"
2 #include "midi_mapping.pb.h"
3
4 #include <alsa/asoundlib.h>
5 #include <google/protobuf/text_format.h>
6 #include <google/protobuf/io/zero_copy_stream.h>
7 #include <google/protobuf/io/zero_copy_stream_impl.h>
8 #include <fcntl.h>
9 #include <sys/eventfd.h>
10
11 #include <functional>
12 #include <thread>
13
14 using namespace google::protobuf;
15 using namespace std;
16 using namespace std::placeholders;
17
18 namespace {
19
20 double map_controller_to_float(int val)
21 {
22         // Slightly hackish mapping so that we can represent exactly 0.0, 0.5 and 1.0.
23         if (val <= 0) {
24                 return 0.0;
25         } else if (val >= 127) {
26                 return 1.0;
27         } else {
28                 return (val + 0.5) / 127.0;
29         }
30 }
31
32 }  // namespace
33
34 MIDIMapper::MIDIMapper(ControllerReceiver *receiver)
35         : receiver(receiver), mapping_proto(new MIDIMappingProto)
36 {
37         should_quit_fd = eventfd(/*initval=*/0, /*flags=*/0);
38         assert(should_quit_fd != -1);
39 }
40
41 MIDIMapper::~MIDIMapper()
42 {
43         should_quit = true;
44         const uint64_t one = 1;
45         write(should_quit_fd, &one, sizeof(one));
46         midi_thread.join();
47         close(should_quit_fd);
48 }
49
50 bool load_midi_mapping_from_file(const string &filename, MIDIMappingProto *new_mapping)
51 {
52         // Read and parse the protobuf from disk.
53         int fd = open(filename.c_str(), O_RDONLY);
54         if (fd == -1) {
55                 perror(filename.c_str());
56                 return false;
57         }
58         io::FileInputStream input(fd);  // Takes ownership of fd.
59         if (!TextFormat::Parse(&input, new_mapping)) {
60                 input.Close();
61                 return false;
62         }
63         input.Close();
64         return true;
65 }
66
67 bool save_midi_mapping_to_file(const MIDIMappingProto &mapping_proto, const string &filename)
68 {
69         // Save to disk. We use the text format because it's friendlier
70         // for a user to look at and edit.
71         int fd = open(filename.c_str(), O_WRONLY | O_TRUNC | O_CREAT, 0666);
72         if (fd == -1) {
73                 perror(filename.c_str());
74                 return false;
75         }
76         io::FileOutputStream output(fd);  // Takes ownership of fd.
77         if (!TextFormat::Print(mapping_proto, &output)) {
78                 // TODO: Don't overwrite the old file (if any) on error.
79                 output.Close();
80                 return false;
81         }
82
83         output.Close();
84         return true;
85 }
86
87 void MIDIMapper::set_midi_mapping(const MIDIMappingProto &new_mapping)
88 {
89         lock_guard<mutex> lock(mapping_mu);
90         if (mapping_proto) {
91                 mapping_proto->CopyFrom(new_mapping);
92         } else {
93                 mapping_proto.reset(new MIDIMappingProto(new_mapping));
94         }
95
96         num_controller_banks = min(max(mapping_proto->num_controller_banks(), 1), 5);
97         current_controller_bank = 0;
98 }
99
100 void MIDIMapper::start_thread()
101 {
102         midi_thread = thread(&MIDIMapper::thread_func, this);
103 }
104
105 const MIDIMappingProto &MIDIMapper::get_current_mapping() const
106 {
107         lock_guard<mutex> lock(mapping_mu);
108         return *mapping_proto;
109 }
110
111 ControllerReceiver *MIDIMapper::set_receiver(ControllerReceiver *new_receiver)
112 {
113         lock_guard<mutex> lock(mapping_mu);
114         swap(receiver, new_receiver);
115         return new_receiver;  // Now old receiver.
116 }
117
118 #define RETURN_ON_ERROR(msg, expr) do {                            \
119         int err = (expr);                                          \
120         if (err < 0) {                                             \
121                 fprintf(stderr, msg ": %s\n", snd_strerror(err));  \
122                 return;                                            \
123         }                                                          \
124 } while (false)
125
126
127 void MIDIMapper::thread_func()
128 {
129         snd_seq_t *seq;
130         int err;
131
132         RETURN_ON_ERROR("snd_seq_open", snd_seq_open(&seq, "default", SND_SEQ_OPEN_DUPLEX, 0));
133         RETURN_ON_ERROR("snd_seq_nonblock", snd_seq_nonblock(seq, 1));
134         RETURN_ON_ERROR("snd_seq_client_name", snd_seq_set_client_name(seq, "nageru"));
135         RETURN_ON_ERROR("snd_seq_create_simple_port",
136                 snd_seq_create_simple_port(seq, "nageru",
137                         SND_SEQ_PORT_CAP_WRITE |
138                         SND_SEQ_PORT_CAP_SUBS_WRITE,
139                         SND_SEQ_PORT_TYPE_MIDI_GENERIC |
140                         SND_SEQ_PORT_TYPE_APPLICATION));
141
142         // Listen to the announce port (0:1), which will tell us about new ports.
143         RETURN_ON_ERROR("snd_seq_connect_from", snd_seq_connect_from(seq, 0, /*client=*/0, /*port=*/1));
144
145         // Now go through all ports and subscribe to them.
146         snd_seq_client_info_t *cinfo;
147         snd_seq_client_info_alloca(&cinfo);
148
149         snd_seq_client_info_set_client(cinfo, -1);
150         while (snd_seq_query_next_client(seq, cinfo) >= 0) {
151                 int client = snd_seq_client_info_get_client(cinfo);
152
153                 snd_seq_port_info_t *pinfo;
154                 snd_seq_port_info_alloca(&pinfo);
155
156                 snd_seq_port_info_set_client(pinfo, client);
157                 snd_seq_port_info_set_port(pinfo, -1);
158                 while (snd_seq_query_next_port(seq, pinfo) >= 0) {
159                         constexpr int mask = SND_SEQ_PORT_CAP_READ | SND_SEQ_PORT_CAP_SUBS_READ;
160                         if ((snd_seq_port_info_get_capability(pinfo) & mask) == mask) {
161                                 subscribe_to_port(seq, *snd_seq_port_info_get_addr(pinfo));
162                         }
163                 }
164         }
165
166         int num_alsa_fds = snd_seq_poll_descriptors_count(seq, POLLIN);
167         unique_ptr<pollfd[]> fds(new pollfd[num_alsa_fds + 1]);
168
169         while (!should_quit) {
170                 snd_seq_poll_descriptors(seq, fds.get(), num_alsa_fds, POLLIN);
171                 fds[num_alsa_fds].fd = should_quit_fd;
172                 fds[num_alsa_fds].events = POLLIN;
173                 fds[num_alsa_fds].revents = 0;
174
175                 err = poll(fds.get(), num_alsa_fds + 1, -1);
176                 if (err == 0 || (err == -1 && errno == EINTR)) {
177                         continue;
178                 }
179                 if (err == -1) {
180                         perror("poll");
181                         break;
182                 }
183                 if (fds[num_alsa_fds].revents) {
184                         // Activity on should_quit_fd.
185                         break;
186                 }
187
188                 // Seemingly we can get multiple events in a single poll,
189                 // and if we don't handle them all, poll will _not_ alert us!
190                 while (!should_quit) {
191                         snd_seq_event_t *event;
192                         err = snd_seq_event_input(seq, &event);
193                         if (err < 0) {
194                                 if (err == -EINTR) continue;
195                                 if (err == -EAGAIN) break;
196                                 fprintf(stderr, "snd_seq_event_input: %s\n", snd_strerror(err));
197                                 return;
198                         }
199                         if (event) {
200                                 handle_event(seq, event);
201                         }
202                 }
203         }
204 }
205
206 void MIDIMapper::handle_event(snd_seq_t *seq, snd_seq_event_t *event)
207 {
208         lock_guard<mutex> lock(mapping_mu);
209         switch (event->type) {
210         case SND_SEQ_EVENT_CONTROLLER: {
211                 printf("Controller %d changed to %d\n", event->data.control.param, event->data.control.value);
212
213                 const int controller = event->data.control.param;
214                 const float value = map_controller_to_float(event->data.control.value);
215
216                 receiver->controller_changed(controller);
217
218                 // Global controllers.
219                 match_controller(controller, MIDIMappingBusProto::kLocutFieldNumber, MIDIMappingProto::kLocutBankFieldNumber,
220                         value, bind(&ControllerReceiver::set_locut, receiver, _2));
221                 match_controller(controller, MIDIMappingBusProto::kLimiterThresholdFieldNumber, MIDIMappingProto::kLimiterThresholdBankFieldNumber,
222                         value, bind(&ControllerReceiver::set_limiter_threshold, receiver, _2));
223                 match_controller(controller, MIDIMappingBusProto::kMakeupGainFieldNumber, MIDIMappingProto::kMakeupGainBankFieldNumber,
224                         value, bind(&ControllerReceiver::set_makeup_gain, receiver, _2));
225
226                 // Bus controllers.
227                 match_controller(controller, MIDIMappingBusProto::kTrebleFieldNumber, MIDIMappingProto::kTrebleBankFieldNumber,
228                         value, bind(&ControllerReceiver::set_treble, receiver, _1, _2));
229                 match_controller(controller, MIDIMappingBusProto::kMidFieldNumber, MIDIMappingProto::kMidBankFieldNumber,
230                         value, bind(&ControllerReceiver::set_mid, receiver, _1, _2));
231                 match_controller(controller, MIDIMappingBusProto::kBassFieldNumber, MIDIMappingProto::kBassBankFieldNumber,
232                         value, bind(&ControllerReceiver::set_bass, receiver, _1, _2));
233                 match_controller(controller, MIDIMappingBusProto::kGainFieldNumber, MIDIMappingProto::kGainBankFieldNumber,
234                         value, bind(&ControllerReceiver::set_gain, receiver, _1, _2));
235                 match_controller(controller, MIDIMappingBusProto::kCompressorThresholdFieldNumber, MIDIMappingProto::kCompressorThresholdBankFieldNumber,
236                         value, bind(&ControllerReceiver::set_compressor_threshold, receiver, _1, _2));
237                 match_controller(controller, MIDIMappingBusProto::kFaderFieldNumber, MIDIMappingProto::kFaderBankFieldNumber,
238                         value, bind(&ControllerReceiver::set_fader, receiver, _1, _2));
239                 break;
240         }
241         case SND_SEQ_EVENT_NOTEON: {
242                 const int note = event->data.note.note;
243
244                 receiver->note_on(note);
245
246                 printf("Note: %d\n", note);
247
248                 // Bank change commands. TODO: Highlight the bank change in the UI.
249                 for (size_t bus_idx = 0; bus_idx < size_t(mapping_proto->bus_mapping_size()); ++bus_idx) {
250                         const MIDIMappingBusProto &bus_mapping = mapping_proto->bus_mapping(bus_idx);
251                         if (bus_mapping.has_prev_bank() &&
252                             bus_mapping.prev_bank().note_number() == note) {
253                                 current_controller_bank = (current_controller_bank + num_controller_banks - 1) % num_controller_banks;
254                         }
255                         if (bus_mapping.has_next_bank() &&
256                             bus_mapping.next_bank().note_number() == note) {
257                                 current_controller_bank = (current_controller_bank + 1) % num_controller_banks;
258                         }
259                         if (bus_mapping.has_select_bank_1() &&
260                             bus_mapping.select_bank_1().note_number() == note) {
261                                 current_controller_bank = 0;
262                         }
263                         if (bus_mapping.has_select_bank_2() &&
264                             bus_mapping.select_bank_2().note_number() == note &&
265                             num_controller_banks >= 2) {
266                                 current_controller_bank = 1;
267                         }
268                         if (bus_mapping.has_select_bank_3() &&
269                             bus_mapping.select_bank_3().note_number() == note &&
270                             num_controller_banks >= 3) {
271                                 current_controller_bank = 2;
272                         }
273                         if (bus_mapping.has_select_bank_4() &&
274                             bus_mapping.select_bank_4().note_number() == note &&
275                             num_controller_banks >= 4) {
276                                 current_controller_bank = 3;
277                         }
278                         if (bus_mapping.has_select_bank_5() &&
279                             bus_mapping.select_bank_5().note_number() == note &&
280                             num_controller_banks >= 5) {
281                                 current_controller_bank = 4;
282                         }
283                 }
284
285                 match_button(note, MIDIMappingBusProto::kToggleLocutFieldNumber, MIDIMappingProto::kToggleLocutBankFieldNumber,
286                         bind(&ControllerReceiver::toggle_locut, receiver, _1));
287                 match_button(note, MIDIMappingBusProto::kToggleAutoGainStagingFieldNumber, MIDIMappingProto::kToggleAutoGainStagingBankFieldNumber,
288                         bind(&ControllerReceiver::toggle_auto_gain_staging, receiver, _1));
289                 match_button(note, MIDIMappingBusProto::kToggleCompressorFieldNumber, MIDIMappingProto::kToggleCompressorBankFieldNumber,
290                         bind(&ControllerReceiver::toggle_compressor, receiver, _1));
291                 match_button(note, MIDIMappingBusProto::kClearPeakFieldNumber, MIDIMappingProto::kClearPeakBankFieldNumber,
292                         bind(&ControllerReceiver::clear_peak, receiver, _1));
293         }
294         case SND_SEQ_EVENT_PORT_START:
295                 subscribe_to_port(seq, event->data.addr);
296                 break;
297         case SND_SEQ_EVENT_PORT_EXIT:
298                 printf("MIDI port %d:%d went away.\n", event->data.addr.client, event->data.addr.port);
299                 break;
300         case SND_SEQ_EVENT_NOTEOFF:
301         case SND_SEQ_EVENT_CLIENT_START:
302         case SND_SEQ_EVENT_CLIENT_EXIT:
303         case SND_SEQ_EVENT_CLIENT_CHANGE:
304         case SND_SEQ_EVENT_PORT_CHANGE:
305         case SND_SEQ_EVENT_PORT_SUBSCRIBED:
306         case SND_SEQ_EVENT_PORT_UNSUBSCRIBED:
307                 break;
308         default:
309                 printf("Ignoring MIDI event of unknown type %d.\n", event->type);
310         }
311 }
312
313 void MIDIMapper::subscribe_to_port(snd_seq_t *seq, const snd_seq_addr_t &addr)
314 {
315         // Client 0 is basically the system; ignore it.
316         if (addr.client == 0) {
317                 return;
318         }
319
320         int err = snd_seq_connect_from(seq, 0, addr.client, addr.port);
321         if (err < 0) {
322                 // Just print out a warning (i.e., don't die); it could
323                 // very well just be e.g. another application.
324                 printf("Couldn't subscribe to MIDI port %d:%d (%s).\n",
325                         addr.client, addr.port, snd_strerror(err));
326         } else {
327                 printf("Subscribed to MIDI port %d:%d.\n", addr.client, addr.port);
328         }
329 }
330
331 void MIDIMapper::match_controller(int controller, int field_number, int bank_field_number, float value, function<void(unsigned, float)> func)
332 {
333         if (bank_mismatch(bank_field_number)) {
334                 return;
335         }
336
337         for (size_t bus_idx = 0; bus_idx < size_t(mapping_proto->bus_mapping_size()); ++bus_idx) {
338                 const MIDIMappingBusProto &bus_mapping = mapping_proto->bus_mapping(bus_idx);
339
340                 const FieldDescriptor *descriptor = bus_mapping.GetDescriptor()->FindFieldByNumber(field_number);
341                 const Reflection *bus_reflection = bus_mapping.GetReflection();
342                 if (!bus_reflection->HasField(bus_mapping, descriptor)) {
343                         continue;
344                 }
345                 const MIDIControllerProto &controller_proto =
346                         static_cast<const MIDIControllerProto &>(bus_reflection->GetMessage(bus_mapping, descriptor));
347                 if (controller_proto.controller_number() == controller) {
348                         func(bus_idx, value);
349                 }
350         }
351 }
352
353 void MIDIMapper::match_button(int note, int field_number, int bank_field_number, function<void(unsigned)> func)
354 {
355         if (bank_mismatch(bank_field_number)) {
356                 return;
357         }
358
359         for (size_t bus_idx = 0; bus_idx < size_t(mapping_proto->bus_mapping_size()); ++bus_idx) {
360                 const MIDIMappingBusProto &bus_mapping = mapping_proto->bus_mapping(bus_idx);
361
362                 const FieldDescriptor *descriptor = bus_mapping.GetDescriptor()->FindFieldByNumber(field_number);
363                 const Reflection *bus_reflection = bus_mapping.GetReflection();
364                 if (!bus_reflection->HasField(bus_mapping, descriptor)) {
365                         continue;
366                 }
367                 const MIDIButtonProto &button_proto =
368                         static_cast<const MIDIButtonProto &>(bus_reflection->GetMessage(bus_mapping, descriptor));
369                 if (button_proto.note_number() == note) {
370                         func(bus_idx);
371                 }
372         }
373 }
374
375 bool MIDIMapper::bank_mismatch(int bank_field_number)
376 {
377         const FieldDescriptor *bank_descriptor = mapping_proto->GetDescriptor()->FindFieldByNumber(bank_field_number);
378         const Reflection *reflection = mapping_proto->GetReflection();
379         return (reflection->HasField(*mapping_proto, bank_descriptor) &&
380                 reflection->GetInt32(*mapping_proto, bank_descriptor) != current_controller_bank);
381 }