Skip to content

Commit 015d2ef

Browse files
Merge pull request #552 from microsoft/hb-us-23028
refactor: removes the use of `contextvars` and refactors the codebase to explicitly pass `user_id`
2 parents e523570 + dc47dc6 commit 015d2ef

File tree

6 files changed

+52
-81
lines changed

6 files changed

+52
-81
lines changed

src/backend/v3/api/router.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import contextvars
32
import json
43
import logging
54
import uuid
@@ -31,7 +30,6 @@
3130
from v3.common.services.team_service import TeamService
3231
from v3.config.settings import (
3332
connection_config,
34-
current_user_id,
3533
orchestration_config,
3634
team_config,
3735
)
@@ -57,8 +55,6 @@ async def start_comms(
5755

5856
user_id = user_id or "00000000-0000-0000-0000-000000000000"
5957

60-
current_user_id.set(user_id)
61-
6258
# Add to the connection manager for backend updates
6359
connection_config.add_connection(
6460
process_id=process_id, connection=websocket, user_id=user_id
@@ -90,7 +86,7 @@ async def start_comms(
9086
logging.error(f"Error in WebSocket connection: {str(e)}")
9187
finally:
9288
# Always clean up the connection
93-
await connection_config.close_connection(user_id)
89+
await connection_config.close_connection(process_id=process_id)
9490

9591

9692
@app_v3.get("/init_team")
@@ -304,18 +300,14 @@ async def process_request(
304300
raise HTTPException(status_code=500, detail="Failed to create plan")
305301

306302
try:
307-
current_user_id.set(user_id) # Set context
308-
current_context = contextvars.copy_context() # Capture context
309303
# background_tasks.add_task(
310304
# lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task)
311305
# )
312306

313-
async def run_with_context():
314-
return await current_context.run(
315-
OrchestrationManager().run_orchestration, user_id, input_task
316-
)
307+
async def run_orchestration_task():
308+
await OrchestrationManager().run_orchestration(user_id, input_task)
317309

318-
background_tasks.add_task(run_with_context)
310+
background_tasks.add_task(run_orchestration_task)
319311

320312
return {
321313
"status": "Request started successfully",

src/backend/v3/config/settings.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
"""
55

66
import asyncio
7-
import contextvars
87
import json
98
import logging
10-
from typing import Dict, Optional
9+
from typing import Dict
1110

1211
from common.config.app_config import config
1312
from common.models.messages_kernel import TeamConfiguration
@@ -21,11 +20,6 @@
2120

2221
logger = logging.getLogger(__name__)
2322

24-
# Create a context variable to track current user
25-
current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
26-
"current_user_id", default=None
27-
)
28-
2923

3024
class AzureConfig:
3125
"""Azure OpenAI and authentication configuration."""
@@ -181,13 +175,10 @@ async def close_connection(self, process_id):
181175
async def send_status_update_async(
182176
self,
183177
message: any,
184-
user_id: Optional[str] = None,
178+
user_id: str,
185179
message_type: WebsocketMessageType = WebsocketMessageType.SYSTEM_MESSAGE,
186180
):
187181
"""Send a status update to a specific client."""
188-
# If no process_id provided, get from context
189-
if user_id is None:
190-
user_id = current_user_id.get()
191182

192183
if not user_id:
193184
logger.warning("No user_id available for WebSocket message")

src/backend/v3/magentic_agents/magentic_agent_factory.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from common.config.app_config import config
1010
from common.models.messages_kernel import TeamConfiguration
11-
from v3.config.settings import current_user_id
1211
from v3.magentic_agents.foundry_agent import FoundryAgentTemplate
1312
from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig
1413

@@ -40,13 +39,12 @@ def __init__(self):
4039
# data = json.load(f)
4140
# return json.loads(json.dumps(data), object_hook=lambda d: SimpleNamespace(**d))
4241

43-
async def create_agent_from_config(
44-
self, agent_obj: SimpleNamespace
45-
) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]:
42+
async def create_agent_from_config(self, user_id: str, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]:
4643
"""
4744
Create an agent from configuration object.
4845
4946
Args:
47+
user_id: User ID
5048
agent_obj: Agent object from parsed JSON (SimpleNamespace)
5149
team_model: Model name to determine which template to use
5250
@@ -62,7 +60,6 @@ async def create_agent_from_config(
6260

6361
if not deployment_name and agent_obj.name.lower() == "proxyagent":
6462
self.logger.info("Creating ProxyAgent")
65-
user_id = current_user_id.get()
6663
return ProxyAgent(user_id=user_id)
6764

6865
# Validate supported models
@@ -133,11 +130,12 @@ async def create_agent_from_config(
133130
)
134131
return agent
135132

136-
async def get_agents(self, team_config_input: TeamConfiguration) -> List:
133+
async def get_agents(self, user_id: str, team_config_input: TeamConfiguration) -> List:
137134
"""
138135
Create and return a team of agents from JSON configuration.
139136
140137
Args:
138+
user_id: User ID
141139
team_config_input: team configuration object from cosmos db
142140
143141
Returns:
@@ -151,11 +149,9 @@ async def get_agents(self, team_config_input: TeamConfiguration) -> List:
151149

152150
for i, agent_cfg in enumerate(team_config_input.agents, 1):
153151
try:
154-
self.logger.info(
155-
f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}"
156-
)
152+
self.logger.info(f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}")
157153

158-
agent = await self.create_agent_from_config(agent_cfg)
154+
agent = await self.create_agent_from_config(user_id, agent_cfg)
159155
initalized_agents.append(agent)
160156
self._agent_list.append(agent) # Keep track for cleanup
161157

src/backend/v3/magentic_agents/proxy_agent.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,11 @@
2424
)
2525
from semantic_kernel.exceptions.agent_exceptions import AgentThreadOperationException
2626
from typing_extensions import override
27-
from v3.callbacks.response_handlers import (
28-
agent_response_callback,
29-
streaming_agent_response_callback,
30-
)
31-
from v3.config.settings import connection_config, current_user_id, orchestration_config
32-
from v3.models.messages import (
33-
UserClarificationRequest,
34-
UserClarificationResponse,
35-
WebsocketMessageType,
36-
)
27+
from v3.callbacks.response_handlers import (agent_response_callback,
28+
streaming_agent_response_callback)
29+
from v3.config.settings import connection_config, orchestration_config
30+
from v3.models.messages import (UserClarificationRequest,
31+
UserClarificationResponse, WebsocketMessageType)
3732

3833

3934
class DummyAgentThread(AgentThread):
@@ -110,13 +105,13 @@ class ProxyAgent(Agent):
110105
"""Simple proxy agent that prompts for human clarification."""
111106

112107
# Declare as Pydantic field
113-
user_id: Optional[str] = Field(
108+
user_id: str = Field(
114109
default=None, description="User ID for WebSocket messaging"
115110
)
116111

117-
def __init__(self, user_id: str = None, **kwargs):
118-
# Get user_id from parameter or context, fallback to empty string
119-
effective_user_id = user_id or current_user_id.get() or ""
112+
def __init__(self, user_id: str, **kwargs):
113+
# Get user_id from parameter, fallback to empty string
114+
effective_user_id = user_id or ""
120115
super().__init__(
121116
name="ProxyAgent",
122117
description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.",
@@ -139,15 +134,15 @@ def _create_message_content(
139134
async def _trigger_response_callbacks(self, message_content: ChatMessageContent):
140135
"""Manually trigger the same response callbacks used by other agents."""
141136
# Get current user_id dynamically instead of using stored value
142-
current_user = current_user_id.get() or self.user_id or ""
137+
current_user = self.user_id or ""
143138

144139
# Trigger the standard agent response callback
145140
agent_response_callback(message_content, current_user)
146141

147142
async def _trigger_streaming_callbacks(self, content: str, is_final: bool = False):
148143
"""Manually trigger streaming callbacks for real-time updates."""
149144
# Get current user_id dynamically instead of using stored value
150-
current_user = current_user_id.get() or self.user_id or ""
145+
current_user = self.user_id or ""
151146
streaming_message = StreamingChatMessageContent(
152147
role=AuthorRole.ASSISTANT, content=content, name=self.name, choice_index=0
153148
)
@@ -181,7 +176,7 @@ async def invoke(
181176
"type": WebsocketMessageType.USER_CLARIFICATION_REQUEST,
182177
"data": clarification_message,
183178
},
184-
user_id=current_user_id.get(),
179+
user_id=self.user_id,
185180
message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST,
186181
)
187182

@@ -238,7 +233,7 @@ async def invoke_stream(
238233
"type": WebsocketMessageType.USER_CLARIFICATION_REQUEST,
239234
"data": clarification_message,
240235
},
241-
user_id=current_user_id.get(),
236+
user_id=self.user_id,
242237
message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST,
243238
)
244239

src/backend/v3/orchestration/human_approval_manager.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT,
2121
)
2222
from semantic_kernel.contents import ChatMessageContent
23-
from v3.config.settings import connection_config, current_user_id, orchestration_config
23+
from v3.config.settings import connection_config, orchestration_config
2424
from v3.models.models import MPlan
25-
from v3.orchestration.helper.plan_to_mplan_converter import PlanToMPlanConverter
25+
from v3.orchestration.helper.plan_to_mplan_converter import \
26+
PlanToMPlanConverter
2627

2728
# Using a module level logger to avoid pydantic issues around inherited fields
2829
logger = logging.getLogger(__name__)
@@ -38,9 +39,17 @@ class HumanApprovalMagenticManager(StandardMagenticManager):
3839
# Define Pydantic fields to avoid validation errors
3940
approval_enabled: bool = True
4041
magentic_plan: Optional[MPlan] = None
41-
current_user_id: Optional[str] = None
42+
current_user_id: str
43+
44+
def __init__(self, user_id: str, *args, **kwargs):
45+
"""
46+
Initialize the HumanApprovalMagenticManager.
47+
Args:
48+
user_id: ID of the user to associate with this orchestration instance.
49+
*args: Additional positional arguments for the parent StandardMagenticManager.
50+
**kwargs: Additional keyword arguments for the parent StandardMagenticManager.
51+
"""
4252

43-
def __init__(self, *args, **kwargs):
4453
# Remove any custom kwargs before passing to parent
4554

4655
plan_append = """
@@ -76,6 +85,8 @@ def __init__(self, *args, **kwargs):
7685
)
7786
kwargs["final_answer_prompt"] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append
7887

88+
kwargs['current_user_id'] = user_id
89+
7990
super().__init__(*args, **kwargs)
8091

8192
async def plan(self, magentic_context: MagenticContext) -> Any:
@@ -100,7 +111,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
100111

101112
self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger)
102113

103-
self.magentic_plan.user_id = current_user_id.get()
114+
self.magentic_plan.user_id = self.current_user_id
104115

105116
# Request approval from the user before executing the plan
106117
approval_message = messages.PlanApprovalRequest(
@@ -124,7 +135,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
124135
# The user_id will be automatically retrieved from context
125136
await connection_config.send_status_update_async(
126137
message=approval_message,
127-
user_id=current_user_id.get(),
138+
user_id=self.current_user_id,
128139
message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST,
129140
)
130141

@@ -141,7 +152,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
141152
"type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE,
142153
"data": approval_response,
143154
},
144-
user_id=current_user_id.get(),
155+
user_id=self.current_user_id,
145156
message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE,
146157
)
147158
raise Exception("Plan execution cancelled by user")
@@ -170,7 +181,7 @@ async def create_progress_ledger(
170181

171182
await connection_config.send_status_update_async(
172183
message=final_message,
173-
user_id=current_user_id.get(),
184+
user_id=self.current_user_id,
174185
message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE,
175186
)
176187

src/backend/v3/orchestration/orchestration_manager.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# Copyright (c) Microsoft. All rights reserved.
22
"""Orchestration manager to handle the orchestration logic."""
33
import asyncio
4-
import contextvars
54
import logging
65
import uuid
7-
from contextvars import ContextVar
86
from typing import List, Optional
97

108
from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential
@@ -15,27 +13,16 @@
1513

1614
# Create custom execution settings to fix schema issues
1715
from semantic_kernel.connectors.ai.open_ai import (
18-
AzureChatCompletion,
19-
OpenAIChatPromptExecutionSettings,
20-
)
21-
from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent
22-
from v3.callbacks.response_handlers import (
23-
agent_response_callback,
24-
streaming_agent_response_callback,
25-
)
26-
from v3.config.settings import (
27-
connection_config,
28-
orchestration_config,
29-
)
16+
AzureChatCompletion, OpenAIChatPromptExecutionSettings)
17+
from semantic_kernel.contents import (ChatMessageContent,
18+
StreamingChatMessageContent)
19+
from v3.callbacks.response_handlers import (agent_response_callback,
20+
streaming_agent_response_callback)
21+
from v3.config.settings import connection_config, orchestration_config
3022
from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory
3123
from v3.models.messages import WebsocketMessageType
3224
from v3.orchestration.human_approval_manager import HumanApprovalMagenticManager
3325

34-
# Context variable to hold the current user ID
35-
current_user_id: ContextVar[Optional[str]] = contextvars.ContextVar(
36-
"current_user_id", default=None
37-
)
38-
3926

4027
class OrchestrationManager:
4128
"""Manager for handling orchestration logic."""
@@ -69,6 +56,7 @@ def get_token():
6956
magentic_orchestration = MagenticOrchestration(
7057
members=agents,
7158
manager=HumanApprovalMagenticManager(
59+
user_id=user_id,
7260
chat_completion_service=AzureChatCompletion(
7361
deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME,
7462
endpoint=config.AZURE_OPENAI_ENDPOINT,
@@ -122,15 +110,14 @@ async def get_current_or_new_orchestration(
122110
except Exception as e:
123111
cls.logger.error("Error closing agent: %s", e)
124112
factory = MagenticAgentFactory()
125-
agents = await factory.get_agents(team_config_input=team_config)
113+
agents = await factory.get_agents(user_id=user_id, team_config_input=team_config)
126114
orchestration_config.orchestrations[user_id] = await cls.init_orchestration(
127115
agents, user_id
128116
)
129117
return orchestration_config.get_current_orchestration(user_id)
130118

131119
async def run_orchestration(self, user_id, input_task) -> None:
132120
"""Run the orchestration with user input loop."""
133-
token = current_user_id.set(user_id)
134121

135122
job_id = str(uuid.uuid4())
136123
orchestration_config.approvals[job_id] = None
@@ -190,4 +177,3 @@ async def run_orchestration(self, user_id, input_task) -> None:
190177
self.logger.error(f"Unexpected error: {e}")
191178
finally:
192179
await runtime.stop_when_idle()
193-
current_user_id.reset(token)

0 commit comments

Comments
 (0)