]> git.sesse.net Git - stockfish/commitdiff
Merge remote-tracking branch 'upstream/master'
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Mon, 25 Dec 2023 17:45:40 +0000 (18:45 +0100)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Mon, 25 Dec 2023 17:45:40 +0000 (18:45 +0100)
1  2 
src/Makefile
src/main.cpp
src/misc.cpp
src/ucioption.cpp

diff --cc src/Makefile
index 71a940f5496903d1c5d8d457584fae0a0b70cf75,761b40869eeae285f793746fe1771a9e53ec514c..7a65345ee56121a91f7791c2103d39e2a6145cce
@@@ -57,14 -53,19 +53,22 @@@ PGOBENCH = $(WINE_PATH) ./$(EXE) benc
  
  ### Source and object files
  SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \
-       misc.cpp movegen.cpp movepick.cpp position.cpp psqt.cpp \
+       misc.cpp movegen.cpp movepick.cpp position.cpp \
        search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
 -      nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp
 +      nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp \
 +      hashprobe.grpc.pb.cc hashprobe.pb.cc
 +CLISRCS = client.cpp hashprobe.grpc.pb.cc hashprobe.pb.cc uci.cpp
  
+ HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
+               nnue/evaluate_nnue.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
+               nnue/layers/affine_transform_sparse_input.h nnue/layers/clipped_relu.h nnue/layers/simd.h \
+               nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
+               nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
+               search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \
+               tt.h tune.h types.h uci.h
  OBJS = $(notdir $(SRCS:.cpp=.o))
 +CLIOBJS = $(notdir $(CLISRCS:.cpp=.o))
  
  VPATH = syzygy:nnue:nnue/features
  
@@@ -567,7 -588,7 +591,7 @@@ endi
  ### 3.3 Optimization
  ifeq ($(optimize),yes)
  
-       CXXFLAGS += -O3 -g
 -      CXXFLAGS += -O3 -funroll-loops
++      CXXFLAGS += -O3 -g -funroll-loops
  
        ifeq ($(comp),gcc)
                ifeq ($(OS), Android)
diff --cc src/main.cpp
index e5f3b329543f22869fa0cc593b1371aff96bfe31,04879cc46733f1a229ba054a98ba5177820fc52e..7fefa5b365ba95b157e0260ecce4335040d40bb9
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
  
 +#include <deque>
+ #include <cstddef>
  #include <iostream>
 +#include <stack>
 +#include <thread>
  
  #include "bitboard.h"
+ #include "evaluate.h"
+ #include "misc.h"
  #include "position.h"
- #include "psqt.h"
  #include "search.h"
- #include "syzygy/tbprobe.h"
  #include "thread.h"
