]> git.sesse.net Git - stockfish/blobdiff - src/syzygy/tbprobe.cpp
Use emplace_back() in TB code
[stockfish] / src / syzygy / tbprobe.cpp
index 2b7f4497246e90a7105ebd202cfe0246191402a3..d0b59f056a9c02119c7ca3ce94e3c6a7dd1341bd 100644 (file)
@@ -55,13 +55,13 @@ int Tablebases::MaxCardinality;
 
 namespace {
 
-constexpr int TBPIECES = 6; // Max number of supported pieces
+constexpr int TBPIECES = 7; // Max number of supported pieces
 
 enum { BigEndian, LittleEndian };
 enum TBType { KEY, WDL, DTZ }; // Used as template parameter
 
 // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables
-enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, SingleValue = 128 };
+enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, SingleValue = 128 };
 
 inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); }
 inline Square operator^=(Square& s, int i) { return s = Square(int(s) ^ i); }
@@ -75,8 +75,8 @@ int MapA1D1D4[SQUARE_NB];
 int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB]
 
 int Binomial[6][SQUARE_NB];    // [k][n] k elements from a set of n elements
-int LeadPawnIdx[5][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB]
-int LeadPawnsSize[5][4];       // [leadPawnsCnt][FILE_A..FILE_D]
+int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB]
+int LeadPawnsSize[6][4];       // [leadPawnsCnt][FILE_A..FILE_D]
 
 // Comparison function to sort leading pawns in ascending MapPawns[] order
 bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; }
@@ -144,16 +144,15 @@ static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes");
 typedef uint16_t Sym; // Huffman symbol
 
 struct LR {
-    enum Side { Left, Right, Value };
+    enum Side { Left, Right };
 
     uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12
                    // bits is the right-hand symbol. If symbol has length 1,
-                   // then the first byte is the stored value.
+                   // then the left-hand symbol is the stored value.
     template<Side S>
     Sym get() {
         return S == Left  ? ((lr[1] & 0xF) << 8) | lr[0] :
-               S == Right ?  (lr[2] << 4) | (lr[1] >> 4) :
-               S == Value ?   lr[0] : (assert(false), Sym(-1));
+               S == Right ?  (lr[2] << 4) | (lr[1] >> 4) : (assert(false), Sym(-1));
     }
 };
 
@@ -217,6 +216,7 @@ public:
         fstat(fd, &statbuf);
         *mapping = statbuf.st_size;
         *baseAddress = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_SHARED, fd, 0);
+        madvise(*baseAddress, statbuf.st_size, MADV_RANDOM);
         ::close(fd);
 
         if (*baseAddress == MAP_FAILED) {
@@ -385,22 +385,35 @@ class TBTables {
 
     typedef std::tuple<Key, TBTable<WDL>*, TBTable<DTZ>*> Entry;
 
-    static const int Size = 1 << 12; // 4K table, indexed by key's 12 lsb
+    static constexpr int Size = 1 << 12; // 4K table, indexed by key's 12 lsb
+    static constexpr int Overflow = 1;  // Number of elements allowed to map to the last bucket
 
-    Entry hashTable[Size];
+    Entry hashTable[Size + Overflow];
 
     std::deque<TBTable<WDL>> wdlTable;
     std::deque<TBTable<DTZ>> dtzTable;
 
     void insert(Key key, TBTable<WDL>* wdl, TBTable<DTZ>* dtz) {
-        Entry* entry = &hashTable[(uint32_t)key & (Size - 1)];
+        uint32_t homeBucket = (uint32_t)key & (Size - 1);
+        Entry entry = std::make_tuple(key, wdl, dtz);
 
         // Ensure last element is empty to avoid overflow when looking up
-        for ( ; entry - hashTable < Size - 1; ++entry)
-            if (std::get<KEY>(*entry) == key || !std::get<WDL>(*entry)) {
-                *entry = std::make_tuple(key, wdl, dtz);
+        for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket) {
+            Key otherKey = std::get<KEY>(hashTable[bucket]);
+            if (otherKey == key || !std::get<WDL>(hashTable[bucket])) {
+                hashTable[bucket] = entry;
                 return;
             }
+
+            // Robin Hood hashing: If we've probed for longer than this element,
+            // insert here and search for a new spot for the other element instead.
+            uint32_t otherHomeBucket = (uint32_t)otherKey & (Size - 1);
+            if (otherHomeBucket > homeBucket) {
+                swap(entry, hashTable[bucket]);
+                key = otherKey;
+                homeBucket = otherHomeBucket;
+            }
+        }
         std::cerr << "TB hash table size too low!" << std::endl;
         exit(1);
     }
@@ -512,7 +525,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
         offset -= d->blockLength[block++] + 1;
 
     // Finally, we find the start address of our block of canonical Huffman symbols
-    uint32_t* ptr = (uint32_t*)(d->data + block * d->sizeofBlock);
+    uint32_t* ptr = (uint32_t*)(d->data + ((uint64_t)block * d->sizeofBlock));
 
     // Read the first 64 bits in our block, this is a (truncated) sequence of
     // unknown number of symbols of unknown length but we know the first one
@@ -575,7 +588,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
         }
     }
 
