X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fsyzygy%2Ftbprobe.cpp;h=03cffca17bc4756839443559ea4247bb13c05d6d;hp=da6dc49f7bbd3f07d77672ff789c265844f8a354;hb=8c0a2733924f59316003c576ee2064ded490678b;hpb=a0486ecb40513db8141fa27c026f64771b8ebb55 diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index da6dc49f..03cffca1 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -55,13 +55,13 @@ int Tablebases::MaxCardinality; namespace { -constexpr int TBPIECES = 6; // Max number of supported pieces +constexpr int TBPIECES = 7; // Max number of supported pieces enum { BigEndian, LittleEndian }; enum TBType { KEY, WDL, DTZ }; // Used as template parameter // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables -enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, SingleValue = 128 }; +enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, SingleValue = 128 }; inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); } inline Square operator^=(Square& s, int i) { return s = Square(int(s) ^ i); } @@ -74,9 +74,9 @@ int MapB1H1H7[SQUARE_NB]; int MapA1D1D4[SQUARE_NB]; int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB] -int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements -int LeadPawnIdx[5][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] -int LeadPawnsSize[5][4]; // [leadPawnsCnt][FILE_A..FILE_D] +int Binomial[7][SQUARE_NB]; // [k][n] k elements from a set of n elements +int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] +int LeadPawnsSize[6][4]; // [leadPawnsCnt][FILE_A..FILE_D] // Comparison function to sort leading pawns in ascending MapPawns[] order bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; } @@ -144,16 +144,15 @@ static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); typedef uint16_t Sym; // Huffman symbol struct LR { - enum Side { Left, Right, Value }; + enum Side { Left, Right }; uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 // bits is the right-hand symbol. If symbol has length 1, - // then the first byte is the stored value. + // then the left-hand symbol is the stored value. template Sym get() { return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : - S == Right ? (lr[2] << 4) | (lr[1] >> 4) : - S == Value ? lr[0] : (assert(false), Sym(-1)); + S == Right ? (lr[2] << 4) | (lr[1] >> 4) : (assert(false), Sym(-1)); } }; @@ -217,6 +216,7 @@ public: fstat(fd, &statbuf); *mapping = statbuf.st_size; *baseAddress = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_SHARED, fd, 0); + madvise(*baseAddress, statbuf.st_size, MADV_RANDOM); ::close(fd); if (*baseAddress == MAP_FAILED) { @@ -385,22 +385,35 @@ class TBTables { typedef std::tuple*, TBTable*> Entry; - static const int Size = 1 << 12; // 4K table, indexed by key's 12 lsb + static constexpr int Size = 1 << 12; // 4K table, indexed by key's 12 lsb + static constexpr int Overflow = 1; // Number of elements allowed to map to the last bucket - Entry hashTable[Size]; + Entry hashTable[Size + Overflow]; std::deque> wdlTable; std::deque> dtzTable; void insert(Key key, TBTable* wdl, TBTable* dtz) { - Entry* entry = &hashTable[(uint32_t)key & (Size - 1)]; + uint32_t homeBucket = (uint32_t)key & (Size - 1); + Entry entry = std::make_tuple(key, wdl, dtz); // Ensure last element is empty to avoid overflow when looking up - for ( ; entry - hashTable < Size - 1; ++entry) - if (std::get(*entry) == key || !std::get(*entry)) { - *entry = std::make_tuple(key, wdl, dtz); + for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket) { + Key otherKey = std::get(hashTable[bucket]); + if (otherKey == key || !std::get(hashTable[bucket])) { + hashTable[bucket] = entry; return; } + + // Robin Hood hashing: If we've probed for longer than this element, + // insert here and search for a new spot for the other element instead. + uint32_t otherHomeBucket = (uint32_t)otherKey & (Size - 1); + if (otherHomeBucket > homeBucket) { + swap(entry, hashTable[bucket]); + key = otherKey; + homeBucket = otherHomeBucket; + } + } std::cerr << "TB hash table size too low!" << std::endl; exit(1); } @@ -512,7 +525,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) { offset -= d->blockLength[block++] + 1; // Finally, we find the start address of our block of canonical Huffman symbols - uint32_t* ptr = (uint32_t*)(d->data + block * d->sizeofBlock); + uint32_t* ptr = (uint32_t*)(d->data + ((uint64_t)block * d->sizeofBlock)); // Read the first 64 bits in our block, this is a (truncated) sequence of // unknown number of symbols of unknown length but we know the first one @@ -575,7 +588,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) { } } - return d->btree[sym].get(); + return d->btree[sym].get(); } bool check_dtz_stm(TBTable*, int, File) { return true; } @@ -601,8 +614,12 @@ int map_score(TBTable* entry, File f, int value, WDLScore wdl) { uint8_t* map = entry->map; uint16_t* idx = entry->get(0, f)->map_idx; - if (flags & TBFlag::Mapped) - value = map[idx[WDLMap[wdl + 2]] + value]; + if (flags & TBFlag::Mapped) { + if (flags & TBFlag::Wide) + value = ((uint16_t *)map)[idx[WDLMap[wdl + 2]] + value]; + else + value = map[idx[WDLMap[wdl + 2]] + value]; + } // DTZ tables store distance to zero in number of moves or plies. We // want to return plies, so we have convert to plies when needed. @@ -994,11 +1011,22 @@ uint8_t* set_dtz_map(TBTable& e, uint8_t* data, File maxFile) { e.map = data; for (File f = FILE_A; f <= maxFile; ++f) { - if (e.get(0, f)->flags & TBFlag::Mapped) - for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x - e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1); - data += *data + 1; + auto flags = e.get(0, f)->flags; + if (flags & TBFlag::Mapped) { + if (flags & TBFlag::Wide) { + data += (uintptr_t)data & 1; // Word alignment, we may have a mixed table + for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x + e.get(0, f)->map_idx[i] = (uint16_t)((uint16_t *)data - (uint16_t *)e.map + 1); + data += 2 * number(data) + 2; + } + } + else { + for (int i = 0; i < 4; ++i) { + e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1); + data += *data + 1; + } } + } } return data += (uintptr_t)data & 1; // Word alignment @@ -1264,7 +1292,7 @@ void Tablebases::init(const std::string& paths) { Binomial[0][0] = 1; for (int n = 1; n < 64; n++) // Squares - for (int k = 0; k < 6 && k <= n; ++k) // Pieces + for (int k = 0; k < 7 && k <= n; ++k) // Pieces Binomial[k][n] = (k > 0 ? Binomial[k - 1][n - 1] : 0) + (k < n ? Binomial[k ][n - 1] : 0); @@ -1274,9 +1302,9 @@ void Tablebases::init(const std::string& paths) { // among pawns with same file, the one with lowest rank. int availableSquares = 47; // Available squares when lead pawn is in a2 - // Init the tables for the encoding of leading pawns group: with 6-men TB we - // can have up to 4 leading pawns (KPPPPK). - for (int leadPawnsCnt = 1; leadPawnsCnt <= 4; ++leadPawnsCnt) + // Init the tables for the encoding of leading pawns group: with 7-men TB we + // can have up to 5 leading pawns (KPPPPPK). + for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt) for (File f = FILE_A; f <= FILE_D; ++f) { // Restart the index at every file because TB table is splitted @@ -1320,11 +1348,22 @@ void Tablebases::init(const std::string& paths) { for (PieceType p3 = PAWN; p3 <= p2; ++p3) { TBTables.add({KING, p1, p2, p3, KING}); - for (PieceType p4 = PAWN; p4 <= p3; ++p4) + for (PieceType p4 = PAWN; p4 <= p3; ++p4) { TBTables.add({KING, p1, p2, p3, p4, KING}); - for (PieceType p4 = PAWN; p4 < KING; ++p4) + for (PieceType p5 = PAWN; p5 <= p4; ++p5) + TBTables.add({KING, p1, p2, p3, p4, p5, KING}); + + for (PieceType p5 = PAWN; p5 < KING; ++p5) + TBTables.add({KING, p1, p2, p3, p4, KING, p5}); + } + + for (PieceType p4 = PAWN; p4 < KING; ++p4) { TBTables.add({KING, p1, p2, p3, KING, p4}); + + for (PieceType p5 = PAWN; p5 <= p4; ++p5) + TBTables.add({KING, p1, p2, p3, KING, p4, p5}); + } } for (PieceType p3 = PAWN; p3 <= p1; ++p3)