]> git.sesse.net Git - stockfish/commitdiff
Transform search output to engine callbacks
authorDisservin <disservin.social@gmail.com>
Sat, 23 Mar 2024 09:22:20 +0000 (10:22 +0100)
committerDisservin <disservin.social@gmail.com>
Fri, 5 Apr 2024 19:03:58 +0000 (21:03 +0200)
Part 2 of the Split UCI into UCIEngine and Engine refactor.
This creates function callbacks for search to use when an update should occur.
The benching in uci.cpp for example does this to extract the total nodes
searched.

No functional change

12 files changed:
src/Makefile
src/engine.cpp
src/engine.h
src/main.cpp
src/score.cpp [new file with mode: 0644]
src/score.h [new file with mode: 0644]
src/search.cpp
src/search.h
src/thread.cpp
src/thread.h
src/uci.cpp
src/uci.h

index 6315bda82df06cbb5af1276a7aaebb80ee052b91..550f5404d143663f8623c60dcf869ded3ec88c11 100644 (file)
@@ -55,7 +55,7 @@ PGOBENCH = $(WINE_PATH) ./$(EXE) bench
 SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \
        misc.cpp movegen.cpp movepick.cpp position.cpp \
        search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
-       nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp engine.cpp
+       nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp engine.cpp score.cpp
 
 HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
                nnue/nnue_misc.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
@@ -63,7 +63,7 @@ HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
                nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
                nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
                search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \
-               tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h
+               tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h score.h
 
 OBJS = $(notdir $(SRCS:.cpp=.o))
 
index 79a2c6047424272c201f927957aefaee4310664d..12fa5c3fd025d3e7a24de6ce3ce7555f4d2b194e 100644 (file)
 
 #include "engine.h"
 
-#include <algorithm>
-#include <cassert>
-#include <cctype>
-#include <cmath>
-#include <cstdint>
-#include <cstdlib>
 #include <deque>
 #include <memory>
-#include <optional>
-#include <sstream>
+#include <ostream>
+#include <string_view>
+#include <utility>
 #include <vector>
 
-#include "benchmark.h"
 #include "evaluate.h"
-#include "movegen.h"
+#include "misc.h"
 #include "nnue/network.h"
 #include "nnue/nnue_common.h"
 #include "perft.h"
@@ -40,6 +34,7 @@
 #include "search.h"
 #include "syzygy/tbprobe.h"
 #include "types.h"
+#include "uci.h"
 #include "ucioption.h"
 
 namespace Stockfish {
@@ -54,7 +49,6 @@ Engine::Engine(std::string path) :
     networks(NN::Networks(
       NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG),
       NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))) {
-    Tune::init(options);
     pos.set(StartFEN, false, &states->back());
 }
 
@@ -77,10 +71,26 @@ void Engine::search_clear() {
     tt.clear(options["Threads"]);
     threads.clear();
 
-    // @TODO wont work multiple instances
+    // @TODO wont work with multiple instances
     Tablebases::init(options["SyzygyPath"]);  // Free mapped files
 }
 