-    return d->btree[sym].get<LR::Value>();
+    return d->btree[sym].get<LR::Left>();
 }
 
 bool check_dtz_stm(TBTable<WDL>*, int, File) { return true; }
@@ -601,8 +614,12 @@ int map_score(TBTable<DTZ>* entry, File f, int value, WDLScore wdl) {
 
     uint8_t* map = entry->map;
     uint16_t* idx = entry->get(0, f)->map_idx;
-    if (flags & TBFlag::Mapped)
-        value = map[idx[WDLMap[wdl + 2]] + value];
+    if (flags & TBFlag::Mapped) {
+        if (flags & TBFlag::Wide)
+            value = ((uint16_t *)map)[idx[WDLMap[wdl + 2]] + value];
+        else
+            value = map[idx[WDLMap[wdl + 2]] + value];
+    }
 
     // DTZ tables store distance to zero in number of moves or plies. We
     // want to return plies, so we have convert to plies when needed.
@@ -973,7 +990,7 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
     d->symlen.resize(number<uint16_t, LittleEndian>(data)); data += sizeof(uint16_t);
     d->btree = (LR*)data;
 
-    // The comrpession scheme used is "Recursive Pairing", that replaces the most
+    // The compression scheme used is "Recursive Pairing", that replaces the most
     // frequent adjacent pair of symbols in the source message by a new symbol,
     // reevaluating the frequencies of all of the symbol pairs with respect to
     // the extended alphabet, and then repeating the process.
@@ -994,11 +1011,22 @@ uint8_t* set_dtz_map(TBTable<DTZ>& e, uint8_t* data, File maxFile) {
     e.map = data;
 
     for (File f = FILE_A; f <= maxFile; ++f) {
-        if (e.get(0, f)->flags & TBFlag::Mapped)
-            for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x
-                e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1);
-                data += *data + 1;
+        auto flags = e.get(0, f)->flags;
+        if (flags & TBFlag::Mapped) {
+            if (flags & TBFlag::Wide) {
+                data += (uintptr_t)data & 1;  // Word alignment, we may have a mixed table
+                for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x
+                    e.get(0, f)->map_idx[i] = (uint16_t)((uint16_t *)data - (uint16_t *)e.map + 1);
+                    data += 2 * number<uint16_t, LittleEndian>(data) + 2;
+                }
+            }
+            else {
+                for (int i = 0; i < 4; ++i) {
+                    e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1);
+                    data += *data + 1;
+                }
             }
+        }
     }
 
     return data += (uintptr_t)data & 1; // Word alignment
