]> git.sesse.net Git - stockfish/blobdiff - src/syzygy/tbprobe.cpp
Shrink the hash table of tablebases back to 4096 entries
[stockfish] / src / syzygy / tbprobe.cpp
index cb9dcdd79d1f9b581ec154f48820d39143eb636b..235fe1953bf97505a8be256b617ef124d9647278 100644 (file)
@@ -34,6 +34,7 @@
 #include "../search.h"
 #include "../thread_win32.h"
 #include "../types.h"
+#include "../uci.h"
 
 #include "tbprobe.h"
 
@@ -54,13 +55,13 @@ int Tablebases::MaxCardinality;
 
 namespace {
 
-constexpr int TBPIECES = 6; // Max number of supported pieces
+constexpr int TBPIECES = 7; // Max number of supported pieces
 
 enum { BigEndian, LittleEndian };
 enum TBType { KEY, WDL, DTZ }; // Used as template parameter
 
 // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables
-enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, SingleValue = 128 };
+enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, SingleValue = 128 };
 
 inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); }
 inline Square operator^=(Square& s, int i) { return s = Square(int(s) ^ i); }
@@ -74,8 +75,8 @@ int MapA1D1D4[SQUARE_NB];
 int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB]
 
 int Binomial[6][SQUARE_NB];    // [k][n] k elements from a set of n elements
-int LeadPawnIdx[5][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB]
-int LeadPawnsSize[5][4];       // [leadPawnsCnt][FILE_A..FILE_D]
+int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB]
+int LeadPawnsSize[6][4];       // [leadPawnsCnt][FILE_A..FILE_D]
 
 // Comparison function to sort leading pawns in ascending MapPawns[] order
 bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; }
@@ -143,16 +144,15 @@ static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes");
 typedef uint16_t Sym; // Huffman symbol
 
 struct LR {
-    enum Side { Left, Right, Value };
+    enum Side { Left, Right };
 
     uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12
                    // bits is the right-hand symbol. If symbol has length 1,
-                   // then the first byte is the stored value.
+                   // then the left-hand symbol is the stored value.
     template<Side S>
     Sym get() {
         return S == Left  ? ((lr[1] & 0xF) << 8) | lr[0] :
-               S == Right ?  (lr[2] << 4) | (lr[1] >> 4) :
-               S == Value ?   lr[0] : (assert(false), Sym(-1));
+               S == Right ?  (lr[2] << 4) | (lr[1] >> 4) : (assert(false), Sym(-1));
     }
 };
 
@@ -384,22 +384,35 @@ class TBTables {
 
     typedef std::tuple<Key, TBTable<WDL>*, TBTable<DTZ>*> Entry;
 
-    static const int Size = 1 << 12; // 4K table, indexed by key's 12 lsb
+    static constexpr int Size = 1 << 12; // 4K table, indexed by key's 12 lsb
+    static constexpr int Overflow = 1;  // Number of elements allowed to map to the last bucket
 
-    Entry hashTable[Size];
+    Entry hashTable[Size + Overflow];
 
     std::deque<TBTable<WDL>> wdlTable;
     std::deque<TBTable<DTZ>> dtzTable;
 
     void insert(Key key, TBTable<WDL>* wdl, TBTable<DTZ>* dtz) {
-        Entry* entry = &hashTable[(uint32_t)key & (Size - 1)];
+        uint32_t homeBucket = (uint32_t)key & (Size - 1);
+        Entry entry = std::make_tuple(key, wdl, dtz);
 
         // Ensure last element is empty to avoid overflow when looking up
-        for ( ; entry - hashTable < Size - 1; ++entry)
-            if (std::get<KEY>(*entry) == key || !std::get<WDL>(*entry)) {
-                *entry = std::make_tuple(key, wdl, dtz);
+        for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket) {
+            Key otherKey = std::get<KEY>(hashTable[bucket]);
+            if (otherKey == key || !std::get<WDL>(hashTable[bucket])) {
+                hashTable[bucket] = entry;
                 return;
             }
+
+            // Robin Hood hashing: If we've probed for longer than this element,
+            // insert here and search for a new spot for the other element instead.
+            uint32_t otherHomeBucket = (uint32_t)otherKey & (Size - 1);
+            if (otherHomeBucket > homeBucket) {
+                swap(entry, hashTable[bucket]);
+                key = otherKey;
+                homeBucket = otherHomeBucket;
+            }
+        }
         std::cerr << "TB hash table size too low!" << std::endl;
         exit(1);
     }
@@ -511,7 +524,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
         offset -= d->blockLength[block++] + 1;
 
     // Finally, we find the start address of our block of canonical Huffman symbols
-    uint32_t* ptr = (uint32_t*)(d->data + block * d->sizeofBlock);
+    uint32_t* ptr = (uint32_t*)(d->data + ((uint64_t)block * d->sizeofBlock));
 
     // Read the first 64 bits in our block, this is a (truncated) sequence of
     // unknown number of symbols of unknown length but we know the first one
@@ -574,7 +587,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
         }
     }
 
