From: Marco Costalba Date: Thu, 12 Apr 2018 07:22:40 +0000 (+0200) Subject: Further documentation and coding style on TB code X-Git-Url: https://git.sesse.net/?p=stockfish;a=commitdiff_plain;h=6413d9b1f965ed73154fdffb36476c82a2c66c96;hp=62619fa228ff9b1f1adfbe023ce4c417807fdeba Further documentation and coding style on TB code This patch adds some documentation and code cleanup to tablebase code. It took me some time to understand the relation among the differrent structs, although I have rewrote them fully in the past. So I wrote some detailed documentation to avoid the same efforts for future readers. Also noteworthy is the use a standard hash table implementation with a more efficient 1D array instead of a 2D array. This reduces the average lookup steps of 90% (from 343 to 38 in a bench 128 1 16 run) and reduces also the table from 5K to 4K entries. I have tested on 5-men and no functional and no slowdown reported. It should be verified on 6-men that the new hash does not overflow. It is enough to run ./stockfish with 6-men available: if it does not assert at startup it means everything is ok with 6-men too. EDIT: verified for 6-men tablebase by Jörg Oster. Thanks! No functional change. --- diff --git a/src/main.cpp b/src/main.cpp index aad09cec..b5067b9b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -43,7 +43,7 @@ int main(int argc, char* argv[]) { Bitbases::init(); Search::init(); Pawns::init(); - Tablebases::init(Options["SyzygyPath"]); + Tablebases::init(Options["SyzygyPath"]); // After Bitboards are set TT.resize(Options["Hash"]); Threads.set(Options["Threads"]); Search::clear(); // After threads are up diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index 9b00bac2..cb9dcdd7 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -56,7 +56,8 @@ namespace { constexpr int TBPIECES = 6; // Max number of supported pieces -enum TBType { WDL, DTZ }; // Used as template parameter +enum { BigEndian, LittleEndian }; +enum TBType { KEY, WDL, DTZ }; // Used as template parameter // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, SingleValue = 128 }; @@ -65,147 +66,17 @@ inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); } inline Square operator^=(Square& s, int i) { return s = Square(int(s) ^ i); } inline Square operator^(Square s, int i) { return Square(int(s) ^ i); } -// DTZ tables don't store valid scores for moves that reset the rule50 counter -// like captures and pawn moves but we can easily recover the correct dtz of the -// previous move if we know the position's WDL score. -int dtz_before_zeroing(WDLScore wdl) { - return wdl == WDLWin ? 1 : - wdl == WDLCursedWin ? 101 : - wdl == WDLBlessedLoss ? -101 : - wdl == WDLLoss ? -1 : 0; -} - -// Return the sign of a number (-1, 0, 1) -template int sign_of(T val) { - return (T(0) < val) - (val < T(0)); -} - -// Numbers in little endian used by sparseIndex[] to point into blockLength[] -struct SparseEntry { - char block[4]; // Number of block - char offset[2]; // Offset within the block -}; - -static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); - -typedef uint16_t Sym; // Huffman symbol - -struct LR { - enum Side { Left, Right, Value }; - - uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 - // bits is the right-hand symbol. If symbol has length 1, - // then the first byte is the stored value. - template - Sym get() { - return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : - S == Right ? (lr[2] << 4) | (lr[1] >> 4) : - S == Value ? lr[0] : (assert(false), Sym(-1)); - } -}; - -static_assert(sizeof(LR) == 3, "LR tree entry must be 3 bytes"); - -struct PairsData { - int flags; - size_t sizeofBlock; // Block size in bytes - size_t span; // About every span values there is a SparseIndex[] entry - int blocksNum; // Number of blocks in the TB file - int maxSymLen; // Maximum length in bits of the Huffman symbols - int minSymLen; // Minimum length in bits of the Huffman symbols - Sym* lowestSym; // lowestSym[l] is the symbol of length l with the lowest value - LR* btree; // btree[sym] stores the left and right symbols that expand sym - uint16_t* blockLength; // Number of stored positions (minus one) for each block: 1..65536 - int blockLengthSize; // Size of blockLength[] table: padded so it's bigger than blocksNum - SparseEntry* sparseIndex; // Partial indices into blockLength[] - size_t sparseIndexSize; // Size of SparseIndex[] table - uint8_t* data; // Start of Huffman compressed data - std::vector base64; // base64[l - min_sym_len] is the 64bit-padded lowest symbol of length l - std::vector symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256 - Piece pieces[TBPIECES]; // Position pieces: the order of pieces defines the groups - uint64_t groupIdx[TBPIECES+1]; // Start index used for the encoding of the group's pieces - int groupLen[TBPIECES+1]; // Number of pieces in a given group: KRKN -> (3, 1) - uint16_t map_idx[4]; // WDLWin, WDLLoss, WDLCursedWin, WDLBlessedLoss (used in DTZ) -}; - -template -struct TBEntry { - typedef typename std::conditional::type Result; - - static constexpr int Sides = Type == WDL ? 2 : 1; - - std::atomic_bool ready; - void* baseAddress; - uint8_t* map; - uint64_t mapping; - Key key; - Key key2; - int pieceCount; - bool hasPawns; - bool hasUniquePieces; - uint8_t pawnCount[2]; // [Lead color / other color] - PairsData items[Sides][4]; // [wtm / btm][FILE_A..FILE_D or 0] - - PairsData* get(int stm, int f) { - return &items[stm % Sides][hasPawns ? f : 0]; - } - - TBEntry() : ready(false), baseAddress(nullptr) {} - explicit TBEntry(const std::string& code); - explicit TBEntry(const TBEntry& wdl); - ~TBEntry(); -}; - -template<> -TBEntry::TBEntry(const std::string& code) : TBEntry() { - - StateInfo st; - Position pos; - - key = pos.set(code, WHITE, &st).material_key(); - pieceCount = popcount(pos.pieces()); - hasPawns = pos.pieces(PAWN); - - hasUniquePieces = false; - for (Color c = WHITE; c <= BLACK; ++c) - for (PieceType pt = PAWN; pt < KING; ++pt) - if (popcount(pos.pieces(c, pt)) == 1) - hasUniquePieces = true; - - if (hasPawns) { - // Set the leading color. In case both sides have pawns the leading color - // is the side with less pawns because this leads to better compression. - bool c = !pos.count(BLACK) - || ( pos.count(WHITE) - && pos.count(BLACK) >= pos.count(WHITE)); - - pawnCount[0] = pos.count(c ? WHITE : BLACK); - pawnCount[1] = pos.count(c ? BLACK : WHITE); - } - - key2 = pos.set(code, BLACK, &st).material_key(); -} - -template<> -TBEntry::TBEntry(const TBEntry& wdl) : TBEntry() { - - key = wdl.key; - key2 = wdl.key2; - pieceCount = wdl.pieceCount; - hasPawns = wdl.hasPawns; - hasUniquePieces = wdl.hasUniquePieces; - - if (hasPawns) { - pawnCount[0] = wdl.pawnCount[0]; - pawnCount[1] = wdl.pawnCount[1]; - } -} +const std::string PieceToChar = " PNBRQK pnbrqk"; int MapPawns[SQUARE_NB]; int MapB1H1H7[SQUARE_NB]; int MapA1D1D4[SQUARE_NB]; int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB] +int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements +int LeadPawnIdx[5][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] +int LeadPawnsSize[5][4]; // [leadPawnsCnt][FILE_A..FILE_D] + // Comparison function to sort leading pawns in ascending MapPawns[] order bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; } int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); } @@ -218,27 +89,21 @@ constexpr Value WDL_to_value[] = { VALUE_MATE - MAX_PLY - 1 }; -const std::string PieceToChar = " PNBRQK pnbrqk"; - -int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements -int LeadPawnIdx[5][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] -int LeadPawnsSize[5][4]; // [leadPawnsCnt][FILE_A..FILE_D] - -enum { BigEndian, LittleEndian }; - template -inline void swap_byte(T& x) +inline void swap_endian(T& x) { - char tmp, *c = (char*)&x; + static_assert(std::is_unsigned::value, "Argument of swap_endian not unsigned"); + + uint8_t tmp, *c = (uint8_t*)&x; for (int i = 0; i < Half; ++i) tmp = c[i], c[i] = c[End - i], c[End - i] = tmp; } -template<> inline void swap_byte(uint8_t&) {} +template<> inline void swap_endian(uint8_t&) {} template T number(void* addr) { - const union { uint32_t i; char c[4]; } Le = { 0x01020304 }; - const bool IsLittleEndian = (Le.c[0] == 4); + static const union { uint32_t i; char c[4]; } Le = { 0x01020304 }; + static const bool IsLittleEndian = (Le.c[0] == 4); T v; @@ -248,55 +113,60 @@ template T number(void* addr) v = *((T*)addr); if (LE != IsLittleEndian) - swap_byte(v); + swap_endian(v); return v; } -class HashTable { - - typedef std::pair*, TBEntry*> EntryPair; - typedef std::pair Entry; - - static constexpr int TBHASHBITS = 10; - static constexpr int HSHMAX = 5; - - Entry hashTable[1 << TBHASHBITS][HSHMAX]; +// DTZ tables don't store valid scores for moves that reset the rule50 counter +// like captures and pawn moves but we can easily recover the correct dtz of the +// previous move if we know the position's WDL score. +int dtz_before_zeroing(WDLScore wdl) { + return wdl == WDLWin ? 1 : + wdl == WDLCursedWin ? 101 : + wdl == WDLBlessedLoss ? -101 : + wdl == WDLLoss ? -1 : 0; +} - std::deque> wdlTable; - std::deque> dtzTable; +// Return the sign of a number (-1, 0, 1) +template int sign_of(T val) { + return (T(0) < val) - (val < T(0)); +} - void insert(Key key, TBEntry* wdl, TBEntry* dtz) { - for (Entry& entry : hashTable[key >> (64 - TBHASHBITS)]) - if (!entry.second.first || entry.first == key) { - entry = std::make_pair(key, std::make_pair(wdl, dtz)); - return; - } +// Numbers in little endian used by sparseIndex[] to point into blockLength[] +struct SparseEntry { + char block[4]; // Number of block + char offset[2]; // Offset within the block +}; - std::cerr << "HSHMAX too low!" << std::endl; - exit(1); - } +static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); -public: - template - TBEntry* get(Key key) { - for (Entry& entry : hashTable[key >> (64 - TBHASHBITS)]) - if (entry.first == key) - return std::get(entry.second); +typedef uint16_t Sym; // Huffman symbol - return nullptr; - } +struct LR { + enum Side { Left, Right, Value }; - void clear() { - memset(hashTable, 0, sizeof(hashTable)); - wdlTable.clear(); - dtzTable.clear(); + uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 + // bits is the right-hand symbol. If symbol has length 1, + // then the first byte is the stored value. + template + Sym get() { + return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : + S == Right ? (lr[2] << 4) | (lr[1] >> 4) : + S == Value ? lr[0] : (assert(false), Sym(-1)); } - size_t size() const { return wdlTable.size(); } - void insert(const std::vector& pieces); }; -HashTable EntryTable; +static_assert(sizeof(LR) == 3, "LR tree entry must be 3 bytes"); + +// Tablebases data layout is structured as following: +// +// TBFile: memory maps/unmaps the physical .rtbw and .rtbz files +// TBTable: one object for each file with corresponding indexing information +// TBTables: has ownership of TBTable objects, keeping a list and a hash +// class TBFile memory maps/unmaps the single .rtbw and .rtbz files. Files are +// memory mapped for best performance. Files are mapped at first access: at init +// time only existence of the file is checked. class TBFile : public std::ifstream { std::string fname; @@ -330,7 +200,7 @@ public: // Memory map the file and check it. File should be already open and will be // closed after mapping. - uint8_t* map(void** baseAddress, uint64_t* mapping, const uint8_t* TB_MAGIC) { + uint8_t* map(void** baseAddress, uint64_t* mapping, TBType type) { assert(is_open()); @@ -380,16 +250,16 @@ public: #endif uint8_t* data = (uint8_t*)*baseAddress; - if ( *data++ != *TB_MAGIC++ - || *data++ != *TB_MAGIC++ - || *data++ != *TB_MAGIC++ - || *data++ != *TB_MAGIC) { + constexpr uint8_t Magics[][4] = { { 0xD7, 0x66, 0x0C, 0xA5 }, + { 0x71, 0xE8, 0x23, 0x5D } }; + + if (memcmp(data, Magics[type == WDL], 4)) { std::cerr << "Corrupted table in file " << fname << std::endl; unmap(*baseAddress, *mapping); return *baseAddress = nullptr, nullptr; } - return data; + return data + 4; // Skip Magics's header } static void unmap(void* baseAddress, uint64_t mapping) { @@ -405,13 +275,158 @@ public: std::string TBFile::Paths; +// struct PairsData contains low level indexing information to access TB data. +// There are 8, 4 or 2 PairsData records for each TBTable, according to type of +// table and if positions have pawns or not. It is populated at first access. +struct PairsData { + uint8_t flags; // Table flags, see enum TBFlag + uint8_t maxSymLen; // Maximum length in bits of the Huffman symbols + uint8_t minSymLen; // Minimum length in bits of the Huffman symbols + uint32_t blocksNum; // Number of blocks in the TB file + size_t sizeofBlock; // Block size in bytes + size_t span; // About every span values there is a SparseIndex[] entry + Sym* lowestSym; // lowestSym[l] is the symbol of length l with the lowest value + LR* btree; // btree[sym] stores the left and right symbols that expand sym + uint16_t* blockLength; // Number of stored positions (minus one) for each block: 1..65536 + uint32_t blockLengthSize; // Size of blockLength[] table: padded so it's bigger than blocksNum + SparseEntry* sparseIndex; // Partial indices into blockLength[] + size_t sparseIndexSize; // Size of SparseIndex[] table + uint8_t* data; // Start of Huffman compressed data + std::vector base64; // base64[l - min_sym_len] is the 64bit-padded lowest symbol of length l + std::vector symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256 + Piece pieces[TBPIECES]; // Position pieces: the order of pieces defines the groups + uint64_t groupIdx[TBPIECES+1]; // Start index used for the encoding of the group's pieces + int groupLen[TBPIECES+1]; // Number of pieces in a given group: KRKN -> (3, 1) + uint16_t map_idx[4]; // WDLWin, WDLLoss, WDLCursedWin, WDLBlessedLoss (used in DTZ) +}; + +// struct TBTable contains indexing information to access the corresponding TBFile. +// There are 2 types of TBTable, corresponding to a WDL or a DTZ file. TBTable +// is populated at init time but the nested PairsData records are populated at +// first access, when the corresponding file is memory mapped. template -TBEntry::~TBEntry() { - if (baseAddress) - TBFile::unmap(baseAddress, mapping); +struct TBTable { + typedef typename std::conditional::type Ret; + + static constexpr int Sides = Type == WDL ? 2 : 1; + + std::atomic_bool ready; + void* baseAddress; + uint8_t* map; + uint64_t mapping; + Key key; + Key key2; + int pieceCount; + bool hasPawns; + bool hasUniquePieces; + uint8_t pawnCount[2]; // [Lead color / other color] + PairsData items[Sides][4]; // [wtm / btm][FILE_A..FILE_D or 0] + + PairsData* get(int stm, int f) { + return &items[stm % Sides][hasPawns ? f : 0]; + } + + TBTable() : ready(false), baseAddress(nullptr) {} + explicit TBTable(const std::string& code); + explicit TBTable(const TBTable& wdl); + + ~TBTable() { + if (baseAddress) + TBFile::unmap(baseAddress, mapping); + } +}; + +template<> +TBTable::TBTable(const std::string& code) : TBTable() { + + StateInfo st; + Position pos; + + key = pos.set(code, WHITE, &st).material_key(); + pieceCount = pos.count(); + hasPawns = pos.pieces(PAWN); + + hasUniquePieces = false; + for (Color c = WHITE; c <= BLACK; ++c) + for (PieceType pt = PAWN; pt < KING; ++pt) + if (popcount(pos.pieces(c, pt)) == 1) + hasUniquePieces = true; + + // Set the leading color. In case both sides have pawns the leading color + // is the side with less pawns because this leads to better compression. + bool c = !pos.count(BLACK) + || ( pos.count(WHITE) + && pos.count(BLACK) >= pos.count(WHITE)); + + pawnCount[0] = pos.count(c ? WHITE : BLACK); + pawnCount[1] = pos.count(c ? BLACK : WHITE); + + key2 = pos.set(code, BLACK, &st).material_key(); +} + +template<> +TBTable::TBTable(const TBTable& wdl) : TBTable() { + + // Use the corresponding WDL table to avoid recalculating all from scratch + key = wdl.key; + key2 = wdl.key2; + pieceCount = wdl.pieceCount; + hasPawns = wdl.hasPawns; + hasUniquePieces = wdl.hasUniquePieces; + pawnCount[0] = wdl.pawnCount[0]; + pawnCount[1] = wdl.pawnCount[1]; } -void HashTable::insert(const std::vector& pieces) { +// class TBTables creates and keeps ownership of the TBTable objects, one for +// each TB file found. It supports a fast, hash based, table lookup. Populated +// at init time, accessed at probe time. +class TBTables { + + typedef std::tuple*, TBTable*> Entry; + + static const int Size = 1 << 12; // 4K table, indexed by key's 12 lsb + + Entry hashTable[Size]; + + std::deque> wdlTable; + std::deque> dtzTable; + + void insert(Key key, TBTable* wdl, TBTable* dtz) { + Entry* entry = &hashTable[(uint32_t)key & (Size - 1)]; + + // Ensure last element is empty to avoid overflow when looking up + for ( ; entry - hashTable < Size - 1; ++entry) + if (std::get(*entry) == key || !std::get(*entry)) { + *entry = std::make_tuple(key, wdl, dtz); + return; + } + std::cerr << "TB hash table size too low!" << std::endl; + exit(1); + } + +public: + template + TBTable* get(Key key) { + for (const Entry* entry = &hashTable[(uint32_t)key & (Size - 1)]; ; ++entry) { + if (std::get(*entry) == key || !std::get(*entry)) + return std::get(*entry); + } + } + + void clear() { + memset(hashTable, 0, sizeof(hashTable)); + wdlTable.clear(); + dtzTable.clear(); + } + size_t size() const { return wdlTable.size(); } + void add(const std::vector& pieces); +}; + +TBTables TBTables; + +// If the corresponding file exists two new objects TBTable and TBTable +// are created and added to the lists and hash table. Called at init time. +void TBTables::add(const std::vector& pieces) { std::string code; @@ -430,6 +445,7 @@ void HashTable::insert(const std::vector& pieces) { wdlTable.emplace_back(code); dtzTable.emplace_back(wdlTable.back()); + // Insert into the hash keys for both colors: KRvK with KR white and black insert(wdlTable.back().key , &wdlTable.back(), &dtzTable.back()); insert(wdlTable.back().key2, &wdlTable.back(), &dtzTable.back()); } @@ -561,11 +577,11 @@ int decompress_pairs(PairsData* d, uint64_t idx) { return d->btree[sym].get(); } -bool check_dtz_stm(TBEntry*, int, File) { return true; } +bool check_dtz_stm(TBTable*, int, File) { return true; } -bool check_dtz_stm(TBEntry* entry, int stm, File f) { +bool check_dtz_stm(TBTable* entry, int stm, File f) { - int flags = entry->get(stm, f)->flags; + auto flags = entry->get(stm, f)->flags; return (flags & TBFlag::STM) == stm || ((entry->key == entry->key2) && !entry->hasPawns); } @@ -574,13 +590,13 @@ bool check_dtz_stm(TBEntry* entry, int stm, File f) { // values 0, 1, 2, ... in order of decreasing frequency. This is done for each // of the four WDLScore values. The mapping information necessary to reconstruct // the original values is stored in the TB file and read during map[] init. -WDLScore map_score(TBEntry*, File, int value, WDLScore) { return WDLScore(value - 2); } +WDLScore map_score(TBTable*, File, int value, WDLScore) { return WDLScore(value - 2); } -int map_score(TBEntry* entry, File f, int value, WDLScore wdl) { +int map_score(TBTable* entry, File f, int value, WDLScore wdl) { constexpr int WDLMap[] = { 1, 3, 0, 2, 0 }; - int flags = entry->get(0, f)->flags; + auto flags = entry->get(0, f)->flags; uint8_t* map = entry->map; uint16_t* idx = entry->get(0, f)->map_idx; @@ -604,8 +620,8 @@ int map_score(TBEntry* entry, File f, int value, WDLScore wdl) { // // idx = Binomial[1][s1] + Binomial[2][s2] + ... + Binomial[k][sk] // -template::Result> -T do_probe_table(const Position& pos, TBEntry* entry, WDLScore wdl, ProbeState* result) { +template +Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* result) { Square squares[TBPIECES]; Piece pieces[TBPIECES]; @@ -659,8 +675,8 @@ T do_probe_table(const Position& pos, TBEntry* entry, WDLScore wdl, ProbeS // DTZ tables are one-sided, i.e. they store positions only for white to // move or only for black to move, so check for side to move to be stm, // early exit otherwise. - if (Type == DTZ && !check_dtz_stm(entry, stm, tbFile)) - return *result = CHANGE_STM, T(); + if (!check_dtz_stm(entry, stm, tbFile)) + return *result = CHANGE_STM, Ret(); // Now we are ready to get all the position pieces (but the lead pawns) and // directly map them to the correct color and square. @@ -829,8 +845,8 @@ encode_remaining: // // The actual grouping depends on the TB generator and can be inferred from the // sequence of pieces in piece[] array. -template -void set_groups(TBEntry& e, PairsData* d, int order[], File f) { +template +void set_groups(T& e, PairsData* d, int order[], File f) { int n = 0, firstLen = e.hasPawns ? 0 : e.hasUniquePieces ? 3 : 2; d->groupLen[n] = 1; @@ -923,7 +939,7 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) { d->sizeofBlock = 1ULL << *data++; d->span = 1ULL << *data++; d->sparseIndexSize = (tbSize + d->span - 1) / d->span; // Round up - int padding = number(data++); + auto padding = number(data++); d->blocksNum = number(data); data += sizeof(uint32_t); d->blockLengthSize = d->blocksNum + padding; // Padded to ensure SparseIndex[] // does not point out of range. @@ -970,9 +986,9 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) { return data + d->symlen.size() * sizeof(LR) + (d->symlen.size() & 1); } -uint8_t* set_dtz_map(TBEntry&, uint8_t*, File) { return nullptr; } +uint8_t* set_dtz_map(TBTable&, uint8_t* data, File) { return data; } -uint8_t* set_dtz_map(TBEntry& e, uint8_t* data, File maxFile) { +uint8_t* set_dtz_map(TBTable& e, uint8_t* data, File maxFile) { e.map = data; @@ -987,8 +1003,10 @@ uint8_t* set_dtz_map(TBEntry& e, uint8_t* data, File maxFile) { return data += (uintptr_t)data & 1; // Word alignment } -template -void do_init(TBEntry& e, uint8_t* data) { +// Populate entry's PairsData records with data from the just memory mapped file. +// Called at first access. +template +void set(T& e, uint8_t* data) { PairsData* d; @@ -999,7 +1017,7 @@ void do_init(TBEntry& e, uint8_t* data) { data++; // First byte stores flags - const int sides = Type == WDL && (e.key != e.key2) ? 2 : 1; + const int sides = T::Sides == 2 && (e.key != e.key2) ? 2 : 1; const File maxFile = e.hasPawns ? FILE_D : FILE_A; bool pp = e.hasPawns && e.pawnCount[1]; // Pawns on both sides @@ -1029,8 +1047,7 @@ void do_init(TBEntry& e, uint8_t* data) { for (int i = 0; i < sides; i++) data = set_sizes(e.get(i, f), data); - if (Type == DTZ) - data = set_dtz_map(e, data, maxFile); + data = set_dtz_map(e, data, maxFile); for (File f = FILE_A; f <= maxFile; ++f) for (int i = 0; i < sides; i++) { @@ -1052,15 +1069,19 @@ void do_init(TBEntry& e, uint8_t* data) { } } +// If the TB file corresponding to the given position is already memory mapped +// then return its base address, otherwise try to memory map and init it. Called +// at every probe, memory map and init only at first access. Function is thread +// safe and can be called concurrently. template -void* init(TBEntry& e, const Position& pos) { +void* mapped(TBTable& e, const Position& pos) { static Mutex mutex; - // Avoid a thread reads 'ready' == true while another is still in do_init(), - // this could happen due to compiler reordering. + // Use 'aquire' to avoid a thread reads 'ready' == true while another is + // still working, this could happen due to compiler reordering. if (e.ready.load(std::memory_order_acquire)) - return e.baseAddress; + return e.baseAddress; // Could be nullptr if file does not exsist std::unique_lock lk(mutex); @@ -1074,31 +1095,28 @@ void* init(TBEntry& e, const Position& pos) { b += std::string(popcount(pos.pieces(BLACK, pt)), PieceToChar[pt]); } - constexpr uint8_t TB_MAGIC[][4] = { { 0xD7, 0x66, 0x0C, 0xA5 }, - { 0x71, 0xE8, 0x23, 0x5D } }; - fname = (e.key == pos.material_key() ? w + 'v' + b : b + 'v' + w) + (Type == WDL ? ".rtbw" : ".rtbz"); - uint8_t* data = TBFile(fname).map(&e.baseAddress, &e.mapping, - TB_MAGIC[Type == WDL]); + uint8_t* data = TBFile(fname).map(&e.baseAddress, &e.mapping, Type); + if (data) - do_init(e, data); + set(e, data); e.ready.store(true, std::memory_order_release); return e.baseAddress; } -template::Result> -T probe_table(const Position& pos, ProbeState* result, WDLScore wdl = WDLDraw) { +template::Ret> +Ret probe_table(const Position& pos, ProbeState* result, WDLScore wdl = WDLDraw) { - if (!(pos.pieces() ^ pos.pieces(KING))) - return T(WDLDraw); // KvK + if (pos.count() == 2) // KvK + return Ret(WDLDraw); - TBEntry* entry = EntryTable.get(pos.material_key()); + TBTable* entry = TBTables.get(pos.material_key()); - if (!entry || !init(*entry, pos)) - return *result = FAIL, T(); + if (!entry || !mapped(*entry, pos)) + return *result = FAIL, Ret(); return do_probe_table(pos, entry, wdl, result); } @@ -1180,9 +1198,13 @@ WDLScore search(Position& pos, ProbeState* result) { } // namespace + +/// Tablebases::init() is called at startup and after every change to +/// "SyzygyPath" UCI option to (re)create the various tables. It is not thread +/// safe, nor it needs to be. void Tablebases::init(const std::string& paths) { - EntryTable.clear(); + TBTables.clear(); MaxCardinality = 0; TBFile::Paths = paths; @@ -1283,33 +1305,34 @@ void Tablebases::init(const std::string& paths) { LeadPawnsSize[leadPawnsCnt][f] = idx; } + // Add entries in TB tables if the corresponding ".rtbw" file exsists for (PieceType p1 = PAWN; p1 < KING; ++p1) { - EntryTable.insert({KING, p1, KING}); + TBTables.add({KING, p1, KING}); for (PieceType p2 = PAWN; p2 <= p1; ++p2) { - EntryTable.insert({KING, p1, p2, KING}); - EntryTable.insert({KING, p1, KING, p2}); + TBTables.add({KING, p1, p2, KING}); + TBTables.add({KING, p1, KING, p2}); for (PieceType p3 = PAWN; p3 < KING; ++p3) - EntryTable.insert({KING, p1, p2, KING, p3}); + TBTables.add({KING, p1, p2, KING, p3}); for (PieceType p3 = PAWN; p3 <= p2; ++p3) { - EntryTable.insert({KING, p1, p2, p3, KING}); + TBTables.add({KING, p1, p2, p3, KING}); for (PieceType p4 = PAWN; p4 <= p3; ++p4) - EntryTable.insert({KING, p1, p2, p3, p4, KING}); + TBTables.add({KING, p1, p2, p3, p4, KING}); for (PieceType p4 = PAWN; p4 < KING; ++p4) - EntryTable.insert({KING, p1, p2, p3, KING, p4}); + TBTables.add({KING, p1, p2, p3, KING, p4}); } for (PieceType p3 = PAWN; p3 <= p1; ++p3) for (PieceType p4 = PAWN; p4 <= (p1 == p3 ? p2 : p3); ++p4) - EntryTable.insert({KING, p1, p2, KING, p3, p4}); + TBTables.add({KING, p1, p2, KING, p3, p4}); } } - sync_cout << "info string Found " << EntryTable.size() << " tablebases" << sync_endl; + sync_cout << "info string Found " << TBTables.size() << " tablebases" << sync_endl; } // Probe the WDL table for a particular position.