Skip to content

Commit 70880e9

Browse files
authored
Merge pull request #155 from vondele/wdlPR
Introduce skipping based on the likelihood of the game outcome
2 parents 87d2a9d + b7dd8d3 commit 70880e9

File tree

4 files changed

+82
-24
lines changed

4 files changed

+82
-24
lines changed

lib/nnue_training_data_formats.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2626

2727
#pragma once
2828

29+
#include <algorithm>
2930
#include <cstdio>
3031
#include <cassert>
3132
#include <string>
3233
#include <string_view>
3334
#include <vector>
35+
#include <cmath>
3436
#include <memory>
3537
#include <fstream>
3638
#include <cstring>
@@ -6850,6 +6852,41 @@ namespace binpack
68506852
pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling
68516853
}
68526854

6855+
// The win rate model returns the probability (per mille) of winning given an eval
6856+
// and a game-ply. The model fits rather accurately the LTC fishtest statistics.
6857+
std::tuple<double, double, double> win_rate_model() const {
6858+
6859+
// The model captures only up to 240 plies, so limit input (and rescale)
6860+
double m = std::min(240, int(ply)) / 64.0;
6861+
6862+
// Coefficients of a 3rd order polynomial fit based on fishtest data
6863+
// for two parameters needed to transform eval to the argument of a
6864+
// logistic function.
6865+
double as[] = {-3.68389304, 30.07065921, -60.52878723, 149.53378557};
6866+
double bs[] = {-2.0181857, 15.85685038, -29.83452023, 47.59078827};
6867+
double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3];
6868+
double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3];
6869+
6870+
// Transform eval to centipawns with limited range
6871+
double x = std::clamp(double(100 * score) / 208, -2000.0, 2000.0);
6872+
double w = 1.0 / (1 + std::exp((a - x) / b));
6873+
double l = 1.0 / (1 + std::exp((a + x) / b));
6874+
double d = 1.0 - w - l;
6875+
6876+
// Return win, loss, draw rate in per mille (rounded to nearest)
6877+
return std::make_tuple(w, l, d);
6878+
}
6879+
6880+
// how likely is end-game result with the current score?
6881+
double score_result_prob() const {
6882+
auto [w, l, d] = win_rate_model();
6883+
if (result > 0)
6884+
return w;
6885+
if (result < 0)
6886+
return l;
6887+
return d;
6888+
}
6889+
68536890
[[nodiscard]] bool isInCheck() const
68546891
{
68556892
return pos.isCheck();

nnue_dataset.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def get_fens(self):
6767
return strings
6868

6969
FenBatchPtr = ctypes.POINTER(FenBatch)
70-
70+
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered)
7171
create_fen_batch_stream = dll.create_fen_batch_stream
7272
create_fen_batch_stream.restype = ctypes.c_void_p
73-
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int]
73+
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool]
7474
destroy_fen_batch_stream = dll.destroy_fen_batch_stream
7575
destroy_fen_batch_stream.argtypes = [ctypes.c_void_p]
7676

@@ -87,19 +87,21 @@ def __init__(
8787
num_workers,
8888
batch_size=None,
8989
filtered=False,
90-
random_fen_skipping=0):
90+
random_fen_skipping=0,
91+
wld_filtered=False):
9192

9293
self.filename = filename.encode('utf-8')
9394
self.cyclic = cyclic
9495
self.num_workers = num_workers
9596
self.batch_size = batch_size
9697
self.filtered = filtered
98+
self.wld_filtered = wld_filtered
9799
self.random_fen_skipping = random_fen_skipping
98100

99101
if batch_size:
100-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping)
102+
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered)
101103
else:
102-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping)
104+
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered)
103105

104106
def __iter__(self):
105107
return self
@@ -131,6 +133,7 @@ def __init__(
131133
batch_size=None,
132134
filtered=False,
133135
random_fen_skipping=0,
136+
wld_filtered=False,
134137
device='cpu'):
135138

