Skip to content

Commit 1db3b39

Browse files
authored
Merge pull request #10 from GeneralUserModels/batcher_upgrade
Move all processing of observations at a batch-level
2 parents dab4d92 + 29bca23 commit 1db3b39

File tree

9 files changed

+121
-310
lines changed

9 files changed

+121
-310
lines changed

gum/batcher.py

Lines changed: 42 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,35 @@
1-
import asyncio
2-
import json
31
import logging
4-
import os
5-
from datetime import datetime, timezone, timedelta
6-
from typing import List, Dict, Any, Optional
7-
from dataclasses import dataclass, asdict
2+
from datetime import datetime, timezone
3+
from typing import List, Optional, Dict, Any
84
from pathlib import Path
9-
10-
@dataclass
11-
class BatchedObservation:
12-
"""Represents a batched observation waiting for processing."""
13-
id: str
14-
observer_name: str
15-
content: str
16-
content_type: str
17-
timestamp: datetime
18-
processed: bool = False
5+
import uuid
6+
from persistqueue import Queue
197

208
class ObservationBatcher:
21-
"""Handles batching of observations to reduce API calls."""
9+
"""A persistent queue for batching observations to reduce API calls."""
2210

23-
def __init__(self, data_directory: str, batch_interval_hours: float = 1, max_batch_size: int = 50):
11+
def __init__(self, data_directory: str, batch_interval_minutes: float = 2, max_batch_size: int = 50):
2412
self.data_directory = Path(data_directory)
25-
self.batch_interval_hours = batch_interval_hours
13+
self.batch_interval_minutes = batch_interval_minutes
2614
self.max_batch_size = max_batch_size
27-
self.batch_file = self.data_directory / "batches" / "pending_observations.json"
28-
self.batch_file.parent.mkdir(exist_ok=True)
15+
16+
# Create persistent queue backed by SQLite
17+
queue_dir = self.data_directory / "batches"
18+
queue_dir.mkdir(parents=True, exist_ok=True)
19+
self._queue = Queue(path=str(queue_dir / "queue"))
2920

3021
self.logger = logging.getLogger("gum.batcher")
31-
self._pending_observations: List[BatchedObservation] = []
32-
self._batch_task: Optional[asyncio.Task] = None
3322

3423
async def start(self):
3524
"""Start the batching system."""
36-
self._load_pending_observations()
37-
self._batch_task = asyncio.create_task(self._batch_loop())
38-
self.logger.info(f"Started batcher with {len(self._pending_observations)} pending observations")
25+
self.logger.info(f"Started batcher with {self._queue.qsize()} items in queue")
3926

4027
async def stop(self):
4128
"""Stop the batching system."""
42-
if self._batch_task:
43-
self._batch_task.cancel()
44-
try:
45-
await self._batch_task
46-
except asyncio.CancelledError:
47-
pass
48-
self._save_pending_observations()
4929
self.logger.info("Stopped batcher")
5030

