55import logging
66import time
77from datetime import datetime , timezone
8- from typing import Dict , Optional , Sequence , Tuple
8+ from typing import Callable , Dict , Optional , Sequence , Tuple
99
1010import clickhouse_driver
1111import numpy
@@ -215,9 +215,14 @@ class ClickHouseRun(Run):
215215 """Represents an MCMC run stored in ClickHouse."""
216216
217217 def __init__ (
218- self , meta : RunMeta , * , created_at : datetime = None , client : clickhouse_driver .Client
218+ self ,
219+ meta : RunMeta ,
220+ * ,
221+ created_at : datetime = None ,
222+ client_fn : Callable [[], clickhouse_driver .Client ],
219223 ) -> None :
220- self ._client = client
224+ self ._client_fn = client_fn
225+ self ._client = client_fn ()
221226 if created_at is None :
222227 created_at = datetime .now ().astimezone (timezone .utc )
223228 self .created_at = created_at
@@ -229,7 +234,7 @@ def __init__(
229234 def init_chain (self , chain_number : int ) -> ClickHouseChain :
230235 cmeta = ChainMeta (self .meta .rid , chain_number )
231236 create_chain_table (self ._client , cmeta , self .meta )
232- chain = ClickHouseChain (cmeta , self .meta , client = self ._client )
237+ chain = ClickHouseChain (cmeta , self .meta , client = self ._client_fn () )
233238 if self ._chains is None :
234239 self ._chains = []
235240 self ._chains .append (chain )
@@ -245,16 +250,39 @@ def get_chains(self) -> Tuple[ClickHouseChain]:
245250 chains = []
246251 for (cid ,) in self ._client .execute (f"SHOW TABLES LIKE '{ self .meta .rid } %'" ):
247252 cm = ChainMeta (self .meta .rid , int (cid .split ("_" )[- 1 ]))
248- chains .append (ClickHouseChain (cm , self .meta , client = self ._client ))
253+ chains .append (ClickHouseChain (cm , self .meta , client = self ._client_fn () ))
249254 return tuple (chains )
250255
251256
252257class ClickHouseBackend (Backend ):
253258 """A backend to store samples in a ClickHouse database."""
254259
255- def __init__ (self , client : clickhouse_driver .Client ) -> None :
260+ def __init__ (
261+ self ,
262+ client : clickhouse_driver .Client = None ,
263+ client_fn : Callable [[], clickhouse_driver .Client ] = None ,
264+ ):
265+ """Create a ClickHouse backend around a database client.
266+
267+ Parameters
268+ ----------
269+ client : clickhouse_driver.Client
270+ One client to use for all runs and chains.
271+ client_fn : callable
272+ A function to create database clients.
273+ Use this in multithreading scenarios to get higher insert performance.
274+ """
275+ if client is None and client_fn is None :
276+ raise ValueError ("Either a `client` or a `client_fn` must be provided." )
277+ self ._client_fn = client_fn
256278 self ._client = client
257- create_runs_table (client )
279+
280+ if client_fn is None :
281+ self ._client_fn = lambda : client
282+ if client is None :
283+ self ._client = self ._client_fn ()
284+
285+ create_runs_table (self ._client )
258286 super ().__init__ ()
259287
260288 def init_run (self , meta : RunMeta ) -> ClickHouseRun :
@@ -271,7 +299,7 @@ def init_run(self, meta: RunMeta) -> ClickHouseRun:
271299 proto = base64 .encodebytes (bytes (meta )).decode ("ascii" ),
272300 )
273301 self ._client .execute (query , [params ])
274- return ClickHouseRun (meta , client = self ._client , created_at = created_at )
302+ return ClickHouseRun (meta , client_fn = self ._client_fn , created_at = created_at )
275303
276304 def get_runs (self ) -> pandas .DataFrame :
277305 df = self ._client .query_dataframe (
@@ -295,5 +323,5 @@ def get_run(self, rid: str) -> ClickHouseRun:
295323 data = base64 .decodebytes (rows [0 ][2 ].encode ("ascii" ))
296324 meta = RunMeta ().parse (data )
297325 return ClickHouseRun (
298- meta , client = self ._client , created_at = rows [0 ][1 ].replace (tzinfo = timezone .utc )
326+ meta , client_fn = self ._client_fn , created_at = rows [0 ][1 ].replace (tzinfo = timezone .utc )
299327 )
0 commit comments