@@ -53,13 +53,15 @@ def _create_workflow(self) -> CompiledStateGraph:
5353 workflow .add_node ("summarize" , summarization_node )
5454 workflow .add_node ("prepare_search" , self ._prepare_search_query )
5555 workflow .add_node ("retrieve_context" , self ._retrieve_context )
56+ workflow .add_node ("prepare_for_generation" , self ._prepare_for_generation )
5657 workflow .add_node ("generate_response" , self ._generate_response )
5758
5859 # Define edges
5960 workflow .add_edge ("filter_messages" , "summarize" )
6061 workflow .add_edge ("summarize" , "prepare_search" )
6162 workflow .add_edge ("prepare_search" , "retrieve_context" )
62- workflow .add_edge ("retrieve_context" , "generate_response" )
63+ workflow .add_edge ("retrieve_context" , "prepare_for_generation" )
64+ workflow .add_edge ("prepare_for_generation" , "generate_response" )
6365
6466 workflow .set_entry_point ("filter_messages" )
6567 workflow .set_finish_point ("generate_response" )
@@ -109,50 +111,51 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
109111
110112 return {"docs_context" : docs_context , "tind_metadata" : tind_metadata , "documents" : formatted_documents }
111113
112- # This should be refactored probably. Very bulky
113- def _generate_response (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
114- """Generate response using the model."""
114+ def _prepare_for_generation (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
115+ """Prepare the current and past messages for response generation."""
115116 messages = state ["messages" ]
116117 summarized_conversation = state .get ("summarized_messages" , messages )
117- docs_context = state .get ("docs_context" , "" )
118- tind_metadata = state .get ("tind_metadata" , "" )
119- model = self ._model
120- documents = state .get ("documents" , [])
121-
122- if not model :
123- return {"messages" : [AIMessage (content = "Model not available." )]}
124-
125- # Get the latest human message
126- latest_message = next (
127- (msg for msg in reversed (messages ) if isinstance (msg , HumanMessage )),
128- None
129- )
130-
131- if not latest_message :
118+
119+ if not any (isinstance (msg , HumanMessage ) for msg in messages ):
132120 return {"messages" : [AIMessage (content = "I'm sorry, I didn't receive a question." )]}
133-
121+
134122 prompt = get_langfuse_prompt ()
135123 system_messages = prompt .invoke ({})
136-
124+
137125 if hasattr (system_messages , "messages" ):
138126 all_messages = summarized_conversation + system_messages .messages
139127 else :
140128 all_messages = summarized_conversation + [system_messages ]
129+
130+ return {"messages" : all_messages }
131+
132+ def _generate_response (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
133+ """Generate response using the model."""
134+ tind_metadata = state .get ("tind_metadata" , "" )
135+ model = self ._model
136+ documents = state .get ("documents" , [])
137+ messages = state ["messages" ]
138+
139+ if not model :
140+ return {"messages" : [AIMessage (content = "Model not available." )]}
141141
142142 # Get response from model
143143 response = model .invoke (
144- all_messages ,
144+ messages ,
145145 additional_model_request_fields = {"documents" : documents },
146146 additional_model_response_field_paths = ["/citations" ]
147147 )
148148 citations = response .response_metadata .get ('additionalModelResponseFields' ).get ('citations' ) if response .response_metadata else None
149149
150- # add citations to graph state
151- if citations :
152- state ['citations' ] = citations
153-
154150 # Create clean response content
155151 response_content = str (response .content ) if hasattr (response , 'content' ) else str (response )
152+
153+ if citations :
154+ state ['citations' ] = citations
155+ response_content += "\n \n Citations:\n "
156+ for citation in citations :
157+ response_content += f"- { citation .get ('text' , '' )} (docs: { citation .get ('document_ids' , [])} )\n "
158+
156159 response_messages : list [AnyMessage ] = [AIMessage (content = response_content ),
157160 ChatMessage (content = tind_metadata , role = 'TIND' ,
158161 response_metadata = {'tind' : True })]
0 commit comments