Skip to content

Commit e9e8a50

Browse files
feat: agent_graph SDK method (#85)
**Requirements** - [x] I have added test coverage for new or changed functionality - [x] I have followed the repository's [pull request submission guidelines](../blob/main/CONTRIBUTING.md#submitting-pull-requests) - [x] I have validated my changes against all supported platform versions **Describe the solution you've provided** This pull requests implements the following functionality: - Adds preliminary support for "Agent Graph" objects within the LaunchDarkly SDK, fetched by their configuration key ```python ai_client.agent_graph(config_key, context_value) -> AgentGraphDefinition ``` - Adds the `AgentGraphDefinition` class which allows traversal of a graph through `traverse` and `reverse_traverse` methods: ```python def handle_traversal(node: AIAgentConfig, ctx: Dict[str, Agent]): node_config = node.get_config() # Returns an AIAgentConfig node_edges = node.get_edges() # Returns Edge[] handoffs = [ctx[edge.target_config] for edge in node_edges] # Specific to OpenAI implementation return Agent( name=node_config.key, instructions=node_config.instructions, handoffs=handoffs, tools=[], ) root = graph.reverse_traverse(fn=handle_traversal) ``` **Describe alternatives you've considered** This is the implementation of a new feature within LaunchDarkly. There weren't alternatives considered as this is the initial implementation and offering regarding this feature. **Additional context** This PR will be followed up with other PRs introducing: - A layer for instrumenting framework-specific implementations of this graph method that will provide easier entry-points for specific frameworks - Framework-specific implementations <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Adds first-class Agent Graph support with traversal utilities and SDK integration. > > - New `AgentGraphDefinition` and `AgentGraphNode` implement graph construction, `traverse` and `reverse_traverse`, child/parent queries, terminals, and depth-ordered execution > - Introduces `AIAgentGraphConfig` and `Edge` types in `models` to define graphs and handoffs > - Adds `LDAIClient.agent_graph(key, context)` to fetch flag variation, validate all referenced agents are enabled, build edges/nodes, and return a disabled graph if invalid/missing root > - Updates package exports to include `AgentGraphDefinition`, `AIAgentGraphConfig`, and `Edge` > - Comprehensive tests verify node building, traversal orders, handoff data, and invalid graph scenarios > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 363337d. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY -->
2 parents 0d0cecc + 363337d commit e9e8a50

File tree

5 files changed

+829
-8
lines changed

5 files changed

+829
-8
lines changed

packages/sdk/server-ai/src/ldai/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
from ldclient import log
44

5+
from ldai.agent_graph import AgentGraphDefinition
56
from ldai.chat import Chat
67
from ldai.client import LDAIClient
78
from ldai.judge import Judge
89
from ldai.models import ( # Deprecated aliases for backward compatibility
9-
AIAgentConfig, AIAgentConfigDefault, AIAgentConfigRequest, AIAgents,
10-
AICompletionConfig, AICompletionConfigDefault, AIConfig, AIJudgeConfig,
11-
AIJudgeConfigDefault, JudgeConfiguration, LDAIAgent, LDAIAgentConfig,
12-
LDAIAgentDefaults, LDMessage, ModelConfig, ProviderConfig)
10+
AIAgentConfig, AIAgentConfigDefault, AIAgentConfigRequest,
11+
AIAgentGraphConfig, AIAgents, AICompletionConfig,
12+
AICompletionConfigDefault, AIConfig, AIJudgeConfig, AIJudgeConfigDefault,
13+
Edge, JudgeConfiguration, LDAIAgent, LDAIAgentConfig, LDAIAgentDefaults,
14+
LDMessage, ModelConfig, ProviderConfig)
1315
from ldai.providers.types import EvalScore, JudgeResponse
1416

