Skip to content

Commit b7dd8d3

Browse files
committed
Introduce skipping based on the likelihood of the game outcome
this intends to skip (or more correctly weight) data points that are possibly incorrectly evaluated, i.e. retain data that are more likely to be correct. Was used to train two recent SF nets: official-stockfish/Stockfish#3816 official-stockfish/Stockfish#3808 --no-wld-fen-skipping option can be used to disable the default
1 parent 87d2a9d commit b7dd8d3

File tree

4 files changed

+82
-24
lines changed

4 files changed

+82
-24
lines changed

lib/nnue_training_data_formats.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2626

2727
#pragma once
2828

29+
#include <algorithm>
2930
#include <cstdio>
3031
#include <cassert>
3132
#include <string>
3233
#include <string_view>
3334
#include <vector>
35+
#include <cmath>
3436
#include <memory>
3537
#include <fstream>
3638
#include <cstring>
@@ -6850,6 +6852,41 @@ namespace binpack
68506852
pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling
68516853
}
68526854

6855+
// The win rate model returns the probability (per mille) of winning given an eval
6856+
// and a game-ply. The model fits rather accurately the LTC fishtest statistics.
6857+
std::tuple<double, double, double> win_rate_model() const {
6858+
6859+
// The model captures only up to 240 plies, so limit input (and rescale)
6860+
double m = std::min(240, int(ply)) / 64.0;
6861+
6862+
// Coefficients of a 3rd order polynomial fit based on fishtest data
6863+
// for two parameters needed to transform eval to the argument of a
6864+
// logistic function.
6865+
double as[] = {-3.68389304, 30.07065921, -60.52878723, 149.53378557};
6866+
double bs[] = {-2.0181857, 15.85685038, -29.83452023, 47.59078827};
6867+
double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3];
6868+
double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3];
6869+
6870+
// Transform eval to centipawns with limited range
6871+
double x = std::clamp(double(100 * score) / 208, -2000.0, 2000.0);
6872+
double w = 1.0 / (1 + std::exp((a - x) / b));
6873+
double l = 1.0 / (1 + std::exp((a + x) / b));
6874+
double d = 1.0 - w - l;
6875+
6876+
// Return win, loss, draw rate in per mille (rounded to nearest)
6877+
return std::make_tuple(w, l, d);
6878+
}
6879+
6880+
// how likely is end-game result with the current score?
6881+
double score_result_prob() const {
6882+
auto [w, l, d] = win_rate_model();
6883+
if (result > 0)
6884+
return w;
6885+
if (result < 0)
6886+
return l;
6887+
return d;
6888+
}
6889+
68536890
[[nodiscard]] bool isInCheck() const
68546891
{
68556892
return pos.isCheck();

nnue_dataset.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def get_fens(self):
6767
return strings
6868

6969
FenBatchPtr = ctypes.POINTER(FenBatch)
70-
70+
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered)
7171
create_fen_batch_stream = dll.create_fen_batch_stream
7272
create_fen_batch_stream.restype = ctypes.c_void_p
73-
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int]
73+
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool]
7474
destroy_fen_batch_stream = dll.destroy_fen_batch_stream
7575
destroy_fen_batch_stream.argtypes = [ctypes.c_void_p]
7676

@@ -87,19 +87,21 @@ def __init__(
8787
num_workers,
8888
batch_size=None,
8989
filtered=False,
90-
random_fen_skipping=0):
90+
random_fen_skipping=0,
91+
wld_filtered=False):
9192

9293
self.filename = filename.encode('utf-8')
9394
self.cyclic = cyclic
9495
self.num_workers = num_workers
9596
self.batch_size = batch_size
9697
self.filtered = filtered
98+
self.wld_filtered = wld_filtered
9799
self.random_fen_skipping = random_fen_skipping
98100

99101
if batch_size:
100-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping)
102+
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered)
101103
else:
102-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping)
104+
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered)
103105

104106
def __iter__(self):
105107
return self
@@ -131,6 +133,7 @@ def __init__(
131133
batch_size=None,
132134
filtered=False,
133135
random_fen_skipping=0,
136+
wld_filtered=False,
134137
device='cpu'):
135138

