1010import warnings
1111from pathlib import Path
1212from threading import Thread , Semaphore , Lock
13+
1314from typing import (
1415 Any ,
1516 Callable ,
@@ -223,8 +224,11 @@ def __init__(
223224 )
224225 self ._thread = None
225226 self ._max_concurrent_job_launch = 5
226- self ._stats_lock = Lock ()
227227 self ._db_lock = Lock ()
228+ self .jobs_done = []
229+ self .jobs_error = []
230+ self .jobs_cancel = []
231+ self .jobs_prolonged = []
228232
229233 def add_backend (
230234 self ,
@@ -523,7 +527,7 @@ def _job_update_loop(self, job_db: JobDatabaseInterface, start_job: Callable[[],
523527 self ._handle_completed_jobs (stats )
524528
525529
526- def _launch_pending_jobs (self , job_db , start_job , stats ):
530+ def _launch_pending_jobs (self , job_db : JobDatabaseInterface , start_job : Callable [[], BatchJob ], stats : Optional [ dict ] = None ):
527531 """Launches jobs concurrently with controlled threading."""
528532 not_started = job_db .get_by_status (statuses = ["not_started" ], max = 200 ).copy ()
529533 if not not_started .empty :
@@ -536,7 +540,7 @@ def _launch_pending_jobs(self, job_db, start_job, stats):
536540 jobs_to_add = self ._get_jobs_to_launch (not_started , per_backend )
537541 self ._run_job_threads (jobs_to_add , start_job , not_started , stats , job_db )
538542
539- def _get_jobs_to_launch (self , not_started , per_backend ) :
543+ def _get_jobs_to_launch (self , not_started : pd . DataFrame , per_backend : Dict [ str , int ]) -> list [ tuple [ Any , str ]] :
540544 """Determines which jobs to launch based on backend availability."""
541545 jobs_to_add = []
542546 total_added = 0
@@ -554,7 +558,7 @@ def _get_jobs_to_launch(self, not_started, per_backend):
554558
555559 return jobs_to_add
556560
557- def _run_job_threads (self , jobs_to_add , start_job , not_started , stats , job_db ) :
561+ def _run_job_threads (self , jobs_to_add : list [ tuple [ Any , str ]], start_job : Callable [[], 'BatchJob' ], not_started : pd . DataFrame , stats : Dict [ str , int ], job_db : 'JobDatabaseInterface' ) -> None :
558562 """Manages threading for job launching."""
559563 semaphore = Semaphore (self ._max_concurrent_job_launch )
560564 threads = []
@@ -563,13 +567,12 @@ def job_worker(i, backend_name):
563567 with semaphore :
564568 try :
565569 self ._launch_job (start_job , not_started , i , backend_name , stats )
570+ stats ["job launch" ] += 1
566571
567572 with self ._db_lock :
568573 job_db .persist (not_started .loc [i : i + 1 ])
569574
570- with self ._stats_lock :
571- stats ["job launch" ] += 1
572- stats ["job_db persist" ] += 1
575+ stats ["job_db persist" ] += 1
573576 except Exception as e :
574577 _log .error (f"Job launch failed for index { i } : { e } " )
575578
@@ -580,6 +583,7 @@ def job_worker(i, backend_name):
580583
581584 for thread in threads :
582585 thread .join ()
586+
583587
584588 def _handle_completed_jobs (self ,stats ):
585589 """Processes completed, canceled, and errored jobs."""
@@ -762,11 +766,6 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
762766
763767 active = job_db .get_by_status (statuses = ["created" , "queued" , "running" ]).copy ()
764768
765- self .jobs_done = []
766- self .jobs_error = []
767- self .jobs_cancel = []
768- self .jobs_prolonged = []
769-
770769 for i in active .index :
771770 job_id = active .loc [i , "id" ]
772771 backend_name = active .loc [i , "backend_name" ]
0 commit comments