]> git.sesse.net Git - stockfish/blob - src/cluster.cpp
[cluster] Add depth condition to cluster TT saves.
[stockfish] / src / cluster.cpp
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2008 Tord Romstad (Glaurung author)
4   Copyright (C) 2008-2015 Marco Costalba, Joona Kiiski, Tord Romstad
5   Copyright (C) 2015-2018 Marco Costalba, Joona Kiiski, Gary Linscott, Tord Romstad
6
7   Stockfish is free software: you can redistribute it and/or modify
8   it under the terms of the GNU General Public License as published by
9   the Free Software Foundation, either version 3 of the License, or
10   (at your option) any later version.
11
12   Stockfish is distributed in the hope that it will be useful,
13   but WITHOUT ANY WARRANTY; without even the implied warranty of
14   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   GNU General Public License for more details.
16
17   You should have received a copy of the GNU General Public License
18   along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 */
20
21 #ifdef USE_MPI
22
23 #include <array>
24 #include <cstddef>
25 #include <cstdlib>
26 #include <iostream>
27 #include <istream>
28 #include <mpi.h>
29 #include <string>
30 #include <vector>
31
32 #include "cluster.h"
33 #include "thread.h"
34 #include "tt.h"
35
36 namespace Cluster {
37
38 static int world_rank = MPI_PROC_NULL;
39 static int world_size = 0;
40 static bool stop_signal = false;
41 static MPI_Request reqStop = MPI_REQUEST_NULL;
42
43 static MPI_Comm InputComm = MPI_COMM_NULL;
44 static MPI_Comm TTComm = MPI_COMM_NULL;
45 static MPI_Comm MoveComm = MPI_COMM_NULL;
46 static MPI_Comm StopComm = MPI_COMM_NULL;
47
48 static MPI_Datatype TTEntryDatatype = MPI_DATATYPE_NULL;
49 static std::vector<KeyedTTEntry> TTBuff;
50
51 static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
52
53 void init() {
54   int thread_support;
55   constexpr std::array<int, 7> TTblocklens = {1, 1, 1, 1, 1, 1, 1};
56   const std::array<MPI_Aint, 7> TTdisps = {offsetof(KeyedTTEntry, first),
57                                            offsetof(KeyedTTEntry, second) + offsetof(TTEntry, key16),
58                                            offsetof(KeyedTTEntry, second) + offsetof(TTEntry, move16),
59                                            offsetof(KeyedTTEntry, second) + offsetof(TTEntry, value16),
60                                            offsetof(KeyedTTEntry, second) + offsetof(TTEntry, eval16),
61                                            offsetof(KeyedTTEntry, second) + offsetof(TTEntry, genBound8),
62                                            offsetof(KeyedTTEntry, second) + offsetof(TTEntry, depth8)};
63   const std::array<MPI_Datatype, 7> TTtypes = {MPI_UINT64_T,
64                                                MPI_UINT16_T,
65                                                MPI_UINT16_T,
66                                                MPI_INT16_T,
67                                                MPI_INT16_T,
68                                                MPI_UINT8_T,
69                                                MPI_INT8_T};
70   const std::array<MPI_Aint, 4> MIdisps = {offsetof(MoveInfo, move),
71                                            offsetof(MoveInfo, depth),
72                                            offsetof(MoveInfo, score),
73                                            offsetof(MoveInfo, rank)};
74
75   MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &thread_support);
76   if (thread_support < MPI_THREAD_MULTIPLE)
77   {
78       std::cerr << "Stockfish requires support for MPI_THREAD_MULTIPLE."
79                 << std::endl;
80       std::exit(EXIT_FAILURE);
81   }
82
83   MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
84   MPI_Comm_size(MPI_COMM_WORLD, &world_size);
85
86   TTBuff.resize(TTSendBufferSize * world_size);
87
88   MPI_Type_create_struct(7, TTblocklens.data(), TTdisps.data(), TTtypes.data(),
89                          &TTEntryDatatype);
90   MPI_Type_commit(&TTEntryDatatype);
91
92   MPI_Type_create_hindexed_block(4, 1, MIdisps.data(), MPI_INT, &MIDatatype);
93   MPI_Type_commit(&MIDatatype);
94
95   MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
96   MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
97   MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
98   MPI_Comm_dup(MPI_COMM_WORLD, &StopComm);
99 }
100
101 void finalize() {
102   MPI_Finalize();
103 }
104
105 bool getline(std::istream& input, std::string& str) {
106   int size;
107   std::vector<char> vec;
108   bool state;
109
110   if (is_root())
111   {
112       state = static_cast<bool>(std::getline(input, str));
113       vec.assign(str.begin(), str.end());
114       size = vec.size();
115   }
116
117   // Some MPI implementations use busy-wait pooling, while we need yielding
118   static MPI_Request reqInput = MPI_REQUEST_NULL;
119   MPI_Ibcast(&size, 1, MPI_INT, 0, InputComm, &reqInput);
120   if (is_root())
121       MPI_Wait(&reqInput, MPI_STATUS_IGNORE);
122   else {
123       while (true) {
124           int flag;
125           MPI_Test(&reqInput, &flag, MPI_STATUS_IGNORE);
126           if (flag)
127               break;
128           else {
129               std::this_thread::sleep_for(std::chrono::milliseconds(10));
130           }
131       }
132   }
133   if (!is_root())
134       vec.resize(size);
135   MPI_Bcast(vec.data(), size, MPI_CHAR, 0, InputComm);
136   if (!is_root())
137       str.assign(vec.begin(), vec.end());
138   MPI_Bcast(&state, 1, MPI_CXX_BOOL, 0, InputComm);
139   return state;
140 }
141
142 void sync_start() {
143   stop_signal = false;
144
145   // Start listening to stop signal
146   if (!is_root())
147       MPI_Ibarrier(StopComm, &reqStop);
148 }
149
150 void sync_stop() {
151   if (is_root()) {
152       if (!stop_signal && Threads.stop) {
153           // Signal the cluster about stopping
154           stop_signal = true;
155           MPI_Ibarrier(StopComm, &reqStop);
156           MPI_Wait(&reqStop, MPI_STATUS_IGNORE);
157       }
158   }
159   else {
160       int flagStop;
161       // Check if we've received any stop signal
162       MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE);
163       if (flagStop)
164           Threads.stop = true;
165   }
166 }
167
168 int size() {
169   return world_size;
170 }
171
172 int rank() {
173   return world_rank;
174 }
175
176 void save(Thread* thread, TTEntry* tte,
177           Key k, Value v, Bound b, Depth d, Move m, Value ev) {
178
179   tte->save(k, v, b, d, m, ev);
180
181   if (d > 5 * ONE_PLY)
182   {
183      // Try to add to thread's send buffer
184      {
185          std::lock_guard<Mutex> lk(thread->ttBuffer.mutex);
186          thread->ttBuffer.buffer.replace(KeyedTTEntry(k,*tte));
187      }
188
189      // Communicate on main search thread
190      if (thread == Threads.main()) {
191          static MPI_Request req = MPI_REQUEST_NULL;
192          static TTSendBuffer<TTSendBufferSize> send_buff = {};
193          int flag;
194          bool found;
195          TTEntry* replace_tte;
196
197          // Test communication status
198          MPI_Test(&req, &flag, MPI_STATUS_IGNORE);
199
200          // Current communication is complete
201          if (flag) {
202              // Save all recieved entries
203              for (auto&& e : TTBuff) {
204                  replace_tte = TT.probe(e.first, found);
205                  replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
206                                    e.second.move(), e.second.eval());
207              }
208
209              // Reset send buffer
210              send_buff = {};
211
212              // Build up new send buffer: best 16 found across all threads
213              for (auto&& th : Threads) {
214                  std::lock_guard<Mutex> lk(th->ttBuffer.mutex);
215                  for (auto&& e : th->ttBuffer.buffer)
216                      send_buff.replace(e);
217                  // Reset thread's send buffer
218                  th->ttBuffer.buffer = {};
219              }
220
221              // Start next communication
222              MPI_Iallgather(send_buff.data(), send_buff.size(), TTEntryDatatype,
223                             TTBuff.data(), TTSendBufferSize, TTEntryDatatype,
224                             TTComm, &req);
225          }
226      }
227   }
228 }
229
230 void pick_moves(MoveInfo& mi) {
231   MoveInfo* pMoveInfo = NULL;
232   if (is_root()) {
233       pMoveInfo = (MoveInfo*)malloc(sizeof(MoveInfo) * size());
234   }
235   MPI_Gather(&mi, 1, MIDatatype, pMoveInfo, 1, MIDatatype, 0, MoveComm);
236   if (is_root()) {
237       std::map<int, int> votes;
238       int minScore = pMoveInfo[0].score;
239       for (int i = 0; i < size(); i++) {
240           minScore = std::min(minScore, pMoveInfo[i].score);
241           votes[pMoveInfo[i].move] = 0;
242       }
243       for (int i = 0; i < size(); i++) {
244           votes[pMoveInfo[i].move] += pMoveInfo[i].score - minScore + pMoveInfo[i].depth;
245       }
246       int bestVote = votes[pMoveInfo[0].move];
247       for (int i = 0; i < size(); i++) {
248           if (votes[pMoveInfo[i].move] > bestVote) {
249               bestVote = votes[pMoveInfo[i].move];
250               mi = pMoveInfo[i];
251           }
252       }
253       free(pMoveInfo);
254   }
255   MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
256 }
257
258 }
259
260 #endif // USE_MPI