1517
__all__ = [
@@ -18,12 +20,15 @@
1820
'AIAgentConfigDefault',
1921
'AIAgentConfigRequest',
2022
'AIAgents',
23+
'AIAgentGraphConfig',
24+
'Edge',
2125
'AICompletionConfig',
2226
'AICompletionConfigDefault',
2327
'AIJudgeConfig',
2428
'AIJudgeConfigDefault',
2529
'Chat',
2630
'EvalScore',
31+
'AgentGraphDefinition',
2732
'Judge',
2833
'JudgeConfiguration',
2934
'JudgeResponse',
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""Graph implementation for managing AI agent graphs."""
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Callable, Dict, List, Optional, Set
5+
6+
from ldclient import Context
7+
8+
from ldai.models import AIAgentConfig, AIAgentGraphConfig, Edge
9+
10+
DEFAULT_FALSE = AIAgentConfig(key="", enabled=False)
11+
12+
13+
class AgentGraphNode:
14+
"""
15+
Node in an agent graph.
16+
"""
17+
18+
def __init__(
19+
self,
20+
key: str,
21+
config: AIAgentConfig,
22+
children: List[Edge],
23+
):
24+
self._key = key
25+
self._config = config
26+
self._children = children
27+
28+
def get_key(self) -> str:
29+
"""Get the key of the node."""
30+
return self._key
31+
32+
def get_config(self) -> AIAgentConfig:
33+
"""Get the config of the node."""
34+
return self._config
35+
36+
def is_terminal(self) -> bool:
37+
"""Check if the node is a terminal node."""
38+
return len(self._children) == 0
39+
40+
def get_edges(self) -> List[Edge]:
41+
"""Get the edges of the node."""
42+
return self._children
43+
44+
45+
class AgentGraphDefinition:
46+
"""
47+
Graph implementation for managing AI agent graphs.
48+
"""
49+
enabled: bool
50+
51+
def __init__(
52+
self,
53+
agent_graph: AIAgentGraphConfig,
54+
nodes: Dict[str, AgentGraphNode],
55+
context: Context,
56+
enabled: bool,
57+
):
58+
self._agent_graph = agent_graph
59+
self._context = context
60+
self._nodes = nodes
61+
self.enabled = enabled
62+
63+
def is_enabled(self) -> bool:
64+
"""Check if the graph is enabled."""
65+
return self.enabled
66+
67+
@staticmethod
68+
def build_nodes(
69+
agent_graph: AIAgentGraphConfig,
70+
graph_nodes: Dict[str, AIAgentConfig],
71+
) -> Dict[str, "AgentGraphNode"]:
72+
"""Build the nodes of the graph into AgentGraphNode objects."""
73+
nodes = {
74+
agent_graph.root_config_key: AgentGraphNode(
75+
agent_graph.root_config_key,
76+
graph_nodes[agent_graph.root_config_key],
77+
[
78+
edge
79+
for edge in agent_graph.edges
80+
if edge.source_config == agent_graph.root_config_key
81+
],
82+
),
83+
}
84+
85+
for edge in agent_graph.edges:
86+
nodes[edge.target_config] = AgentGraphNode(
87+
edge.target_config,
88+
graph_nodes[edge.target_config],
89+
[e for e in agent_graph.edges if e.source_config == edge.target_config],
90+
)
91+
92+
return nodes
93+
94+
def get_node(self, key: str) -> Optional[AgentGraphNode]:
95+
"""Get a node by its key."""
96+
return self._nodes.get(key)
97+
98+
def _get_child_edges(self, config_key: str) -> List[Edge]:
99+
"""Get the child edges of the given config."""
100+
return [
101+
edge for edge in self._agent_graph.edges if edge.source_config == config_key
102+
]
103+
104+
def get_child_nodes(self, node_key: str) -> List[AgentGraphNode]:
105+
"""Get the child nodes of the given node key as AgentGraphNode objects."""
106+
nodes: List[AgentGraphNode] = []
107+
for edge in self._agent_graph.edges:
108+
if edge.source_config == node_key:
109+
node = self.get_node(edge.target_config)
110+
if node is not None:
111+
nodes.append(node)
112+
return nodes
113+
114+
def get_parent_nodes(self, node_key: str) -> List[AgentGraphNode]:
115+
"""Get the parent nodes of the given node key as AgentGraphNode objects."""
116+
nodes: List[AgentGraphNode] = []
117+
for edge in self._agent_graph.edges:
118+
if edge.target_config == node_key:
119+
node = self.get_node(edge.source_config)
120+
if node is not None:
121+
nodes.append(node)
122+
return nodes
123+
124+
def _collect_nodes(
125+
self,
126+
node: AgentGraphNode,
127+
node_depths: Dict[str, int],
128+
nodes_by_depth: Dict[int, List[AgentGraphNode]],
129+
visited: Set[str],
130+
max_depth: int,
131+
) -> None:
132+
"""Collect all reachable nodes from the given node and group them by depth."""
133+
node_key = node.get_key()
134+
if node_key in visited:
135+
return
136+
visited.add(node_key)
137+
138+
# Use max_depth for nodes not in node_depths to ensure they execute last
139+
node_depth = node_depths.get(node_key, max_depth)
140+
if node_depth not in nodes_by_depth:
141+
nodes_by_depth[node_depth] = []
142+
nodes_by_depth[node_depth].append(node)
143+
144+
for child in self.get_child_nodes(node_key):
145+
self._collect_nodes(child, node_depths, nodes_by_depth, visited, max_depth)
146+
147+
def terminal_nodes(self) -> List[AgentGraphNode]:
148+
"""Get the terminal nodes of the graph, meaning any nodes without children."""
149+
return [
150+
node
151+
for node in self._nodes.values()
152+
if len(self.get_child_nodes(node.get_key())) == 0
153+
]
154+
155+
def root(self) -> Optional[AgentGraphNode]:
156+
"""Get the root node of the graph."""
157+
return self._nodes.get(self._agent_graph.root_config_key)
158+
159+
def traverse(
160+
self,
161+
fn: Callable[["AgentGraphNode", Dict[str, Any]], Any],
162+
execution_context: Optional[Dict[str, Any]] = None,
163+
) -> Any:
164+
"""Traverse from the root down to terminal nodes, visiting nodes in order of depth.
165+
Nodes with the longest paths from the root (deepest nodes) will always be visited last."""
166+
if execution_context is None:
167+
execution_context = {}
168+
169+
root_node = self.root()
170+
if root_node is None:
171+
return
172+
173+
node_depths: Dict[str, int] = {root_node.get_key(): 0}
174+
current_level: List[AgentGraphNode] = [root_node]
175+
depth = 0
176+
max_depth_limit = 10 # Infinite loop protection limit
177+
max_depth_encountered = 0
178+
seen_nodes: Set[str] = {root_node.get_key()}
179+
180+
while current_level:
181+
next_level: List[AgentGraphNode] = []
182+
depth += 1
183+
184+
for node in current_level:
185+
node_key = node.get_key()
186+
for child in self.get_child_nodes(node_key):
187+
child_key = child.get_key()
188+
if depth <= max_depth_limit:
189+
# Defer this child to the next level if it's at a longer path
190+
if child_key not in node_depths or depth > node_depths[child_key]:
191+
node_depths[child_key] = depth
192+
max_depth_encountered = max(max_depth_encountered, depth)
193+
# Add to next level if not already visited (prevents cycles)
194+
if child_key not in seen_nodes:
195+
seen_nodes.add(child_key)
196+
next_level.append(child)
197+
else:
198+
max_depth_encountered = max(max_depth_encountered, depth)
199+
if child_key not in seen_nodes:
200+
# Push this to the next level to be visited
201+
seen_nodes.add(child_key)
202+
next_level.append(child)
203+
204+
current_level = next_level
205+
206+
# Use max_depth_limit + 1 to ensure they execute after all recorded nodes
207+
max_depth = max(max_depth_limit + 1, max_depth_encountered + 1)
208+
209+
# Group all nodes by depth
210+
nodes_by_depth: Dict[int, List[AgentGraphNode]] = {}
211+
# New visited for children nodes
212+
visited: Set[str] = set()
213+
214+
self._collect_nodes(root_node, node_depths, nodes_by_depth, visited, max_depth)
215+
# Execute the lambda at this level for the nodes at this depth
216+
for depth_level in sorted(nodes_by_depth.keys()):
217+
for node in nodes_by_depth[depth_level]:
218+
execution_context[node.get_key()] = fn(node, execution_context)
219+
220+
return execution_context[self._agent_graph.root_config_key]
221+
222+
def reverse_traverse(
223+
self,
224+
fn: Callable[["AgentGraphNode", Dict[str, Any]], Any],
225+
execution_context: Optional[Dict[str, Any]] = None,
226+
) -> Any:
227+
"""Traverse from terminal nodes up to the root, visiting nodes level by level.
228+
The root node will always be visited last, even if multiple paths converge at it."""
229+
if execution_context is None:
230+
execution_context = {}
231+
232+
terminal_nodes = self.terminal_nodes()
233+
if not terminal_nodes:
234+
return
235+
236+
visited: Set[str] = set()
237+
current_level: List[AgentGraphNode] = terminal_nodes
238+
root_key = self._agent_graph.root_config_key
239+
root_node_seen = False
240+
241+
while current_level:
242+
next_level: List[AgentGraphNode] = []
243+
244+
for node in current_level:
245+
node_key = node.get_key()
246+
if node_key in visited:
247+
continue
248+
249+
visited.add(node_key)
250+
# Skip the root node if we reach a terminus, it will be visited last
251+
if node_key == root_key:
252+
root_node_seen = True
253+
continue
254+
255+
execution_context[node_key] = fn(node, execution_context)
256+
257+
for parent in self.get_parent_nodes(node_key):
258+
parent_key = parent.get_key()
259+
if parent_key not in visited:
260+
next_level.append(parent)
261+
262+
current_level = next_level
263+
264+
# If we saw the root node, append it at the end as it'll always be the last node in a
265+
# reverse traversal (this should always happen, non-contiguous graphs are invalid)
266+
if root_node_seen:
267+
root_node = self.root()
268+
if root_node is not None:
269+
execution_context[root_node.get_key()] = fn(
270+
root_node, execution_context
271+
)
272+
273+
return execution_context[self._agent_graph.root_config_key]

0 commit comments

Comments
 (0)