- #include "tt.h"
+ #include "tune.h"
+ #include "types.h"
  #include "uci.h"
  
 +#include <grpc/grpc.h>
 +#include <grpc++/server.h>
 +#include <grpc++/server_builder.h>
 +#include "hashprobe.h"
 +#include "hashprobe.grpc.pb.h"
 +#include "tt.h"
 +
 +using grpc::Server;
 +using grpc::ServerBuilder;
 +using grpc::ServerContext;
 +using grpc::Status;
 +using grpc::StatusCode;
 +using namespace hashprobe;
  using namespace Stockfish;
  
 +Status HashProbeImpl::Probe(ServerContext* context,
 +                            const HashProbeRequest* request,
 +                          HashProbeResponse *response) {
 +      Position pos;
 +      StateInfo st;
 +      pos.set(request->fen(), /*isChess960=*/false, &st, Threads.main());
 +      if (!pos.pos_is_ok()) {
 +              return Status(StatusCode::INVALID_ARGUMENT, "Invalid FEN");
 +      }
 +
 +      bool invert = (pos.side_to_move() == BLACK);
 +      StateListPtr setup_states = StateListPtr(new std::deque<StateInfo>(1));
 +
 +      ProbeMove(&pos, setup_states.get(), invert, response->mutable_root());
 +
 +      MoveList<LEGAL> moves(pos);
 +      for (const ExtMove* em = moves.begin(); em != moves.end(); ++em) {
 +              HashProbeLine *line = response->add_line();
 +              FillMove(&pos, em->move, line->mutable_move());
 +              setup_states->push_back(StateInfo());
 +              pos.do_move(em->move, setup_states->back());
 +              ProbeMove(&pos, setup_states.get(), !invert, line);
 +              pos.undo_move(em->move);
 +      }
 +
 +      return Status::OK;
 +}
 +
 +void HashProbeImpl::FillMove(Position *pos, Move move, HashProbeMove* decoded) {
 +      if (!is_ok(move)) return;
 +
 +      Square from = from_sq(move);
 +      Square to = to_sq(move);
 +
 +      if (type_of(move) == CASTLING) {
 +              to = make_square(to > from ? FILE_G : FILE_C, rank_of(from));
 +      }
 +
 +      Piece moved_piece = pos->moved_piece(move);
 +      std::string pretty;
 +      if (type_of(move) == CASTLING) {
 +              if (to > from) {
 +                      pretty = "O-O";
 +              } else {
 +                      pretty = "O-O-O";
 +              }
 +      } else if (type_of(moved_piece) == PAWN) {
 +              if (type_of(move) == EN_PASSANT || pos->piece_on(to) != NO_PIECE) {
 +                      // Capture.
 +                      pretty = char('a' + file_of(from));
 +                      pretty += "x";
 +              }
 +              pretty += UCI::square(to);
 +              if (type_of(move) == PROMOTION) {
 +                      pretty += "=";
 +                      pretty += " PNBRQK"[promotion_type(move)];
 +              }
 +      } else {
 +              pretty = " PNBRQK"[type_of(moved_piece)];
 +              Bitboard attackers = pos->attackers_to(to) & pos->pieces(color_of(moved_piece), type_of(moved_piece));
 +              if (more_than_one(attackers)) {
 +                      // Remove all illegal moves to disambiguate.
 +                      Bitboard att_copy = attackers;
 +                      while (att_copy) {
 +                              Square s = pop_lsb(att_copy);
 +                              Move m = make_move(s, to);
 +                              if (!pos->pseudo_legal(m) || !pos->legal(m)) {
 +                                      attackers &= ~square_bb(s);
 +                              }
 +                      }
 +              }
 +              if (more_than_one(attackers)) {
 +                      // Disambiguate by file if possible.
 +                      Bitboard attackers_this_file = attackers & file_bb(file_of(from));
 +                      if (attackers != attackers_this_file) {
 +                              pretty += char('a' + file_of(from));
 +                              attackers = attackers_this_file;
 +                      }
 +                      if (more_than_one(attackers)) {
 +                              // Still ambiguous, so need to disambiguate by rank.
 +                              pretty += char('1' + rank_of(from));
 +                      }
 +              }
 +
 +              if (type_of(move) == EN_PASSANT || pos->piece_on(to) != NO_PIECE) {
 +                      pretty += "x";
 +              }
 +
 +              pretty += UCI::square(to);
 +      }
 +
 +      if (pos->gives_check(move)) {
 +              // Check if mate.
 +              StateInfo si;
 +              pos->do_move(move, si, true);
 +              if (MoveList<LEGAL>(*pos).size() > 0) {
 +                      pretty += "+";
 +              } else {
 +                      pretty += "#";
 +              }
 +              pos->undo_move(move);
 +      }
 +
 +      decoded->set_pretty(pretty);
 +}
 +
 +void HashProbeImpl::ProbeMove(Position* pos, std::deque<StateInfo>* setup_states, bool invert, HashProbeLine* response) {
 +      bool found;
 +      TTEntry *entry = TT.probe(pos->key(), found);
 +      response->set_found(found);
 +      if (found) {
 +              TTEntry entry_copy = *entry;
 +              Value value = entry_copy.value();
 +              Value eval = entry_copy.eval();
 +              Bound bound = entry_copy.bound();
 +
 +              if (invert) {
 +                      value = -value;
 +                      eval = -eval;
 +                      if (bound == BOUND_UPPER) {
 +                              bound = BOUND_LOWER;
 +                      } else if (bound == BOUND_LOWER) {
 +                              bound = BOUND_UPPER;
 +                      }
 +              }
 +
 +              response->set_depth(entry_copy.depth());
 +              FillValue(eval, response->mutable_eval());
 +              if (entry_copy.depth() > DEPTH_NONE) {
 +                      FillValue(value, response->mutable_value());
 +              }
 +              response->set_bound(HashProbeLine::ValueBound(bound));
 +
 +              // Follow the PV until we hit an illegal move.
 +              std::stack<Move> pv;
 +              std::set<Key> seen;
 +              while (is_ok(entry_copy.move()) &&
 +                     pos->pseudo_legal(entry_copy.move()) &&
 +                     pos->legal(entry_copy.move())) {
 +                      FillMove(pos, entry_copy.move(), response->add_pv());
 +                      if (seen.count(pos->key())) break;
 +                      pv.push(entry_copy.move());
 +                      seen.insert(pos->key());
 +                      setup_states->push_back(StateInfo());
 +                      pos->do_move(entry_copy.move(), setup_states->back());
 +                      entry = TT.probe(pos->key(), found);
 +                      if (!found) {
 +                              break;
 +                      }
 +                      entry_copy = *entry;
 +              }
 +
 +              // Unroll the PV back again, so the Position object remains unchanged.
 +              while (!pv.empty()) {
 +                      pos->undo_move(pv.top());
 +                      pv.pop();
 +              }
 +      }
 +}
 +
 +void HashProbeImpl::FillValue(Value value, HashProbeScore* score) {
 +      if (abs(value) < VALUE_MATE - MAX_PLY) {
 +              score->set_score_type(HashProbeScore::SCORE_CP);
 +              score->set_score_cp(value * 100 / PawnValueEg);
 +      } else {
 +              score->set_score_type(HashProbeScore::SCORE_MATE);
 +              score->set_score_mate((value > 0 ? VALUE_MATE - value + 1 : -VALUE_MATE - value) / 2);
 +      }
 +}
 +
 +HashProbeThread::HashProbeThread(const std::string &server_address) {
 +      builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
 +      builder.RegisterService(&service);
 +      server = std::move(builder.BuildAndStart());
 +      std::cout << "Server listening on " << server_address << std::endl;
 +      std::thread([this]{ server->Wait(); }).detach();
 +}
 +
 +void HashProbeThread::Shutdown() {
 +      server->Shutdown();
 +}
 +
  int main(int argc, char* argv[]) {
  
-   std::cout << engine_info() << std::endl;
+     std::cout << engine_info() << std::endl;
  
-   CommandLine::init(argc, argv);
-   UCI::init(Options);
-   Tune::init();
-   PSQT::init();
-   Bitboards::init();
-   Position::init();
-   Threads.set(size_t(Options["Threads"]));
-   Search::clear(); // After threads are up
-   Eval::NNUE::init();
+     CommandLine::init(argc, argv);
+     UCI::init(Options);
+     Tune::init();
+     Bitboards::init();
+     Position::init();
+     Threads.set(size_t(Options["Threads"]));
+     Search::clear();  // After threads are up
+     Eval::NNUE::init();
  
-   UCI::loop(argc, argv);
+     UCI::loop(argc, argv);
  
-   Threads.set(0);
-   return 0;
+     Threads.set(0);
+     return 0;
  }