-    return d->btree[sym].get<LR::Value>();
+    return d->btree[sym].get<LR::Left>();
 }
 
 bool check_dtz_stm(TBTable<WDL>*, int, File) { return true; }
@@ -600,8 +613,12 @@ int map_score(TBTable<DTZ>* entry, File f, int value, WDLScore wdl) {
 
     uint8_t* map = entry->map;
     uint16_t* idx = entry->get(0, f)->map_idx;
-    if (flags & TBFlag::Mapped)
-        value = map[idx[WDLMap[wdl + 2]] + value];
+    if (flags & TBFlag::Mapped) {
+        if (flags & TBFlag::Wide)
+            value = ((uint16_t *)map)[idx[WDLMap[wdl + 2]] + value];
+        else
+            value = map[idx[WDLMap[wdl + 2]] + value];
+    }
 
     // DTZ tables store distance to zero in number of moves or plies. We
     // want to return plies, so we have convert to plies when needed.
@@ -972,7 +989,7 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
     d->symlen.resize(number<uint16_t, LittleEndian>(data)); data += sizeof(uint16_t);
     d->btree = (LR*)data;
 
-    // The comrpession scheme used is "Recursive Pairing", that replaces the most
+    // The compression scheme used is "Recursive Pairing", that replaces the most
     // frequent adjacent pair of symbols in the source message by a new symbol,
     // reevaluating the frequencies of all of the symbol pairs with respect to
     // the extended alphabet, and then repeating the process.
@@ -993,11 +1010,22 @@ uint8_t* set_dtz_map(TBTable<DTZ>& e, uint8_t* data, File maxFile) {
     e.map = data;
 
     for (File f = FILE_A; f <= maxFile; ++f) {
-        if (e.get(0, f)->flags & TBFlag::Mapped)
-            for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x
-                e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1);
-                data += *data + 1;
+        auto flags = e.get(0, f)->flags;
+        if (flags & TBFlag::Mapped) {
+            if (flags & TBFlag::Wide) {
+                data += (uintptr_t)data & 1;  // Word alignment, we may have a mixed table
+                for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x
+                    e.get(0, f)->map_idx[i] = (uint16_t)((uint16_t *)data - (uint16_t *)e.map + 1);
+                    data += 2 * number<uint16_t, LittleEndian>(data) + 2;
+                }
             }
+            else {
+                for (int i = 0; i < 4; ++i) {
+                    e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1);
+                    data += *data + 1;
+                }
+            }
+        }
     }
 
     return data += (uintptr_t)data & 1; // Word alignment
@@ -1130,11 +1158,11 @@ Ret probe_table(const Position& pos, ProbeState* result, WDLScore wdl = WDLDraw)
 // All of this means that during probing, the engine must look at captures and probe
 // their results and must probe the position itself. The "best" result of these
 // probes is the correct result for the position.
-// DTZ table don't store values when a following move is a zeroing winning move
+// DTZ tables do not store values when a following move is a zeroing winning move
 // (winning capture or winning pawn move). Also DTZ store wrong values for positions
 // where the best move is an ep-move (even if losing). So in all these cases set
 // the state to ZEROING_BEST_MOVE.
