X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fsyzygy%2Ftbprobe.cpp;h=dc7c523ae79900e5cc0634926596db6303743b9a;hp=da6dc49f7bbd3f07d77672ff789c265844f8a354;hb=9afa03b80ea4610729427ee8287a5bbadba03e02;hpb=ba2a2c34bb3098648d6f772a4b26d84d761afe2e diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index da6dc49f..dc7c523a 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -55,13 +55,13 @@ int Tablebases::MaxCardinality; namespace { -constexpr int TBPIECES = 6; // Max number of supported pieces +constexpr int TBPIECES = 7; // Max number of supported pieces enum { BigEndian, LittleEndian }; enum TBType { KEY, WDL, DTZ }; // Used as template parameter // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables -enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, SingleValue = 128 }; +enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, SingleValue = 128 }; inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); } inline Square operator^=(Square& s, int i) { return s = Square(int(s) ^ i); } @@ -75,8 +75,8 @@ int MapA1D1D4[SQUARE_NB]; int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB] int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements -int LeadPawnIdx[5][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] -int LeadPawnsSize[5][4]; // [leadPawnsCnt][FILE_A..FILE_D] +int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] +int LeadPawnsSize[6][4]; // [leadPawnsCnt][FILE_A..FILE_D] // Comparison function to sort leading pawns in ascending MapPawns[] order bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; } @@ -144,16 +144,15 @@ static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); typedef uint16_t Sym; // Huffman symbol struct LR { - enum Side { Left, Right, Value }; + enum Side { Left, Right }; uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 // bits is the right-hand symbol. If symbol has length 1, - // then the first byte is the stored value. + // then the left-hand symbol is the stored value. template Sym get() { return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : - S == Right ? (lr[2] << 4) | (lr[1] >> 4) : - S == Value ? lr[0] : (assert(false), Sym(-1)); + S == Right ? (lr[2] << 4) | (lr[1] >> 4) : (assert(false), Sym(-1)); } }; @@ -385,7 +384,7 @@ class TBTables { typedef std::tuple*, TBTable*> Entry; - static const int Size = 1 << 12; // 4K table, indexed by key's 12 lsb + static const int Size = 1 << 16; // 64K table, indexed by key's 16 lsb Entry hashTable[Size]; @@ -512,7 +511,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) { offset -= d->blockLength[block++] + 1; // Finally, we find the start address of our block of canonical Huffman symbols - uint32_t* ptr = (uint32_t*)(d->data + block * d->sizeofBlock); + uint32_t* ptr = (uint32_t*)(d->data + ((uint64_t)block * d->sizeofBlock)); // Read the first 64 bits in our block, this is a (truncated) sequence of // unknown number of symbols of unknown length but we know the first one @@ -575,7 +574,7 @@ int decompress_pairs(PairsData* d, uint64_t idx) { } } - return d->btree[sym].get(); + return d->btree[sym].get(); } bool check_dtz_stm(TBTable*, int, File) { return true; } @@ -601,8 +600,12 @@ int map_score(TBTable* entry, File f, int value, WDLScore wdl) { uint8_t* map = entry->map; uint16_t* idx = entry->get(0, f)->map_idx; - if (flags & TBFlag::Mapped) - value = map[idx[WDLMap[wdl + 2]] + value]; + if (flags & TBFlag::Mapped) { + if (flags & TBFlag::Wide) + value = ((uint16_t *)map)[idx[WDLMap[wdl + 2]] + value]; + else + value = map[idx[WDLMap[wdl + 2]] + value]; + } // DTZ tables store distance to zero in number of moves or plies. We // want to return plies, so we have convert to plies when needed. @@ -994,11 +997,22 @@ uint8_t* set_dtz_map(TBTable& e, uint8_t* data, File maxFile) { e.map = data; for (File f = FILE_A; f <= maxFile; ++f) { - if (e.get(0, f)->flags & TBFlag::Mapped) - for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x - e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1); - data += *data + 1; + auto flags = e.get(0, f)->flags; + if (flags & TBFlag::Mapped) { + if (flags & TBFlag::Wide) { + data += (uintptr_t)data & 1; // Word alignment, we may have a mixed table + for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x + e.get(0, f)->map_idx[i] = (uint16_t)((uint16_t *)data - (uint16_t *)e.map + 1); + data += 2 * number(data) + 2; + } } + else { + for (int i = 0; i < 4; ++i) { + e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1); + data += *data + 1; + } + } + } } return data += (uintptr_t)data & 1; // Word alignment @@ -1274,9 +1288,9 @@ void Tablebases::init(const std::string& paths) { // among pawns with same file, the one with lowest rank. int availableSquares = 47; // Available squares when lead pawn is in a2 - // Init the tables for the encoding of leading pawns group: with 6-men TB we - // can have up to 4 leading pawns (KPPPPK). - for (int leadPawnsCnt = 1; leadPawnsCnt <= 4; ++leadPawnsCnt) + // Init the tables for the encoding of leading pawns group: with 7-men TB we + // can have up to 5 leading pawns (KPPPPPK). + for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt) for (File f = FILE_A; f <= FILE_D; ++f) { // Restart the index at every file because TB table is splitted @@ -1320,11 +1334,22 @@ void Tablebases::init(const std::string& paths) { for (PieceType p3 = PAWN; p3 <= p2; ++p3) { TBTables.add({KING, p1, p2, p3, KING}); - for (PieceType p4 = PAWN; p4 <= p3; ++p4) + for (PieceType p4 = PAWN; p4 <= p3; ++p4) { TBTables.add({KING, p1, p2, p3, p4, KING}); - for (PieceType p4 = PAWN; p4 < KING; ++p4) + for (PieceType p5 = PAWN; p5 <= p4; ++p5) + TBTables.add({KING, p1, p2, p3, p4, p5, KING}); + + for (PieceType p5 = PAWN; p5 < KING; ++p5) + TBTables.add({KING, p1, p2, p3, p4, KING, p5}); + } + + for (PieceType p4 = PAWN; p4 < KING; ++p4) { TBTables.add({KING, p1, p2, p3, KING, p4}); + + for (PieceType p5 = PAWN; p5 <= p4; ++p5) + TBTables.add({KING, p1, p2, p3, KING, p4, p5}); + } } for (PieceType p3 = PAWN; p3 <= p1; ++p3)