136139
self.feature_set = feature_set.encode('utf-8')
@@ -143,13 +146,14 @@ def __init__(
143146
self.num_workers = num_workers
144147
self.batch_size = batch_size
145148
self.filtered = filtered
149+
self.wld_filtered = wld_filtered
146150
self.random_fen_skipping = random_fen_skipping
147151
self.device = device
148152

149153
if batch_size:
150-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping)
154+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered)
151155
else:
152-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping)
156+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered)
153157

154158
def __iter__(self):
155159
return self
@@ -167,9 +171,11 @@ def __next__(self):
167171
def __del__(self):
168172
self.destroy_stream(self.stream)
169173

174+
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
175+
# bool filtered, int random_fen_skipping, bool wld_filtered)
170176
create_sparse_batch_stream = dll.create_sparse_batch_stream
171177
create_sparse_batch_stream.restype = ctypes.c_void_p
172-
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int]
178+
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool]
173179
destroy_sparse_batch_stream = dll.destroy_sparse_batch_stream
174180
destroy_sparse_batch_stream.argtypes = [ctypes.c_void_p]
175181

@@ -198,7 +204,7 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
198204
return b
199205

200206
class SparseBatchProvider(TrainingDataProvider):
201-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, device='cpu'):
207+
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, device='cpu'):
202208
super(SparseBatchProvider, self).__init__(
203209
feature_set,
204210
create_sparse_batch_stream,
@@ -211,10 +217,11 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
211217
batch_size,
212218
filtered,
213219
random_fen_skipping,
220+
wld_filtered,
214221
device)
215222

216223
class SparseBatchDataset(torch.utils.data.IterableDataset):
217-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, device='cpu'):
224+
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, device='cpu'):
218225
super(SparseBatchDataset).__init__()
219226
self.feature_set = feature_set
220227
self.filename = filename
@@ -223,10 +230,12 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
223230
self.num_workers = num_workers
224231
self.filtered = filtered
225232
self.random_fen_skipping = random_fen_skipping
233+
self.wld_filtered = wld_filtered
226234
self.device = device
227235

228236
def __iter__(self):
229-
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers, filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, device=self.device)
237+
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
238+
filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, wld_filtered=self.wld_filtered, device=self.device)
230239

231240
class FixedNumBatchesDataset(Dataset):
232241
def __init__(self, dataset, num_batches):

train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from pytorch_lightning import loggers as pl_loggers
1010
from torch.utils.data import DataLoader, Dataset
1111

