]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_common.h
Assume network file is in little-endian byte order
[stockfish] / src / nnue / nnue_common.h
index ff33cc7974b12d616d6f6f49c3dcbcaa5d252e5e..61f18aeec848be9b25edde2b68733e97c8cf254b 100644 (file)
@@ -21,6 +21,9 @@
 #ifndef NNUE_COMMON_H_INCLUDED
 #define NNUE_COMMON_H_INCLUDED
 
+#include <cstring>
+#include <iostream>
+
 #if defined(USE_AVX2)
 #include <immintrin.h>
 
@@ -33,6 +36,9 @@
 #elif defined(USE_SSE2)
 #include <emmintrin.h>
 
+#elif defined(USE_MMX)
+#include <mmintrin.h>
+
 #elif defined(USE_NEON)
 #include <arm_neon.h>
 #endif
@@ -41,7 +47,7 @@
 //       compiled with older g++ crashes because the output memory is not aligned
 //       even though alignas is specified.
 #if defined(USE_AVX2)
-#if defined(__GNUC__ ) && (__GNUC__ < 9)
+#if defined(__GNUC__ ) && (__GNUC__ < 9) && defined(_WIN32)
 #define _mm256_loadA_si256  _mm256_loadu_si256
 #define _mm256_storeA_si256 _mm256_storeu_si256
 #else
@@ -51,7 +57,7 @@
 #endif
 
 #if defined(USE_AVX512)
-#if defined(__GNUC__ ) && (__GNUC__ < 9)
+#if defined(__GNUC__ ) && (__GNUC__ < 9) && defined(_WIN32)
 #define _mm512_loadA_si512   _mm512_loadu_si512
 #define _mm512_storeA_si512  _mm512_storeu_si512
 #else
@@ -79,6 +85,9 @@ namespace Eval::NNUE {
   #elif defined(USE_SSE2)
   constexpr std::size_t kSimdWidth = 16;
 
+  #elif defined(USE_MMX)
+  constexpr std::size_t kSimdWidth = 8;
+
   #elif defined(USE_NEON)
   constexpr std::size_t kSimdWidth = 16;
   #endif
@@ -95,6 +104,22 @@ namespace Eval::NNUE {
     return (n + base - 1) / base * base;
   }
 
+  // Read a signed or unsigned integer from  a stream in little-endian order
+  template <typename IntType>
+  inline IntType read_le(std::istream& stream) {
+    // Read the relevant bytes from the stream in little-endian order
+    std::uint8_t u[sizeof(IntType)];
+    stream.read(reinterpret_cast<char*>(u), sizeof(IntType));
+    // Use unsigned arithmetic to convert to machine order
+    typename std::make_unsigned<IntType>::type v = 0;
+    for (std::size_t i = 0; i < sizeof(IntType); ++i)
+      v = (v << 8) | u[sizeof(IntType) - i - 1];
+    // Copy the machine-ordered bytes into a potentially signed value
+    IntType w;
+    std::memcpy(&w, &v, sizeof(IntType));
+    return w;
+  }
+
 }  // namespace Eval::NNUE
 
 #endif // #ifndef NNUE_COMMON_H_INCLUDED