136139
self.feature_set = feature_set.encode('utf-8')
@@ -143,13 +146,14 @@ def __init__(
143146
self.num_workers = num_workers
144147
self.batch_size = batch_size
145148
self.filtered = filtered
149+
self.wld_filtered = wld_filtered
146150
self.random_fen_skipping = random_fen_skipping
147151
self.device = device
148152

149153
if batch_size:
150-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping)
154+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered)
151155
else:
152-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping)
156+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered)
153157

154158
def __iter__(self):
155159
return self
@@ -167,9 +171,11 @@ def __next__(self):
167171
def __del__(self):
168172
self.destroy_stream(self.stream)
169173

174+
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
175+
# bool filtered, int random_fen_skipping, bool wld_filtered)
170176
create_sparse_batch_stream = dll.create_sparse_batch_stream
171177
create_sparse_batch_stream.restype = ctypes.c_void_p
172-
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int]
178+
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool]
173179
destroy_sparse_batch_stream = dll.destroy_sparse_batch_stream
174180
destroy_sparse_batch_stream.argtypes = [ctypes.c_void_p]
175181

@@ -198,7 +204,7 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
198204
return b
199205

200206
class SparseBatchProvider(TrainingDataProvider):
201-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, device='cpu'):
207+
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, device='cpu'):
202208
super(SparseBatchProvider, self).__init__(
203209
feature_set,
204210
create_sparse_batch_stream,
@@ -211,10 +217,11 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
211217
batch_size,
212218
filtered,
213219
random_fen_skipping,
220+
wld_filtered,
214221
device)
215222

216223
class SparseBatchDataset(torch.utils.data.IterableDataset):
217-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, device='cpu'):
224+
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, device='cpu'):
218225
super(SparseBatchDataset).__init__()
219226
self.feature_set = feature_set
220227
self.filename = filename
@@ -223,10 +230,12 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
223230
self.num_workers = num_workers
224231
self.filtered = filtered
225232
self.random_fen_skipping = random_fen_skipping
233+
self.wld_filtered = wld_filtered
226234
self.device = device
227235

228236
def __iter__(self):
229-
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers, filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, device=self.device)
237+
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
238+
filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, wld_filtered=self.wld_filtered, device=self.device)
230239

231240
class FixedNumBatchesDataset(Dataset):
232241
def __init__(self, dataset, num_batches):

train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from pytorch_lightning import loggers as pl_loggers
1010
from torch.utils.data import DataLoader, Dataset
1111