51-
def add_observation(self, observer_name: str, content: str, content_type: str) -> str:
52-
"""Add an observation to the batch queue.
31+
def push(self, observer_name: str, content: str, content_type: str) -> str:
32+
"""Push an observation onto the queue.
5333
5434
Args:
5535
observer_name: Name of the observer
@@ -59,116 +39,42 @@ def add_observation(self, observer_name: str, content: str, content_type: str) -
5939
Returns:
6040
str: Observation ID
6141
"""
62-
import uuid
63-
64-
observation = BatchedObservation(
65-
id=str(uuid.uuid4()),
66-
observer_name=observer_name,
67-
content=content,
68-
content_type=content_type,
69-
timestamp=datetime.now(timezone.utc)
70-
)
71-
72-
self._pending_observations.append(observation)
73-
self.logger.debug(f"Added observation {observation.id} to batch (total: {len(self._pending_observations)})")
42+
observation_id = str(uuid.uuid4())
43+
observation_dict = {
44+
'id': observation_id,
45+
'observer_name': observer_name,
46+
'content': content,
47+
'content_type': content_type,
48+
'timestamp': datetime.now(timezone.utc).isoformat()
49+
}
7450

75-
# Save immediately to prevent data loss
76-
self._save_pending_observations()
51+
# Add to queue - automatically persisted by persist-queue
52+
self._queue.put(observation_dict)
53+
self.logger.debug(f"Pushed observation {observation_id} to queue (size: {self._queue.qsize()})")
7754

78-
return observation.id
55+
return observation_id
7956

80-
def get_pending_count(self) -> int:
81-
"""Get the number of pending observations."""
82-
return len([obs for obs in self._pending_observations if not obs.processed])
57+
def size(self) -> int:
58+
"""Get the current size of the queue."""
59+
return self._queue.qsize()
8360

84-
def get_batch(self, max_size: Optional[int] = None) -> List[BatchedObservation]:
85-
"""Get a batch of unprocessed observations.
61+
def pop_batch(self, batch_size: Optional[int] = None) -> List[Dict[str, Any]]:
62+
"""Pop a batch of observations from the front of the queue (FIFO).
8663
8764
Args:
88-
max_size: Maximum number of observations to return
65+
batch_size: Number of items to pop. Defaults to max_batch_size
8966
9067
Returns:
91-
List of batched observations
68+
List of observation dictionaries popped from queue
9269
"""
93-
unprocessed = [obs for obs in self._pending_observations if not obs.processed]
94-
max_size = max_size or self.max_batch_size
95-
return unprocessed[:max_size]
70+
batch_size = batch_size or self.max_batch_size
9671

97-
def mark_processed(self, observation_ids: List[str]):
98-
"""Mark observations as processed.
72+
batch = []
73+
for _ in range(min(batch_size, self._queue.qsize())):
74+
batch.append(self._queue.get_nowait())
9975

100-
Args:
101-
observation_ids: List of observation IDs to mark as processed
102-
"""
103-
for obs in self._pending_observations:
104-
if obs.id in observation_ids:
105-
obs.processed = True
106-
107-
# Remove processed observations older than 24 hours
108-
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=24)
109-
self._pending_observations = [
110-
obs for obs in self._pending_observations
111-
if not obs.processed or obs.timestamp > cutoff_time
112-
]
76+
if batch:
77+
self.logger.debug(f"Popped batch of {len(batch)} observations (queue size: {self._queue.qsize()})")
11378

114-
self._save_pending_observations()
115-
self.logger.debug(f"Marked {len(observation_ids)} observations as processed")
116-
117-
async def _batch_loop(self):
118-
"""Main batching loop that processes observations periodically."""
119-
while True:
120-
try:
121-
# Wait for the batch interval
122-
await asyncio.sleep(self.batch_interval_hours * 3600)
123-
124-
# Get pending observations
125-
batch = self.get_batch()
126-
if batch:
127-
self.logger.info(f"Processing batch of {len(batch)} observations")
128-
# Signal that we have a batch ready
129-
# The main GUM class will handle the actual processing
130-
# For now, just log that we have a batch
131-
self.logger.info(f"Batch ready with {len(batch)} observations")
132-
else:
133-
self.logger.debug("No observations to process in this batch")
134-
135-
except asyncio.CancelledError:
136-
break
137-
except Exception as e:
138-
self.logger.error(f"Error in batch loop: {e}")
139-
await asyncio.sleep(60) # Wait a minute before retrying
140-
141-
def _load_pending_observations(self):
142-
"""Load pending observations from disk."""
143-
if self.batch_file.exists():
144-
try:
145-
with open(self.batch_file, 'r') as f:
146-
data = json.load(f)
147-
self._pending_observations = [
148-
BatchedObservation(**obs_data)
149-
for obs_data in data
150-
]
151-
# Convert timestamp strings back to datetime objects
152-
for obs in self._pending_observations:
153-
if isinstance(obs.timestamp, str):
154-
obs.timestamp = datetime.fromisoformat(obs.timestamp.replace('Z', '+00:00'))
155-
except Exception as e:
156-
self.logger.error(f"Error loading pending observations: {e}")
157-
self._pending_observations = []
158-
else:
159-
self._pending_observations = []
160-
161-
def _save_pending_observations(self):
162-
"""Save pending observations to disk."""
163-
try:
164-
# Convert datetime objects to ISO format strings
165-
data = []
166-
for obs in self._pending_observations:
167-
obs_dict = asdict(obs)
168-
obs_dict['timestamp'] = obs.timestamp.isoformat()
169-
data.append(obs_dict)
170-
171-
with open(self.batch_file, 'w') as f:
172-
json.dump(data, f, indent=2)
173-
except Exception as e:
174-
self.logger.error(f"Error saving pending observations: {e}")
79+
return batch
80+