12-
def make_data_loaders(train_filename, val_filename, feature_set, num_workers, batch_size, filtered, random_fen_skipping, main_device):
12+
def make_data_loaders(train_filename, val_filename, feature_set, num_workers, batch_size, filtered, random_fen_skipping, wld_filtered, main_device):
1313
# Epoch and validation sizes are arbitrary
1414
epoch_size = 100000000
1515
val_size = 1000000
1616
features_name = feature_set.name
1717
train_infinite = nnue_dataset.SparseBatchDataset(features_name, train_filename, batch_size, num_workers=num_workers,
18-
filtered=filtered, random_fen_skipping=random_fen_skipping, device=main_device)
18+
filtered=filtered, random_fen_skipping=random_fen_skipping, wld_filtered=wld_filtered, device=main_device)
1919
val_infinite = nnue_dataset.SparseBatchDataset(features_name, val_filename, batch_size, filtered=filtered,
20-
random_fen_skipping=random_fen_skipping, device=main_device)
20+
random_fen_skipping=random_fen_skipping, wld_filtered=wld_filtered, device=main_device)
2121
# num_workers has to be 0 for sparse, and 1 for dense
2222
# it currently cannot work in parallel mode but it shouldn't need to
2323
train = DataLoader(nnue_dataset.FixedNumBatchesDataset(train_infinite, (epoch_size + batch_size - 1) // batch_size), batch_size=None, batch_sampler=None)
@@ -36,6 +36,7 @@ def main():
3636
parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.")
3737
parser.add_argument("--smart-fen-skipping", action='store_true', dest='smart_fen_skipping_deprecated', help="If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored")
3838
parser.add_argument("--no-smart-fen-skipping", action='store_true', dest='no_smart_fen_skipping', help="If used then no smart fen skipping will be done. By default smart fen skipping is done.")
39+
parser.add_argument("--no-wld-fen-skipping", action='store_true', dest='no_wld_fen_skipping', help="If used then no wld fen skipping will be done. By default wld fen skipping is done.")
3940
parser.add_argument("--random-fen-skipping", default=3, type=int, dest='random_fen_skipping', help="skip fens randomly on average random_fen_skipping before using one.")
4041
parser.add_argument("--resume-from-model", dest='resume_from_model', help="Initializes training using the weights from the given .pt model")
4142
features.add_argparse_args(parser)
@@ -71,6 +72,7 @@ def main():
7172
print('Using batch size {}'.format(batch_size))
7273

7374
print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
75+
print('WLD fen skipping: {}'.format(not args.no_wld_fen_skipping))
7476
print('Random fen skipping: {}'.format(args.random_fen_skipping))
7577

7678
if args.threads > 0:
@@ -89,7 +91,7 @@ def main():
8991
nnue.to(device=main_device)
9092

9193
print('Using c++ data loader')
92-
train, val = make_data_loaders(args.train, args.val, feature_set, args.num_workers, batch_size, not args.no_smart_fen_skipping, args.random_fen_skipping, main_device)
94+
train, val = make_data_loaders(args.train, args.val, feature_set, args.num_workers, batch_size, not args.no_smart_fen_skipping, args.random_fen_skipping, not args.no_wld_fen_skipping, main_device)
9395

9496
trainer.fit(nnue, train, val)
9597

training_data_loader.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -808,16 +808,23 @@ struct FenBatchStream : Stream<FenBatch>
808808
std::vector<std::thread> m_workers;
809809
};
810810

811-
std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered, int random_fen_skipping)
811+
std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered, int random_fen_skipping, bool wld_filtered)
812812
{
813-
if (filtered || random_fen_skipping)
813+
if (filtered || random_fen_skipping || wld_filtered)
814814
{
815815
return [
816816
random_fen_skipping,
817817
prob = double(random_fen_skipping) / (random_fen_skipping + 1),
818-
filtered
818+
filtered,
819+
wld_filtered
819820
](const TrainingDataEntry& e){
820821

822+
auto do_wld_skip = [&]() {
823+
std::bernoulli_distribution distrib(1.0 - e.score_result_prob());
824+
auto& prng = rng::get_thread_local_rng();
825+
return distrib(prng);
826+
};
827+
821828
auto do_skip = [&]() {
822829
std::bernoulli_distribution distrib(prob);
823830
auto& prng = rng::get_thread_local_rng();
@@ -829,7 +836,7 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered,
829836
};
830837

831838
static thread_local std::mt19937 gen(std::random_device{}());
832-
return (random_fen_skipping && do_skip()) || (filtered && do_filter());
839+
return (random_fen_skipping && do_skip()) || (filtered && do_filter()) || (wld_filtered && do_wld_skip());
833840
};
834841
}
835842

@@ -896,9 +903,10 @@ extern "C" {
896903
return nullptr;
897904
}
898905

899-
EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping)
906+
// changing the signature needs matching changes in nnue_dataset.py
907+
EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered)
900908
{
901-
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping);
909+
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered);
902910

903911
return new FenBatchStream(concurrency, filename, batch_size, cyclic, skipPredicate);
904912
}
@@ -908,9 +916,11 @@ extern "C" {
908916
delete stream;
909917
}
910918

911-
EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping)
919+
// changing the signature needs matching changes in nnue_dataset.py
920+
EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
921+
bool filtered, int random_fen_skipping, bool wld_filtered)
912922
{
913-
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping);
923+
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered);
914924

915925
std::string_view feature_set(feature_set_c);
916926
if (feature_set == "HalfKP")
@@ -981,7 +991,7 @@ extern "C" {
981991

982992
int main()
983993
{
984-
auto stream = create_sparse_batch_stream("HalfKP", 4, "10m_d3_q_2.binpack", 8192, true, false, 0);
994+
auto stream = create_sparse_batch_stream("HalfKP", 4, "10m_d3_q_2.binpack", 8192, true, false, 0, false);
985995
auto t0 = std::chrono::high_resolution_clock::now();
986996
for (int i = 0; i < 1000; ++i)
987997
{

0 commit comments

Comments
 (0)