Skip to content

Commit ff6ab73

Browse files
committed
investigating separating out documents from the rest of the message history and instructions.
1 parent cd8f211 commit ff6ab73

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

willa/chatbot/graph_manager.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class WillaChatbotState(TypedDict):
2222
docs_context: NotRequired[str]
2323
search_query: NotRequired[str]
2424
tind_metadata: NotRequired[str]
25+
documents: NotRequired[list[Any]]
2526
context: NotRequired[dict[str, Any]]
2627

2728

@@ -87,17 +88,25 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
8788
vector_store = self._vector_store
8889

8990
if not search_query or not vector_store:
90-
return {"docs_context": "", "tind_metadata": ""}
91+
return {"docs_context": "", "tind_metadata": "", "documents": []}
9192

9293
# Search for relevant documents
9394
retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])})
9495
matching_docs = retriever.invoke(search_query)
96+
formatted_documents = [
97+
{
98+
"page_content": doc.page_content,
99+
"start_index": str(doc.metadata.get('start_index')) if doc.metadata.get('start_index') else '',
100+
"total_pages": str(doc.metadata.get('total_pages')) if doc.metadata.get('total_pages') else '',
101+
}
102+
for doc in matching_docs
103+
]
95104

96105
# Format context and metadata
97106
docs_context = '\n\n'.join(doc.page_content for doc in matching_docs)
98107
tind_metadata = format_tind_context.get_tind_context(matching_docs)
99108

100-
return {"docs_context": docs_context, "tind_metadata": tind_metadata}
109+
return {"docs_context": docs_context, "tind_metadata": tind_metadata, "documents": formatted_documents}
101110

102111
# This should be refactored probably. Very bulky
103112
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
@@ -107,6 +116,7 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess
107116
docs_context = state.get("docs_context", "")
108117
tind_metadata = state.get("tind_metadata", "")
109118
model = self._model
119+
documents = state.get("documents", [])
110120

111121
if not model:
112122
return {"messages": [AIMessage(content="Model not available.")]}
@@ -121,16 +131,20 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess
121131
return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]}
122132

123133
prompt = get_langfuse_prompt()
124-
system_messages = prompt.invoke({'context': docs_context,
125-
'question': latest_message.content})
134+
system_messages = prompt.invoke({})
135+
126136
if hasattr(system_messages, "messages"):
127137
all_messages = summarized_conversation + system_messages.messages
128138
else:
129139
all_messages = summarized_conversation + [system_messages]
130140

131141
# Get response from model
132-
response = model.invoke(all_messages)
133-
142+
response = model.invoke(
143+
all_messages,
144+
additional_model_request_fields={"documents": documents},
145+
additional_model_response_field_paths=["/citations"]
146+
)
147+
# print(response.response_metadata)
134148
# Create clean response content
135149
response_content = str(response.content) if hasattr(response, 'content') else str(response)
136150
response_messages: list[AnyMessage] = [AIMessage(content=response_content),

0 commit comments

Comments
 (0)