Skip to content

Commit cd976ff

Browse files
committed
add prepare generation node
- temporarily add raw citations to response.
1 parent e999fa8 commit cd976ff

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

willa/chatbot/graph_manager.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ def _create_workflow(self) -> CompiledStateGraph:
5353
workflow.add_node("summarize", summarization_node)
5454
workflow.add_node("prepare_search", self._prepare_search_query)
5555
workflow.add_node("retrieve_context", self._retrieve_context)
56+
workflow.add_node("prepare_for_generation", self._prepare_for_generation)
5657
workflow.add_node("generate_response", self._generate_response)
5758

5859
# Define edges
5960
workflow.add_edge("filter_messages", "summarize")
6061
workflow.add_edge("summarize", "prepare_search")
6162
workflow.add_edge("prepare_search", "retrieve_context")
62-
workflow.add_edge("retrieve_context", "generate_response")
63+
workflow.add_edge("retrieve_context", "prepare_for_generation")
64+
workflow.add_edge("prepare_for_generation", "generate_response")
6365

6466
workflow.set_entry_point("filter_messages")
6567
workflow.set_finish_point("generate_response")
@@ -109,50 +111,51 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
109111

110112
return {"docs_context": docs_context, "tind_metadata": tind_metadata, "documents": formatted_documents}
111113

112-
# This should be refactored probably. Very bulky
113-
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
114-
"""Generate response using the model."""
114+
def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
115+
"""Prepare the current and past messages for response generation."""
115116
messages = state["messages"]
116117
summarized_conversation = state.get("summarized_messages", messages)
117-
docs_context = state.get("docs_context", "")
118-
tind_metadata = state.get("tind_metadata", "")
119-
model = self._model
120-
documents = state.get("documents", [])
121-
122-
if not model:
123-
return {"messages": [AIMessage(content="Model not available.")]}
124-
125-
# Get the latest human message
126-
latest_message = next(
127-
(msg for msg in reversed(messages) if isinstance(msg, HumanMessage)),
128-
None
129-
)
130-
131-
if not latest_message:
118+
119+
if not any(isinstance(msg, HumanMessage) for msg in messages):
132120
return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]}
133-
121+
134122
prompt = get_langfuse_prompt()
135123
system_messages = prompt.invoke({})
136-
124+
137125
if hasattr(system_messages, "messages"):
138126
all_messages = summarized_conversation + system_messages.messages
139127
else:
140128
all_messages = summarized_conversation + [system_messages]
129+
130+
return {"messages": all_messages}
131+
132+
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
133+
"""Generate response using the model."""
134+
tind_metadata = state.get("tind_metadata", "")
135+
model = self._model
136+
documents = state.get("documents", [])
137+
messages = state["messages"]
138+
139+
if not model:
140+
return {"messages": [AIMessage(content="Model not available.")]}
141141

142142
# Get response from model
143143
response = model.invoke(
144-
all_messages,
144+
messages,
145145
additional_model_request_fields={"documents": documents},
146146
additional_model_response_field_paths=["/citations"]
147147
)
148148
citations = response.response_metadata.get('additionalModelResponseFields').get('citations') if response.response_metadata else None
149149

150-
# add citations to graph state
151-
if citations:
152-
state['citations'] = citations
153-
154150
# Create clean response content
155151
response_content = str(response.content) if hasattr(response, 'content') else str(response)
152+
153+
if citations:
154+
state['citations'] = citations
155+
response_content += "\n\nCitations:\n"
156+
for citation in citations:
157+
response_content += f"- {citation.get('text', '')} (docs: {citation.get('document_ids', [])})\n"
158+
156159
response_messages: list[AnyMessage] = [AIMessage(content=response_content),
157160
ChatMessage(content=tind_metadata, role='TIND',
158161
response_metadata={'tind': True})]

0 commit comments

Comments
 (0)