]> git.sesse.net Git - stockfish/blobdiff - src/syzygy/tbprobe.cpp
Make casting styles consistent
[stockfish] / src / syzygy / tbprobe.cpp
index 115815e121abbe57114e9ae00527c4e8cb0c8078..13d271fce8a69c32e79cc265414806a826c228ea 100644 (file)
@@ -1,6 +1,6 @@
 /*
   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
-  Copyright (C) 2004-2021 The Stockfish developers (see AUTHORS file)
+  Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
 
   Stockfish is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
 
+#include "tbprobe.h"
+
+#include <sys/stat.h>
 #include <algorithm>
 #include <atomic>
+#include <cassert>
 #include <cstdint>
-#include <cstring>   // For std::memset and std::memcpy
+#include <cstdlib>
+#include <cstring>
 #include <deque>
 #include <fstream>
+#include <initializer_list>
 #include <iostream>
-#include <list>
+#include <mutex>
 #include <sstream>
+#include <string_view>
 #include <type_traits>
-#include <mutex>
+#include <utility>
+#include <vector>
 
 #include "../bitboard.h"
+#include "../misc.h"
 #include "../movegen.h"
 #include "../position.h"
 #include "../search.h"
 #include "../types.h"
 #include "../uci.h"
 
-#include "tbprobe.h"
-
 #ifndef _WIN32
 #include <fcntl.h>
-#include <unistd.h>
 #include <sys/mman.h>
-#include <sys/stat.h>
+#include <unistd.h>
 #else
 #define WIN32_LEAN_AND_MEAN
 #ifndef NOMINMAX
 #include <windows.h>
 #endif
 
-using namespace Tablebases;
+using namespace Stockfish::Tablebases;
 
-int Tablebases::MaxCardinality;
+int Stockfish::Tablebases::MaxCardinality;
+
+namespace Stockfish {
 
 namespace {
 
 constexpr int TBPIECES = 7; // Max number of supported pieces
+constexpr int MAX_DTZ = 1 << 18; // Max DTZ supported, large enough to deal with the syzygy TB limit.
 
 enum { BigEndian, LittleEndian };
 enum TBType { WDL, DTZ }; // Used as template parameter
@@ -67,7 +76,7 @@ enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, Singl
 inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); }
 inline Square operator^(Square s, int i) { return Square(int(s) ^ i); }
 
-const std::string PieceToChar = " PNBRQK  pnbrqk";
+constexpr std::string_view PieceToChar = " PNBRQK  pnbrqk";
 
 int MapPawns[SQUARE_NB];
 int MapB1H1H7[SQUARE_NB];
@@ -103,12 +112,9 @@ template<> inline void swap_endian<uint8_t>(uint8_t&) {}
 
 template<typename T, int LE> T number(void* addr)
 {
-    static const union { uint32_t i; char c[4]; } Le = { 0x01020304 };
-    static const bool IsLittleEndian = (Le.c[0] == 4);
-
     T v;
 
-    if ((uintptr_t)addr & (alignof(T) - 1)) // Unaligned pointer (very rare)
+    if (uintptr_t(addr) & (alignof(T) - 1)) // Unaligned pointer (very rare)
         std::memcpy(&v, addr, sizeof(T));
     else
         v = *((T*)addr);
@@ -141,7 +147,7 @@ struct SparseEntry {
 
 static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes");
 
-typedef uint16_t Sym; // Huffman symbol
+using Sym = uint16_t; // Huffman symbol
 
 struct LR {
     enum Side { Left, Right };
@@ -190,7 +196,8 @@ public:
         std::stringstream ss(Paths);
         std::string path;
 
-        while (std::getline(ss, path, SepChar)) {
+        while (std::getline(ss, path, SepChar))
+        {
             fname = path + "/" + f;
             std::ifstream::open(fname);
             if (is_open())
@@ -198,13 +205,10 @@ public:
         }
     }
 
-    // Memory map the file and check it. File should be already open and will be
-    // closed after mapping.
+    // Memory map the file and check it.
     uint8_t* map(void** baseAddress, uint64_t* mapping, TBType type) {
-
-        assert(is_open());
-
-        close(); // Need to re-open to get native file descriptor
+        if (is_open())
+            close(); // Need to re-open to get native file descriptor
 
 #ifndef _WIN32
         struct stat statbuf;
@@ -235,7 +239,7 @@ public:
         }
 #else
         // Note FILE_FLAG_RANDOM_ACCESS is only a hint to Windows and as such may get ignored.
-        HANDLE fd = CreateFile(fname.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr,
+        HANDLE fd = CreateFileA(fname.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr,
                                OPEN_EXISTING, FILE_FLAG_RANDOM_ACCESS, nullptr);
 
         if (fd == INVALID_HANDLE_VALUE)
@@ -259,7 +263,7 @@ public:
             exit(EXIT_FAILURE);
         }
 
-        *mapping = (uint64_t)mmap;
+        *mapping = uint64_t(mmap);
         *baseAddress = MapViewOfFile(mmap, FILE_MAP_READ, 0, 0, 0);
 
         if (!*baseAddress)
@@ -328,7 +332,7 @@ struct PairsData {
 // first access, when the corresponding file is memory mapped.
 template<TBType Type>
 struct TBTable {
-    typedef typename std::conditional<Type == WDL, WDLScore, int>::type Ret;
+    using Ret = typename std::conditional<Type == WDL, WDLScore, int>::type;
 
     static constexpr int Sides = Type == WDL ? 2 : 1;
 
@@ -425,7 +429,7 @@ class TBTables {
     std::deque<TBTable<DTZ>> dtzTable;
 
     void insert(Key key, TBTable<WDL>* wdl, TBTable<DTZ>* dtz) {
-        uint32_t homeBucket = (uint32_t)key & (Size - 1);
+        uint32_t homeBucket = uint32_t(key) & (Size - 1);
         Entry entry{ key, wdl, dtz };
 
         // Ensure last element is empty to avoid overflow when looking up
@@ -438,7 +442,7 @@ class TBTables {
 
             // Robin Hood hashing: If we've probed for longer than this element,
             // insert here and search for a new spot for the other element instead.
-            uint32_t otherHomeBucket = (uint32_t)otherKey & (Size - 1);
+            uint32_t otherHomeBucket = uint32_t(otherKey) & (Size - 1);
             if (otherHomeBucket > homeBucket) {
                 std::swap(entry, hashTable[bucket]);
                 key = otherKey;
@@ -452,7 +456,7 @@ class TBTables {
 public:
     template<TBType Type>
     TBTable<Type>* get(Key key) {
-        for (const Entry* entry = &hashTable[(uint32_t)key & (Size - 1)]; ; ++entry) {
+        for (const Entry* entry = &hashTable[uint32_t(key) & (Size - 1)]; ; ++entry) {
             if (entry->key == key || !entry->get<Type>())
                 return entry->get<Type>();
         }
@@ -485,7 +489,7 @@ void TBTables::add(const std::vector<PieceType>& pieces) {
 
     file.close();
 
-    MaxCardinality = std::max((int)pieces.size(), MaxCardinality);
+    MaxCardinality = std::max(int(pieces.size()), MaxCardinality);
 
     wdlTable.emplace_back(code);
     dtzTable.emplace_back(wdlTable.back());
@@ -556,7 +560,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
         offset -= d->blockLength[block++] + 1;
 
     // Finally, we find the start address of our block of canonical Huffman symbols
-    uint32_t* ptr = (uint32_t*)(d->data + ((uint64_t)block * d->sizeofBlock));
+    uint32_t* ptr = (uint32_t*)(d->data + (uint64_t(block) * d->sizeofBlock));
 
     // Read the first 64 bits in our block, this is a (truncated) sequence of
     // unknown number of symbols of unknown length but we know the first one
@@ -565,7 +569,8 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
     int buf64Size = 64;
     Sym sym;
 
-    while (true) {
+    while (true)
+    {
         int len = 0; // This is the symbol length - d->min_sym_len
 
         // Now get the symbol length. For any symbol s64 of length l right-padded
@@ -595,7 +600,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
 
         if (buf64Size <= 32) { // Refill the buffer
             buf64Size += 32;
-            buf64 |= (uint64_t)number<uint32_t, BigEndian>(ptr++) << (64 - buf64Size);
+            buf64 |= uint64_t(number<uint32_t, BigEndian>(ptr++)) << (64 - buf64Size);
         }
     }
 
@@ -603,8 +608,8 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
     // We binary-search for our value recursively expanding into the left and
     // right child symbols until we reach a leaf node where symlen[sym] + 1 == 1
     // that will store the value we need.
-    while (d->symlen[sym]) {
-
+    while (d->symlen[sym])
+    {
         Sym left = d->btree[sym].get<LR::Left>();
 
         // If a symbol contains 36 sub-symbols (d->symlen[sym] + 1 = 36) and
@@ -709,7 +714,7 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
 
         leadPawns = b = pos.pieces(color_of(pc), PAWN);
         do
-            squares[size++] = pop_lsb(&b) ^ flipSquares;
+            squares[size++] = pop_lsb(b) ^ flipSquares;
         while (b);
 
         leadPawnsCnt = size;
@@ -729,7 +734,7 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
     // directly map them to the correct color and square.
     b = pos.pieces() ^ leadPawns;
     do {
-        Square s = pop_lsb(&b);
+        Square s = pop_lsb(b);
         squares[size] = s ^ flipSquares;
         pieces[size++] = Piece(pos.piece_on(s) ^ flipColor);
     } while (b);
@@ -768,7 +773,7 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
         goto encode_remaining; // With pawns we have finished special treatments
     }
 
-    // In positions withouth pawns, we further flip the squares to ensure leading
+    // In positions without pawns, we further flip the squares to ensure leading
     // piece is below RANK_5.
     if (rank_of(squares[0]) > RANK_4)
         for (int i = 0; i < size; ++i)
@@ -811,7 +816,7 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
     // Rs "together" in 62 * 61 / 2 ways (we divide by 2 because rooks can be
     // swapped and still get the same position.)
     //
-    // In case we have at least 3 unique pieces (inlcuded kings) we encode them
+    // In case we have at least 3 unique pieces (included kings) we encode them
     // together.
     if (entry->hasUniquePieces) {
 
@@ -826,7 +831,7 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
                    + (squares[1] - adjust1)) * 62
                    +  squares[2] - adjust2;
 
-        // First piece is on a1-h8 diagonal, second below: map this occurence to
+        // First piece is on a1-h8 diagonal, second below: map this occurrence to
         // 6 to differentiate from the above case, rank_of() maps a1-d4 diagonal
         // to 0...3 and finally MapB1H1H7[] maps the b1-h1-h7 triangle to 0..27.
         else if (off_A1H8(squares[1]))
@@ -856,7 +861,7 @@ encode_remaining:
     idx *= d->groupIdx[0];
     Square* groupSq = squares + d->groupLen[0];
 
-    // Encode remainig pawns then pieces according to square, in ascending order
+    // Encode remaining pawns then pieces according to square, in ascending order
     bool remainingPawns = entry->hasPawns && entry->pawnCount[1];
 
     while (d->groupLen[++next])
@@ -884,7 +889,7 @@ encode_remaining:
 
 // Group together pieces that will be encoded together. The general rule is that
 // a group contains pieces of same type and color. The exception is the leading
-// group that, in case of positions withouth pawns, can be formed by 3 different
+// group that, in case of positions without pawns, can be formed by 3 different
 // pieces (default) or by the king pair when there is not a unique piece apart
 // from the kings. When there are pawns, pawns are always first in pieces[].
 //
@@ -916,7 +921,7 @@ void set_groups(T& e, PairsData* d, int order[], File f) {
     //
     // This ensures unique encoding for the whole position. The order of the
     // groups is a per-table parameter and could not follow the canonical leading
-    // pawns/pieces -> remainig pawns -> remaining pieces. In particular the
+    // pawns/pieces -> remaining pawns -> remaining pieces. In particular the
     // first group is at order[0] position and the remaining pawns, when present,
     // are at order[1] position.
     bool pp = e.hasPawns && e.pawnCount[1]; // Pawns on both sides
@@ -936,7 +941,7 @@ void set_groups(T& e, PairsData* d, int order[], File f) {
             d->groupIdx[1] = idx;
             idx *= Binomial[d->groupLen[1]][48 - d->groupLen[0]];
         }
-        else // Remainig pieces
+        else // Remaining pieces
         {
             d->groupIdx[next] = idx;
             idx *= Binomial[d->groupLen[next]][freeSquares];
@@ -946,7 +951,7 @@ void set_groups(T& e, PairsData* d, int order[], File f) {
     d->groupIdx[n] = idx;
 }
 
-// In Recursive Pairing each symbol represents a pair of childern symbols. So
+// In Recursive Pairing each symbol represents a pair of children symbols. So
 // read d->btree[] symbols data and expand each one in his left and right child
 // symbol until reaching the leafs that represent the symbol value.
 uint8_t set_symlen(PairsData* d, Sym s, std::vector<bool>& visited) {
@@ -995,13 +1000,19 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
     d->lowestSym = (Sym*)data;
     d->base64.resize(d->maxSymLen - d->minSymLen + 1);
 
+    // See https://en.wikipedia.org/wiki/Huffman_coding
     // The canonical code is ordered such that longer symbols (in terms of
     // the number of bits of their Huffman code) have lower numeric value,
     // so that d->lowestSym[i] >= d->lowestSym[i+1] (when read as LittleEndian).
     // Starting from this we compute a base64[] table indexed by symbol length
     // and containing 64 bit values so that d->base64[i] >= d->base64[i+1].
-    // See https://en.wikipedia.org/wiki/Huffman_coding
-    for (int i = d->base64.size() - 2; i >= 0; --i) {
+
+    // Implementation note: we first cast the unsigned size_t "base64.size()"
+    // to a signed int "base64_size" variable and then we are able to subtract 2,
+    // avoiding unsigned overflow warnings.
+
+    int base64_size = static_cast<int>(d->base64.size());
+    for (int i = base64_size - 2; i >= 0; --i) {
         d->base64[i] = (d->base64[i + 1] + number<Sym, LittleEndian>(&d->lowestSym[i])
                                          - number<Sym, LittleEndian>(&d->lowestSym[i + 1])) / 2;
 
@@ -1012,10 +1023,10 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
     // than d->base64[i+1] and given the above assert condition, we ensure that
     // d->base64[i] >= d->base64[i+1]. Moreover for any symbol s64 of length i
     // and right-padded to 64 bits holds d->base64[i-1] >= s64 >= d->base64[i].
-    for (size_t i = 0; i < d->base64.size(); ++i)
+    for (int i = 0; i < base64_size; ++i)
         d->base64[i] <<= 64 - i - d->minSymLen; // Right-padding to 64 bits
 
-    data += d->base64.size() * sizeof(Sym);
+    data += base64_size * sizeof(Sym);
     d->symlen.resize(number<uint16_t, LittleEndian>(data)); data += sizeof(uint16_t);
     d->btree = (LR*)data;
 
@@ -1023,7 +1034,7 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
     // frequent adjacent pair of symbols in the source message by a new symbol,
     // reevaluating the frequencies of all of the symbol pairs with respect to
     // the extended alphabet, and then repeating the process.
-    // See http://www.larsson.dogma.net/dcc99.pdf
+    // See https://web.archive.org/web/20201106232444/http://www.larsson.dogma.net/dcc99.pdf
     std::vector<bool> visited(d->symlen.size());
 
     for (Sym sym = 0; sym < d->symlen.size(); ++sym)
@@ -1043,22 +1054,22 @@ uint8_t* set_dtz_map(TBTable<DTZ>& e, uint8_t* data, File maxFile) {
         auto flags = e.get(0, f)->flags;
         if (flags & TBFlag::Mapped) {
             if (flags & TBFlag::Wide) {
-                data += (uintptr_t)data & 1;  // Word alignment, we may have a mixed table
+                data += uintptr_t(data) & 1;  // Word alignment, we may have a mixed table
                 for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x
-                    e.get(0, f)->map_idx[i] = (uint16_t)((uint16_t *)data - (uint16_t *)e.map + 1);
+                    e.get(0, f)->map_idx[i] = uint16_t((uint16_t*)data - (uint16_t*)e.map + 1);
                     data += 2 * number<uint16_t, LittleEndian>(data) + 2;
                 }
             }
             else {
                 for (int i = 0; i < 4; ++i) {
-                    e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1);
+                    e.get(0, f)->map_idx[i] = uint16_t(data - e.map + 1);
                     data += *data + 1;
                 }
             }
         }
     }
 
-    return data += (uintptr_t)data & 1; // Word alignment
+    return data += uintptr_t(data) & 1; // Word alignment
 }
 
 // Populate entry's PairsData records with data from the just memory mapped file.
@@ -1099,7 +1110,7 @@ void set(T& e, uint8_t* data) {
             set_groups(e, e.get(i, f), order[i], f);
     }
 
-    data += (uintptr_t)data & 1; // Word alignment
+    data += uintptr_t(data) & 1; // Word alignment
 
     for (File f = FILE_A; f <= maxFile; ++f)
         for (int i = 0; i < sides; i++)
@@ -1121,7 +1132,7 @@ void set(T& e, uint8_t* data) {
 
     for (File f = FILE_A; f <= maxFile; ++f)
         for (int i = 0; i < sides; i++) {
-            data = (uint8_t*)(((uintptr_t)data + 0x3F) & ~0x3F); // 64 byte alignment
+            data = (uint8_t*)((uintptr_t(data) + 0x3F) & ~0x3F); // 64 byte alignment
             (d = e.get(i, f))->data = data;
             data += d->blocksNum * d->sizeofBlock;
         }
@@ -1289,7 +1300,7 @@ void Tablebases::init(const std::string& paths) {
     for (auto s : diagonal)
         MapA1D1D4[s] = code++;
 
-    // MapKK[] encodes all the 461 possible legal positions of two kings where
+    // MapKK[] encodes all the 462 possible legal positions of two kings where
     // the first is in the a1-d1-d4 triangle. If the first king is on the a1-d4
     // diagonal, the other one shall not to be above the a1-h8 diagonal.
     std::vector<std::pair<int, Square>> bothOnDiagonal;
@@ -1316,7 +1327,7 @@ void Tablebases::init(const std::string& paths) {
     for (auto p : bothOnDiagonal)
         MapKK[p.first][p.second] = code++;
 
-    // Binomial[] stores the Binomial Coefficents using Pascal rule. There
+    // Binomial[] stores the Binomial Coefficients using Pascal rule. There
     // are Binomial[k][n] ways to choose k elements from a set of n elements.
     Binomial[0][0] = 1;
 
@@ -1336,7 +1347,7 @@ void Tablebases::init(const std::string& paths) {
     for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt)
         for (File f = FILE_A; f <= FILE_D; ++f)
         {
-            // Restart the index at every file because TB table is splitted
+            // Restart the index at every file because TB table is split
             // by file, so we can reuse the same index for different files.
             int idx = 0;
 
@@ -1512,7 +1523,7 @@ int Tablebases::probe_dtz(Position& pos, ProbeState* result) {
 // A return value false indicates that not all probes were successful.
 bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
 
-    ProbeState result;
+    ProbeState result = OK;
     StateInfo st;
 
     // Obtain 50-move counter for the root position
@@ -1521,7 +1532,7 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
     // Check whether a position was repeated since the last zeroing move.
     bool rep = pos.has_repeated();
 
-    int dtz, bound = Options["Syzygy50MoveRule"] ? 900 : 1;
+    int dtz, bound = Options["Syzygy50MoveRule"] ? (MAX_DTZ - 100) : 1;
 
     // Probe and rank each move
     for (auto& m : rootMoves)
@@ -1535,6 +1546,14 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
             WDLScore wdl = -probe_wdl(pos, &result);
             dtz = dtz_before_zeroing(wdl);
         }
+        else if (pos.is_draw(1))
+        {
+            // In case a root move leads to a draw by repetition or
+            // 50-move rule, we set dtz to zero. Note: since we are
+            // only 1 ply from the root, this must be a true 3-fold
+            // repetition inside the game history.
+            dtz = 0;
+        }
         else
         {
             // Otherwise, take dtz for the new position and correct by 1 ply
@@ -1556,8 +1575,8 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
 
         // Better moves are ranked higher. Certain wins are ranked equally.
         // Losing moves are ranked equally unless a 50-move draw is in sight.
-        int r =  dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? 1000 : 1000 - (dtz + cnt50))
-               : dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -1000 : -1000 + (-dtz + cnt50))
+        int r =  dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? MAX_DTZ : MAX_DTZ - (dtz + cnt50))
+               : dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -MAX_DTZ : -MAX_DTZ + (-dtz + cnt50))
                : 0;
         m.tbRank = r;
 
@@ -1565,9 +1584,9 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
         // 1 cp to cursed wins and let it grow to 49 cp as the positions gets
         // closer to a real win.
         m.tbScore =  r >= bound ? VALUE_MATE - MAX_PLY - 1
-                   : r >  0     ? Value((std::max( 3, r - 800) * int(PawnValueEg)) / 200)
+                   : r >  0     ? Value((std::max( 3, r - (MAX_DTZ - 200)) * int(PawnValue)) / 200)
                    : r == 0     ? VALUE_DRAW
-                   : r > -bound ? Value((std::min(-3, r + 800) * int(PawnValueEg)) / 200)
+                   : r > -bound ? Value((std::min(-3, r + (MAX_DTZ - 200)) * int(PawnValue)) / 200)
                    :             -VALUE_MATE + MAX_PLY + 1;
     }
 
@@ -1581,10 +1600,11 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
 // A return value false indicates that not all probes were successful.
 bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) {
 
-    static const int WDL_to_rank[] = { -1000, -899, 0, 899, 1000 };
+    static const int WDL_to_rank[] = { -MAX_DTZ, -MAX_DTZ + 101, 0, MAX_DTZ - 101, MAX_DTZ };
 
-    ProbeState result;
+    ProbeState result = OK;
     StateInfo st;
+    WDLScore wdl;
 
     bool rule50 = Options["Syzygy50MoveRule"];
 
@@ -1593,7 +1613,10 @@ bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) {
     {
         pos.do_move(m.pv[0], st);
 
-        WDLScore wdl = -probe_wdl(pos, &result);
+        if (pos.is_draw(1))
+            wdl = WDLDraw;
+        else
+            wdl = -probe_wdl(pos, &result);
 
         pos.undo_move(m.pv[0]);
 
@@ -1610,3 +1633,5 @@ bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) {
 
     return true;
 }
+
+} // namespace Stockfish