Skip to content

Commit de259ea

Browse files
committed
First working version of tooling support.
1 parent 57225da commit de259ea

File tree

2 files changed

+427
-41
lines changed

2 files changed

+427
-41
lines changed

llms_wrapper/llms.py

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
77
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")
88
import litellm
9+
import json
910
import time
1011
import traceback
1112
import inspect
@@ -43,6 +44,25 @@
4344
]
4445

4546

47+
def toolnames2funcs(tools):
48+
fmap = {}
49+
for tool in tools:
50+
name = tool["function"]["name"]
51+
func = get_func_by_name(name)
52+
if func is None:
53+
raise Exception(f"Function {name} not found")
54+
fmap[name] = func
55+
return fmap
56+
57+
def get_func_by_name(name):
58+
# Walk up the call stack
59+
for frame_info in inspect.stack():
60+
frame = frame_info.frame
61+
func = frame.f_locals.get(name) or frame.f_globals.get(name)
62+
if callable(func):
63+
return func
64+
return None # Not found
65+
4666
class LLMS:
4767
"""
4868
Class that represents a preconfigured set of large language modelservices.
@@ -420,6 +440,7 @@ def query(
420440
return_response: bool = False,
421441
debug=False,
422442
litellm_debug=None,
443+
recursive_call_info: Optional[Dict[str, any]] = None,
423444
**kwargs,
424445
) -> Dict[str, any]:
425446
"""
@@ -457,19 +478,38 @@ def query(
457478
KNOWN_LLM_CONFIG_FIELDS,
458479
ignore_underscored=True,
459480
)
481+
if recursive_call_info is None:
482+
recursive_call_info = {}
460483
if llm.get("api_key"):
461484
completion_kwargs["api_key"] = llm["api_key"]
462485
elif llm.get("api_key_env"):
463486
completion_kwargs["api_key"] = os.getenv(llm["api_key_env"])
464487
if llm.get("api_url"):
465488
completion_kwargs["api_base"] = llm["api_url"]
466489
if tools is not None:
490+
# add tooling-related arguments to completion_kwargs
467491
completion_kwargs["tools"] = tools
492+
if not self.supports_function_calling(llmalias):
493+
# see https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support
494+
litellm.add_function_to_prompt = True
495+
else:
496+
if "tool_choice" not in completion_kwargs:
497+
# this is the default, but lets be explicit
498+
completion_kwargs["tool_choice"] = "auto"
499+
# Not known/supported by litellm, apparently
500+
# if "parallel_tool_choice" not in completion_kwargs:
501+
# completion_kwargs["parallel_tool_choice"] = True
502+
fmap = toolnames2funcs(tools)
503+
else:
504+
fmap = {}
468505
ret = {}
506+
# before adding the kwargs, save the recursive_call_info and remove it from kwargs
507+
if debug:
508+
print(f"DEBUG: Received recursive call info: {recursive_call_info}")
469509
if kwargs:
470510
completion_kwargs.update(dict_except(kwargs, KNOWN_LLM_CONFIG_FIELDS, ignore_underscored=True))
471511
if debug:
472-
logger.debug(f"Calling completion with kwargs {completion_kwargs}")
512+
print(f"DEBUG: Calling completion with kwargs {completion_kwargs}")
473513
# if we have min_delay set, we look at the _last_request_time for the LLM and caclulate the time
474514
# to wait until we can send the next request and then just wait
475515
min_delay = llm.get("min_delay", kwargs.get("min_delay", 0.0))
@@ -481,7 +521,13 @@ def query(
481521
if "min_delay" in completion_kwargs:
482522
raise ValueError("Error: min_delay should not be passed as a keyword argument")
483523
try:
484-
start = time.time()
524+
# if we have been called recursively and the recursive_call_info has a start time,
525+
# use that as the start time
526+
if recursive_call_info.get("start") is not None:
527+
start = recursive_call_info["start"]
528+
else:
529+
start = time.time()
530+
recursive_call_info["start"] = start
485531
response = litellm.completion(
486532
model=llm["llm"],
487533
messages=messages,
@@ -503,6 +549,8 @@ def query(
503549
model=llm["llm"],
504550
messages=messages,
505551
)
552+
if debug:
553+
print(f"DEBUG: cost for this call {ret['cost']}")
506554
except Exception as e:
507555
logger.debug(f"Error in completion_cost for model {llm['llm']}: {e}")
508556
ret["cost"] = 0.0
@@ -512,6 +560,22 @@ def query(
512560
ret["n_completion_tokens"] = usage.completion_tokens
513561
ret["n_prompt_tokens"] = usage.prompt_tokens
514562
ret["n_total_tokens"] = usage.total_tokens
563+
# add the cost and tokens from the recursive call info, if available
564+
if recursive_call_info.get("cost") is not None:
565+
ret["cost"] += recursive_call_info["cost"]
566+
if debug:
567+
print(f"DEBUG: cost for this and previous calls {ret['cost']}")
568+
if recursive_call_info.get("n_completion_tokens") is not None:
569+
ret["n_completion_tokens"] += recursive_call_info["n_completion_tokens"]
570+
if recursive_call_info.get("n_prompt_tokens") is not None:
571+
ret["n_prompt_tokens"] += recursive_call_info["n_prompt_tokens"]
572+
if recursive_call_info.get("n_total_tokens") is not None:
573+
ret["n_total_tokens"] += recursive_call_info["n_total_tokens"]
574+
recursive_call_info["cost"] = ret["cost"]
575+
recursive_call_info["n_completion_tokens"] = ret["n_completion_tokens"]
576+
recursive_call_info["n_prompt_tokens"] = ret["n_prompt_tokens"]
577+
recursive_call_info["n_total_tokens"] = ret["n_total_tokens"]
578+
515579
response_message = response['choices'][0]['message']
516580
# Does not seem to work see https://github.com/BerriAI/litellm/issues/389
517581
# ret["response_ms"] = response["response_ms"]
@@ -521,9 +585,70 @@ def query(
521585
ret["ok"] = True
522586
# TODO: if feasable handle all tool calling here or in a separate method which does
523587
# all the tool calling steps (up to a specified maximum).
588+
if debug:
589+
print(f"DEBUG: checking for tool_calls: {response_message}, have tools: {tools is not None}")
524590
if tools is not None:
525-
ret["tool_calls"] = response_message.tool_calls
526-
ret["response_message"] = response_message
591+
if hasattr(response_message, "tool_calls") and response_message.tool_calls is not None:
592+
tool_calls = response_message.tool_calls
593+
else:
594+
tool_calls = []
595+
if debug:
596+
print(f"DEBUG: got {len(tool_calls)} tool calls:")
597+
for tool_call in tool_calls:
598+
print(f"DEBUG: {tool_call}")
599+
if len(tool_calls) > 0: # not an empty list
600+
if debug:
601+
print(f"DEBUG: appending response message: {response_message}")
602+
messages.append(response_message)
603+
for tool_call in tool_calls:
604+
function_name = tool_call.function.name
605+
if debug:
606+
print(f"DEBUG: tool call {function_name}")
607+
fun2call = fmap.get(function_name)
608+
if fun2call is None:
609+
ret["error"] = f"Unknown tooling function name: {function_name}"
610+
ret["answer"] = ""
611+
ret["ok"] = False
612+
return ret
613+
function_args = json.loads(tool_call.function.arguments)
614+
try:
615+
if debug:
616+
print(f"DEBUG: calling {function_name} with args {function_args}")
617+
function_response = fun2call(**function_args)
618+
if debug:
619+
print(f"DEBUG: got response {function_response}")
620+
except Exception as e:
621+
tb = traceback.extract_tb(e.__traceback__)
622+
filename, lineno, funcname, text = tb[-1]
623+
if debug:
624+
print(f"DEBUG: function call got error {e}")
625+
ret["error"] = f"Error executing tool function {function_name}: {str(e)} in {filename}:{lineno} {funcname}"
626+
if debug:
627+
logger.error(f"Returning error: {e}")
628+
ret["answer"] = ""
629+
ret["ok"] = False
630+
return ret
631+
messages.append(
632+
dict(
633+
tool_call_id=tool_call.id,
634+
role="tool", name=function_name,
635+
content=json.dumps(function_response)))
636+
# recursively call query
637+
if debug:
638+
print(f"DEBUG: recursively calling query with messages:")
639+
for idx, msg in enumerate(messages):
640+
print(f"DEBUG: Message {idx}: {msg}")
641+
print(f"DEBUG: recursively_call_info is {recursive_call_info}")
642+
return self.query(
643+
llmalias,
644+
messages,
645+
tools=tools,
646+
return_cost=return_cost,
647+
return_response=return_response,
648+
debug=debug,
649+
litellm_debug=litellm_debug,
650+
recursive_call_info=recursive_call_info,
651+
**kwargs)
527652
except Exception as e:
528653
tb = traceback.extract_tb(e.__traceback__)
529654
filename, lineno, funcname, text = tb[-1]

0 commit comments

Comments
 (0)