diff --cc src/misc.cpp
index d8895d87e53e5928c598768ce5f695c45195ff44,4193f8d2c7de056faa26616afb1424c22e694ee0..59c5e406030e25b6b5937957a0540920ba7ec099
@@@ -105,78 -109,85 +109,86 @@@ struct Tie: public std::streambuf {  /
  
  class Logger {
  
-   Logger() : in(cin.rdbuf(), file.rdbuf()), out(cout.rdbuf(), file.rdbuf()) {}
-  ~Logger() { start(""); }
+     Logger() :
+         in(std::cin.rdbuf(), file.rdbuf()),
+         out(std::cout.rdbuf(), file.rdbuf()) {}
+     ~Logger() { start(""); }
  
-   ofstream file;
-   Tie in, out;
+     std::ofstream file;
+     Tie           in, out;
  
- public:
-   static void start(const std::string& fname) {
-     static Logger l;
-     if (l.file.is_open())
-     {
-         cout.rdbuf(l.out.buf);
-         cin.rdbuf(l.in.buf);
-         l.file.close();
-     }
+    public:
+     static void start(const std::string& fname) {
  
-     if (!fname.empty())
-     {
-         l.file.open(fname, ifstream::out);
+         static Logger l;
  
-         if (!l.file.is_open())
+         if (l.file.is_open())
          {
-             cerr << "Unable to open debug log file " << fname << endl;
-             exit(EXIT_FAILURE);
+             std::cout.rdbuf(l.out.buf);
+             std::cin.rdbuf(l.in.buf);
+             l.file.close();
          }
  
-         cin.rdbuf(&l.in);
-         cout.rdbuf(&l.out);
+         if (!fname.empty())
+         {
+             l.file.open(fname, std::ifstream::out);
+             if (!l.file.is_open())
+             {
+                 std::cerr << "Unable to open debug log file " << fname << std::endl;
+                 exit(EXIT_FAILURE);
+             }
+             std::cin.rdbuf(&l.in);
+             std::cout.rdbuf(&l.out);
+         }
      }
-   }
  };
  
- } // namespace
+ }  // namespace
  