+void Engine::set_on_update_no_moves(std::function<void(const Engine::InfoShort&)>&& f) {
+    updateContext.onUpdateNoMoves = std::move(f);
+}
+
+void Engine::set_on_update_full(std::function<void(const Engine::InfoFull&)>&& f) {
+    updateContext.onUpdateFull = std::move(f);
+}
+
+void Engine::set_on_iter(std::function<void(const Engine::InfoIter&)>&& f) {
+    updateContext.onIter = std::move(f);
+}
+
+void Engine::set_on_bestmove(std::function<void(std::string_view, std::string_view)>&& f) {
+    updateContext.onBestmove = std::move(f);
+}
+
 void Engine::wait_for_search_finished() { threads.main_thread()->wait_for_search_finished(); }
 
 void Engine::set_position(const std::string& fen, const std::vector<std::string>& moves) {
@@ -102,7 +112,7 @@ void Engine::set_position(const std::string& fen, const std::vector<std::string>
 
 // modifiers
 
-void Engine::resize_threads() { threads.set({options, threads, tt, networks}); }
+void Engine::resize_threads() { threads.set({options, threads, tt, networks}, updateContext); }
 
 void Engine::set_tt_size(size_t mb) {
     wait_for_search_finished();
@@ -113,7 +123,7 @@ void Engine::set_ponderhit(bool b) { threads.main_manager()->ponder = b; }
 
 // network related
 
-void Engine::verify_networks() {
+void Engine::verify_networks() const {
     networks.big.verify(options["EvalFile"]);
     networks.small.verify(options["EvalFileSmall"]);
 }
@@ -138,9 +148,7 @@ void Engine::save_network(const std::pair<std::optional<std::string>, std::strin
 
 OptionsMap& Engine::get_options() { return options; }
 
-uint64_t Engine::nodes_searched() const { return threads.nodes_searched(); }
-
-void Engine::trace_eval() {
+void Engine::trace_eval() const {
     StateListPtr trace_states(new std::deque<StateInfo>(1));
     Position     p;
     p.set(pos.fen(), options["UCI_Chess960"], &trace_states->back());
index 6afc423de93f628e0e5f611779b73adb4ee53dcb..f74209d90954228c925eec97802f9d40c4ef86d4 100644 (file)
 #ifndef ENGINE_H_INCLUDED
 #define ENGINE_H_INCLUDED
 
-#include "misc.h"
+#include <cstddef>
+#include <functional>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
 #include "nnue/network.h"
 #include "position.h"
 #include "search.h"
@@ -31,6 +38,10 @@ namespace Stockfish {
 
 class Engine {
    public:
+    using InfoShort = Search::InfoShort;
+    using InfoFull  = Search::InfoFull;
+    using InfoIter  = Search::InfoIteration;
+
     Engine(std::string path = "");
     ~Engine() { wait_for_search_finished(); }
 
@@ -51,9 +62,14 @@ class Engine {
     void set_ponderhit(bool);
     void search_clear();
 
+    void set_on_update_no_moves(std::function<void(const InfoShort&)>&&);
+    void set_on_update_full(std::function<void(const InfoFull&)>&&);
+    void set_on_iter(std::function<void(const InfoIter&)>&&);
+    void set_on_bestmove(std::function<void(std::string_view, std::string_view)>&&);
+
     // network related
 
-    void verify_networks();
+    void verify_networks() const;
     void load_networks();
     void load_big_network(const std::string& file);
     void load_small_network(const std::string& file);
@@ -61,9 +77,7 @@ class Engine {
 
     // utility functions
 
-    void trace_eval();
-    // nodes since last search clear
-    uint64_t    nodes_searched() const;
+    void        trace_eval() const;
     OptionsMap& get_options();
 
    private:
@@ -76,6 +90,8 @@ class Engine {
     ThreadPool           threads;
     TranspositionTable   tt;
     Eval::NNUE::Networks networks;
+
+    Search::SearchManager::UpdateContext updateContext;
 };
 
 }  // namespace Stockfish
index 4e72c00398a477f6de38bdb39b39546b772b074d..a6a3d1c4e857ecc6ee66c0895702ffeb0bd5d2cf 100644 (file)
@@ -21,9 +21,9 @@
 #include "bitboard.h"
 #include "misc.h"
 #include "position.h"
-#include "tune.h"
 #include "types.h"
 #include "uci.h"
+#include "tune.h"
 
 using namespace Stockfish;
 
@@ -35,6 +35,9 @@ int main(int argc, char* argv[]) {
     Position::init();
 
     UCIEngine uci(argc, argv);
+
+    Tune::init(uci.engine_options());
+
     uci.loop();
 
     return 0;
diff --git a/src/score.cpp b/src/score.cpp
new file mode 100644 (file)
index 0000000..d1a8a6a
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+  Stockfish, a UCI chess playing engine derived from Glaurung 2.1
+  Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file)
+
+  Stockfish is free software: you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation, either version 3 of the License, or
+  (at your option) any later version.
+
+  Stockfish is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#include "score.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+
+#include "uci.h"
+
+namespace Stockfish {
+
+Score::Score(Value v, const Position& pos) {
+    assert(-VALUE_INFINITE < v && v < VALUE_INFINITE);
+
+    if (std::abs(v) < VALUE_TB_WIN_IN_MAX_PLY)
+    {
+        score = InternalUnits{UCIEngine::to_cp(v, pos)};
+    }
+    else if (std::abs(v) <= VALUE_TB)
+    {
+        auto distance = VALUE_TB - std::abs(v);
+        score         = (v > 0) ? TBWin{distance} : TBWin{-distance};
+    }
+    else
+    {
+        auto distance = VALUE_MATE - std::abs(v);
+        score         = (v > 0) ? Mate{distance} : Mate{-distance};
+    }
+}
+
+}
\ No newline at end of file
diff --git a/src/score.h b/src/score.h
new file mode 100644 (file)
index 0000000..b94d9f7
--- /dev/null
@@ -0,0 +1,69 @@
+/*
+  Stockfish, a UCI chess playing engine derived from Glaurung 2.1
+  Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file)
+
+  Stockfish is free software: you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation, either version 3 of the License, or
+  (at your option) any later version.
+
+  Stockfish is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#ifndef SCORE_H_INCLUDED
+#define SCORE_H_INCLUDED
+
+#include <variant>
+#include <utility>
+
+#include "types.h"
+
+namespace Stockfish {
+
+class Position;
+
+class Score {
+   public:
+    struct Mate {
+        int plies;
+    };
+
+    struct TBWin {
+        int plies;
+    };
+
+    struct InternalUnits {
+        int value;
+    };
+
+    Score() = default;
+    Score(Value v, const Position& pos);
+
+    template<typename T>
+    bool is() const {
+        return std::holds_alternative<T>(score);
+    }
+
+    template<typename T>
+    T get() const {
+        return std::get<T>(score);
+    }
+
+    template<typename F>
+    decltype(auto) visit(F&& f) const {
+        return std::visit(std::forward<F>(f), score);
+    }
+
+   private:
+    std::variant<Mate, TBWin, InternalUnits> score;
+};
+
+}
+
+#endif  // #ifndef SCORE_H_INCLUDED
index efc0075061322e62bc1ecb6f49a880355d069d3c..51cd1ae11dc85f6b021e26b3f800e5cc7eb007dc 100644 (file)
@@ -26,8 +26,7 @@
 #include <cstdint>
 #include <cstdlib>
 #include <initializer_list>
-#include <iostream>
-#include <sstream>
+#include <string>
 #include <utility>
 
 #include "evaluate.h"
@@ -157,9 +156,8 @@ void Search::Worker::start_searching() {
     if (rootMoves.empty())
     {
         rootMoves.emplace_back(Move::none());
-        sync_cout << "info depth 0 score "
-                  << UCIEngine::to_score(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW, rootPos)
-                  << sync_endl;
+        main_manager()->updates.onUpdateNoMoves(
+          {0, {rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW, rootPos}});
     }
     else
     {
@@ -201,18 +199,16 @@ void Search::Worker::start_searching() {
 
     // Send again PV info if we have a new best thread
     if (bestThread != this)
-        sync_cout << main_manager()->pv(*bestThread, threads, tt, bestThread->completedDepth)
-                  << sync_endl;
+        main_manager()->pv(*bestThread, threads, tt, bestThread->completedDepth);
 
-    sync_cout << "bestmove "
-              << UCIEngine::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960());
+    std::string ponder;
 
     if (bestThread->rootMoves[0].pv.size() > 1
         || bestThread->rootMoves[0].extract_ponder_from_tt(tt, rootPos))
-        std::cout << " ponder "
-                  << UCIEngine::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960());
+        ponder = UCIEngine::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960());
 
-    std::cout << sync_endl;
+    auto bestmove = UCIEngine::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960());
+    main_manager()->updates.onBestmove(bestmove, ponder);
 }
 
 // Main iterative deepening loop. It calls search()
@@ -345,7 +341,7 @@ void Search::Worker::iterative_deepening() {
                 // the UI) before a re-search.
                 if (mainThread && multiPV == 1 && (bestValue <= alpha || bestValue >= beta)
                     && mainThread->tm.elapsed(threads.nodes_searched()) > 3000)
-                    sync_cout << main_manager()->pv(*this, threads, tt, rootDepth) << sync_endl;
+                    main_manager()->pv(*this, threads, tt, rootDepth);
 
                 // In case of failing low/high increase aspiration window and
                 // re-search, otherwise exit the loop.
@@ -382,7 +378,7 @@ void Search::Worker::iterative_deepening() {
                 // had time to fully search other root-moves. Thus we suppress this output and
                 // below pick a proven score/PV for this thread (from the previous iteration).
                 && !(threads.abortedSearch && rootMoves[0].uciScore <= VALUE_TB_LOSS_IN_MAX_PLY))
-                sync_cout << main_manager()->pv(*this, threads, tt, rootDepth) << sync_endl;
+                main_manager()->pv(*this, threads, tt, rootDepth);
         }
 
         if (!threads.stop)
@@ -934,9 +930,10 @@ moves_loop:  // When in check, search starts here
 
         if (rootNode && is_mainthread()
             && main_manager()->tm.elapsed(threads.nodes_searched()) > 3000)
-            sync_cout << "info depth " << depth << " currmove "
-                      << UCIEngine::move(move, pos.is_chess960()) << " currmovenumber "
-                      << moveCount + thisThread->pvIdx << sync_endl;
+        {
+            main_manager()->updates.onIter(
+              {depth, UCIEngine::move(move, pos.is_chess960()), moveCount + thisThread->pvIdx});
+        }
         if (PvNode)
             (ss + 1)->pv = nullptr;
 
@@ -1871,11 +1868,10 @@ void SearchManager::check_time(Search::Worker& worker) {
         worker.threads.stop = worker.threads.abortedSearch = true;
 }
 
-std::string SearchManager::pv(const Search::Worker&     worker,
-                              const ThreadPool&         threads,
-                              const TranspositionTable& tt,
-                              Depth                     depth) const {
-    std::stringstream ss;
+void SearchManager::pv(const Search::Worker&     worker,
+                       const ThreadPool&         threads,
+                       const TranspositionTable& tt,
+                       Depth                     depth) const {
 
     const auto  nodes     = threads.nodes_searched();
     const auto& rootMoves = worker.rootMoves;
@@ -1901,29 +1897,39 @@ std::string SearchManager::pv(const Search::Worker&     worker,
         bool tb = worker.tbConfig.rootInTB && std::abs(v) <= VALUE_TB;
         v       = tb ? rootMoves[i].tbScore : v;
 
-        if (ss.rdbuf()->in_avail())  // Not at first line
-            ss << "\n";
+        std::string pv;
+        for (Move m : rootMoves[i].pv)
+            pv += UCIEngine::move(m, pos.is_chess960()) + " ";
+
+        // remove last whitespace
+        if (!pv.empty())
+            pv.pop_back();
 
-        ss << "info"
-           << " depth " << d << " seldepth " << rootMoves[i].selDepth << " multipv " << i + 1
-           << " score " << UCIEngine::to_score(v, pos);
+        auto wdl   = worker.options["UCI_ShowWDL"] ? UCIEngine::wdl(v, pos) : "";
+        auto bound = rootMoves[i].scoreLowerbound
+                     ? "lowerbound"
+                     : (rootMoves[i].scoreUpperbound ? "upperbound" : "");
 
-        if (worker.options["UCI_ShowWDL"])
-            ss << UCIEngine::wdl(v, pos);
+        InfoFull info;
+
+        info.depth    = d;
+        info.selDepth = rootMoves[i].selDepth;
+        info.multiPV  = i + 1;
+        info.score    = {v, pos};
+        info.wdl      = wdl;
 
         if (i == pvIdx && !tb && updated)  // tablebase- and previous-scores are exact
-            ss << (rootMoves[i].scoreLowerbound
-                     ? " lowerbound"
-                     : (rootMoves[i].scoreUpperbound ? " upperbound" : ""));
+            info.bound = bound;
 
-        ss << " nodes " << nodes << " nps " << nodes * 1000 / time << " hashfull " << tt.hashfull()
-           << " tbhits " << tbHits << " time " << time << " pv";
+        info.timeMs   = time;
+        info.nodes    = nodes;
+        info.nps      = nodes * 1000 / time;
+        info.tbHits   = tbHits;
+        info.pv       = pv;
+        info.hashfull = tt.hashfull();
 
-        for (Move m : rootMoves[i].pv)
-            ss << " " << UCIEngine::move(m, pos.is_chess960());
+        updates.onUpdateFull(info);
     }
-
-    return ss.str();
 }
 
 // Called in case we have no ponder move before exiting the search,
index 22f75ffd4d8fb9c4c857d857f7f63d29e64083e8..d1464840310255f87ca58e372f21b8c63ea67c2c 100644 (file)
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <functional>
 #include <memory>
-#include <string>
+#include <string_view>
 #include <vector>
 
 #include "misc.h"
 #include "movepick.h"
 #include "position.h"
+#include "score.h"
 #include "syzygy/tbprobe.h"
 #include "timeman.h"
 #include "types.h"
@@ -139,7 +141,6 @@ struct SharedState {
         tt(transpositionTable),
         networks(nets) {}
 
-
     const OptionsMap&           options;
     ThreadPool&                 threads;
     TranspositionTable&         tt;
@@ -156,16 +157,56 @@ class ISearchManager {
     virtual void check_time(Search::Worker&) = 0;
 };
 
+struct InfoShort {
+    int   depth;
+    Score score;
+};
+
+struct InfoFull: InfoShort {
+    int              selDepth;
+    size_t           multiPV;
+    std::string_view wdl;
+    std::string_view bound;
+    size_t           timeMs;
+    size_t           nodes;
+    size_t           nps;
+    size_t           tbHits;
+    std::string_view pv;
+    int              hashfull;
+};
+
+struct InfoIteration {
+    int              depth;
+    std::string_view currmove;
+    size_t           currmovenumber;
+};
+
 // SearchManager manages the search from the main thread. It is responsible for
 // keeping track of the time, and storing data strictly related to the main thread.
 class SearchManager: public ISearchManager {
    public:
+    using UpdateShort    = std::function<void(const InfoShort&)>;
+    using UpdateFull     = std::function<void(const InfoFull&)>;
+    using UpdateIter     = std::function<void(const InfoIteration&)>;
+    using UpdateBestmove = std::function<void(std::string_view, std::string_view)>;
+
+    struct UpdateContext {
+        UpdateShort    onUpdateNoMoves;
+        UpdateFull     onUpdateFull;
+        UpdateIter     onIter;
+        UpdateBestmove onBestmove;
+    };
+
+
+    SearchManager(const UpdateContext& updateContext) :
+        updates(updateContext) {}
+
     void check_time(Search::Worker& worker) override;
 
-    std::string pv(const Search::Worker&     worker,
-                   const ThreadPool&         threads,
-                   const TranspositionTable& tt,
-                   Depth                     depth) const;
+    void pv(const Search::Worker&     worker,
+            const ThreadPool&         threads,
+            const TranspositionTable& tt,
+            Depth                     depth) const;
 
     Stockfish::TimeManagement tm;
     int                       callsCnt;
@@ -178,6 +219,8 @@ class SearchManager: public ISearchManager {
     bool                 stopOnPonderhit;
 
     size_t id;
+
+    const UpdateContext& updates;
 };
 
 class NullSearchManager: public ISearchManager {
index 90add4ad0594b18458be44b213c4752cbcae3c41..85a2bcbb167708446369e6bf9244482331f30d31 100644 (file)
@@ -119,7 +119,8 @@ uint64_t ThreadPool::tb_hits() const { return accumulate(&Search::Worker::tbHits
 // Creates/destroys threads to match the requested number.
 // Created and launched threads will immediately go to sleep in idle_loop.
 // Upon resizing, threads are recreated to allow for binding if necessary.
-void ThreadPool::set(Search::SharedState sharedState) {
+void ThreadPool::set(Search::SharedState                         sharedState,
+                     const Search::SearchManager::UpdateContext& updateContext) {
 
     if (threads.size() > 0)  // destroy any existing thread(s)
     {
@@ -133,14 +134,15 @@ void ThreadPool::set(Search::SharedState sharedState) {
 
     if (requested > 0)  // create new thread(s)
     {
-        threads.push_back(new Thread(
-          sharedState, std::unique_ptr<Search::ISearchManager>(new Search::SearchManager()), 0));
-
+        auto manager = std::make_unique<Search::SearchManager>(updateContext);
+        threads.push_back(new Thread(sharedState, std::move(manager), 0));
 
         while (threads.size() < requested)
-            threads.push_back(new Thread(
-              sharedState, std::unique_ptr<Search::ISearchManager>(new Search::NullSearchManager()),
-              threads.size()));
+        {
+            auto null_manager = std::make_unique<Search::NullSearchManager>();
+            threads.push_back(new Thread(sharedState, std::move(null_manager), threads.size()));
+        }
+
         clear();
 
         main_thread()->wait_for_search_finished();
index 81fcc72a7ee88fcfa4de2f83418c7c09a98651ff..223652aec99e340ff4657be1212d2e9803d22479 100644 (file)
@@ -82,7 +82,7 @@ class ThreadPool {
 
     void start_thinking(const OptionsMap&, Position&, StateListPtr&, Search::LimitsType);
     void clear();
-    void set(Search::SharedState);
+    void set(Search::SharedState, const Search::SearchManager::UpdateContext&);
 
     Search::SearchManager* main_manager();
     Thread*                main_thread() const { return threads.front(); }
index ed23c00a49df8cf7c12010b7ead583f24d2f6b24..d6936d38b234a91ccd8ad6f526a83fbe2be54e82 100644 (file)
 #include "uci.h"
 
 #include <algorithm>
-#include <cassert>
 #include <cctype>
 #include <cmath>
 #include <cstdint>
-#include <cstdlib>
 #include <deque>
 #include <memory>
 #include <optional>
 #include <sstream>
+#include <string_view>
 #include <utility>
 #include <vector>
 
 #include "engine.h"
 #include "evaluate.h"
 #include "movegen.h"
-#include "nnue/network.h"
-#include "nnue/nnue_common.h"
-#include "perft.h"
 #include "position.h"
+#include "score.h"
 #include "search.h"
 #include "syzygy/tbprobe.h"
 #include "types.h"
@@ -49,6 +46,13 @@ namespace Stockfish {
 constexpr auto StartFEN  = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
 constexpr int  MaxHashMB = Is64Bit ? 33554432 : 2048;
 
+template<typename... Ts>
+struct overload: Ts... {
+    using Ts::operator()...;
+};
+
+template<typename... Ts>
+overload(Ts...) -> overload<Ts...>;
 
 UCIEngine::UCIEngine(int argc, char** argv) :
     engine(argv[0]),
@@ -81,6 +85,12 @@ UCIEngine::UCIEngine(int argc, char** argv) :
     options["EvalFileSmall"] << Option(EvalFileDefaultNameSmall,
                                        [this](const Option& o) { engine.load_small_network(o); });
 
+
+    engine.set_on_iter([](const auto& i) { on_iter(i); });
+    engine.set_on_update_no_moves([](const auto& i) { on_update_no_moves(i); });
+    engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
+    engine.set_on_bestmove([](const auto& bm, const auto& p) { on_bestmove(bm, p); });
+
     engine.load_networks();
     engine.resize_threads();
     engine.search_clear();  // After threads are up
@@ -221,6 +231,13 @@ void UCIEngine::go(Position& pos, std::istringstream& is) {
 void UCIEngine::bench(Position& pos, std::istream& args) {
     std::string token;
     uint64_t    num, nodes = 0, cnt = 1;
+    uint64_t    nodesSearched = 0;
+    const auto& options       = engine.get_options();
+
+    engine.set_on_update_full([&](const auto& i) {
+        nodesSearched = i.nodes;
+        on_update_full(i, options["UCI_ShowWDL"]);
+    });
 
     std::vector<std::string> list = setup_bench(pos, args);
 
@@ -242,7 +259,8 @@ void UCIEngine::bench(Position& pos, std::istream& args) {
             {
                 go(pos, is);
                 engine.wait_for_search_finished();
-                nodes += engine.nodes_searched();
+                nodes += nodesSearched;
+                nodesSearched = 0;
             }
             else
                 engine.trace_eval();
@@ -265,6 +283,9 @@ void UCIEngine::bench(Position& pos, std::istream& args) {
     std::cerr << "\n==========================="
               << "\nTotal time (ms) : " << elapsed << "\nNodes searched  : " << nodes
               << "\nNodes/second    : " << 1000 * nodes / elapsed << std::endl;
+
+    // reset callback, to not capture a dangling reference to nodesSearched
+    engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
 }
 
 
@@ -335,22 +356,22 @@ int win_rate_model(Value v, const Position& pos) {
 }
 }
 
-std::string UCIEngine::to_score(Value v, const Position& pos) {
-    assert(-VALUE_INFINITE < v && v < VALUE_INFINITE);
-
-    std::stringstream ss;
-
-    if (std::abs(v) < VALUE_TB_WIN_IN_MAX_PLY)
-        ss << "cp " << to_cp(v, pos);
-    else if (std::abs(v) <= VALUE_TB)
-    {
-        const int ply = VALUE_TB - std::abs(v);  // recompute ss->ply
-        ss << "cp " << (v > 0 ? 20000 - ply : -20000 + ply);
-    }
-    else
-        ss << "mate " << (v > 0 ? VALUE_MATE - v + 1 : -VALUE_MATE - v) / 2;
-
-    return ss.str();
+std::string UCIEngine::format_score(const Score& s) {
+    constexpr int TB_CP = 20000;
+    const auto    format =
+      overload{[](Score::Mate mate) -> std::string {
+                   auto m = (mate.plies > 0 ? (mate.plies + 1) : -mate.plies) / 2;
+                   return std::string("mate ") + std::to_string(m);
+               },
+               [](Score::TBWin tb) -> std::string {
+                   return std::string("cp ")
+                        + std::to_string((tb.plies > 0 ? TB_CP - tb.plies : -TB_CP + tb.plies));
+               },
+               [](Score::InternalUnits units) -> std::string {
+                   return std::string("cp ") + std::to_string(units.value);
+               }};
+
+    return s.visit(format);
 }
 
 // Turns a Value to an integer centipawn number,
@@ -414,4 +435,51 @@ Move UCIEngine::to_move(const Position& pos, std::string str) {
     return Move::none();
 }
 
+void UCIEngine::on_update_no_moves(const Engine::InfoShort& info) {
+    sync_cout << "info depth" << info.depth << " score " << format_score(info.score) << sync_endl;
+}
+
+void UCIEngine::on_update_full(const Engine::InfoFull& info, bool showWDL) {
+    std::stringstream ss;
+
+    ss << "info";
+    ss << " depth " << info.depth                 //
+       << " seldepth " << info.selDepth           //
+       << " multipv " << info.multiPV             //
+       << " score " << format_score(info.score);  //
+
+    if (showWDL)
+        ss << " wdl " << info.wdl;
+
+    if (!info.bound.empty())
+        ss << " " << info.bound;
+
+    ss << " nodes " << info.nodes        //
+       << " nps " << info.nps            //
+       << " hashfull " << info.hashfull  //
+       << " tbhits " << info.tbHits      //
+       << " time " << info.timeMs        //
+       << " pv " << info.pv;             //
+
+    sync_cout << ss.str() << sync_endl;
+}
+
+void UCIEngine::on_iter(const Engine::InfoIter& info) {
+    std::stringstream ss;
+
+    ss << "info";
+    ss << " depth " << info.depth                     //
+       << " currmove " << info.currmove               //
+       << " currmovenumber " << info.currmovenumber;  //
+
+    sync_cout << ss.str() << sync_endl;
+}
+
+void UCIEngine::on_bestmove(std::string_view bestmove, std::string_view ponder) {
+    sync_cout << "bestmove " << bestmove;
+    if (!ponder.empty())
+        std::cout << " ponder " << ponder;
+    std::cout << sync_endl;
+}
+
 }  // namespace Stockfish
index c4e90b48d3580078b157b1db203f6daff675341c..fa8c57fd912af5649debc19fe854a90a3453b23d 100644 (file)
--- a/src/uci.h
+++ b/src/uci.h
 
 #include <iostream>
 #include <string>
+#include <string_view>
 
 #include "engine.h"
 #include "misc.h"
-#include "nnue/network.h"
-#include "position.h"
 #include "search.h"
-#include "thread.h"
-#include "tt.h"
-#include "ucioption.h"
 
 namespace Stockfish {
 
+class Position;
 class Move;
+class Score;
 enum Square : int;
 using Value = int;
 
@@ -44,7 +42,7 @@ class UCIEngine {
     void loop();
 
     static int         to_cp(Value v, const Position& pos);
-    static std::string to_score(Value v, const Position& pos);
+    static std::string format_score(const Score& s);
     static std::string square(Square s);
     static std::string move(Move m, bool chess960);
     static std::string wdl(Value v, const Position& pos);
@@ -52,6 +50,8 @@ class UCIEngine {
 
     static Search::LimitsType parse_limits(const Position& pos, std::istream& is);
 
+    auto& engine_options() { return engine.get_options(); }
+
    private:
     Engine      engine;
     CommandLine cli;
@@ -60,6 +60,11 @@ class UCIEngine {
     void bench(Position& pos, std::istream& args);
     void position(std::istringstream& is);
     void setoption(std::istringstream& is);
+
+    static void on_update_no_moves(const Engine::InfoShort& info);
+    static void on_update_full(const Engine::InfoFull& info, bool showWDL);
+    static void on_iter(const Engine::InfoIter& info);
+    static void on_bestmove(std::string_view bestmove, std::string_view ponder);
 };
 
 }  // namespace Stockfish