@@ -67,10 +67,10 @@ def get_fens(self):
6767 return strings
6868
6969FenBatchPtr = ctypes .POINTER (FenBatch )
70-
70+ # EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered)
7171create_fen_batch_stream = dll .create_fen_batch_stream
7272create_fen_batch_stream .restype = ctypes .c_void_p
73- create_fen_batch_stream .argtypes = [ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int ]
73+ create_fen_batch_stream .argtypes = [ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes . c_bool ]
7474destroy_fen_batch_stream = dll .destroy_fen_batch_stream
7575destroy_fen_batch_stream .argtypes = [ctypes .c_void_p ]
7676
@@ -87,19 +87,21 @@ def __init__(
8787 num_workers ,
8888 batch_size = None ,
8989 filtered = False ,
90- random_fen_skipping = 0 ):
90+ random_fen_skipping = 0 ,
91+ wld_filtered = False ):
9192
9293 self .filename = filename .encode ('utf-8' )
9394 self .cyclic = cyclic
9495 self .num_workers = num_workers
9596 self .batch_size = batch_size
9697 self .filtered = filtered
98+ self .wld_filtered = wld_filtered
9799 self .random_fen_skipping = random_fen_skipping
98100
99101 if batch_size :
100- self .stream = create_fen_batch_stream (self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping )
102+ self .stream = create_fen_batch_stream (self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered )
101103 else :
102- self .stream = create_fen_batch_stream (self .num_workers , self .filename , cyclic , filtered , random_fen_skipping )
104+ self .stream = create_fen_batch_stream (self .num_workers , self .filename , cyclic , filtered , random_fen_skipping , wld_filtered )
103105
104106 def __iter__ (self ):
105107 return self
@@ -131,6 +133,7 @@ def __init__(
131133 batch_size = None ,
132134 filtered = False ,
133135 random_fen_skipping = 0 ,
136+ wld_filtered = False ,
134137 device = 'cpu' ):
135138
136139 self .feature_set = feature_set .encode ('utf-8' )
@@ -143,13 +146,14 @@ def __init__(
143146 self .num_workers = num_workers
144147 self .batch_size = batch_size
145148 self .filtered = filtered
149+ self .wld_filtered = wld_filtered
146150 self .random_fen_skipping = random_fen_skipping
147151 self .device = device
148152
149153 if batch_size :
150- self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping )
154+ self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered )
151155 else :
152- self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , cyclic , filtered , random_fen_skipping )
156+ self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , cyclic , filtered , random_fen_skipping , wld_filtered )
153157
154158 def __iter__ (self ):
155159 return self
@@ -167,9 +171,11 @@ def __next__(self):
167171 def __del__ (self ):
168172 self .destroy_stream (self .stream )
169173
174+ # EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
175+ # bool filtered, int random_fen_skipping, bool wld_filtered)
170176create_sparse_batch_stream = dll .create_sparse_batch_stream
171177create_sparse_batch_stream .restype = ctypes .c_void_p
172- create_sparse_batch_stream .argtypes = [ctypes .c_char_p , ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int ]
178+ create_sparse_batch_stream .argtypes = [ctypes .c_char_p , ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes . c_bool ]
173179destroy_sparse_batch_stream = dll .destroy_sparse_batch_stream
174180destroy_sparse_batch_stream .argtypes = [ctypes .c_void_p ]
175181
@@ -198,7 +204,7 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
198204 return b
199205
200206class SparseBatchProvider (TrainingDataProvider ):
201- def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , device = 'cpu' ):
207+ def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , device = 'cpu' ):
202208 super (SparseBatchProvider , self ).__init__ (
203209 feature_set ,
204210 create_sparse_batch_stream ,
@@ -211,10 +217,11 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
211217 batch_size ,
212218 filtered ,
213219 random_fen_skipping ,
220+ wld_filtered ,
214221 device )
215222
216223class SparseBatchDataset (torch .utils .data .IterableDataset ):
217- def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , device = 'cpu' ):
224+ def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , device = 'cpu' ):
218225 super (SparseBatchDataset ).__init__ ()
219226 self .feature_set = feature_set
220227 self .filename = filename
@@ -223,10 +230,12 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
223230 self .num_workers = num_workers
224231 self .filtered = filtered
225232 self .random_fen_skipping = random_fen_skipping
233+ self .wld_filtered = wld_filtered
226234 self .device = device
227235
228236 def __iter__ (self ):
229- return SparseBatchProvider (self .feature_set , self .filename , self .batch_size , cyclic = self .cyclic , num_workers = self .num_workers , filtered = self .filtered , random_fen_skipping = self .random_fen_skipping , device = self .device )
237+ return SparseBatchProvider (self .feature_set , self .filename , self .batch_size , cyclic = self .cyclic , num_workers = self .num_workers ,
238+ filtered = self .filtered , random_fen_skipping = self .random_fen_skipping , wld_filtered = self .wld_filtered , device = self .device )
230239
231240class FixedNumBatchesDataset (Dataset ):
232241 def __init__ (self , dataset , num_batches ):
0 commit comments