- /// engine_info() returns the full name of the current Stockfish version.
- /// For local dev compiles we try to append the commit sha and commit date
- /// from git if that fails only the local compilation date is set and "nogit" is specified:
- /// Stockfish dev-YYYYMMDD-SHA
- /// or
- /// Stockfish dev-YYYYMMDD-nogit
- ///
- /// For releases (non dev builds) we only include the version number:
- /// Stockfish version
  
- string engine_info(bool to_uci) {
-   stringstream ss;
-   ss << "Stockfish " << version << setfill('0');
+ // Returns the full name of the current Stockfish version.
+ // For local dev compiles we try to append the commit sha and commit date
+ // from git if that fails only the local compilation date is set and "nogit" is specified:
+ // Stockfish dev-YYYYMMDD-SHA
+ // or
+ // Stockfish dev-YYYYMMDD-nogit
+ //
+ // For releases (non-dev builds) we only include the version number:
+ // Stockfish version
+ std::string engine_info(bool to_uci) {
+     std::stringstream ss;
+     ss << "Stockfish " << version << std::setfill('0');
+     if constexpr (version == "dev")
+     {
+         ss << "-";
+ #ifdef GIT_DATE
+         ss << stringify(GIT_DATE);
+ #else
+         constexpr std::string_view months("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec");
+         std::string                month, day, year;
+         std::stringstream          date(__DATE__);  // From compiler, format is "Sep 21 2008"
  
-   if constexpr (version == "dev")
-   {
-       ss << "-";
-       #ifdef GIT_DATE
-       ss << stringify(GIT_DATE);
-       #else
-       constexpr string_view months("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec");
-       string month, day, year;
-       stringstream date(__DATE__); // From compiler, format is "Sep 21 2008"
+         date >> month >> day >> year;
+         ss << year << std::setw(2) << std::setfill('0') << (1 + months.find(month) / 4)
+            << std::setw(2) << std::setfill('0') << day;
+ #endif
  
-       date >> month >> day >> year;
-       ss << year << setw(2) << setfill('0') << (1 + months.find(month) / 4) << setw(2) << setfill('0') << day;
-       #endif
+         ss << "-";
  
-       ss << "-asn";
-   }
+ #ifdef GIT_SHA
+         ss << stringify(GIT_SHA);
+ #else
+         ss << "nogit";
+ #endif
++      ss << "-asn";
+     }
  
-   ss << (to_uci  ? "\nid author ": " by ")
-      << "the Stockfish developers (see AUTHORS file)";
+     ss << (to_uci ? "\nid author " : " by ") << "the Stockfish developers (see AUTHORS file)";
  
-   return ss.str();
+     return ss.str();
  }
  
  
index e10ba00a36f7147e1df9148dd6fa6a9c5c107c61,1dc9b89baed4e3298db8e79b6c6cafb32ec8c62a..2c84084813f1d8ae369054c5b9c9fa535b75d532
  #include "evaluate.h"
  #include "misc.h"
  #include "search.h"
+ #include "syzygy/tbprobe.h"
  #include "thread.h"
  #include "tt.h"
+ #include "types.h"
  #include "uci.h"
- #include "syzygy/tbprobe.h"
 +#include "hashprobe.h"
  
  using std::string;
  
  namespace Stockfish {
  
- UCI::OptionsMap Options; // Global object
+ UCI::OptionsMap Options;  // Global object
 +std::unique_ptr<HashProbeThread> hash_probe_thread;
  
  namespace UCI {
  
@@@ -45,51 -50,39 +52,48 @@@ static void on_hash_size(const Option& 
  static void on_logger(const Option& o) { start_logger(o); }
  static void on_threads(const Option& o) { Threads.set(size_t(o)); }
  static void on_tb_path(const Option& o) { Tablebases::init(o); }
 -static void on_eval_file(const Option&) { Eval::NNUE::init(); }
 +static void on_use_NNUE(const Option& ) { Eval::NNUE::init(); }
 +static void on_eval_file(const Option& ) { Eval::NNUE::init(); }
 +static void on_rpc_server_address(const Option& o) {
 +      if (hash_probe_thread) {
 +              hash_probe_thread->Shutdown();
 +      }
 +      std::string addr = o;
 +      hash_probe_thread.reset(new HashProbeThread(addr));
 +}
  
- /// Our case insensitive less() function as required by UCI protocol
- bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const {
+ // Our case insensitive less() function as required by UCI protocol
+ bool CaseInsensitiveLess::operator()(const string& s1, const string& s2) const {
  
-   return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(),
-          [](char c1, char c2) { return tolower(c1) < tolower(c2); });
+     return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(),
+                                         [](char c1, char c2) { return tolower(c1) < tolower(c2); });
  }
  
  
- /// UCI::init() initializes the UCI options to their hard-coded default values
+ // Initializes the UCI options to their hard-coded default values
  void init(OptionsMap& o) {
  
-   constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048;
-   o["Debug Log File"]        << Option("", on_logger);
-   o["Threads"]               << Option(1, 1, 1024, on_threads);
-   o["Hash"]                  << Option(16, 1, MaxHashMB, on_hash_size);
-   o["Clear Hash"]            << Option(on_clear_hash);
-   o["Ponder"]                << Option(false);
-   o["MultiPV"]               << Option(1, 1, 500);
-   o["Skill Level"]           << Option(20, 0, 20);
-   o["Move Overhead"]         << Option(10, 0, 5000);
-   o["Slow Mover"]            << Option(100, 10, 1000);
-   o["nodestime"]             << Option(0, 0, 10000);
-   o["UCI_Chess960"]          << Option(false);
-   o["UCI_AnalyseMode"]       << Option(false);
-   o["UCI_LimitStrength"]     << Option(false);
-   o["UCI_Elo"]               << Option(1320, 1320, 3190);
-   o["UCI_ShowWDL"]           << Option(false);
-   o["SyzygyPath"]            << Option("<empty>", on_tb_path);
-   o["SyzygyProbeDepth"]      << Option(1, 1, 100);
-   o["Syzygy50MoveRule"]      << Option(true);
-   o["SyzygyProbeLimit"]      << Option(7, 0, 7);
-   o["EvalFile"]              << Option(EvalFileDefaultName, on_eval_file);
-   o["RPCServerAddress"]      << Option("<empty>", on_rpc_server_address);
+     constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048;
+     o["Debug Log File"] << Option("", on_logger);
+     o["Threads"] << Option(1, 1, 1024, on_threads);
+     o["Hash"] << Option(16, 1, MaxHashMB, on_hash_size);
+     o["Clear Hash"] << Option(on_clear_hash);
+     o["Ponder"] << Option(false);
+     o["MultiPV"] << Option(1, 1, 500);
+     o["Skill Level"] << Option(20, 0, 20);
+     o["Move Overhead"] << Option(10, 0, 5000);
+     o["nodestime"] << Option(0, 0, 10000);
+     o["UCI_Chess960"] << Option(false);
+     o["UCI_LimitStrength"] << Option(false);
+     o["UCI_Elo"] << Option(1320, 1320, 3190);
+     o["UCI_ShowWDL"] << Option(false);
+     o["SyzygyPath"] << Option("<empty>", on_tb_path);
+     o["SyzygyProbeDepth"] << Option(1, 1, 100);
+     o["Syzygy50MoveRule"] << Option(true);
+     o["SyzygyProbeLimit"] << Option(7, 0, 7);
+     o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file);
++    o["RPCServerAddress"] << Option("<empty>", on_rpc_server_address);
  }