static MPI_Comm MoveComm = MPI_COMM_NULL;
static MPI_Comm signalsComm = MPI_COMM_NULL;
-static std::vector<KeyedTTEntry> TTBuff;
+static std::vector<KeyedTTEntry> TTRecvBuff;
+static MPI_Request reqGather = MPI_REQUEST_NULL;
+static uint64_t gathersPosted = 0;
+
+static std::atomic<uint64_t> TTCacheCounter = {};
static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
+
void init() {
int thread_support;
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
- TTBuff.resize(TTSendBufferSize * world_size);
-
const std::array<MPI_Aint, 4> MIdisps = {offsetof(MoveInfo, move),
offsetof(MoveInfo, depth),
offsetof(MoveInfo, score),
return world_rank;
}
+void ttRecvBuff_resize(size_t nThreads) {
+
+ TTRecvBuff.resize(TTCacheSize * world_size * nThreads);
+ std::fill(TTRecvBuff.begin(), TTRecvBuff.end(), KeyedTTEntry());
+
+}
+
bool getline(std::istream& input, std::string& str) {
signals_process();
+ // finalize outstanding messages in the gather loop
+ MPI_Allreduce(&gathersPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm);
+ if (gathersPosted < globalCounter)
+ {
+ size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
+ MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
+ TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
+ TTComm, &reqGather);
+ ++gathersPosted;
+ }
+ assert(gathersPosted == globalCounter);
+
}
void signals_init() {
{
// Try to add to thread's send buffer
{
- std::lock_guard<Mutex> lk(thread->ttBuffer.mutex);
- thread->ttBuffer.buffer.replace(KeyedTTEntry(k,*tte));
- ++thread->ttBuffer.counter;
+ std::lock_guard<Mutex> lk(thread->ttCache.mutex);
+ thread->ttCache.buffer.replace(KeyedTTEntry(k,*tte));
+ ++TTCacheCounter;
}
+ size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
+
// Communicate on main search thread
- if (thread == Threads.main() && thread->ttBuffer.counter * Threads.size() > TTSendBufferSize)
+ if (thread == Threads.main() && TTCacheCounter > size() * recvBuffPerRankSize)
{
- static MPI_Request req = MPI_REQUEST_NULL;
- static TTSendBuffer<TTSendBufferSize> send_buff = {};
- int flag;
-
// Test communication status
- MPI_Test(&req, &flag, MPI_STATUS_IGNORE);
+ int flag;
+ MPI_Test(&reqGather, &flag, MPI_STATUS_IGNORE);
// Current communication is complete
if (flag)
{
- // Save all received entries (except ours)
+ // Save all received entries to TT, and store our TTCaches, ready for the next round of communication
for (size_t irank = 0; irank < size_t(size()) ; ++irank)
{
if (irank == size_t(rank()))
- continue;
-
- for (size_t i = irank * TTSendBufferSize ; i < (irank + 1) * TTSendBufferSize; ++i)
{
- auto&& e = TTBuff[i];
- bool found;
- TTEntry* replace_tte;
- replace_tte = TT.probe(e.first, found);
- replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
- e.second.move(), e.second.eval());
- }
- }
+ // Copy from the thread caches to the right spot in the buffer
+ size_t i = irank * recvBuffPerRankSize;
+ for (auto&& th : Threads)
+ {
+ std::lock_guard<Mutex> lk(th->ttCache.mutex);
- // Reset send buffer
- send_buff = {};
+ for (auto&& e : th->ttCache.buffer)
+ TTRecvBuff[i++] = e;
- // Build up new send buffer: best 16 found across all threads
- for (auto&& th : Threads)
- {
- std::lock_guard<Mutex> lk(th->ttBuffer.mutex);
- for (auto&& e : th->ttBuffer.buffer)
- send_buff.replace(e);
- // Reset thread's send buffer
- th->ttBuffer.buffer = {};
- th->ttBuffer.counter = 0;
+ // Reset thread's send buffer
+ th->ttCache.buffer = {};
+ }
+
+ TTCacheCounter = 0;
+ }
+ else
+ for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i)
+ {
+ auto&& e = TTRecvBuff[i];
+ bool found;
+ TTEntry* replace_tte;
+ replace_tte = TT.probe(e.first, found);
+ replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
+ e.second.move(), e.second.eval());
+ }
}
// Start next communication
- MPI_Iallgather(send_buff.data(), send_buff.size() * sizeof(KeyedTTEntry), MPI_BYTE,
- TTBuff.data(), TTSendBufferSize * sizeof(KeyedTTEntry), MPI_BYTE,
- TTComm, &req);
+ MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
+ TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
+ TTComm, &reqGather);
+ ++gathersPosted;
+
+ // Force check of time on the next occasion.
+ static_cast<MainThread*>(thread)->callsCnt = 0;
+
}
}
}
#ifdef USE_MPI
using KeyedTTEntry = std::pair<Key, TTEntry>;
-constexpr std::size_t TTSendBufferSize = 32;
-template <std::size_t N> class TTSendBuffer : public std::array<KeyedTTEntry, N> {
+constexpr std::size_t TTCacheSize = 32;
+template <std::size_t N> class TTCache : public std::array<KeyedTTEntry, N> {
struct Compare {
inline bool operator()(const KeyedTTEntry& lhs, const KeyedTTEntry& rhs) {
inline bool is_root() { return rank() == 0; }
void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev);
void pick_moves(MoveInfo& mi);
+void ttRecvBuff_resize(size_t nThreads);
uint64_t nodes_searched();
uint64_t tb_hits();
void signals_init();
constexpr bool is_root() { return true; }
inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); }
inline void pick_moves(MoveInfo&) { }
+inline void ttRecvBuff_resize(size_t) { }
uint64_t nodes_searched();
uint64_t tb_hits();
inline void signals_init() { }