Skip to content

Commit 0e94669

Browse files
authored
[WIP] Support a wider range of model families and prompt methods (#18)
1 parent 4e0cfb4 commit 0e94669

File tree

9 files changed

+672
-54
lines changed

9 files changed

+672
-54
lines changed

coml/core.py

Lines changed: 211 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
from __future__ import annotations
22

33
import copy
4+
import random
45
import re
56
import warnings
6-
from typing import Any, cast
7+
from typing import Any, Callable, Literal, TypeVar, cast
78

89
import colorama
910
from langchain.chat_models.base import BaseChatModel
11+
from langchain.embeddings.base import Embeddings
1012
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
13+
from scipy.spatial.distance import cosine as cosine_distance
1114

1215
from .prompt_utils import (
1316
CHECK_INSTRUCTION,
1417
EXPLAIN_INSTRUCTION,
1518
FIX_INSTRUCTION,
1619
GENERATE_INSTRUCTION,
20+
GENERATE_INSTRUCTION_COT,
1721
SANITY_CHECK_INSTRUCTION,
1822
SUGGEST_INSTRUCTION,
1923
FixContext,
@@ -33,6 +37,8 @@
3337

3438
_debug_mode: bool = False
3539

40+
_Type = TypeVar("_Type")
41+
3642

3743
def debug_messages(*messages: BaseMessage) -> None:
3844
if not _debug_mode:
@@ -91,8 +97,58 @@ def parse_code(response: str) -> str:
9197

9298

9399
class CoMLAgent:
94-
def __init__(self, llm: BaseChatModel):
100+
"""
101+
CoML agent that accepts data science requests and generates code.
102+
103+
Attributes:
104+
llm: The language model that generates responses.
105+
prompt_version: The version of prompt to use (can be ``v1`` or ``v2``).
106+
prompt_validation: A function that takes a list of messages and returns
107+
whether the prompt is valid, which is useful for limiting the number of
108+
tokens in the prompt.
109+
num_examples: The number of examples to show in the prompt. It can be a
110+
number between 0 and 1, interpreted as the percentage of examples to show.
111+
It can also be an integer, interpreted as the number of examples to show.
112+
message_style: Can be ``chatgpt`` in which system messages are shown, or
113+
``gemini`` in which only human and ai messages are shown.
114+
chain_of_thought: Whether to use chain of thought (COT) in the prompt.
115+
context_order: The order of the context in the prompt. Default to ``vcr``.
116+
``v`` for variable descriptions, ``c`` for codes, ``r`` for request.
117+
ensemble: Perform ``ensemble`` number of LLM calls and ensemble the results.
118+
ensemble_shuffle: Shuffle the examples in the prompt before ensemble.
119+
example_ranking: A model that ranks the examples. If provided, the examples
120+
will be ranked by the model before selecting the examples.
121+
intact_instruction: Whether to instruct LLM to keep the variables unmodified.
122+
For experimenting purposes only.
123+
"""
124+
125+
def __init__(
126+
self,
127+
llm: BaseChatModel,
128+
prompt_version: Literal["v1", "v2", "kaggle", "leetcode"] = "v2",
129+
prompt_validation: Callable[[list[BaseMessage]], bool] | None = None,
130+
num_examples: float | int = 1.0,
131+
message_style: Literal["chatgpt", "gemini"] = "chatgpt",
132+
chain_of_thought: bool = False,
133+
context_order: Literal[
134+
"vcr", "cvr", "rvc", "rcv", "vr", "rv", "cr", "rc", "r"
135+
] = "vcr",
136+
ensemble: int | None = None,
137+
ensemble_shuffle: bool = True,
138+
example_ranking: Embeddings | None = None,
139+
intact_instruction: bool = True,
140+
):
95141
self.llm = llm
142+
self.prompt_version = prompt_version
143+
self.prompt_validation = prompt_validation
144+
self.num_examples = num_examples
145+
self.message_style = message_style
146+
self.chain_of_thought = chain_of_thought
147+
self.context_order = context_order
148+
self.ensemble = ensemble
149+
self.ensemble_shuffle = ensemble_shuffle
150+
self.example_ranking = example_ranking
151+
self.intact_instruction = intact_instruction
96152

97153
def _fix_context_from_any_context(
98154
self, context: GenerateContext | FixContext, **kwargs: Any
@@ -110,27 +166,162 @@ def _fix_context_from_any_context(
110166
context["interactions"].append(InteractionIncomplete(**kwargs))
111167
return context
112168

169+
def _pre_generation(self, messages: list[BaseMessage]) -> list[BaseMessage]:
170+
if self.message_style == "gemini":
171+
# Merge the first two messages.
172+
if len(messages) > 1 and isinstance(messages[1], HumanMessage):
173+
messages1_content = cast(str, messages[1].content)
174+
if not messages1_content.startswith("### Task Start ###"):
175+
messages1_content = "### Task Start ###\n\n" + messages1_content
176+
messages[1] = HumanMessage(
177+
content=cast(str, messages[0].content) + "\n\n" + messages1_content
178+
)
179+
messages = messages[1:]
180+
else:
181+
messages[0] = HumanMessage(content=cast(str, messages[0].content))
182+
183+
if self.prompt_validation is not None and not self.prompt_validation(messages):
184+
raise ValueError("Prompt validation failed.")
185+
186+
return messages
187+
188+
def _ensemble_generate(self, messages: list[BaseMessage]) -> BaseMessage:
189+
"""Ensemble the result from multiple LLM calls."""
190+
191+
if not self.ensemble:
192+
return self._generate(messages)
193+
194+
results: list[tuple[float, BaseMessage]] = []
195+
for _ in range(self.ensemble):
196+
if self.ensemble_shuffle and len(messages) > 2:
197+
# Shuffle the examples
198+
first_message = messages[0]
199+
interactions = messages[1:]
200+
201+
start_indices = []
202+
# Can be [Task A Human, AI, Human AI, Task B Human, AI]
203+
for index, message in enumerate(interactions):
204+
if isinstance(message, HumanMessage) and cast(
205+
str, message.content
206+
).startswith("### Task Start ###"):
207+
start_indices.append(index)
208+
209+
# Can be [Human, AI, Human AI]
210+
if not start_indices:
211+
# Loosen the constraint and find all human messages
212+
start_indices = [
213+
index
214+
for index, message in enumerate(interactions)
215+
if isinstance(message, HumanMessage)
216+
]
217+
218+
groups = [
219+
interactions[index:index_next]
220+
for index, index_next in zip(start_indices, start_indices[1:])
221+
]
222+
223+
# Shuffle the groups and combine them
224+
random.shuffle(groups)
225+
messages = (
226+
[first_message]
227+
+ [message for group in groups for message in group]
228+
+ interactions[start_indices[-1] :]
229+
)
230+
231+
messages = self._pre_generation(messages)
232+
result = self.llm.generate([messages], logprobs=True)
233+
message = result.generations[0][0].message # type: ignore
234+
generation_info = result.generations[0][0].generation_info
235+
if generation_info is None or "logprobs" not in generation_info:
236+
raise ValueError("Logprobs not found in generation_info.")
237+
logprobs = [
238+
content["logprob"] for content in generation_info["logprobs"]["content"]
239+
]
240+
if not logprobs:
241+
mean_logprobs = float("-inf")
242+
else:
243+
mean_logprobs = sum(logprobs) / len(logprobs)
244+
245+
results.append((mean_logprobs, message))
246+
247+
results.sort(key=lambda x: x[0], reverse=True)
248+
249+
return results[0][1]
250+
251+
def _generate(self, messages: list[BaseMessage]) -> BaseMessage:
252+
"""Generate a response from the LLM."""
253+
messages = self._pre_generation(messages)
254+
return self.llm(messages)
255+
256+
def _select_examples(self, query: str, fewshots: list[_Type]) -> list[_Type]:
257+
"""Select examples from the fewshots."""
258+
if self.num_examples == 0:
259+
return []
260+
261+
if self.example_ranking is not None:
262+
documents = [cast(Any, shot).get("request", "N/A") for shot in fewshots]
263+
embeddings = self.example_ranking.embed_documents(documents)
264+
# Use embed_documents instead of embed_query because the latter has cache
265+
query_embedding = self.example_ranking.embed_documents([query])[0]
266+
similarities = [
267+
(cosine_distance(query_embedding, embedding), shot)
268+
for embedding, shot in zip(embeddings, fewshots)
269+
]
270+
similarities.sort(key=lambda x: x[0])
271+
fewshots = [shot for _, shot in similarities]
272+
273+
if isinstance(self.num_examples, int):
274+
return fewshots[: self.num_examples]
275+
else:
276+
num_shots = max(int(len(fewshots) * self.num_examples), 1)
277+
return fewshots[:num_shots]
278+
113279
def generate_code(
114-
self, request: str, variable_descriptions: dict[str, str], codes: list[str]
280+
self,
281+
request: str,
282+
variable_descriptions: dict[str, str],
283+
codes: list[str],
115284
) -> GenerateContext:
116-
fewshots = cached_generate_fewshots()
117-
messages: list[BaseMessage] = [
118-
SystemMessage(content=GENERATE_INSTRUCTION),
119-
]
120-
for shot in fewshots:
121-
question, answer = render_generate_context(shot)
285+
fewshots = cached_generate_fewshots(self.prompt_version)
286+
messages: list[BaseMessage] = []
287+
288+
if self.chain_of_thought:
289+
generate_instruction = GENERATE_INSTRUCTION_COT
290+
else:
291+
generate_instruction = GENERATE_INSTRUCTION
292+
if not self.intact_instruction:
293+
generate_instruction = re.sub(
294+
r"- Do not overwrite or modify.*\n", "", generate_instruction
295+
)
296+
for shot in fewshots:
297+
if "answer_wo_intact" in shot:
298+
shot["answer"] = shot.pop("answer_wo_intact")
299+
if "rationale_wo_intact" in shot:
300+
shot["rationale"] = shot.pop("rationale_wo_intact")
301+
messages.append(SystemMessage(content=generate_instruction))
302+
303+
for shot in self._select_examples(request, fewshots):
304+
question, answer = render_generate_context(
305+
shot, cot=self.chain_of_thought, context_order=self.context_order
306+
)
122307
messages.append(HumanMessage(content=question))
123308
if answer is not None:
124309
messages.append(AIMessage(content=answer))
125310
context = GenerateContextIncomplete(
126311
variables=variable_descriptions, codes=codes, request=request
127312
)
128-
question, _ = render_generate_context(context)
313+
question, _ = render_generate_context(
314+
context, cot=self.chain_of_thought, context_order=self.context_order
315+
)
129316
messages.append(HumanMessage(content=question))
317+
130318
debug_messages(*messages)
131319

132-
response = self.llm(messages)
320+
response = self._ensemble_generate(messages)
133321
debug_messages(response)
322+
323+
if not isinstance(response.content, str):
324+
raise ValueError(f"Response is not a string: {response.content}")
134325
code = parse_code(response.content)
135326
return {**context, "answer": code}
136327

@@ -142,22 +333,24 @@ def fix_code(
142333
prev_context: GenerateContext | FixContext,
143334
) -> FixContext | None:
144335
fewshots = cached_fix_fewshots()
336+
fewshots = self._select_examples(prev_context["request"] or "N/A", fewshots)
145337
messages: list[BaseMessage] = [
146338
SystemMessage(content=FIX_INSTRUCTION),
147339
]
148340
context = self._fix_context_from_any_context(
149341
prev_context, error=error, output=output, hint=hint
150342
)
151343
for shot in fewshots + [context]:
152-
interactions = render_fix_context(shot)
344+
interactions = render_fix_context(shot, context_order=self.context_order)
153345
for index, interaction in enumerate(interactions):
154346
if index % 2 == 0:
155347
messages.append(HumanMessage(content=interaction))
156348
else:
157349
messages.append(AIMessage(content=interaction))
158-
debug_messages(*messages[-2:])
159350

160-
response = self.llm(messages)
351+
debug_messages(*messages)
352+
353+
response = self._ensemble_generate(messages)
161354
debug_messages(response)
162355
explanation, observation, code = parse_fix(response.content)
163356
if "THE CODE IS CORRECT." in observation:
@@ -186,7 +379,7 @@ def suggest(self, codes: list[str]) -> list[str]:
186379
HumanMessage(content=human_message),
187380
]
188381
debug_messages(*messages)
189-
response = self.llm(messages)
382+
response = self._generate(messages)
190383
suggestions = re.split(r"\d+\.\s+", response.content)
191384
suggestions = [s.strip().replace("\n", " ") for s in suggestions if s.strip()]
192385
debug_messages(response)
@@ -199,7 +392,7 @@ def explain(self, code: str) -> str:
199392
HumanMessage(content=code),
200393
]
201394
debug_messages(*messages)
202-
response = self.llm(messages)
395+
response = self._generate(messages)
203396
debug_messages(response)
204397
return response.content
205398

@@ -212,7 +405,7 @@ def static_check(
212405
HumanMessage(content=render_check_context(code, context)),
213406
]
214407
debug_messages(*messages)
215-
response = self.llm(messages)
408+
response = self._generate(messages)
216409
debug_messages(response)
217410
reason, last_line = response.content.rstrip().rsplit("\n", 1)
218411
if "INCORRECT" in last_line.upper():
@@ -236,7 +429,7 @@ def output_sanity_check(
236429
),
237430
]
238431
debug_messages(*messages)
239-
response = self.llm(messages)
432+
response = self._generate(messages)
240433
debug_messages(response)
241434
reason, last_line = response.content.rstrip().rsplit("\n", 1)
242435
if "INCORRECT" in last_line.upper():

coml/magics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,11 @@ def display_statuses(statuses):
285285
)
286286
html += message_template.format(
287287
display_names[name],
288-
loading
289-
if name not in statuses
290-
else VERIFY_STATUS_ICON[statuses[name]["result"]],
288+
(
289+
loading
290+
if name not in statuses
291+
else VERIFY_STATUS_ICON[statuses[name]["result"]]
292+
),
291293
detail_message,
292294
)
293295

0 commit comments

Comments
 (0)