-template<bool CheckZeroingMoves = false>
+template<bool CheckZeroingMoves>
 WDLScore search(Position& pos, ProbeState* result) {
 
     WDLScore value, bestValue = WDLLoss;
@@ -1152,7 +1180,7 @@ WDLScore search(Position& pos, ProbeState* result) {
         moveCount++;
 
         pos.do_move(move, st);
-        value = -search(pos, result);
+        value = -search<false>(pos, result);
         pos.undo_move(move);
 
         if (*result == FAIL)
@@ -1273,9 +1301,9 @@ void Tablebases::init(const std::string& paths) {
     // among pawns with same file, the one with lowest rank.
     int availableSquares = 47; // Available squares when lead pawn is in a2
 
-    // Init the tables for the encoding of leading pawns group: with 6-men TB we
-    // can have up to 4 leading pawns (KPPPPK).
-    for (int leadPawnsCnt = 1; leadPawnsCnt <= 4; ++leadPawnsCnt)
+    // Init the tables for the encoding of leading pawns group: with 7-men TB we
+    // can have up to 5 leading pawns (KPPPPPK).
+    for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt)
         for (File f = FILE_A; f <= FILE_D; ++f)
         {
             // Restart the index at every file because TB table is splitted
@@ -1319,11 +1347,22 @@ void Tablebases::init(const std::string& paths) {
             for (PieceType p3 = PAWN; p3 <= p2; ++p3) {
                 TBTables.add({KING, p1, p2, p3, KING});
 
-                for (PieceType p4 = PAWN; p4 <= p3; ++p4)
+                for (PieceType p4 = PAWN; p4 <= p3; ++p4) {
                     TBTables.add({KING, p1, p2, p3, p4, KING});
 
-                for (PieceType p4 = PAWN; p4 < KING; ++p4)
+                    for (PieceType p5 = PAWN; p5 <= p4; ++p5)
+                        TBTables.add({KING, p1, p2, p3, p4, p5, KING});
+
+                    for (PieceType p5 = PAWN; p5 < KING; ++p5)
+                        TBTables.add({KING, p1, p2, p3, p4, KING, p5});
+                }
+
+                for (PieceType p4 = PAWN; p4 < KING; ++p4) {
                     TBTables.add({KING, p1, p2, p3, KING, p4});
+
+                    for (PieceType p5 = PAWN; p5 <= p4; ++p5)
+                        TBTables.add({KING, p1, p2, p3, KING, p4, p5});
+                }
             }
 
             for (PieceType p3 = PAWN; p3 <= p1; ++p3)
@@ -1346,7 +1385,7 @@ void Tablebases::init(const std::string& paths) {
 WDLScore Tablebases::probe_wdl(Position& pos, ProbeState* result) {
 
     *result = OK;
-    return search(pos, result);
+    return search<false>(pos, result);
 }
 
 // Probe the DTZ table for a particular position.
@@ -1354,6 +1393,7 @@ WDLScore Tablebases::probe_wdl(Position& pos, ProbeState* result) {
 // The return value is from the point of view of the side to move:
 //         n < -100 : loss, but draw under 50-move rule
 // -100 <= n < -1   : loss in n ply (assuming 50-move counter == 0)
+//        -1        : loss, the side to move is mated
 //         0        : draw
 //     1 < n <= 100 : win in n ply (assuming 50-move counter == 0)
 //   100 < n        : win, but draw under 50-move rule
@@ -1410,13 +1450,12 @@ int Tablebases::probe_dtz(Position& pos, ProbeState* result) {
         // otherwise we will get the dtz of the next move sequence. Search the
         // position after the move to get the score sign (because even in a
         // winning position we could make a losing capture or going for a draw).
-        dtz = zeroing ? -dtz_before_zeroing(search(pos, result))
+        dtz = zeroing ? -dtz_before_zeroing(search<false>(pos, result))
                       : -probe_dtz(pos, result);
 
-        pos.undo_move(move);
-
-        if (*result == FAIL)
-            return 0;
+        // If the move mates, force minDTZ to 1
+        if (dtz == 1 && pos.checkers() && MoveList<LEGAL>(pos).size() == 0)
+            minDTZ = 1;
 
         // Convert result from 1-ply search. Zeroing moves are already accounted
         // by dtz_before_zeroing() that returns the DTZ of the previous move.
@@ -1426,217 +1465,118 @@ int Tablebases::probe_dtz(Position& pos, ProbeState* result) {
         // Skip the draws and if we are winning only pick positive dtz
         if (dtz < minDTZ && sign_of(dtz) == sign_of(wdl))
             minDTZ = dtz;
-    }
 
-    // Special handle a mate position, when there are no legal moves, in this
-    // case return value is somewhat arbitrary, so stick to the original TB code
-    // that returns -1 in this case.
-    return minDTZ == 0xFFFF ? -1 : minDTZ;
-}
-
-// Check whether there has been at least one repetition of positions
-// since the last capture or pawn move.
-static int has_repeated(StateInfo *st)
-{
-    while (1) {
-        int i = 4, e = std::min(st->rule50, st->pliesFromNull);
+        pos.undo_move(move);
 
-        if (e < i)
+        if (*result == FAIL)
             return 0;
-
-        StateInfo *stp = st->previous->previous;
-
-        do {
-            stp = stp->previous->previous;
-
-            if (stp->key == st->key)
-                return 1;
-
-            i += 2;
-        } while (i <= e);
-
-        st = st->previous;
     }
+
+    // When there are no legal moves, the position is mate: we return -1
+    return minDTZ == 0xFFFF ? -1 : minDTZ;
 }
 
-// Use the DTZ tables to filter out moves that don't preserve the win or draw.
-// If the position is lost, but DTZ is fairly high, only keep moves that
-// maximise DTZ.
+
+// Use the DTZ tables to rank root moves.
 //
-// A return value false indicates that not all probes were successful and that
-// no moves were filtered out.
-bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves, Value& score)
-{
-    assert(rootMoves.size());
+// A return value false indicates that not all probes were successful.
+bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
 
     ProbeState result;
-    int dtz = probe_dtz(pos, &result);
-
-    if (result == FAIL)
-        return false;
-
     StateInfo st;
 
-    // Probe each move
-    for (size_t i = 0; i < rootMoves.size(); ++i) {
-        Move move = rootMoves[i].pv[0];
-        pos.do_move(move, st);
-        int v = 0;
-
-        if (pos.checkers() && dtz > 0) {
-            ExtMove s[MAX_MOVES];
-
-            if (generate<LEGAL>(pos, s) == s)
-                v = 1;
-        }
-
-        if (!v) {
-            if (st.rule50 != 0) {
-                v = -probe_dtz(pos, &result);
-
-                if (v > 0)
-                    ++v;
-                else if (v < 0)
-                    --v;
-            } else {
-                v = -probe_wdl(pos, &result);
-                v = dtz_before_zeroing(WDLScore(v));
-            }
-        }
-
-        pos.undo_move(move);
-
-        if (result == FAIL)
-            return false;
-
-        rootMoves[i].score = (Value)v;
-    }
-
-    // Obtain 50-move counter for the root position.
-    // In Stockfish there seems to be no clean way, so we do it like this:
-    int cnt50 = st.previous ? st.previous->rule50 : 0;
-
-    // Use 50-move counter to determine whether the root position is
-    // won, lost or drawn.
-    WDLScore wdl = WDLDraw;
+    // Obtain 50-move counter for the root position
+    int cnt50 = pos.rule50_count();
 
-    if (dtz > 0)
-        wdl = (dtz + cnt50 <= 100) ? WDLWin : WDLCursedWin;
-    else if (dtz < 0)
-        wdl = (-dtz + cnt50 <= 100) ? WDLLoss : WDLBlessedLoss;
+    // Check whether a position was repeated since the last zeroing move.
+    bool rep = pos.has_repeated();
 
-    // Determine the score to report to the user.
-    score = WDL_to_value[wdl + 2];
+    int dtz, bound = Options["Syzygy50MoveRule"] ? 900 : 1;
 
-    // If the position is winning or losing, but too few moves left, adjust the
-    // score to show how close it is to winning or losing.
-    // NOTE: int(PawnValueEg) is used as scaling factor in score_to_uci().
-    if (wdl == WDLCursedWin && dtz <= 100)
-        score = (Value)(((200 - dtz - cnt50) * int(PawnValueEg)) / 200);
-    else if (wdl == WDLBlessedLoss && dtz >= -100)
-        score = -(Value)(((200 + dtz - cnt50) * int(PawnValueEg)) / 200);
-
-    // Now be a bit smart about filtering out moves.
-    size_t j = 0;
-
-    if (dtz > 0) { // winning (or 50-move rule draw)
-        int best = 0xffff;
-
-        for (size_t i = 0; i < rootMoves.size(); ++i) {
-            int v = rootMoves[i].score;
+    // Probe and rank each move
+    for (auto& m : rootMoves)
+    {
+        pos.do_move(m.pv[0], st);
 
-            if (v > 0 && v < best)
-                best = v;
+        // Calculate dtz for the current move counting from the root position
+        if (pos.rule50_count() == 0)
+        {
+            // In case of a zeroing move, dtz is one of -101/-1/0/1/101
+            WDLScore wdl = -probe_wdl(pos, &result);
+            dtz = dtz_before_zeroing(wdl);
         }
-
-        int max = best;
-
-        // If the current phase has not seen repetitions, then try all moves
-        // that stay safely within the 50-move budget, if there are any.
-        if (!has_repeated(st.previous) && best + cnt50 <= 99)
-            max = 99 - cnt50;
-
-        for (size_t i = 0; i < rootMoves.size(); ++i) {
-            int v = rootMoves[i].score;
-
-            if (v > 0 && v <= max)
-                rootMoves[j++] = rootMoves[i];
+        else
+        {
+            // Otherwise, take dtz for the new position and correct by 1 ply
+            dtz = -probe_dtz(pos, &result);
+            dtz =  dtz > 0 ? dtz + 1
+                 : dtz < 0 ? dtz - 1 : dtz;
         }
-    } else if (dtz < 0) { // losing (or 50-move rule draw)
-        int best = 0;
 
-        for (size_t i = 0; i < rootMoves.size(); ++i) {
-            int v = rootMoves[i].score;
+        // Make sure that a mating move is assigned a dtz value of 1
+        if (   pos.checkers()
+            && dtz == 2
+            && MoveList<LEGAL>(pos).size() == 0)
+            dtz = 1;
 
-            if (v < best)
-                best = v;
-        }
+        pos.undo_move(m.pv[0]);
 
-        // Try all moves, unless we approach or have a 50-move rule draw.
-        if (-best * 2 + cnt50 < 100)
-            return true;
+        if (result == FAIL)
+            return false;
 
-        for (size_t i = 0; i < rootMoves.size(); ++i) {
-            if (rootMoves[i].score == best)
-                rootMoves[j++] = rootMoves[i];
-        }
-    } else { // drawing
-        // Try all moves that preserve the draw.
-        for (size_t i = 0; i < rootMoves.size(); ++i) {
-            if (rootMoves[i].score == 0)
-                rootMoves[j++] = rootMoves[i];
-        }
+        // Better moves are ranked higher. Certain wins are ranked equally.
+        // Losing moves are ranked equally unless a 50-move draw is in sight.
+        int r =  dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? 1000 : 1000 - (dtz + cnt50))
+               : dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -1000 : -1000 + (-dtz + cnt50))
+               : 0;
+        m.tbRank = r;
+
+        // Determine the score to be displayed for this move. Assign at least
+        // 1 cp to cursed wins and let it grow to 49 cp as the positions gets
+        // closer to a real win.
+        m.tbScore =  r >= bound ? VALUE_MATE - MAX_PLY - 1
+                   : r >  0     ? Value((std::max( 3, r - 800) * int(PawnValueEg)) / 200)
+                   : r == 0     ? VALUE_DRAW
+                   : r > -bound ? Value((std::min(-3, r + 800) * int(PawnValueEg)) / 200)
+                   :             -VALUE_MATE + MAX_PLY + 1;
     }
 
-    rootMoves.resize(j, Search::RootMove(MOVE_NONE));
-
     return true;
 }
 
-// Use the WDL tables to filter out moves that don't preserve the win or draw.
+
+// Use the WDL tables to rank root moves.
 // This is a fallback for the case that some or all DTZ tables are missing.
 //
-// A return value false indicates that not all probes were successful and that
-// no moves were filtered out.
-bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves, Value& score)
-{
-    ProbeState result;
+// A return value false indicates that not all probes were successful.
+bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) {
 
-    WDLScore wdl = Tablebases::probe_wdl(pos, &result);
+    static const int WDL_to_rank[] = { -1000, -899, 0, 899, 1000 };
 
-    if (result == FAIL)
-        return false;
+    ProbeState result;
+    StateInfo st;
 
-    score = WDL_to_value[wdl + 2];
+    bool rule50 = Options["Syzygy50MoveRule"];
 
-    StateInfo st;
+    // Probe and rank each move
+    for (auto& m : rootMoves)
+    {
+        pos.do_move(m.pv[0], st);
 
-    int best = WDLLoss;
+        WDLScore wdl = -probe_wdl(pos, &result);
 
-    // Probe each move
-    for (size_t i = 0; i < rootMoves.size(); ++i) {
-        Move move = rootMoves[i].pv[0];
-        pos.do_move(move, st);
-        WDLScore v = -Tablebases::probe_wdl(pos, &result);
-        pos.undo_move(move);
+        pos.undo_move(m.pv[0]);
 
         if (result == FAIL)
             return false;
 
-        rootMoves[i].score = (Value)v;
+        m.tbRank = WDL_to_rank[wdl + 2];
 
-        if (v > best)
-            best = v;
+        if (!rule50)
+            wdl =  wdl > WDLDraw ? WDLWin
+                 : wdl < WDLDraw ? WDLLoss : WDLDraw;
+        m.tbScore = WDL_to_value[wdl + 2];
     }
 
-    size_t j = 0;
-
-    for (size_t i = 0; i < rootMoves.size(); ++i) {
-        if (rootMoves[i].score == best)
-            rootMoves[j++] = rootMoves[i];
-    }
-
-    rootMoves.resize(j, Search::RootMove(MOVE_NONE));
-
     return true;
 }