@@ -22,6 +22,7 @@ class WillaChatbotState(TypedDict):
2222 docs_context : NotRequired [str ]
2323 search_query : NotRequired [str ]
2424 tind_metadata : NotRequired [str ]
25+ documents : NotRequired [list [Any ]]
2526 context : NotRequired [dict [str , Any ]]
2627
2728
@@ -87,17 +88,25 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
8788 vector_store = self ._vector_store
8889
8990 if not search_query or not vector_store :
90- return {"docs_context" : "" , "tind_metadata" : "" }
91+ return {"docs_context" : "" , "tind_metadata" : "" , "documents" : [] }
9192
9293 # Search for relevant documents
9394 retriever = vector_store .as_retriever (search_kwargs = {"k" : int (CONFIG ['K_VALUE' ])})
9495 matching_docs = retriever .invoke (search_query )
96+ formatted_documents = [
97+ {
98+ "page_content" : doc .page_content ,
99+ "start_index" : str (doc .metadata .get ('start_index' )) if doc .metadata .get ('start_index' ) else '' ,
100+ "total_pages" : str (doc .metadata .get ('total_pages' )) if doc .metadata .get ('total_pages' ) else '' ,
101+ }
102+ for doc in matching_docs
103+ ]
95104
96105 # Format context and metadata
97106 docs_context = '\n \n ' .join (doc .page_content for doc in matching_docs )
98107 tind_metadata = format_tind_context .get_tind_context (matching_docs )
99108
100- return {"docs_context" : docs_context , "tind_metadata" : tind_metadata }
109+ return {"docs_context" : docs_context , "tind_metadata" : tind_metadata , "documents" : formatted_documents }
101110
102111 # This should be refactored probably. Very bulky
103112 def _generate_response (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
@@ -107,6 +116,7 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess
107116 docs_context = state .get ("docs_context" , "" )
108117 tind_metadata = state .get ("tind_metadata" , "" )
109118 model = self ._model
119+ documents = state .get ("documents" , [])
110120
111121 if not model :
112122 return {"messages" : [AIMessage (content = "Model not available." )]}
@@ -121,16 +131,20 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess
121131 return {"messages" : [AIMessage (content = "I'm sorry, I didn't receive a question." )]}
122132
123133 prompt = get_langfuse_prompt ()
124- system_messages = prompt .invoke ({'context' : docs_context ,
125- 'question' : latest_message . content })
134+ system_messages = prompt .invoke ({})
135+
126136 if hasattr (system_messages , "messages" ):
127137 all_messages = summarized_conversation + system_messages .messages
128138 else :
129139 all_messages = summarized_conversation + [system_messages ]
130140
131141 # Get response from model
132- response = model .invoke (all_messages )
133-
142+ response = model .invoke (
143+ all_messages ,
144+ additional_model_request_fields = {"documents" : documents },
145+ additional_model_response_field_paths = ["/citations" ]
146+ )
147+ # print(response.response_metadata)
134148 # Create clean response content
135149 response_content = str (response .content ) if hasattr (response , 'content' ) else str (response )
136150 response_messages : list [AnyMessage ] = [AIMessage (content = response_content ),
0 commit comments