From: Steinar H. Gunderson Date: Mon, 28 Dec 2020 18:20:52 +0000 (+0100) Subject: Merge remote-tracking branch 'upstream/master' into HEAD X-Git-Url: https://git.sesse.net/?p=stockfish;a=commitdiff_plain;h=f2e94d6d35c14b274ed29fb67475acea5adc285f;hp=51deae899814bbbfd9db5686b824f23105ca8a39 Merge remote-tracking branch 'upstream/master' into HEAD --- diff --git a/src/Makefile b/src/Makefile index 87203547..e3466ea1 100644 --- a/src/Makefile +++ b/src/Makefile @@ -39,9 +39,12 @@ PGOBENCH = ./$(EXE) bench SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp \ material.cpp misc.cpp movegen.cpp movepick.cpp pawns.cpp position.cpp psqt.cpp \ search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ - nnue/evaluate_nnue.cpp nnue/features/half_kp.cpp + nnue/evaluate_nnue.cpp nnue/features/half_kp.cpp \ + hashprobe.grpc.pb.cc hashprobe.pb.cc +CLISRCS = client.cpp hashprobe.grpc.pb.cc hashprobe.pb.cc uci.cpp OBJS = $(notdir $(SRCS:.cpp=.o)) +CLIOBJS = $(notdir $(CLISRCS:.cpp=.o)) VPATH = syzygy:nnue:nnue/features @@ -307,7 +310,7 @@ endif ifeq ($(COMP),gcc) comp=gcc CXX=g++ - CXXFLAGS += -pedantic -Wextra -Wshadow + CXXFLAGS += -pedantic -Wextra ifeq ($(arch),$(filter $(arch),armv7 armv8)) ifeq ($(OS),Android) @@ -466,7 +469,7 @@ endif ### 3.3 Optimization ifeq ($(optimize),yes) - CXXFLAGS += -O3 + CXXFLAGS += -O3 -g ifeq ($(comp),gcc) ifeq ($(OS), Android) @@ -782,7 +785,7 @@ default: ### Section 5. Private Targets ### ========================================================================== -all: $(EXE) .depend +all: $(EXE) client .depend config-sanity: net @echo "" @@ -878,6 +881,32 @@ icc-profile-use: EXTRACXXFLAGS='-prof_use -prof_dir ./profdir' \ all +### GRPC + +PROTOS_PATH = . +PROTOC = protoc +GRPC_CPP_PLUGIN = grpc_cpp_plugin +GRPC_CPP_PLUGIN_PATH ?= `which $(GRPC_CPP_PLUGIN)` + +%.grpc.pb.h %.grpc.pb.cc: %.proto + $(PROTOC) -I $(PROTOS_PATH) --grpc_out=. --plugin=protoc-gen-grpc=$(GRPC_CPP_PLUGIN_PATH) $< + +# oh my +%.cpp: %.cc + cp $< $@ + +%.pb.h %.pb.cc: %.proto + $(PROTOC) -I $(PROTOS_PATH) --cpp_out=. $< + +#LDFLAGS += -Wl,-Bstatic -Wl,-\( -lprotobuf -lgrpc++_unsecure -lgrpc_unsecure -lgrpc -lz -Wl,-\) -Wl,-Bdynamic -ldl +LDFLAGS += /usr/lib/x86_64-linux-gnu/libprotobuf.a /usr/lib/x86_64-linux-gnu/libgrpc++_unsecure.a /usr/lib/x86_64-linux-gnu/libgrpc_unsecure.a /usr/lib/x86_64-linux-gnu/libgrpc.a /usr/lib/x86_64-linux-gnu/libcares.a -ldl -lz +#LDFLAGS += /usr/lib/x86_64-linux-gnu/libprotobuf.a /usr/lib/libgrpc++_unsecure.a /usr/lib/libgrpc_unsecure.a /usr/lib/libgrpc.a /usr/lib/x86_64-linux-gnu/libcares.a -ldl -lz + +client: $(CLIOBJS) + $(CXX) -o $@ $(CLIOBJS) $(LDFLAGS) + +# Other stuff + .depend: -@$(CXX) $(DEPENDFLAGS) -MM $(SRCS) > $@ 2> /dev/null diff --git a/src/client.cpp b/src/client.cpp new file mode 100644 index 00000000..a59b961a --- /dev/null +++ b/src/client.cpp @@ -0,0 +1,81 @@ +#include +#include +#include + +#include + +#include "hashprobe.grpc.pb.h" +#include "types.h" +#include "uci.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using namespace hashprobe; + +std::string FormatMove(const HashProbeMove &move) { + if (move.pretty().empty()) return "MOVE_NONE"; + return move.pretty(); +} + +int main(int argc, char** argv) { + std::shared_ptr channel(grpc::CreateChannel( + "localhost:50051", grpc::InsecureChannelCredentials())); + std::unique_ptr stub(HashProbe::NewStub(channel)); + + for ( ;; ) { + char buf[256]; + if (fgets(buf, sizeof(buf), stdin) == nullptr || buf[0] == '\n') { + exit(0); + } + + char *ptr = strchr(buf, '\n'); + if (ptr != nullptr) *ptr = 0; + + HashProbeRequest request; + request.set_fen(buf); + + HashProbeResponse response; + ClientContext context; + Status status = stub->Probe(&context, request, &response); + + if (status.ok()) { + for (const HashProbeLine &line : response.line()) { + std::cout << FormatMove(line.move()) << " "; + std::cout << line.found() << " "; + for (const HashProbeMove &move : line.pv()) { + std::cout << FormatMove(move) << ","; + } + std::cout << " "; + switch (line.bound()) { + case HashProbeLine::BOUND_NONE: + std::cout << "?"; + break; + case HashProbeLine::BOUND_EXACT: + std::cout << "=="; + break; + case HashProbeLine::BOUND_UPPER: + std::cout << "<="; + break; + case HashProbeLine::BOUND_LOWER: + std::cout << ">="; + break; + } + switch (line.value().score_type()) { + case HashProbeScore::SCORE_CP: + std::cout << " cp " << line.value().score_cp() << " "; + break; + case HashProbeScore::SCORE_MATE: + std::cout << " mate " << line.value().score_mate() << " "; + break; + } + std::cout << line.depth() << std::endl; + } + std::cout << "END" << std::endl; + } else { + std::cout << "ERROR" << std::endl; + } + } + + return 0; +} diff --git a/src/hashprobe.h b/src/hashprobe.h new file mode 100644 index 00000000..999266d5 --- /dev/null +++ b/src/hashprobe.h @@ -0,0 +1,37 @@ +#ifndef HASHPROBE_H_INCLUDED +#define HASHPROBE_H_INCLUDED + +#include "types.h" + +#include +#include + +#include +#include +#include +#include "hashprobe.grpc.pb.h" + +class HashProbeImpl final : public hashprobe::HashProbe::Service { +public: + grpc::Status Probe(grpc::ServerContext* context, + const hashprobe::HashProbeRequest* request, + hashprobe::HashProbeResponse *response); + +private: + void FillMove(Position* pos, Move move, hashprobe::HashProbeMove* decoded); + void ProbeMove(Position* pos, std::deque* setup_states, bool invert, hashprobe::HashProbeLine* response); + void FillValue(Value value, hashprobe::HashProbeScore* score); +}; + +class HashProbeThread { +public: + HashProbeThread(const std::string &server_address); + void Shutdown(); + +private: + HashProbeImpl service; + grpc::ServerBuilder builder; + std::unique_ptr server; +}; + +#endif diff --git a/src/hashprobe.proto b/src/hashprobe.proto new file mode 100644 index 00000000..175cd773 --- /dev/null +++ b/src/hashprobe.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; +package hashprobe; + +message HashProbeRequest { + string fen = 1; +} +message HashProbeResponse { + HashProbeLine root = 2; + repeated HashProbeLine line = 1; +} +message HashProbeLine { + HashProbeMove move = 1; + bool found = 2; + + repeated HashProbeMove pv = 3; + HashProbeScore value = 4; // Dynamic eval (may be inexact, see the "bound" field) + HashProbeScore eval = 5; // Static eval + int32 depth = 6; + + enum ValueBound { + BOUND_NONE = 0; + BOUND_UPPER = 1; + BOUND_LOWER = 2; + BOUND_EXACT = 3; + }; + ValueBound bound = 7; +} + +message HashProbeMove { + string from_sq = 1; // a1, a2, etc. + string to_sq = 2; + string promotion = 3; // Q, R, etc. + + string pretty = 4; // e.g. Rxf6+ +} +message HashProbeScore { + enum ScoreType { + SCORE_NONE = 0; + SCORE_CP = 1; + SCORE_MATE = 2; + } + ScoreType score_type = 1; + int32 score_cp = 2; + int32 score_mate = 3; +} + +service HashProbe { + rpc Probe(HashProbeRequest) returns (HashProbeResponse) {} +} diff --git a/src/main.cpp b/src/main.cpp index e6dff918..e8548ac1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -16,7 +16,10 @@ along with this program. If not, see . */ +#include #include +#include +#include #include "bitboard.h" #include "endgame.h" @@ -27,6 +30,197 @@ #include "uci.h" #include "syzygy/tbprobe.h" +#include +#include +#include +#include "hashprobe.h" +#include "hashprobe.grpc.pb.h" +#include "tt.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using grpc::StatusCode; +using namespace hashprobe; + +Status HashProbeImpl::Probe(ServerContext* context, + const HashProbeRequest* request, + HashProbeResponse *response) { + Position pos; + StateInfo st; + pos.set(request->fen(), /*isChess960=*/false, &st, Threads.main()); + if (!pos.pos_is_ok()) { + return Status(StatusCode::INVALID_ARGUMENT, "Invalid FEN"); + } + + bool invert = (pos.side_to_move() == BLACK); + StateListPtr setup_states = StateListPtr(new std::deque(1)); + + ProbeMove(&pos, setup_states.get(), invert, response->mutable_root()); + + MoveList moves(pos); + for (const ExtMove* em = moves.begin(); em != moves.end(); ++em) { + HashProbeLine *line = response->add_line(); + FillMove(&pos, em->move, line->mutable_move()); + setup_states->push_back(StateInfo()); + pos.do_move(em->move, setup_states->back()); + ProbeMove(&pos, setup_states.get(), !invert, line); + pos.undo_move(em->move); + } + + return Status::OK; +} + +void HashProbeImpl::FillMove(Position *pos, Move move, HashProbeMove* decoded) { + if (!is_ok(move)) return; + + Square from = from_sq(move); + Square to = to_sq(move); + + if (type_of(move) == CASTLING) { + to = make_square(to > from ? FILE_G : FILE_C, rank_of(from)); + } + + Piece moved_piece = pos->moved_piece(move); + std::string pretty; + if (type_of(move) == CASTLING) { + if (to > from) { + pretty = "O-O"; + } else { + pretty = "O-O-O"; + } + } else if (type_of(moved_piece) == PAWN) { + if (type_of(move) == ENPASSANT || pos->piece_on(to) != NO_PIECE) { + // Capture. + pretty = char('a' + file_of(from)); + pretty += "x"; + } + pretty += UCI::square(to); + if (type_of(move) == PROMOTION) { + pretty += "="; + pretty += " PNBRQK"[promotion_type(move)]; + } + } else { + pretty = " PNBRQK"[type_of(moved_piece)]; + Bitboard attackers = pos->attackers_to(to) & pos->pieces(color_of(moved_piece), type_of(moved_piece)); + if (more_than_one(attackers)) { + // Remove all illegal moves to disambiguate. + Bitboard att_copy = attackers; + while (att_copy) { + Square s = pop_lsb(&att_copy); + Move m = make_move(s, to); + if (!pos->pseudo_legal(m) || !pos->legal(m)) { + attackers &= ~SquareBB[s]; + } + } + } + if (more_than_one(attackers)) { + // Disambiguate by file if possible. + Bitboard attackers_this_file = attackers & file_bb(file_of(from)); + if (attackers != attackers_this_file) { + pretty += char('a' + file_of(from)); + attackers = attackers_this_file; + } + if (more_than_one(attackers)) { + // Still ambiguous, so need to disambiguate by rank. + pretty += char('1' + rank_of(from)); + } + } + + if (type_of(move) == ENPASSANT || pos->piece_on(to) != NO_PIECE) { + pretty += "x"; + } + + pretty += UCI::square(to); + } + + if (pos->gives_check(move)) { + // Check if mate. + StateInfo si; + pos->do_move(move, si, true); + if (MoveList(*pos).size() > 0) { + pretty += "+"; + } else { + pretty += "#"; + } + pos->undo_move(move); + } + + decoded->set_pretty(pretty); +} + +void HashProbeImpl::ProbeMove(Position* pos, std::deque* setup_states, bool invert, HashProbeLine* response) { + bool found; + TTEntry *entry = TT.probe(pos->key(), found); + response->set_found(found); + if (found) { + Value value = entry->value(); + Value eval = entry->eval(); + Bound bound = entry->bound(); + + if (invert) { + value = -value; + eval = -eval; + if (bound == BOUND_UPPER) { + bound = BOUND_LOWER; + } else if (bound == BOUND_LOWER) { + bound = BOUND_UPPER; + } + } + + response->set_depth(entry->depth()); + FillValue(eval, response->mutable_eval()); + if (entry->depth() > DEPTH_NONE) { + FillValue(value, response->mutable_value()); + } + response->set_bound(HashProbeLine::ValueBound(bound)); + + // Follow the PV until we hit an illegal move. + std::stack pv; + std::set seen; + while (found && is_ok(entry->move()) && + pos->pseudo_legal(entry->move()) && + pos->legal(entry->move())) { + FillMove(pos, entry->move(), response->add_pv()); + if (seen.count(pos->key())) break; + pv.push(entry->move()); + seen.insert(pos->key()); + setup_states->push_back(StateInfo()); + pos->do_move(entry->move(), setup_states->back()); + entry = TT.probe(pos->key(), found); + } + + // Unroll the PV back again, so the Position object remains unchanged. + while (!pv.empty()) { + pos->undo_move(pv.top()); + pv.pop(); + } + } +} + +void HashProbeImpl::FillValue(Value value, HashProbeScore* score) { + if (abs(value) < VALUE_MATE - MAX_PLY) { + score->set_score_type(HashProbeScore::SCORE_CP); + score->set_score_cp(value * 100 / PawnValueEg); + } else { + score->set_score_type(HashProbeScore::SCORE_MATE); + score->set_score_mate((value > 0 ? VALUE_MATE - value + 1 : -VALUE_MATE - value) / 2); + } +} + +HashProbeThread::HashProbeThread(const std::string &server_address) { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server = std::move(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + std::thread([this]{ server->Wait(); }).detach(); +} + +void HashProbeThread::Shutdown() { + server->Shutdown(); +} + namespace PSQT { void init(); } diff --git a/src/misc.cpp b/src/misc.cpp index f2bce6b0..832a9ac1 100644 --- a/src/misc.cpp +++ b/src/misc.cpp @@ -150,6 +150,7 @@ const string engine_info(bool to_uci) { { date >> month >> day >> year; ss << setw(2) << day << setw(2) << (1 + months.find(month) / 4) << year.substr(2); + ss << "-asn"; } ss << (to_uci ? "\nid author ": " by ") diff --git a/src/position.cpp b/src/position.cpp index 07ce0a7c..13010c1a 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -283,8 +283,6 @@ Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Th st->accumulator.state[WHITE] = Eval::NNUE::INIT; st->accumulator.state[BLACK] = Eval::NNUE::INIT; - assert(pos_is_ok()); - return *this; } diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index 4d682f1a..28b70a4a 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -74,7 +74,7 @@ int MapB1H1H7[SQUARE_NB]; int MapA1D1D4[SQUARE_NB]; int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB] -int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements +int Binomial[7][SQUARE_NB]; // [k][n] k elements from a set of n elements int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] int LeadPawnsSize[6][4]; // [leadPawnsCnt][FILE_A..FILE_D] @@ -1321,7 +1321,7 @@ void Tablebases::init(const std::string& paths) { Binomial[0][0] = 1; for (int n = 1; n < 64; n++) // Squares - for (int k = 0; k < 6 && k <= n; ++k) // Pieces + for (int k = 0; k < 7 && k <= n; ++k) // Pieces Binomial[k][n] = (k > 0 ? Binomial[k - 1][n - 1] : 0) + (k < n ? Binomial[k ][n - 1] : 0); diff --git a/src/ucioption.cpp b/src/ucioption.cpp index bb0b8311..71239159 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -27,11 +27,13 @@ #include "thread.h" #include "tt.h" #include "uci.h" +#include "hashprobe.h" #include "syzygy/tbprobe.h" using std::string; UCI::OptionsMap Options; // Global object +std::unique_ptr hash_probe_thread; namespace UCI { @@ -43,6 +45,13 @@ void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_tb_path(const Option& o) { Tablebases::init(o); } void on_use_NNUE(const Option& ) { Eval::NNUE::init(); } void on_eval_file(const Option& ) { Eval::NNUE::init(); } +void on_rpc_server_address(const Option& o) { + if (hash_probe_thread) { + hash_probe_thread->Shutdown(); + } + std::string addr = o; + hash_probe_thread.reset(new HashProbeThread(addr)); +} /// Our case insensitive less() function as required by UCI protocol bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { @@ -81,6 +90,7 @@ void init(OptionsMap& o) { o["SyzygyProbeLimit"] << Option(7, 0, 7); o["Use NNUE"] << Option(true, on_use_NNUE); o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file); + o["RPCServerAddress"] << Option("", on_rpc_server_address); }