gum/cli.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def parse_args():
3131
parser.add_argument('--reset-cache', action='store_true', help='Reset the GUM cache and exit') # Add this line
3232

3333
# Batching configuration arguments
34-
parser.add_argument('--use-batched-client', action='store_true', help='Enable batched client processing')
35-
parser.add_argument('--batch-interval-hours', type=float, help='Hours between batch processing')
34+
parser.add_argument('--batch-interval-minutes', type=float, help='Minutes between batch processing')
3635
parser.add_argument('--max-batch-size', type=int, help='Maximum number of observations per batch')
3736

3837
args = parser.parse_args()
@@ -58,10 +57,8 @@ async def main():
5857
model = args.model or os.getenv('MODEL_NAME') or 'gpt-4o-mini'
5958
user_name = args.user_name or os.getenv('USER_NAME')
6059

61-
# Batching configuration - follow same pattern as other args
62-
use_batched_client = args.use_batched_client or os.getenv('USE_BATCHED_CLIENT', 'false').lower() == 'true'
63-
64-
batch_interval_hours = args.batch_interval_hours or float(os.getenv('BATCH_INTERVAL_HOURS', '1'))
60+
# Batching configuration - follow same pattern as other args
61+
batch_interval_minutes = args.batch_interval_minutes or float(os.getenv('BATCH_INTERVAL_MINUTES', '2'))
6562
max_batch_size = args.max_batch_size or int(os.getenv('MAX_BATCH_SIZE', '50'))
6663

6764
# you need one or the other
@@ -86,17 +83,12 @@ async def main():
8683
print("-" * 80)
8784
else:
8885
print(f"Listening to {user_name} with model {model}")
89-
if use_batched_client:
90-
print(f"Batching enabled: processing every {batch_interval_hours} hours (max {max_batch_size} observations per batch)")
91-
else:
92-
print("Batching disabled: processing observations immediately")
9386

9487
async with gum(
9588
user_name,
9689
model,
9790
Screen(model),
98-
use_batched_client=use_batched_client,
99-
batch_interval_hours=batch_interval_hours,
91+
batch_interval_minutes=batch_interval_minutes,
10092
max_batch_size=max_batch_size
10193
) as gum_instance:
10294
await asyncio.Future() # run forever (Ctrl-C to stop)

gum/db_utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from .models import (
2727
Observation,
2828
Proposition,
29-
proposition_parent,
3029
observation_proposition,
3130
)
3231

@@ -45,13 +44,7 @@ def build_fts_query(raw: str, mode: str = "OR") -> str:
4544
else: # implicit AND
4645
return " ".join(tokens)
4746

48-
def _has_child_subquery() -> select:
49-
return (
50-
select(literal_column("1"))
51-
.select_from(proposition_parent)
52-
.where(proposition_parent.c.parent_id == Proposition.id)
53-
.exists()
54-
)
47+
5548

5649

5750
async def search_propositions_bm25(
@@ -74,7 +67,6 @@ async def search_propositions_bm25(
7467
# 1 Build candidate list
7568
# --------------------------------------------------------
7669
candidate_pool = limit * 10 if enable_mmr else limit
77-
has_child = _has_child_subquery()
7870

7971
if has_query:
8072
fts_prop = Table("propositions_fts", MetaData())
@@ -143,14 +135,12 @@ async def search_propositions_bm25(
143135
stmt = (
144136
select(Proposition, best_scores.c.bm25)
145137
.join(best_scores, best_scores.c.pid == Proposition.id)
146-
.where(~has_child)
147138
.order_by(best_scores.c.bm25.asc()) # smallest→best
148139
)
149140
else:
150141
# --- 1-b No user query ------------------------------
151142
stmt = (
152143
select(Proposition, literal_column("0.0").label("bm25"))
153-
.where(~has_child)
154144
.order_by(Proposition.created_at.desc())
155145
)
156146

0 commit comments

Comments
 (0)