12-
def make_data_loaders(train_filename, val_filename, feature_set, num_workers, batch_size, filtered, random_fen_skipping, main_device):
12+
def make_data_loaders(train_filename, val_filename, feature_set, num_workers, batch_size, filtered, random_fen_skipping, wld_filtered, main_device):
1313
# Epoch and validation sizes are arbitrary
1414
epoch_size = 100000000
1515
val_size = 1000000
1616
features_name = feature_set.name
1717
train_infinite = nnue_dataset.SparseBatchDataset(features_name, train_filename, batch_size, num_workers=num_workers,
18-
filtered=filtered, random_fen_skipping=random_fen_skipping, device=main_device)
18+
filtered=filtered, random_fen_skipping=random_fen_skipping, wld_filtered=wld_filtered, device=main_device)
1919
val_infinite = nnue_dataset.SparseBatchDataset(features_name, val_filename, batch_size, filtered=filtered,
20-
random_fen_skipping=random_fen_skipping, device=main_device)
20+
random_fen_skipping=random_fen_skipping, wld_filtered=wld_filtered, device=main_device)
2121
# num_workers has to be 0 for sparse, and 1 for dense
2222
# it currently cannot work in parallel mode but it shouldn't need to
2323
train = DataLoader(nnue_dataset.FixedNumBatchesDataset(train_infinite, (epoch_size + batch_size - 1) // batch_size), batch_size=None, batch_sampler=None)
@@ -36,6 +36,7 @@ def main():
3636
parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.")
3737
parser.add_argument("--smart-fen-skipping", action='store_true', dest='smart_fen_skipping_deprecated', help="If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored")
3838
parser.add_argument("--no-smart-fen-skipping", action='store_true', dest='no_smart_fen_skipping', help="If used then no smart fen skipping will be done. By default smart fen skipping is done.")
39+
parser.add_argument("--no-wld-fen-skipping", action='store_true', dest='no_wld_fen_skipping', help="If used then no wld fen skipping will be done. By default wld fen skipping is done.")
3940
parser.add_argument("--random-fen-skipping", default=3, type=int, dest='random_fen_skipping', help="skip fens randomly on average random_fen_skipping before using one.")
4041
parser.add_argument("--resume-from-model", dest='resume_from_model', help="Initializes training using the weights from the given .pt model")
4142
features.add_argparse_args(parser)
@@ -71,6 +72,7 @@ def main():
7172
print('Using batch size {}'.format(batch_size))
7273

7374
print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
75+
print('WLD fen skipping: {}'.format(not args.no_wld_fen_skipping))
7476
print('Random fen skipping: {}'.format(args.random_fen_skipping))
7577

7678
if args.threads > 0:
@@ -89,7 +91,7 @@ def main():
8991
nnue.to(device=main_device)
9092

9193
print('Using c++ data loader')
92-
train, val = make_data_loaders(args.train, args.val, feature_set, args.num_workers, batch_size, not args.no_smart_fen_skipping, args.random_fen_skipping, main_device)
94+
train, val = make_data_loaders(args.train, args.val, feature_set, args.num_workers, batch_size, not args.no_smart_fen_skipping, args.random_fen_skipping, not args.no_wld_fen_skipping, main_device)
9395

9496
trainer.fit(nnue, train, val)
9597

training_data_loader.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -808,16 +808,23 @@ struct FenBatchStream : Stream<FenBatch>
808808
std::vector<std::thread> m_workers;
809809
};
810810

811-
std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered, int random_fen_skipping)
811+
std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered, int random_fen_skipping, bool wld_filtered)
812812
{
813-
if (filtered || random_fen_skipping)
813+
if (filtered || random_fen_skipping || wld_filtered)
814814
{
815815
return [
816816
random_fen_skipping,
817817
prob = double(random_fen_skipping) / (random_fen_skipping + 1),
818-
filtered
818+
filtered,
819+
wld_filtered
819820
](const TrainingDataEntry& e){
820821

822+
auto do_wld_skip = [&]() {
823+
std::bernoulli_distribution distrib(1.0 - e.score_result_prob());
824+
auto& prng = rng::get_thread_local_rng();
825+
return distrib(prng);
826+
};
827+
821828
auto do_skip = [&]() {
822829
std::bernoulli_distribution distrib(prob);
823830
auto& prng = rng::get_thread_local_rng();
@@ -829,7 +836,7 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered,
829836
};
830837

831838
static thread_local std::mt19937 gen(std::random_device{}());
832-
return (random_fen_skipping && do_skip()) || (filtered && do_filter());
839+
return (random_fen_skipping && do_skip()) || (filtered && do_filter()) || (wld_filtered && do_wld_skip());
833840
};
834841
}
835842

@@ -896,9 +903,10 @@ extern "C" {
896903
return nullptr;
897904
}
898905

899-
EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping)
906+
// changing the signature needs matching changes in nnue_dataset.py
907+
EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered)
900908
{
901-
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping);
909+
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered);
902910

903911
return new FenBatchStream(concurrency, filename, batch_size, cyclic, skipPredicate);
904912
}
@@ -908,9 +916,11 @@ extern "C" {
908916
delete stream;
909917
}
910918

911-
EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping)
919+
// changing the signature needs matching changes in nnue_dataset.py
920+
EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
921+
bool filtered, int random_fen_skipping, bool wld_filtered)
912922
{
913-
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping);
923+
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered);
914924

915925
std::string_view feature_set(feature_set_c);
916926
if (feature_set == "HalfKP")
@@ -981,7 +991,7 @@ extern "C" {
981991

982992
int main()
983993
{
984-
auto stream = create_sparse_batch_stream("HalfKP", 4, "10m_d3_q_2.binpack", 8192, true, false, 0);
994+
auto stream = create_sparse_batch_stream("HalfKP", 4, "10m_d3_q_2.binpack", 8192, true, false, 0, false);
985995
auto t0 = std::chrono::high_resolution_clock::now();
986996
for (int i = 0; i < 1000; ++i)
987997
{

0 commit comments

Comments
 (0)