@@ -1249,7 +1277,7 @@ void Tablebases::init(const std::string& paths) {
                         continue; // First on diagonal, second above
 
                     else if (!off_A1H8(s1) && !off_A1H8(s2))
-                        bothOnDiagonal.push_back(std::make_pair(idx, s2));
+                        bothOnDiagonal.emplace_back(idx, s2);
 
                     else
                         MapKK[idx][s2] = code++;
@@ -1274,9 +1302,9 @@ void Tablebases::init(const std::string& paths) {
     // among pawns with same file, the one with lowest rank.
     int availableSquares = 47; // Available squares when lead pawn is in a2
 
-    // Init the tables for the encoding of leading pawns group: with 6-men TB we
-    // can have up to 4 leading pawns (KPPPPK).
-    for (int leadPawnsCnt = 1; leadPawnsCnt <= 4; ++leadPawnsCnt)
+    // Init the tables for the encoding of leading pawns group: with 7-men TB we
+    // can have up to 5 leading pawns (KPPPPPK).
+    for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt)
         for (File f = FILE_A; f <= FILE_D; ++f)
         {
             // Restart the index at every file because TB table is splitted
@@ -1320,11 +1348,22 @@ void Tablebases::init(const std::string& paths) {
             for (PieceType p3 = PAWN; p3 <= p2; ++p3) {
                 TBTables.add({KING, p1, p2, p3, KING});
 
-                for (PieceType p4 = PAWN; p4 <= p3; ++p4)
+                for (PieceType p4 = PAWN; p4 <= p3; ++p4) {
                     TBTables.add({KING, p1, p2, p3, p4, KING});
 
-                for (PieceType p4 = PAWN; p4 < KING; ++p4)
+                    for (PieceType p5 = PAWN; p5 <= p4; ++p5)
+                        TBTables.add({KING, p1, p2, p3, p4, p5, KING});
+
+                    for (PieceType p5 = PAWN; p5 < KING; ++p5)
+                        TBTables.add({KING, p1, p2, p3, p4, KING, p5});
+                }
+
+                for (PieceType p4 = PAWN; p4 < KING; ++p4) {
                     TBTables.add({KING, p1, p2, p3, KING, p4});
+
+                    for (PieceType p5 = PAWN; p5 <= p4; ++p5)
+                        TBTables.add({KING, p1, p2, p3, KING, p4, p5});
+                }
             }
 
             for (PieceType p3 = PAWN; p3 <= p1; ++p3)
@@ -1491,12 +1530,12 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
         int r =  dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? 1000 : 1000 - (dtz + cnt50))
                : dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -1000 : -1000 + (-dtz + cnt50))
                : 0;
-        m.TBRank = r;
+        m.tbRank = r;
 
         // Determine the score to be displayed for this move. Assign at least
         // 1 cp to cursed wins and let it grow to 49 cp as the positions gets
         // closer to a real win.
-        m.TBScore =  r >= bound ? VALUE_MATE - MAX_PLY - 1
+        m.tbScore =  r >= bound ? VALUE_MATE - MAX_PLY - 1
                    : r >  0     ? Value((std::max( 3, r - 800) * int(PawnValueEg)) / 200)
                    : r == 0     ? VALUE_DRAW
                    : r > -bound ? Value((std::min(-3, r + 800) * int(PawnValueEg)) / 200)
@@ -1532,12 +1571,12 @@ bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) {
         if (result == FAIL)
             return false;
 
-        m.TBRank = WDL_to_rank[wdl + 2];
+        m.tbRank = WDL_to_rank[wdl + 2];
 
         if (!rule50)
             wdl =  wdl > WDLDraw ? WDLWin
                  : wdl < WDLDraw ? WDLLoss : WDLDraw;
-        m.TBScore = WDL_to_value[wdl + 2];
+        m.tbScore = WDL_to_value[wdl + 2];
     }
 
     return true;