11from __future__ import annotations
22
33import copy
4+ import random
45import re
56import warnings
6- from typing import Any , cast
7+ from typing import Any , Callable , Literal , TypeVar , cast
78
89import colorama
910from langchain .chat_models .base import BaseChatModel
11+ from langchain .embeddings .base import Embeddings
1012from langchain .schema import AIMessage , BaseMessage , HumanMessage , SystemMessage
13+ from scipy .spatial .distance import cosine as cosine_distance
1114
1215from .prompt_utils import (
1316 CHECK_INSTRUCTION ,
1417 EXPLAIN_INSTRUCTION ,
1518 FIX_INSTRUCTION ,
1619 GENERATE_INSTRUCTION ,
20+ GENERATE_INSTRUCTION_COT ,
1721 SANITY_CHECK_INSTRUCTION ,
1822 SUGGEST_INSTRUCTION ,
1923 FixContext ,
3337
3438_debug_mode : bool = False
3539
40+ _Type = TypeVar ("_Type" )
41+
3642
3743def debug_messages (* messages : BaseMessage ) -> None :
3844 if not _debug_mode :
@@ -91,8 +97,58 @@ def parse_code(response: str) -> str:
9197
9298
9399class CoMLAgent :
94- def __init__ (self , llm : BaseChatModel ):
100+ """
101+ CoML agent that accepts data science requests and generates code.
102+
103+ Attributes:
104+ llm: The language model that generates responses.
105+ prompt_version: The version of prompt to use (can be ``v1`` or ``v2``).
106+ prompt_validation: A function that takes a list of messages and returns
107+ whether the prompt is valid, which is useful for limiting the number of
108+ tokens in the prompt.
109+ num_examples: The number of examples to show in the prompt. It can be a
110+ number between 0 and 1, interpreted as the percentage of examples to show.
111+ It can also be an integer, interpreted as the number of examples to show.
112+ message_style: Can be ``chatgpt`` in which system messages are shown, or
113+ ``gemini`` in which only human and ai messages are shown.
114+ chain_of_thought: Whether to use chain of thought (COT) in the prompt.
115+ context_order: The order of the context in the prompt. Default to ``vcr``.
116+ ``v`` for variable descriptions, ``c`` for codes, ``r`` for request.
117+ ensemble: Perform ``ensemble`` number of LLM calls and ensemble the results.
118+ ensemble_shuffle: Shuffle the examples in the prompt before ensemble.
119+ example_ranking: A model that ranks the examples. If provided, the examples
120+ will be ranked by the model before selecting the examples.
121+ intact_instruction: Whether to instruct LLM to keep the variables unmodified.
122+ For experimenting purposes only.
123+ """
124+
125+ def __init__ (
126+ self ,
127+ llm : BaseChatModel ,
128+ prompt_version : Literal ["v1" , "v2" , "kaggle" , "leetcode" ] = "v2" ,
129+ prompt_validation : Callable [[list [BaseMessage ]], bool ] | None = None ,
130+ num_examples : float | int = 1.0 ,
131+ message_style : Literal ["chatgpt" , "gemini" ] = "chatgpt" ,
132+ chain_of_thought : bool = False ,
133+ context_order : Literal [
134+ "vcr" , "cvr" , "rvc" , "rcv" , "vr" , "rv" , "cr" , "rc" , "r"
135+ ] = "vcr" ,
136+ ensemble : int | None = None ,
137+ ensemble_shuffle : bool = True ,
138+ example_ranking : Embeddings | None = None ,
139+ intact_instruction : bool = True ,
140+ ):
95141 self .llm = llm
142+ self .prompt_version = prompt_version
143+ self .prompt_validation = prompt_validation
144+ self .num_examples = num_examples
145+ self .message_style = message_style
146+ self .chain_of_thought = chain_of_thought
147+ self .context_order = context_order
148+ self .ensemble = ensemble
149+ self .ensemble_shuffle = ensemble_shuffle
150+ self .example_ranking = example_ranking
151+ self .intact_instruction = intact_instruction
96152
97153 def _fix_context_from_any_context (
98154 self , context : GenerateContext | FixContext , ** kwargs : Any
@@ -110,27 +166,162 @@ def _fix_context_from_any_context(
110166 context ["interactions" ].append (InteractionIncomplete (** kwargs ))
111167 return context
112168
169+ def _pre_generation (self , messages : list [BaseMessage ]) -> list [BaseMessage ]:
170+ if self .message_style == "gemini" :
171+ # Merge the first two messages.
172+ if len (messages ) > 1 and isinstance (messages [1 ], HumanMessage ):
173+ messages1_content = cast (str , messages [1 ].content )
174+ if not messages1_content .startswith ("### Task Start ###" ):
175+ messages1_content = "### Task Start ###\n \n " + messages1_content
176+ messages [1 ] = HumanMessage (
177+ content = cast (str , messages [0 ].content ) + "\n \n " + messages1_content
178+ )
179+ messages = messages [1 :]
180+ else :
181+ messages [0 ] = HumanMessage (content = cast (str , messages [0 ].content ))
182+
183+ if self .prompt_validation is not None and not self .prompt_validation (messages ):
184+ raise ValueError ("Prompt validation failed." )
185+
186+ return messages
187+
188+ def _ensemble_generate (self , messages : list [BaseMessage ]) -> BaseMessage :
189+ """Ensemble the result from multiple LLM calls."""
190+
191+ if not self .ensemble :
192+ return self ._generate (messages )
193+
194+ results : list [tuple [float , BaseMessage ]] = []
195+ for _ in range (self .ensemble ):
196+ if self .ensemble_shuffle and len (messages ) > 2 :
197+ # Shuffle the examples
198+ first_message = messages [0 ]
199+ interactions = messages [1 :]
200+
201+ start_indices = []
202+ # Can be [Task A Human, AI, Human AI, Task B Human, AI]
203+ for index , message in enumerate (interactions ):
204+ if isinstance (message , HumanMessage ) and cast (
205+ str , message .content
206+ ).startswith ("### Task Start ###" ):
207+ start_indices .append (index )
208+
209+ # Can be [Human, AI, Human AI]
210+ if not start_indices :
211+ # Loosen the constraint and find all human messages
212+ start_indices = [
213+ index
214+ for index , message in enumerate (interactions )
215+ if isinstance (message , HumanMessage )
216+ ]
217+
218+ groups = [
219+ interactions [index :index_next ]
220+ for index , index_next in zip (start_indices , start_indices [1 :])
221+ ]
222+
223+ # Shuffle the groups and combine them
224+ random .shuffle (groups )
225+ messages = (
226+ [first_message ]
227+ + [message for group in groups for message in group ]
228+ + interactions [start_indices [- 1 ] :]
229+ )
230+
231+ messages = self ._pre_generation (messages )
232+ result = self .llm .generate ([messages ], logprobs = True )
233+ message = result .generations [0 ][0 ].message # type: ignore
234+ generation_info = result .generations [0 ][0 ].generation_info
235+ if generation_info is None or "logprobs" not in generation_info :
236+ raise ValueError ("Logprobs not found in generation_info." )
237+ logprobs = [
238+ content ["logprob" ] for content in generation_info ["logprobs" ]["content" ]
239+ ]
240+ if not logprobs :
241+ mean_logprobs = float ("-inf" )
242+ else :
243+ mean_logprobs = sum (logprobs ) / len (logprobs )
244+
245+ results .append ((mean_logprobs , message ))
246+
247+ results .sort (key = lambda x : x [0 ], reverse = True )
248+
249+ return results [0 ][1 ]
250+
251+ def _generate (self , messages : list [BaseMessage ]) -> BaseMessage :
252+ """Generate a response from the LLM."""
253+ messages = self ._pre_generation (messages )
254+ return self .llm (messages )
255+
256+ def _select_examples (self , query : str , fewshots : list [_Type ]) -> list [_Type ]:
257+ """Select examples from the fewshots."""
258+ if self .num_examples == 0 :
259+ return []
260+
261+ if self .example_ranking is not None :
262+ documents = [cast (Any , shot ).get ("request" , "N/A" ) for shot in fewshots ]
263+ embeddings = self .example_ranking .embed_documents (documents )
264+ # Use embed_documents instead of embed_query because the latter has cache
265+ query_embedding = self .example_ranking .embed_documents ([query ])[0 ]
266+ similarities = [
267+ (cosine_distance (query_embedding , embedding ), shot )
268+ for embedding , shot in zip (embeddings , fewshots )
269+ ]
270+ similarities .sort (key = lambda x : x [0 ])
271+ fewshots = [shot for _ , shot in similarities ]
272+
273+ if isinstance (self .num_examples , int ):
274+ return fewshots [: self .num_examples ]
275+ else :
276+ num_shots = max (int (len (fewshots ) * self .num_examples ), 1 )
277+ return fewshots [:num_shots ]
278+
113279 def generate_code (
114- self , request : str , variable_descriptions : dict [str , str ], codes : list [str ]
280+ self ,
281+ request : str ,
282+ variable_descriptions : dict [str , str ],
283+ codes : list [str ],
115284 ) -> GenerateContext :
116- fewshots = cached_generate_fewshots ()
117- messages : list [BaseMessage ] = [
118- SystemMessage (content = GENERATE_INSTRUCTION ),
119- ]
120- for shot in fewshots :
121- question , answer = render_generate_context (shot )
285+ fewshots = cached_generate_fewshots (self .prompt_version )
286+ messages : list [BaseMessage ] = []
287+
288+ if self .chain_of_thought :
289+ generate_instruction = GENERATE_INSTRUCTION_COT
290+ else :
291+ generate_instruction = GENERATE_INSTRUCTION
292+ if not self .intact_instruction :
293+ generate_instruction = re .sub (
294+ r"- Do not overwrite or modify.*\n" , "" , generate_instruction
295+ )
296+ for shot in fewshots :
297+ if "answer_wo_intact" in shot :
298+ shot ["answer" ] = shot .pop ("answer_wo_intact" )
299+ if "rationale_wo_intact" in shot :
300+ shot ["rationale" ] = shot .pop ("rationale_wo_intact" )
301+ messages .append (SystemMessage (content = generate_instruction ))
302+
303+ for shot in self ._select_examples (request , fewshots ):
304+ question , answer = render_generate_context (
305+ shot , cot = self .chain_of_thought , context_order = self .context_order
306+ )
122307 messages .append (HumanMessage (content = question ))
123308 if answer is not None :
124309 messages .append (AIMessage (content = answer ))
125310 context = GenerateContextIncomplete (
126311 variables = variable_descriptions , codes = codes , request = request
127312 )
128- question , _ = render_generate_context (context )
313+ question , _ = render_generate_context (
314+ context , cot = self .chain_of_thought , context_order = self .context_order
315+ )
129316 messages .append (HumanMessage (content = question ))
317+
130318 debug_messages (* messages )
131319
132- response = self .llm (messages )
320+ response = self ._ensemble_generate (messages )
133321 debug_messages (response )
322+
323+ if not isinstance (response .content , str ):
324+ raise ValueError (f"Response is not a string: { response .content } " )
134325 code = parse_code (response .content )
135326 return {** context , "answer" : code }
136327
@@ -142,22 +333,24 @@ def fix_code(
142333 prev_context : GenerateContext | FixContext ,
143334 ) -> FixContext | None :
144335 fewshots = cached_fix_fewshots ()
336+ fewshots = self ._select_examples (prev_context ["request" ] or "N/A" , fewshots )
145337 messages : list [BaseMessage ] = [
146338 SystemMessage (content = FIX_INSTRUCTION ),
147339 ]
148340 context = self ._fix_context_from_any_context (
149341 prev_context , error = error , output = output , hint = hint
150342 )
151343 for shot in fewshots + [context ]:
152- interactions = render_fix_context (shot )
344+ interactions = render_fix_context (shot , context_order = self . context_order )
153345 for index , interaction in enumerate (interactions ):
154346 if index % 2 == 0 :
155347 messages .append (HumanMessage (content = interaction ))
156348 else :
157349 messages .append (AIMessage (content = interaction ))
158- debug_messages (* messages [- 2 :])
159350
160- response = self .llm (messages )
351+ debug_messages (* messages )
352+
353+ response = self ._ensemble_generate (messages )
161354 debug_messages (response )
162355 explanation , observation , code = parse_fix (response .content )
163356 if "THE CODE IS CORRECT." in observation :
@@ -186,7 +379,7 @@ def suggest(self, codes: list[str]) -> list[str]:
186379 HumanMessage (content = human_message ),
187380 ]
188381 debug_messages (* messages )
189- response = self .llm (messages )
382+ response = self ._generate (messages )
190383 suggestions = re .split (r"\d+\.\s+" , response .content )
191384 suggestions = [s .strip ().replace ("\n " , " " ) for s in suggestions if s .strip ()]
192385 debug_messages (response )
@@ -199,7 +392,7 @@ def explain(self, code: str) -> str:
199392 HumanMessage (content = code ),
200393 ]
201394 debug_messages (* messages )
202- response = self .llm (messages )
395+ response = self ._generate (messages )
203396 debug_messages (response )
204397 return response .content
205398
@@ -212,7 +405,7 @@ def static_check(
212405 HumanMessage (content = render_check_context (code , context )),
213406 ]
214407 debug_messages (* messages )
215- response = self .llm (messages )
408+ response = self ._generate (messages )
216409 debug_messages (response )
217410 reason , last_line = response .content .rstrip ().rsplit ("\n " , 1 )
218411 if "INCORRECT" in last_line .upper ():
@@ -236,7 +429,7 @@ def output_sanity_check(
236429 ),
237430 ]
238431 debug_messages (* messages )
239- response = self .llm (messages )
432+ response = self ._generate (messages )
240433 debug_messages (response )
241434 reason , last_line = response .content .rstrip ().rsplit ("\n " , 1 )
242435 if "INCORRECT" in last_line .upper ():
0 commit comments