X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=src%2Fnnue%2Fnnue_common.h;h=779f4e755557aceb8f30fbab29203afad8778a2d;hb=8a912951de6d4bff78d3ff5258213a0c7e6f494e;hp=1bce00ae4650aa72a03ec9d577ceb7f8a4e2b41c;hpb=ad926d34c0105d523bfa5cb92cbcf9f337d54c08;p=stockfish diff --git a/src/nnue/nnue_common.h b/src/nnue/nnue_common.h index 1bce00ae..779f4e75 100644 --- a/src/nnue/nnue_common.h +++ b/src/nnue/nnue_common.h @@ -1,6 +1,6 @@ /* Stockfish, a UCI chess playing engine derived from Glaurung 2.1 - Copyright (C) 2004-2022 The Stockfish developers (see AUTHORS file) + Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file) Stockfish is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -21,10 +21,14 @@ #ifndef NNUE_COMMON_H_INCLUDED #define NNUE_COMMON_H_INCLUDED +#include +#include +#include #include #include +#include -#include "../misc.h" // for IsLittleEndian +#include "../misc.h" #if defined(USE_AVX2) #include @@ -38,9 +42,6 @@ #elif defined(USE_SSE2) #include -#elif defined(USE_MMX) -#include - #elif defined(USE_NEON) #include #endif @@ -57,6 +58,9 @@ namespace Stockfish::Eval::NNUE { // Size of cache line (in bytes) constexpr std::size_t CacheLineSize = 64; + constexpr const char Leb128MagicString[] = "COMPRESSED_LEB128"; + constexpr const std::size_t Leb128MagicStringSize = sizeof(Leb128MagicString) - 1; + // SIMD width (in bytes) #if defined(USE_AVX2) constexpr std::size_t SimdWidth = 32; @@ -64,9 +68,6 @@ namespace Stockfish::Eval::NNUE { #elif defined(USE_SSE2) constexpr std::size_t SimdWidth = 16; - #elif defined(USE_MMX) - constexpr std::size_t SimdWidth = 8; - #elif defined(USE_NEON) constexpr std::size_t SimdWidth = 16; #endif @@ -83,6 +84,7 @@ namespace Stockfish::Eval::NNUE { return (n + base - 1) / base * base; } + // read_little_endian() is our utility to read an integer (signed or unsigned, any size) // from a stream in little-endian order. We swap the byte order after the read if // necessary to return a result with the byte ordering of the compiling machine. @@ -95,7 +97,7 @@ namespace Stockfish::Eval::NNUE { else { std::uint8_t u[sizeof(IntType)]; - typename std::make_unsigned::type v = 0; + std::make_unsigned_t v = 0; stream.read(reinterpret_cast(u), sizeof(IntType)); for (std::size_t i = 0; i < sizeof(IntType); ++i) @@ -107,6 +109,7 @@ namespace Stockfish::Eval::NNUE { return result; } + // write_little_endian() is our utility to write an integer (signed or unsigned, any size) // to a stream in little-endian order. We swap the byte order before the write if // necessary to always write in little endian order, independently of the byte @@ -119,7 +122,7 @@ namespace Stockfish::Eval::NNUE { else { std::uint8_t u[sizeof(IntType)]; - typename std::make_unsigned::type v = value; + std::make_unsigned_t v = value; std::size_t i = 0; // if constexpr to silence the warning about shift by 8 @@ -127,16 +130,17 @@ namespace Stockfish::Eval::NNUE { { for (; i + 1 < sizeof(IntType); ++i) { - u[i] = v; + u[i] = (std::uint8_t)v; v >>= 8; } } - u[i] = v; + u[i] = (std::uint8_t)v; stream.write(reinterpret_cast(u), sizeof(IntType)); } } + // read_little_endian(s, out, N) : read integers in bulk from a little indian stream. // This reads N integers from stream s and put them in array out. template @@ -148,6 +152,7 @@ namespace Stockfish::Eval::NNUE { out[i] = read_little_endian(stream); } + // write_little_endian(s, values, N) : write integers in bulk to a little indian stream. // This takes N integers from array values and writes them on stream s. template @@ -159,6 +164,122 @@ namespace Stockfish::Eval::NNUE { write_little_endian(stream, values[i]); } + + // read_leb_128(s, out, N) : read N signed integers from the stream s, putting them in + // the array out. The stream is assumed to be compressed using the signed LEB128 format. + // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. + template + inline void read_leb_128(std::istream& stream, IntType* out, std::size_t count) { + + // Check the presence of our LEB128 magic string + char leb128MagicString[Leb128MagicStringSize]; + stream.read(leb128MagicString, Leb128MagicStringSize); + assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0); + + static_assert(std::is_signed_v, "Not implemented for unsigned types"); + + const std::uint32_t BUF_SIZE = 4096; + std::uint8_t buf[BUF_SIZE]; + + auto bytes_left = read_little_endian(stream); + + std::uint32_t buf_pos = BUF_SIZE; + for (std::size_t i = 0; i < count; ++i) + { + IntType result = 0; + size_t shift = 0; + do + { + if (buf_pos == BUF_SIZE) + { + stream.read(reinterpret_cast(buf), std::min(bytes_left, BUF_SIZE)); + buf_pos = 0; + } + + std::uint8_t byte = buf[buf_pos++]; + --bytes_left; + result |= (byte & 0x7f) << shift; + shift += 7; + + if ((byte & 0x80) == 0) + { + out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0) ? result + : result | ~((1 << shift) - 1); + break; + } + } + while (shift < sizeof(IntType) * 8); + } + + assert(bytes_left == 0); + } + + + // write_leb_128(s, values, N) : write signed integers to a stream with LEB128 compression. + // This takes N integers from array values, compress them with the LEB128 algorithm and + // writes the result on the stream s. + // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. + template + inline void write_leb_128(std::ostream& stream, const IntType* values, std::size_t count) { + + // Write our LEB128 magic string + stream.write(Leb128MagicString, Leb128MagicStringSize); + + static_assert(std::is_signed_v, "Not implemented for unsigned types"); + + std::uint32_t byte_count = 0; + for (std::size_t i = 0; i < count; ++i) + { + IntType value = values[i]; + std::uint8_t byte; + do + { + byte = value & 0x7f; + value >>= 7; + ++byte_count; + } + while ((byte & 0x40) == 0 ? value != 0 : value != -1); + } + + write_little_endian(stream, byte_count); + + const std::uint32_t BUF_SIZE = 4096; + std::uint8_t buf[BUF_SIZE]; + std::uint32_t buf_pos = 0; + + auto flush = [&]() { + if (buf_pos > 0) + { + stream.write(reinterpret_cast(buf), buf_pos); + buf_pos = 0; + } + }; + + auto write = [&](std::uint8_t byte) { + buf[buf_pos++] = byte; + if (buf_pos == BUF_SIZE) + flush(); + }; + + for (std::size_t i = 0; i < count; ++i) + { + IntType value = values[i]; + while (true) + { + std::uint8_t byte = value & 0x7f; + value >>= 7; + if ((byte & 0x40) == 0 ? value == 0 : value == -1) + { + write(byte); + break; + } + write(byte | 0x80); + } + } + + flush(); + } + } // namespace Stockfish::Eval::NNUE #endif // #ifndef NNUE_COMMON_H_INCLUDED