From: Steinar H. Gunderson Date: Sat, 15 Jan 2022 10:01:13 +0000 (+0100) Subject: Merge remote-tracking branch 'upstream/master' X-Git-Url: https://git.sesse.net/?a=commitdiff_plain;h=c357c4ad6f7318234c4d745eaa6b0c4774e28741;hp=7678d63cf2323e51c01e60cdff4ac3d685313790;p=stockfish Merge remote-tracking branch 'upstream/master' --- diff --git a/src/Makefile b/src/Makefile index 0e889888..aeb8ba9d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -41,9 +41,12 @@ endif SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp \ material.cpp misc.cpp movegen.cpp movepick.cpp pawns.cpp position.cpp psqt.cpp \ search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ - nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp + nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp \ + hashprobe.grpc.pb.cc hashprobe.pb.cc +CLISRCS = client.cpp hashprobe.grpc.pb.cc hashprobe.pb.cc uci.cpp OBJS = $(notdir $(SRCS:.cpp=.o)) +CLIOBJS = $(notdir $(CLISRCS:.cpp=.o)) VPATH = syzygy:nnue:nnue/features @@ -344,7 +347,7 @@ endif ifeq ($(COMP),gcc) comp=gcc CXX=g++ - CXXFLAGS += -pedantic -Wextra -Wshadow + CXXFLAGS += -pedantic -Wextra ifeq ($(arch),$(filter $(arch),armv7 armv8)) ifeq ($(OS),Android) @@ -521,7 +524,7 @@ endif ### 3.3 Optimization ifeq ($(optimize),yes) - CXXFLAGS += -O3 + CXXFLAGS += -O3 -g ifeq ($(comp),gcc) ifeq ($(OS), Android) @@ -858,7 +861,7 @@ default: ### Section 5. Private Targets ### ========================================================================== -all: $(EXE) .depend +all: $(EXE) client .depend config-sanity: net @echo "" @@ -958,6 +961,32 @@ icc-profile-use: EXTRACXXFLAGS='-prof_use -prof_dir ./profdir' \ all +### GRPC + +PROTOS_PATH = . +PROTOC = protoc +GRPC_CPP_PLUGIN = grpc_cpp_plugin +GRPC_CPP_PLUGIN_PATH ?= `which $(GRPC_CPP_PLUGIN)` + +%.grpc.pb.h %.grpc.pb.cc: %.proto + $(PROTOC) -I $(PROTOS_PATH) --grpc_out=. --plugin=protoc-gen-grpc=$(GRPC_CPP_PLUGIN_PATH) $< + +# oh my +%.cpp: %.cc + cp $< $@ + +%.pb.h %.pb.cc: %.proto + $(PROTOC) -I $(PROTOS_PATH) --cpp_out=. $< + +#LDFLAGS += -Wl,-Bstatic -Wl,-\( -lprotobuf -lgrpc++_unsecure -lgrpc_unsecure -lgrpc -lz -Wl,-\) -Wl,-Bdynamic -ldl +LDFLAGS += /usr/lib/x86_64-linux-gnu/libprotobuf.a /usr/lib/x86_64-linux-gnu/libgrpc++_unsecure.a /usr/lib/x86_64-linux-gnu/libgrpc_unsecure.a /usr/lib/x86_64-linux-gnu/libgrpc.a /usr/lib/x86_64-linux-gnu/libcares.a /usr/lib/x86_64-linux-gnu/libgpr.a /usr/lib/x86_64-linux-gnu/libabsl_str_format_internal.a /usr/lib/x86_64-linux-gnu/libabsl_strings.a /usr/lib/x86_64-linux-gnu/libabsl_flags_marshalling.a /usr/lib/x86_64-linux-gnu/libabsl_throw_delegate.a /usr/lib/x86_64-linux-gnu/libabsl_raw_logging_internal.a /usr/lib/x86_64-linux-gnu/libabsl_base.a /usr/lib/x86_64-linux-gnu/libabsl_int128.a /usr/lib/x86_64-linux-gnu/libabsl_bad_optional_access.a -ldl -lz +#LDFLAGS += /usr/lib/x86_64-linux-gnu/libprotobuf.a /usr/lib/libgrpc++_unsecure.a /usr/lib/libgrpc_unsecure.a /usr/lib/libgrpc.a /usr/lib/x86_64-linux-gnu/libcares.a -ldl -lz + +client: $(CLIOBJS) + $(CXX) -o $@ $(CLIOBJS) $(LDFLAGS) + +# Other stuff + .depend: $(SRCS) -@$(CXX) $(DEPENDFLAGS) -MM $(SRCS) > $@ 2> /dev/null diff --git a/src/client.cpp b/src/client.cpp new file mode 100644 index 00000000..0400f933 --- /dev/null +++ b/src/client.cpp @@ -0,0 +1,81 @@ +#include +#include +#include + +#include + +#include "hashprobe.grpc.pb.h" +#include "types.h" +#include "uci.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using namespace hashprobe; + +std::string FormatMove(const HashProbeMove &move) { + if (move.pretty().empty()) return "MOVE_NONE"; + return move.pretty(); +} + +int main(int argc, char** argv) { + std::shared_ptr channel(grpc::CreateChannel( + "localhost:50051", grpc::InsecureChannelCredentials())); + std::unique_ptr stub(HashProbe::NewStub(channel)); + + for ( ;; ) { + char buf[256]; + if (fgets(buf, sizeof(buf), stdin) == nullptr || buf[0] == '\n') { + exit(0); + } + + char *ptr = strchr(buf, '\n'); + if (ptr != nullptr) *ptr = 0; + + HashProbeRequest request; + request.set_fen(buf); + + HashProbeResponse response; + ClientContext context; + Status status = stub->Probe(&context, request, &response); + + if (status.ok()) { + for (const HashProbeLine &line : response.line()) { + std::cout << FormatMove(line.move()) << " "; + std::cout << line.found() << " "; + for (const HashProbeMove &move : line.pv()) { + std::cout << FormatMove(move) << ","; + } + std::cout << " "; + switch (line.bound()) { + case HashProbeLine::BOUND_NONE: + std::cout << "?"; + break; + case HashProbeLine::BOUND_EXACT: + std::cout << "=="; + break; + case HashProbeLine::BOUND_UPPER: + std::cout << "<="; + break; + case HashProbeLine::BOUND_LOWER: + std::cout << ">="; + break; + } + switch (line.value().score_type()) { + case HashProbeScore::SCORE_CP: + std::cout << " cp " << line.value().score_cp() << " "; + break; + case HashProbeScore::SCORE_MATE: + std::cout << " mate " << line.value().score_mate() << " "; + break; + } + std::cout << line.depth() << std::endl; + } + std::cout << "END" << std::endl; + } else { + std::cout << "ERROR" << std::endl; + } + } + + return 0; +} diff --git a/src/hashprobe.h b/src/hashprobe.h new file mode 100644 index 00000000..42c13eef --- /dev/null +++ b/src/hashprobe.h @@ -0,0 +1,38 @@ +#ifndef HASHPROBE_H_INCLUDED +#define HASHPROBE_H_INCLUDED + +#include "position.h" +#include "types.h" + +#include +#include + +#include +#include +#include +#include "hashprobe.grpc.pb.h" + +class HashProbeImpl final : public hashprobe::HashProbe::Service { +public: + grpc::Status Probe(grpc::ServerContext* context, + const hashprobe::HashProbeRequest* request, + hashprobe::HashProbeResponse *response); + +private: + void FillMove(Stockfish::Position* pos, Stockfish::Move move, hashprobe::HashProbeMove* decoded); + void ProbeMove(Stockfish::Position* pos, std::deque* setup_states, bool invert, hashprobe::HashProbeLine* response); + void FillValue(Stockfish::Value value, hashprobe::HashProbeScore* score); +}; + +class HashProbeThread { +public: + HashProbeThread(const std::string &server_address); + void Shutdown(); + +private: + HashProbeImpl service; + grpc::ServerBuilder builder; + std::unique_ptr server; +}; + +#endif diff --git a/src/hashprobe.proto b/src/hashprobe.proto new file mode 100644 index 00000000..175cd773 --- /dev/null +++ b/src/hashprobe.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; +package hashprobe; + +message HashProbeRequest { + string fen = 1; +} +message HashProbeResponse { + HashProbeLine root = 2; + repeated HashProbeLine line = 1; +} +message HashProbeLine { + HashProbeMove move = 1; + bool found = 2; + + repeated HashProbeMove pv = 3; + HashProbeScore value = 4; // Dynamic eval (may be inexact, see the "bound" field) + HashProbeScore eval = 5; // Static eval + int32 depth = 6; + + enum ValueBound { + BOUND_NONE = 0; + BOUND_UPPER = 1; + BOUND_LOWER = 2; + BOUND_EXACT = 3; + }; + ValueBound bound = 7; +} + +message HashProbeMove { + string from_sq = 1; // a1, a2, etc. + string to_sq = 2; + string promotion = 3; // Q, R, etc. + + string pretty = 4; // e.g. Rxf6+ +} +message HashProbeScore { + enum ScoreType { + SCORE_NONE = 0; + SCORE_CP = 1; + SCORE_MATE = 2; + } + ScoreType score_type = 1; + int32 score_cp = 2; + int32 score_mate = 3; +} + +service HashProbe { + rpc Probe(HashProbeRequest) returns (HashProbeResponse) {} +} diff --git a/src/main.cpp b/src/main.cpp index fad0ef84..435e436c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -16,7 +16,10 @@ along with this program. If not, see . */ +#include #include +#include +#include #include "bitboard.h" #include "endgame.h" @@ -28,8 +31,203 @@ #include "tt.h" #include "uci.h" +#include +#include +#include +#include "hashprobe.h" +#include "hashprobe.grpc.pb.h" +#include "tt.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using grpc::StatusCode; +using namespace hashprobe; using namespace Stockfish; +Status HashProbeImpl::Probe(ServerContext* context, + const HashProbeRequest* request, + HashProbeResponse *response) { + Position pos; + StateInfo st; + pos.set(request->fen(), /*isChess960=*/false, &st, Threads.main()); + if (!pos.pos_is_ok()) { + return Status(StatusCode::INVALID_ARGUMENT, "Invalid FEN"); + } + + bool invert = (pos.side_to_move() == BLACK); + StateListPtr setup_states = StateListPtr(new std::deque(1)); + + ProbeMove(&pos, setup_states.get(), invert, response->mutable_root()); + + MoveList moves(pos); + for (const ExtMove* em = moves.begin(); em != moves.end(); ++em) { + HashProbeLine *line = response->add_line(); + FillMove(&pos, em->move, line->mutable_move()); + setup_states->push_back(StateInfo()); + pos.do_move(em->move, setup_states->back()); + ProbeMove(&pos, setup_states.get(), !invert, line); + pos.undo_move(em->move); + } + + return Status::OK; +} + +void HashProbeImpl::FillMove(Position *pos, Move move, HashProbeMove* decoded) { + if (!is_ok(move)) return; + + Square from = from_sq(move); + Square to = to_sq(move); + + if (type_of(move) == CASTLING) { + to = make_square(to > from ? FILE_G : FILE_C, rank_of(from)); + } + + Piece moved_piece = pos->moved_piece(move); + std::string pretty; + if (type_of(move) == CASTLING) { + if (to > from) { + pretty = "O-O"; + } else { + pretty = "O-O-O"; + } + } else if (type_of(moved_piece) == PAWN) { + if (type_of(move) == EN_PASSANT || pos->piece_on(to) != NO_PIECE) { + // Capture. + pretty = char('a' + file_of(from)); + pretty += "x"; + } + pretty += UCI::square(to); + if (type_of(move) == PROMOTION) { + pretty += "="; + pretty += " PNBRQK"[promotion_type(move)]; + } + } else { + pretty = " PNBRQK"[type_of(moved_piece)]; + Bitboard attackers = pos->attackers_to(to) & pos->pieces(color_of(moved_piece), type_of(moved_piece)); + if (more_than_one(attackers)) { + // Remove all illegal moves to disambiguate. + Bitboard att_copy = attackers; + while (att_copy) { + Square s = pop_lsb(att_copy); + Move m = make_move(s, to); + if (!pos->pseudo_legal(m) || !pos->legal(m)) { + attackers &= ~SquareBB[s]; + } + } + } + if (more_than_one(attackers)) { + // Disambiguate by file if possible. + Bitboard attackers_this_file = attackers & file_bb(file_of(from)); + if (attackers != attackers_this_file) { + pretty += char('a' + file_of(from)); + attackers = attackers_this_file; + } + if (more_than_one(attackers)) { + // Still ambiguous, so need to disambiguate by rank. + pretty += char('1' + rank_of(from)); + } + } + + if (type_of(move) == EN_PASSANT || pos->piece_on(to) != NO_PIECE) { + pretty += "x"; + } + + pretty += UCI::square(to); + } + + if (pos->gives_check(move)) { + // Check if mate. + StateInfo si; + pos->do_move(move, si, true); + if (MoveList(*pos).size() > 0) { + pretty += "+"; + } else { + pretty += "#"; + } + pos->undo_move(move); + } + + decoded->set_pretty(pretty); +} + +void HashProbeImpl::ProbeMove(Position* pos, std::deque* setup_states, bool invert, HashProbeLine* response) { + bool found; + TTEntry *entry = TT.probe(pos->key(), found); + response->set_found(found); + if (found) { + TTEntry entry_copy = *entry; + Value value = entry_copy.value(); + Value eval = entry_copy.eval(); + Bound bound = entry_copy.bound(); + + if (invert) { + value = -value; + eval = -eval; + if (bound == BOUND_UPPER) { + bound = BOUND_LOWER; + } else if (bound == BOUND_LOWER) { + bound = BOUND_UPPER; + } + } + + response->set_depth(entry_copy.depth()); + FillValue(eval, response->mutable_eval()); + if (entry_copy.depth() > DEPTH_NONE) { + FillValue(value, response->mutable_value()); + } + response->set_bound(HashProbeLine::ValueBound(bound)); + + // Follow the PV until we hit an illegal move. + std::stack pv; + std::set seen; + while (is_ok(entry_copy.move()) && + pos->pseudo_legal(entry_copy.move()) && + pos->legal(entry_copy.move())) { + FillMove(pos, entry_copy.move(), response->add_pv()); + if (seen.count(pos->key())) break; + pv.push(entry_copy.move()); + seen.insert(pos->key()); + setup_states->push_back(StateInfo()); + pos->do_move(entry_copy.move(), setup_states->back()); + entry = TT.probe(pos->key(), found); + if (!found) { + break; + } + entry_copy = *entry; + } + + // Unroll the PV back again, so the Position object remains unchanged. + while (!pv.empty()) { + pos->undo_move(pv.top()); + pv.pop(); + } + } +} + +void HashProbeImpl::FillValue(Value value, HashProbeScore* score) { + if (abs(value) < VALUE_MATE - MAX_PLY) { + score->set_score_type(HashProbeScore::SCORE_CP); + score->set_score_cp(value * 100 / PawnValueEg); + } else { + score->set_score_type(HashProbeScore::SCORE_MATE); + score->set_score_mate((value > 0 ? VALUE_MATE - value + 1 : -VALUE_MATE - value) / 2); + } +} + +HashProbeThread::HashProbeThread(const std::string &server_address) { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server = std::move(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + std::thread([this]{ server->Wait(); }).detach(); +} + +void HashProbeThread::Shutdown() { + server->Shutdown(); +} + int main(int argc, char* argv[]) { std::cout << engine_info() << std::endl; diff --git a/src/misc.cpp b/src/misc.cpp index 41c59b3f..9618ba76 100644 --- a/src/misc.cpp +++ b/src/misc.cpp @@ -155,6 +155,7 @@ string engine_info(bool to_uci) { { date >> month >> day >> year; ss << setw(2) << day << setw(2) << (1 + months.find(month) / 4) << year.substr(2); + ss << "-asn"; } ss << (to_uci ? "\nid author ": " by ") diff --git a/src/position.cpp b/src/position.cpp index ec9229ea..a1643d6d 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -283,8 +283,6 @@ Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Th thisThread = th; set_state(st); - assert(pos_is_ok()); - return *this; } diff --git a/src/ucioption.cpp b/src/ucioption.cpp index 922fa34f..6adbbde1 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -27,6 +27,7 @@ #include "thread.h" #include "tt.h" #include "uci.h" +#include "hashprobe.h" #include "syzygy/tbprobe.h" using std::string; @@ -34,6 +35,7 @@ using std::string; namespace Stockfish { UCI::OptionsMap Options; // Global object +std::unique_ptr hash_probe_thread; namespace UCI { @@ -45,6 +47,13 @@ void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_tb_path(const Option& o) { Tablebases::init(o); } void on_use_NNUE(const Option& ) { Eval::NNUE::init(); } void on_eval_file(const Option& ) { Eval::NNUE::init(); } +void on_rpc_server_address(const Option& o) { + if (hash_probe_thread) { + hash_probe_thread->Shutdown(); + } + std::string addr = o; + hash_probe_thread.reset(new HashProbeThread(addr)); +} /// Our case insensitive less() function as required by UCI protocol bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { @@ -81,6 +90,7 @@ void init(OptionsMap& o) { o["SyzygyProbeLimit"] << Option(7, 0, 7); o["Use NNUE"] << Option(true, on_use_NNUE); o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file); + o["RPCServerAddress"] << Option("", on_rpc_server_address); }