]> git.sesse.net Git - stockfish/blob - src/cluster.cpp
Always wait before posting the next call in _sync.
[stockfish] / src / cluster.cpp
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2008 Tord Romstad (Glaurung author)
4   Copyright (C) 2008-2015 Marco Costalba, Joona Kiiski, Tord Romstad
5   Copyright (C) 2015-2018 Marco Costalba, Joona Kiiski, Gary Linscott, Tord Romstad
6
7   Stockfish is free software: you can redistribute it and/or modify
8   it under the terms of the GNU General Public License as published by
9   the Free Software Foundation, either version 3 of the License, or
10   (at your option) any later version.
11
12   Stockfish is distributed in the hope that it will be useful,
13   but WITHOUT ANY WARRANTY; without even the implied warranty of
14   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   GNU General Public License for more details.
16
17   You should have received a copy of the GNU General Public License
18   along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 */
20
21 #ifdef USE_MPI
22
23 #include <array>
24 #include <cstddef>
25 #include <cstdlib>
26 #include <iostream>
27 #include <istream>
28 #include <mpi.h>
29 #include <string>
30 #include <vector>
31
32 #include "cluster.h"
33 #include "thread.h"
34 #include "tt.h"
35
36 namespace Cluster {
37
38 static int world_rank = MPI_PROC_NULL;
39 static int world_size = 0;
40
41 static MPI_Request reqSignals = MPI_REQUEST_NULL;
42 static uint64_t signalsCallCounter = 0;
43
44 enum Signals : int { SIG_NODES = 0, SIG_STOP = 1, SIG_TB = 2, SIG_NB = 3};
45 static uint64_t signalsSend[SIG_NB] = {};
46 static uint64_t signalsRecv[SIG_NB] = {};
47
48 static uint64_t nodesSearchedOthers = 0;
49 static uint64_t tbHitsOthers = 0;
50 static uint64_t stopSignalsPosted = 0;
51
52 static MPI_Comm InputComm = MPI_COMM_NULL;
53 static MPI_Comm TTComm = MPI_COMM_NULL;
54 static MPI_Comm MoveComm = MPI_COMM_NULL;
55 static MPI_Comm signalsComm = MPI_COMM_NULL;
56
57 static std::vector<KeyedTTEntry> TTRecvBuff;
58 static MPI_Request reqGather = MPI_REQUEST_NULL;
59 static uint64_t gathersPosted = 0;
60
61 static std::atomic<uint64_t> TTCacheCounter = {};
62
63 static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
64
65
66 void init() {
67
68   int thread_support;
69   MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &thread_support);
70   if (thread_support < MPI_THREAD_MULTIPLE)
71   {
72       std::cerr << "Stockfish requires support for MPI_THREAD_MULTIPLE."
73                 << std::endl;
74       std::exit(EXIT_FAILURE);
75   }
76
77   MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
78   MPI_Comm_size(MPI_COMM_WORLD, &world_size);
79
80   const std::array<MPI_Aint, 4> MIdisps = {offsetof(MoveInfo, move),
81                                            offsetof(MoveInfo, depth),
82                                            offsetof(MoveInfo, score),
83                                            offsetof(MoveInfo, rank)};
84   MPI_Type_create_hindexed_block(4, 1, MIdisps.data(), MPI_INT, &MIDatatype);
85   MPI_Type_commit(&MIDatatype);
86
87   MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
88   MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
89   MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
90   MPI_Comm_dup(MPI_COMM_WORLD, &signalsComm);
91 }
92
93 void finalize() {
94
95
96   // free data tyes and communicators
97   MPI_Type_free(&MIDatatype);
98
99   MPI_Comm_free(&InputComm);
100   MPI_Comm_free(&TTComm);
101   MPI_Comm_free(&MoveComm);
102   MPI_Comm_free(&signalsComm);
103
104   MPI_Finalize();
105 }
106
107 int size() {
108
109   return world_size;
110 }
111
112 int rank() {
113
114   return world_rank;
115 }
116
117 void ttRecvBuff_resize(size_t nThreads) {
118
119   TTRecvBuff.resize(TTCacheSize * world_size * nThreads);
120   std::fill(TTRecvBuff.begin(), TTRecvBuff.end(), KeyedTTEntry());
121
122 }
123
124
125 bool getline(std::istream& input, std::string& str) {
126
127   int size;
128   std::vector<char> vec;
129   bool state;
130
131   if (is_root())
132   {
133       state = static_cast<bool>(std::getline(input, str));
134       vec.assign(str.begin(), str.end());
135       size = vec.size();
136   }
137
138   // Some MPI implementations use busy-wait polling, while we need yielding
139   static MPI_Request reqInput = MPI_REQUEST_NULL;
140   MPI_Ibcast(&size, 1, MPI_INT, 0, InputComm, &reqInput);
141   if (is_root())
142       MPI_Wait(&reqInput, MPI_STATUS_IGNORE);
143   else
144   {
145       while (true)
146       {
147           int flag;
148           MPI_Test(&reqInput, &flag, MPI_STATUS_IGNORE);
149           if (flag)
150               break;
151           else
152               std::this_thread::sleep_for(std::chrono::milliseconds(10));
153       }
154   }
155
156   if (!is_root())
157       vec.resize(size);
158   MPI_Bcast(vec.data(), size, MPI_CHAR, 0, InputComm);
159   if (!is_root())
160       str.assign(vec.begin(), vec.end());
161   MPI_Bcast(&state, 1, MPI_CXX_BOOL, 0, InputComm);
162
163   return state;
164 }
165
166 void signals_send() {
167
168   signalsSend[SIG_NODES] = Threads.nodes_searched();
169   signalsSend[SIG_TB] = Threads.tb_hits();
170   signalsSend[SIG_STOP] = Threads.stop;
171   MPI_Iallreduce(signalsSend, signalsRecv, SIG_NB, MPI_UINT64_T,
172                  MPI_SUM, signalsComm, &reqSignals);
173   ++signalsCallCounter;
174 }
175
176 void signals_process() {
177
178   nodesSearchedOthers = signalsRecv[SIG_NODES] - signalsSend[SIG_NODES];
179   tbHitsOthers = signalsRecv[SIG_TB] - signalsSend[SIG_TB];
180   stopSignalsPosted = signalsRecv[SIG_STOP];
181   if (signalsRecv[SIG_STOP] > 0)
182       Threads.stop = true;
183 }
184
185 void signals_sync() {
186
187   while(stopSignalsPosted < uint64_t(size()))
188       signals_poll();
189
190   // finalize outstanding messages of the signal loops. We might have issued one call less than needed on some ranks.
191   uint64_t globalCounter;
192   MPI_Allreduce(&signalsCallCounter, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); // MoveComm needed
193   if (signalsCallCounter < globalCounter)
194   {
195       MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
196       signals_send();
197   }
198   assert(signalsCallCounter == globalCounter);
199   MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
200
201   signals_process();
202
203   // finalize outstanding messages in the gather loop
204   MPI_Allreduce(&gathersPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm);
205   if (gathersPosted < globalCounter)
206   {
207      size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
208      MPI_Wait(&reqGather, MPI_STATUS_IGNORE);
209      MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
210                     TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
211                     TTComm, &reqGather);
212      ++gathersPosted;
213   }
214   assert(gathersPosted == globalCounter);
215   MPI_Wait(&reqGather, MPI_STATUS_IGNORE);
216
217 }
218
219 void signals_init() {
220
221   stopSignalsPosted = tbHitsOthers = nodesSearchedOthers = 0;
222
223   signalsSend[SIG_NODES] = signalsRecv[SIG_NODES] = 0;
224   signalsSend[SIG_TB] = signalsRecv[SIG_TB] = 0;
225   signalsSend[SIG_STOP] = signalsRecv[SIG_STOP] = 0;
226
227 }
228
229 void signals_poll() {
230
231   int flag;
232   MPI_Test(&reqSignals, &flag, MPI_STATUS_IGNORE);
233   if (flag)
234   {
235      signals_process();
236      signals_send();
237   }
238 }
239
240 void save(Thread* thread, TTEntry* tte,
241           Key k, Value v, Bound b, Depth d, Move m, Value ev) {
242
243   tte->save(k, v, b, d, m, ev);
244
245   if (d > 5 * ONE_PLY)
246   {
247      // Try to add to thread's send buffer
248      {
249          std::lock_guard<Mutex> lk(thread->ttCache.mutex);
250          thread->ttCache.buffer.replace(KeyedTTEntry(k,*tte));
251          ++TTCacheCounter;
252      }
253
254      size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
255
256      // Communicate on main search thread
257      if (thread == Threads.main() && TTCacheCounter > size() * recvBuffPerRankSize)
258      {
259          // Test communication status
260          int flag;
261          MPI_Test(&reqGather, &flag, MPI_STATUS_IGNORE);
262
263          // Current communication is complete
264          if (flag)
265          {
266              // Save all received entries to TT, and store our TTCaches, ready for the next round of communication
267              for (size_t irank = 0; irank < size_t(size()) ; ++irank)
268              {
269                  if (irank == size_t(rank()))
270                  {
271                     // Copy from the thread caches to the right spot in the buffer
272                     size_t i = irank * recvBuffPerRankSize;
273                     for (auto&& th : Threads)
274                     {
275                         std::lock_guard<Mutex> lk(th->ttCache.mutex);
276
277                         for (auto&& e : th->ttCache.buffer)
278                             TTRecvBuff[i++] = e;
279
280                         // Reset thread's send buffer
281                         th->ttCache.buffer = {};
282                     }
283
284                     TTCacheCounter = 0;
285                  }
286                  else
287                     for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i)
288                     {
289                         auto&& e = TTRecvBuff[i];
290                         bool found;
291                         TTEntry* replace_tte;
292                         replace_tte = TT.probe(e.first, found);
293                         replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
294                                           e.second.move(), e.second.eval());
295                     }
296              }
297
298              // Start next communication
299              MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
300                             TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
301                             TTComm, &reqGather);
302              ++gathersPosted;
303
304              // Force check of time on the next occasion.
305              static_cast<MainThread*>(thread)->callsCnt = 0;
306
307          }
308      }
309   }
310 }
311
312
313 // TODO update to the scheme in master.. can this use aggregation of votes?
314 void pick_moves(MoveInfo& mi) {
315
316   MoveInfo* pMoveInfo = NULL;
317   if (is_root())
318   {
319       pMoveInfo = (MoveInfo*)malloc(sizeof(MoveInfo) * size());
320   }
321   MPI_Gather(&mi, 1, MIDatatype, pMoveInfo, 1, MIDatatype, 0, MoveComm);
322
323   if (is_root())
324   {
325       std::map<int, int> votes;
326       int minScore = pMoveInfo[0].score;
327       for (int i = 0; i < size(); ++i)
328       {
329           minScore = std::min(minScore, pMoveInfo[i].score);
330           votes[pMoveInfo[i].move] = 0;
331       }
332       for (int i = 0; i < size(); ++i)
333       {
334           votes[pMoveInfo[i].move] += pMoveInfo[i].score - minScore + pMoveInfo[i].depth;
335       }
336       int bestVote = votes[pMoveInfo[0].move];
337       for (int i = 0; i < size(); ++i)
338       {
339           if (votes[pMoveInfo[i].move] > bestVote)
340           {
341               bestVote = votes[pMoveInfo[i].move];
342               mi = pMoveInfo[i];
343           }
344       }
345       free(pMoveInfo);
346   }
347   MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
348 }
349
350 uint64_t nodes_searched() {
351
352   return nodesSearchedOthers + Threads.nodes_searched();
353 }
354
355 uint64_t tb_hits() {
356
357   return tbHitsOthers + Threads.tb_hits();
358 }
359
360 }
361
362 #else
363
364 #include "cluster.h"
365 #include "thread.h"
366
367 namespace Cluster {
368
369 uint64_t nodes_searched() {
370
371   return Threads.nodes_searched();
372 }
373
374 uint64_t tb_hits() {
375
376   return Threads.tb_hits();
377 }
378
379 }
380
381 #endif // USE_MPI