]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/simd.h
Add support for ARM dot product instructions
[stockfish] / src / nnue / layers / simd.h
index aeab39c4fc393cd86cd211304ad5f5242ebeaf37..22c51980eccd5de5b253b91609377d38df418c08 100644 (file)
@@ -153,7 +153,7 @@ namespace Stockfish::Simd {
       asm(
         "vpdpbusd %[b0], %[a0], %[acc]\n\t"
         "vpdpbusd %[b1], %[a1], %[acc]\n\t"
-        : [acc]"+v"(acc)
+        : [acc]"+&v"(acc)
         : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
       );
 #   else
@@ -165,18 +165,19 @@ namespace Stockfish::Simd {
       __m512i tmp0 = _mm512_maddubs_epi16(a0, b0);
       __m512i tmp1 = _mm512_maddubs_epi16(a1, b1);
       asm(
-          "vpaddsw     %[tmp0], %[tmp1], %[tmp0]\n\t"
           "vpmaddwd    %[tmp0], %[ones], %[tmp0]\n\t"
+          "vpmaddwd    %[tmp1], %[ones], %[tmp1]\n\t"
+          "vpaddd      %[tmp0], %[tmp1], %[tmp0]\n\t"
           "vpaddd      %[acc], %[tmp0], %[acc]\n\t"
-          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0)
-          : [tmp1]"v"(tmp1), [ones]"v"(_mm512_set1_epi16(1))
+          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
+          : [ones]"v"(_mm512_set1_epi16(1))
       );
 #   else
       __m512i product0 = _mm512_maddubs_epi16(a0, b0);
       __m512i product1 = _mm512_maddubs_epi16(a1, b1);
-      product0 = _mm512_adds_epi16(product0, product1);
       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
-      acc = _mm512_add_epi32(acc, product0);
+      product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1));
+      acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1));
 #   endif
 # endif
     }
@@ -249,7 +250,7 @@ namespace Stockfish::Simd {
       asm(
         VNNI_PREFIX "vpdpbusd %[b0], %[a0], %[acc]\n\t"
         VNNI_PREFIX "vpdpbusd %[b1], %[a1], %[acc]\n\t"
-        : [acc]"+v"(acc)
+        : [acc]"+&v"(acc)
         : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
       );
 #   else
@@ -261,18 +262,19 @@ namespace Stockfish::Simd {
       __m256i tmp0 = _mm256_maddubs_epi16(a0, b0);
       __m256i tmp1 = _mm256_maddubs_epi16(a1, b1);
       asm(
-          "vpaddsw     %[tmp0], %[tmp1], %[tmp0]\n\t"
           "vpmaddwd    %[tmp0], %[ones], %[tmp0]\n\t"
+          "vpmaddwd    %[tmp1], %[ones], %[tmp1]\n\t"
+          "vpaddd      %[tmp0], %[tmp1], %[tmp0]\n\t"
           "vpaddd      %[acc], %[tmp0], %[acc]\n\t"
-          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0)
-          : [tmp1]"v"(tmp1), [ones]"v"(_mm256_set1_epi16(1))
+          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
+          : [ones]"v"(_mm256_set1_epi16(1))
       );
 #   else
       __m256i product0 = _mm256_maddubs_epi16(a0, b0);
       __m256i product1 = _mm256_maddubs_epi16(a1, b1);
-      product0 = _mm256_adds_epi16(product0, product1);
       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
-      acc = _mm256_add_epi32(acc, product0);
+      product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1));
+      acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1));
 #   endif
 # endif
     }
@@ -326,23 +328,37 @@ namespace Stockfish::Simd {
       __m128i tmp0 = _mm_maddubs_epi16(a0, b0);
       __m128i tmp1 = _mm_maddubs_epi16(a1, b1);
       asm(
-          "paddsw     %[tmp1], %[tmp0]\n\t"
           "pmaddwd    %[ones], %[tmp0]\n\t"
+          "pmaddwd    %[ones], %[tmp1]\n\t"
+          "paddd      %[tmp1], %[tmp0]\n\t"
           "paddd      %[tmp0], %[acc]\n\t"
-          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0)
-          : [tmp1]"v"(tmp1), [ones]"v"(_mm_set1_epi16(1))
+          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
+          : [ones]"v"(_mm_set1_epi16(1))
       );
 #   else
       __m128i product0 = _mm_maddubs_epi16(a0, b0);
       __m128i product1 = _mm_maddubs_epi16(a1, b1);
-      product0 = _mm_adds_epi16(product0, product1);
       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
-      acc = _mm_add_epi32(acc, product0);
+      product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1));
+      acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1));
 #   endif
     }
 
 #endif
 
+#if defined (USE_NEON_DOTPROD)
+
+    [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
+        int32x4_t& acc,
+        int8x16_t a0, int8x16_t b0,
+        int8x16_t a1, int8x16_t b1) {
+
+        acc = vdotq_s32(acc, a0, b0);
+        acc = vdotq_s32(acc, a1, b1);
+    }
+
+#endif
+
 #if defined (USE_NEON)
 
     [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {