2 Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3 Copyright (C) 2004-2008 Tord Romstad (Glaurung author)
4 Copyright (C) 2008-2015 Marco Costalba, Joona Kiiski, Tord Romstad
5 Copyright (C) 2015-2018 Marco Costalba, Joona Kiiski, Gary Linscott, Tord Romstad
7 Stockfish is free software: you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation, either version 3 of the License, or
10 (at your option) any later version.
12 Stockfish is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with this program. If not, see <http://www.gnu.org/licenses/>.
38 static int world_rank = MPI_PROC_NULL;
39 static int world_size = 0;
41 static MPI_Request reqSignals = MPI_REQUEST_NULL;
42 static uint64_t signalsCallCounter = 0;
44 enum Signals : int { SIG_NODES = 0, SIG_STOP = 1, SIG_TB = 2, SIG_NB = 3};
45 static uint64_t signalsSend[SIG_NB] = {};
46 static uint64_t signalsRecv[SIG_NB] = {};
48 static uint64_t nodesSearchedOthers = 0;
49 static uint64_t tbHitsOthers = 0;
50 static uint64_t stopSignalsPosted = 0;
52 static MPI_Comm InputComm = MPI_COMM_NULL;
53 static MPI_Comm TTComm = MPI_COMM_NULL;
54 static MPI_Comm MoveComm = MPI_COMM_NULL;
55 static MPI_Comm signalsComm = MPI_COMM_NULL;
57 static std::vector<KeyedTTEntry> TTRecvBuff;
58 static MPI_Request reqGather = MPI_REQUEST_NULL;
59 static uint64_t gathersPosted = 0;
61 static std::atomic<uint64_t> TTCacheCounter = {};
63 static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
69 MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &thread_support);
70 if (thread_support < MPI_THREAD_MULTIPLE)
72 std::cerr << "Stockfish requires support for MPI_THREAD_MULTIPLE."
74 std::exit(EXIT_FAILURE);
77 MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
78 MPI_Comm_size(MPI_COMM_WORLD, &world_size);
80 const std::array<MPI_Aint, 4> MIdisps = {offsetof(MoveInfo, move),
81 offsetof(MoveInfo, depth),
82 offsetof(MoveInfo, score),
83 offsetof(MoveInfo, rank)};
84 MPI_Type_create_hindexed_block(4, 1, MIdisps.data(), MPI_INT, &MIDatatype);
85 MPI_Type_commit(&MIDatatype);
87 MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
88 MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
89 MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
90 MPI_Comm_dup(MPI_COMM_WORLD, &signalsComm);
96 // free data tyes and communicators
97 MPI_Type_free(&MIDatatype);
99 MPI_Comm_free(&InputComm);
100 MPI_Comm_free(&TTComm);
101 MPI_Comm_free(&MoveComm);
102 MPI_Comm_free(&signalsComm);
117 void ttRecvBuff_resize(size_t nThreads) {
119 TTRecvBuff.resize(TTCacheSize * world_size * nThreads);
120 std::fill(TTRecvBuff.begin(), TTRecvBuff.end(), KeyedTTEntry());
125 bool getline(std::istream& input, std::string& str) {
128 std::vector<char> vec;
133 state = static_cast<bool>(std::getline(input, str));
134 vec.assign(str.begin(), str.end());
138 // Some MPI implementations use busy-wait polling, while we need yielding
139 static MPI_Request reqInput = MPI_REQUEST_NULL;
140 MPI_Ibcast(&size, 1, MPI_INT, 0, InputComm, &reqInput);
142 MPI_Wait(&reqInput, MPI_STATUS_IGNORE);
148 MPI_Test(&reqInput, &flag, MPI_STATUS_IGNORE);
152 std::this_thread::sleep_for(std::chrono::milliseconds(10));
158 MPI_Bcast(vec.data(), size, MPI_CHAR, 0, InputComm);
160 str.assign(vec.begin(), vec.end());
161 MPI_Bcast(&state, 1, MPI_CXX_BOOL, 0, InputComm);
166 void signals_send() {
168 signalsSend[SIG_NODES] = Threads.nodes_searched();
169 signalsSend[SIG_TB] = Threads.tb_hits();
170 signalsSend[SIG_STOP] = Threads.stop;
171 MPI_Iallreduce(signalsSend, signalsRecv, SIG_NB, MPI_UINT64_T,
172 MPI_SUM, signalsComm, &reqSignals);
173 ++signalsCallCounter;
176 void signals_process() {
178 nodesSearchedOthers = signalsRecv[SIG_NODES] - signalsSend[SIG_NODES];
179 tbHitsOthers = signalsRecv[SIG_TB] - signalsSend[SIG_TB];
180 stopSignalsPosted = signalsRecv[SIG_STOP];
181 if (signalsRecv[SIG_STOP] > 0)
185 void signals_sync() {
187 while(stopSignalsPosted < uint64_t(size()))
190 // finalize outstanding messages of the signal loops. We might have issued one call less than needed on some ranks.
191 uint64_t globalCounter;
192 MPI_Allreduce(&signalsCallCounter, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); // MoveComm needed
193 if (signalsCallCounter < globalCounter)
195 MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
198 assert(signalsCallCounter == globalCounter);
199 MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
203 // finalize outstanding messages in the gather loop
204 MPI_Allreduce(&gathersPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm);
205 if (gathersPosted < globalCounter)
207 size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
208 MPI_Wait(&reqGather, MPI_STATUS_IGNORE);
209 MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
210 TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
214 assert(gathersPosted == globalCounter);
215 MPI_Wait(&reqGather, MPI_STATUS_IGNORE);
219 void signals_init() {
221 stopSignalsPosted = tbHitsOthers = nodesSearchedOthers = 0;
223 signalsSend[SIG_NODES] = signalsRecv[SIG_NODES] = 0;
224 signalsSend[SIG_TB] = signalsRecv[SIG_TB] = 0;
225 signalsSend[SIG_STOP] = signalsRecv[SIG_STOP] = 0;
229 void signals_poll() {
232 MPI_Test(&reqSignals, &flag, MPI_STATUS_IGNORE);
240 void save(Thread* thread, TTEntry* tte,
241 Key k, Value v, Bound b, Depth d, Move m, Value ev) {
243 tte->save(k, v, b, d, m, ev);
247 // Try to add to thread's send buffer
249 std::lock_guard<Mutex> lk(thread->ttCache.mutex);
250 thread->ttCache.buffer.replace(KeyedTTEntry(k,*tte));
254 size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
256 // Communicate on main search thread
257 if (thread == Threads.main() && TTCacheCounter > size() * recvBuffPerRankSize)
259 // Test communication status
261 MPI_Test(&reqGather, &flag, MPI_STATUS_IGNORE);
263 // Current communication is complete
266 // Save all received entries to TT, and store our TTCaches, ready for the next round of communication
267 for (size_t irank = 0; irank < size_t(size()) ; ++irank)
269 if (irank == size_t(rank()))
271 // Copy from the thread caches to the right spot in the buffer
272 size_t i = irank * recvBuffPerRankSize;
273 for (auto&& th : Threads)
275 std::lock_guard<Mutex> lk(th->ttCache.mutex);
277 for (auto&& e : th->ttCache.buffer)
280 // Reset thread's send buffer
281 th->ttCache.buffer = {};
287 for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i)
289 auto&& e = TTRecvBuff[i];
291 TTEntry* replace_tte;
292 replace_tte = TT.probe(e.first, found);
293 replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
294 e.second.move(), e.second.eval());
298 // Start next communication
299 MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
300 TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
304 // Force check of time on the next occasion.
305 static_cast<MainThread*>(thread)->callsCnt = 0;
313 // TODO update to the scheme in master.. can this use aggregation of votes?
314 void pick_moves(MoveInfo& mi) {
316 MoveInfo* pMoveInfo = NULL;
319 pMoveInfo = (MoveInfo*)malloc(sizeof(MoveInfo) * size());
321 MPI_Gather(&mi, 1, MIDatatype, pMoveInfo, 1, MIDatatype, 0, MoveComm);
325 std::map<int, int> votes;
326 int minScore = pMoveInfo[0].score;
327 for (int i = 0; i < size(); ++i)
329 minScore = std::min(minScore, pMoveInfo[i].score);
330 votes[pMoveInfo[i].move] = 0;
332 for (int i = 0; i < size(); ++i)
334 votes[pMoveInfo[i].move] += pMoveInfo[i].score - minScore + pMoveInfo[i].depth;
336 int bestVote = votes[pMoveInfo[0].move];
337 for (int i = 0; i < size(); ++i)
339 if (votes[pMoveInfo[i].move] > bestVote)
341 bestVote = votes[pMoveInfo[i].move];
347 MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
350 uint64_t nodes_searched() {
352 return nodesSearchedOthers + Threads.nodes_searched();
357 return tbHitsOthers + Threads.tb_hits();
369 uint64_t nodes_searched() {
371 return Threads.nodes_searched();
376 return Threads.tb_hits();