]> git.sesse.net Git - stockfish/blob - src/cluster.cpp
Implement yielding loop while waiting for input
[stockfish] / src / cluster.cpp
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2008 Tord Romstad (Glaurung author)
4   Copyright (C) 2008-2015 Marco Costalba, Joona Kiiski, Tord Romstad
5   Copyright (C) 2015-2018 Marco Costalba, Joona Kiiski, Gary Linscott, Tord Romstad
6
7   Stockfish is free software: you can redistribute it and/or modify
8   it under the terms of the GNU General Public License as published by
9   the Free Software Foundation, either version 3 of the License, or
10   (at your option) any later version.
11
12   Stockfish is distributed in the hope that it will be useful,
13   but WITHOUT ANY WARRANTY; without even the implied warranty of
14   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   GNU General Public License for more details.
16
17   You should have received a copy of the GNU General Public License
18   along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 */
20
21 #ifdef USE_MPI
22
23 #include <array>
24 #include <cstddef>
25 #include <cstdlib>
26 #include <iostream>
27 #include <istream>
28 #include <mpi.h>
29 #include <string>
30 #include <vector>
31
32 #include "cluster.h"
33 #include "thread.h"
34 #include "tt.h"
35
36 namespace Cluster {
37
38 static int world_rank = MPI_PROC_NULL;
39 static int world_size = 0;
40 static bool stop_signal = false;
41 static MPI_Request reqStop = MPI_REQUEST_NULL;
42
43 static MPI_Comm InputComm = MPI_COMM_NULL;
44 static MPI_Comm TTComm = MPI_COMM_NULL;
45 static MPI_Comm MoveComm = MPI_COMM_NULL;
46 static MPI_Comm StopComm = MPI_COMM_NULL;
47
48 static MPI_Datatype TTEntryDatatype = MPI_DATATYPE_NULL;
49 static std::vector<TTEntry> TTBuff;
50
51 static MPI_Op BestMoveOp = MPI_OP_NULL;
52 static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
53
54 static void BestMove(void* in, void* inout, int* len, MPI_Datatype* datatype) {
55   if (*datatype != MIDatatype)
56       MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
57   MoveInfo* l = static_cast<MoveInfo*>(in);
58   MoveInfo* r = static_cast<MoveInfo*>(inout);
59   for (int i=0; i < *len; ++i)
60   {
61       if (l[i].depth >= r[i].depth && l[i].score >= r[i].score)
62          r[i] = l[i];
63   }
64 }
65
66 void init() {
67   int thread_support;
68   constexpr std::array<int, 6> TTblocklens = {1, 1, 1, 1, 1, 1};
69   const std::array<MPI_Aint, 6> TTdisps = {offsetof(TTEntry, key16),
70                                            offsetof(TTEntry, move16),
71                                            offsetof(TTEntry, value16),
72                                            offsetof(TTEntry, eval16),
73                                            offsetof(TTEntry, genBound8),
74                                            offsetof(TTEntry, depth8)};
75   const std::array<MPI_Datatype, 6> TTtypes = {MPI_UINT16_T,
76                                                MPI_UINT16_T,
77                                                MPI_INT16_T,
78                                                MPI_INT16_T,
79                                                MPI_UINT8_T,
80                                                MPI_INT8_T};
81   const std::array<MPI_Aint, 3> MIdisps = {offsetof(MoveInfo, depth),
82                                            offsetof(MoveInfo, score),
83                                            offsetof(MoveInfo, rank)};
84
85   MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &thread_support);
86   if (thread_support < MPI_THREAD_MULTIPLE)
87   {
88       std::cerr << "Stockfish requires support for MPI_THREAD_MULTIPLE."
89                 << std::endl;
90       std::exit(EXIT_FAILURE);
91   }
92
93   MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
94   MPI_Comm_size(MPI_COMM_WORLD, &world_size);
95
96   TTBuff.resize(TTSendBufferSize * world_size);
97
98   MPI_Type_create_struct(6, TTblocklens.data(), TTdisps.data(), TTtypes.data(),
99                          &TTEntryDatatype);
100   MPI_Type_commit(&TTEntryDatatype);
101
102   MPI_Type_create_hindexed_block(3, 1, MIdisps.data(), MPI_INT, &MIDatatype);
103   MPI_Type_commit(&MIDatatype);
104   MPI_Op_create(BestMove, false, &BestMoveOp);
105
106   MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
107   MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
108   MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
109   MPI_Comm_dup(MPI_COMM_WORLD, &StopComm);
110 }
111
112 void finalize() {
113   MPI_Finalize();
114 }
115
116 bool getline(std::istream& input, std::string& str) {
117   int size;
118   std::vector<char> vec;
119   bool state;
120
121   if (is_root())
122   {
123       state = static_cast<bool>(std::getline(input, str));
124       vec.assign(str.begin(), str.end());
125       size = vec.size();
126   }
127
128   // Some MPI implementations use busy-wait pooling, while we need yielding
129   static MPI_Request reqInput = MPI_REQUEST_NULL;
130   MPI_Ibarrier(InputComm, &reqInput);
131   if (is_root())
132       MPI_Wait(&reqInput, MPI_STATUS_IGNORE);
133   else {
134       while (true) {
135           static int flag;
136           MPI_Test(&reqInput, &flag, MPI_STATUS_IGNORE);
137           if (flag)
138               break;
139           else
140               std::this_thread::sleep_for(std::chrono::milliseconds(10));
141       }
142   }
143
144   MPI_Bcast(&size, 1, MPI_UNSIGNED_LONG, 0, InputComm);
145   if (!is_root())
146       vec.resize(size);
147   MPI_Bcast(vec.data(), size, MPI_CHAR, 0, InputComm);
148   if (!is_root())
149       str.assign(vec.begin(), vec.end());
150   MPI_Bcast(&state, 1, MPI_CXX_BOOL, 0, InputComm);
151   return state;
152 }
153
154 void sync_start() {
155   stop_signal = false;
156
157   // Start listening to stop signal
158   if (!is_root())
159       MPI_Ibarrier(StopComm, &reqStop);
160 }
161
162 void sync_stop() {
163   if (is_root()) {
164       if (!stop_signal && Threads.stop) {
165           // Signal the cluster about stopping
166           stop_signal = true;
167           MPI_Ibarrier(StopComm, &reqStop);
168           MPI_Wait(&reqStop, MPI_STATUS_IGNORE);
169       }
170   }
171   else {
172       int flagStop;
173       // Check if we've received any stop signal
174       MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE);
175       if (flagStop)
176           Threads.stop = true;
177   }
178 }
179
180 int size() {
181   return world_size;
182 }
183
184 int rank() {
185   return world_rank;
186 }
187
188 void save(Thread* thread, TTEntry* tte,
189           Key k, Value v, Bound b, Depth d, Move m, Value ev, uint8_t g) {
190   tte->save(k, v, b, d, m, ev, g);
191   // Try to add to thread's send buffer
192   {
193       std::lock_guard<Mutex> lk(thread->ttBuffer.mutex);
194       thread->ttBuffer.buffer.replace(*tte);
195   }
196
197   // Communicate on main search thread
198   if (thread == Threads.main()) {
199       static MPI_Request req = MPI_REQUEST_NULL;
200       static TTSendBuffer<TTSendBufferSize> send_buff = {};
201       int flag;
202       bool found;
203       TTEntry* replace_tte;
204
205       // Test communication status
206       MPI_Test(&req, &flag, MPI_STATUS_IGNORE);
207
208       // Current communication is complete
209       if (flag) {
210           // Save all recieved entries
211           for (auto&& e : TTBuff) {
212               replace_tte = TT.probe(e.key(), found);
213               replace_tte->save(e.key(), e.value(), e.bound(), e.depth(),
214                                 e.move(), e.eval(), e.gen());
215           }
216
217           // Reset send buffer
218           send_buff = {};
219
220           // Build up new send buffer: best 16 found across all threads
221           for (auto&& th : Threads) {
222               std::lock_guard<Mutex> lk(th->ttBuffer.mutex);
223               for (auto&& e : th->ttBuffer.buffer)
224                   send_buff.replace(e);
225               // Reset thread's send buffer
226               th->ttBuffer.buffer = {};
227           }
228
229           // Start next communication
230           MPI_Iallgather(send_buff.data(), send_buff.size(), TTEntryDatatype,
231                          TTBuff.data(), TTSendBufferSize, TTEntryDatatype,
232                          TTComm, &req);
233       }
234   }
235 }
236
237 void reduce_moves(MoveInfo& mi) {
238   MPI_Allreduce(MPI_IN_PLACE, &mi, 1, MIDatatype, BestMoveOp, MoveComm);
239 }
240
241 }
242
243 #endif // USE_MPI