-
Notifications
You must be signed in to change notification settings - Fork 287
Expand file tree
/
Copy pathdistributed_mcts_reasoner.py
More file actions
416 lines (341 loc) · 13.2 KB
/
distributed_mcts_reasoner.py
File metadata and controls
416 lines (341 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
"""
Distributed Monte Carlo Tree Search (MCTS) Reasoner
====================================================
A "build your own o1" implementation using Modal for parallel reasoning exploration.
This system explores thousands of reasoning paths in parallel, using MCTS to intelligently
navigate the space of possible thought sequences and find optimal solutions to complex problems.
"""
import asyncio
import math
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import modal
# Modal setup
app = modal.App("mcts-reasoner")
image = modal.Image.debian_slim(python_version="3.11").pip_install(
"openai>=1.12.0", "numpy>=1.24.0"
)
# Configuration
OPENAI_API_KEY = modal.Secret.from_name("openai-api-key")
MODEL_NAME = "gpt-4o-mini" # Fast and cheap for exploration
MAX_DEPTH = 8 # Maximum reasoning steps
EXPLORATION_CONSTANT = 1.414 # UCB1 exploration parameter (sqrt(2))
NUM_SIMULATIONS = 100 # Total MCTS iterations
PARALLEL_WORKERS = 20 # Concurrent explorations
@dataclass
class Node:
"""
Represents a node in the MCTS reasoning tree.
Each node is a partial reasoning path with statistics for UCB1 selection.
"""
state: str # Current reasoning step/thought
parent: Optional["Node"] = None
children: List["Node"] = field(default_factory=list)
visits: int = 0
value: float = 0.0 # Total reward accumulated
depth: int = 0
is_terminal: bool = False
is_solution: bool = False
def ucb1_score(self, exploration_weight: float = EXPLORATION_CONSTANT) -> float:
"""
Upper Confidence Bound formula for balancing exploration vs exploitation.
Returns infinity for unvisited nodes to prioritize exploration.
"""
if self.visits == 0:
return float("inf")
if self.parent is None or self.parent.visits == 0:
return self.value / self.visits
exploitation = self.value / self.visits
exploration = exploration_weight * math.sqrt(
math.log(self.parent.visits) / self.visits
)
return exploitation + exploration
def best_child(self) -> Optional["Node"]:
"""Select child with highest UCB1 score."""
if not self.children:
return None
return max(self.children, key=lambda c: c.ucb1_score())
def update(self, reward: float):
"""Backpropagate reward up the tree."""
self.visits += 1
self.value += reward
def get_path(self) -> List[str]:
"""Get full reasoning path from root to this node."""
path = []
current = self
while current is not None:
if current.state: # Skip empty root
path.append(current.state)
current = current.parent
return list(reversed(path))
@app.cls(
image=image,
secrets=[OPENAI_API_KEY],
timeout=1800,
)
class MCTSMaster:
"""
Stateful master node managing the global MCTS tree.
Handles selection, expansion, and backpropagation across parallel workers.
"""
def __init__(self, problem: str):
self.problem = problem
self.root = Node(state="", depth=0)
self.lock = asyncio.Lock()
self.best_solution = None
self.best_reward = -float("inf")
@modal.method()
async def select_leaf(self) -> Dict[str, Any]:
"""
Selection phase: traverse tree using UCB1 until reaching a leaf.
Returns the leaf node's state and path for expansion.
"""
async with self.lock:
current = self.root
# Traverse to leaf using UCB1
while current.children and not current.is_terminal:
current = current.best_child()
if current is None:
break
# Return serializable node data
return {
"state": current.state,
"depth": current.depth,
"path": current.get_path(),
"node_id": id(current), # For tracking during backprop
}
@modal.method()
async def expand_and_evaluate(
self, node_data: Dict[str, Any], new_steps: List[str], rewards: List[float]
):
"""
Expansion phase: add new children to the tree and backpropagate rewards.
This is called by workers after LLM generation.
"""
async with self.lock:
# Find parent node (simplified - in production use proper ID mapping)
current = self.root
for step in node_data["path"]:
found = False
for child in current.children:
if child.state == step:
current = child
found = True
break
if not found:
break
# Add new children
for step, reward in zip(new_steps, rewards):
child = Node(
state=step,
parent=current,
depth=current.depth + 1,
is_terminal=(current.depth + 1 >= MAX_DEPTH),
)
current.children.append(child)
# Backpropagate
self._backpropagate(child, reward)
# Track best solution
if reward > self.best_reward:
self.best_reward = reward
self.best_solution = child.get_path()
def _backpropagate(self, node: Node, reward: float):
"""Propagate reward up to root, updating all ancestors."""
current = node
while current is not None:
current.update(reward)
current = current.parent
@modal.method()
async def get_best_solution(self) -> Dict[str, Any]:
"""Return the best reasoning path found."""
async with self.lock:
if self.best_solution is None:
# Fallback: most visited path from root
path = []
current = self.root
while current.children:
current = max(current.children, key=lambda c: c.visits)
path.append(current.state)
return {"path": path, "reward": 0.0}
return {
"path": self.best_solution,
"reward": self.best_reward,
"total_simulations": self.root.visits,
}
@modal.method()
async def get_tree_stats(self) -> Dict[str, int]:
"""Get tree exploration statistics."""
async with self.lock:
def count_nodes(node: Node) -> int:
return 1 + sum(count_nodes(c) for c in node.children)
return {
"total_nodes": count_nodes(self.root),
"root_visits": self.root.visits,
"num_children": len(self.root.children),
}
@app.function(
image=image,
secrets=[OPENAI_API_KEY],
timeout=300,
)
async def mcts_worker(
problem: str, node_data: Dict[str, Any], worker_id: int
) -> Dict[str, Any]:
"""
Worker function: expands a leaf node by generating next reasoning steps with LLM.
Returns new steps and their evaluated rewards.
"""
import openai
client = openai.AsyncOpenAI()
current_path = node_data["path"]
depth = node_data["depth"]
# Build prompt with current reasoning chain
prompt = f"""You are solving this problem using step-by-step reasoning:
PROBLEM: {problem}
REASONING SO FAR:
{chr(10).join(f"{i + 1}. {step}" for i, step in enumerate(current_path)) if current_path else "(starting)"}
Generate the NEXT logical reasoning step. Be concise and specific.
If you've reached a solution, state it clearly with "SOLUTION: [answer]"
Next step:"""
try:
# Generate next reasoning step
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
max_tokens=150,
temperature=0.8, # Higher temp for diversity
n=3, # Generate multiple candidate steps
)
steps = []
rewards = []
for choice in response.choices:
step = choice.message.content.strip()
# Check if this is a solution
is_solution = "SOLUTION:" in step.upper() or depth >= MAX_DEPTH - 1
# Evaluate step quality (simulate rollout)
reward = await evaluate_step(
problem, current_path + [step], is_solution, client
)
steps.append(step)
rewards.append(reward)
return {"steps": steps, "rewards": rewards, "worker_id": worker_id}
except Exception as e:
print(f"Worker {worker_id} error: {e}")
return {"steps": [], "rewards": [], "worker_id": worker_id}
async def evaluate_step(
problem: str, path: List[str], is_solution: bool, client
) -> float:
"""
Evaluate the quality of a reasoning path.
Uses LLM to score logical consistency and correctness.
"""
# Fast heuristic evaluation
path_text = "\n".join(f"{i + 1}. {step}" for i, step in enumerate(path))
if is_solution:
# Full verification for claimed solutions
verify_prompt = f"""Problem: {problem}
Proposed solution path:
{path_text}
Is this solution correct? Answer with a score from 0.0 (completely wrong) to 1.0 (perfect solution).
Score:"""
try:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": verify_prompt}],
max_tokens=10,
temperature=0.0,
)
score_text = response.choices[0].message.content.strip()
# Extract numeric score
match = re.search(r"(\d+\.?\d*)", score_text)
if match:
return float(match.group(1))
return 0.5
except Exception:
return 0.5
else:
# Heuristic: longer valid paths get higher base reward
# Penalize repetition
unique_ratio = len(set(path)) / len(path) if path else 1.0
depth_bonus = len(path) * 0.1
return min(unique_ratio * depth_bonus, 0.8) # Cap non-solution rewards
@app.function(
image=image,
secrets=[OPENAI_API_KEY],
timeout=1800,
)
async def run_mcts(
problem: str, num_simulations: int = NUM_SIMULATIONS
) -> Dict[str, Any]:
"""
Main MCTS orchestration: run parallel simulations to find best reasoning path.
"""
print(f"🧠 Starting MCTS with {num_simulations} simulations...")
print(f"Problem: {problem}\n")
# Initialize master tree
master = MCTSMaster(problem)
# Run simulations in batches for parallelism
batch_size = PARALLEL_WORKERS
num_batches = (num_simulations + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
print(f"📊 Batch {batch_idx + 1}/{num_batches}")
# Select leaves to expand
tasks = []
for worker_id in range(
min(batch_size, num_simulations - batch_idx * batch_size)
):
tasks.append(master.select_leaf.remote())
leaf_nodes = await asyncio.gather(*tasks)
# Expand leaves in parallel
expansion_tasks = []
for worker_id, node_data in enumerate(leaf_nodes):
expansion_tasks.append(
mcts_worker.remote.aio(problem, node_data, worker_id)
)
results = await asyncio.gather(*expansion_tasks)
# Update tree with results
for node_data, result in zip(leaf_nodes, results):
if result["steps"]:
await master.expand_and_evaluate.remote(
node_data, result["steps"], result["rewards"]
)
# Show progress
stats = await master.get_tree_stats.remote()
print(f" Nodes: {stats['total_nodes']}, Root visits: {stats['root_visits']}")
# Get final solution
solution = await master.get_best_solution.remote()
print("\n✅ MCTS Complete!")
print(f"Best reward: {solution['reward']:.3f}")
print(f"Total simulations: {solution['total_simulations']}")
return solution
@app.local_entrypoint()
def main(problem: str = None):
"""
CLI entry point with example problems.
"""
# Example problems
EXAMPLE_PROBLEMS = {
"math": "If x + 5 = 12 and y = 2x, what is y²?",
"logic": "Three people (A, B, C) are in a room. A says 'B is lying'. B says 'C is lying'. C says 'A and B are both lying'. Who is telling the truth?",
"code": "Write a Python function to find the longest palindromic substring in O(n²) time.",
}
if problem is None:
print("🎯 Example Problems:")
for key, prob in EXAMPLE_PROBLEMS.items():
print(f" {key}: {prob}")
problem = EXAMPLE_PROBLEMS["math"]
print("\n🔥 Running default problem: math\n")
# Run MCTS
result = run_mcts.remote(problem)
# Display results
print("\n" + "=" * 60)
print("🎯 FINAL REASONING PATH:")
print("=" * 60)
for i, step in enumerate(result["path"], 1):
print(f"{i}. {step}")
print("=" * 60)
print(f"\n💎 Confidence Score: {result['reward']:.2f}")
# Usage:
# modal run distributed_mcts_reasoner.py
# modal run distributed_mcts_reasoner.py --problem "Your custom problem here"