Skip to content

Commit ce0dc7d

Browse files
committed
This should fix #31 partly
1 parent 77fc654 commit ce0dc7d

File tree

3 files changed

+67
-22
lines changed

3 files changed

+67
-22
lines changed

llms_wrapper/llms.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
22
Module related to using LLMs.
33
"""
4-
import os
4+
import os, sys
55
import warnings
66
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
77
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")
88
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
99
import litellm
1010
import json
1111
import time
12+
import threading
1213
import traceback
1314
import inspect
1415
import docstring_parser
@@ -18,7 +19,7 @@
1819
from typing import Optional, Dict, List, Union, Tuple, Callable, get_args, get_origin
1920
from copy import deepcopy
2021

21-
from litellm import completion, completion_cost, token_counter
22+
from litellm import completion, completion_cost, token_counter, cost_per_token
2223
from litellm.utils import get_model_info, get_supported_openai_params, supports_response_schema
2324
from litellm.utils import supports_function_calling, supports_parallel_function_calling
2425
from litellm._logging import _enable_debugging as litellm_enable_debugging
@@ -740,12 +741,14 @@ def cleaned_args(args: dict):
740741
if via_streaming:
741742
# TODO: check if model supports streaming
742743
completion_kwargs["stream"] = True
744+
completion_kwargs["stream_options"] = {"include_usage": True}
743745
logger.debug(f"completion kwargs after detecting via_streaming: {cleaned_args(completion_kwargs)}")
744746
elif stream:
745747
# TODO: check if model supports streaming
746748
# if streaming is enabled, we always return the original response
747749
return_response = True
748750
completion_kwargs["stream"] = True
751+
completion_kwargs["stream_options"] = {"include_usage": True}
749752
logger.debug(f"completion kwargs after detecting stream: {cleaned_args(completion_kwargs)}")
750753
ret = {}
751754
# before adding the kwargs, save the recursive_call_info and remove it from kwargs
@@ -773,6 +776,20 @@ def cleaned_args(args: dict):
773776
else:
774777
start = time.time()
775778
recursive_call_info["start"] = start
779+
callback_data = {}
780+
# NOTE: this does not seem to work!?!?!?
781+
callback_complete = threading.Event()
782+
def cost_callback(kwargs, completion_response, start_time, end_time):
783+
cost = kwargs.get("response_cost")
784+
if cost is not None:
785+
callback_data["cost"] = cost
786+
callback_complete.set()
787+
if hasattr(completion_response, "usage"):
788+
callback_data["prompt_tokens"] = completion_response.usage.prompt_tokens
789+
callback_data["completion_tokens"] = completion_response.usage.completion_tokens
790+
callback_data["total_tokens"] = completion_response.usage.prompt_tokens + completion_response.usage.completion_tokens
791+
litellm.success_callback = [cost_callback]
792+
logger.info(f"completion kwargs: {completion_kwargs}")
776793
response = litellm.completion(
777794
model=llm["llm"],
778795
messages=messages,
@@ -786,25 +803,40 @@ def cleaned_args(args: dict):
786803
logger.debug(f"Retrieving chunks ...")
787804
n_chunks = 0
788805
for chunk in response:
789-
choice0 = chunk["choices"][0]
806+
# TODO: this should work, but does not!!!!
807+
if hasattr(chunk, "usage"):
808+
if chunk.usage:
809+
prompt_cost, completion_cost_value = cost_per_token(
810+
model=llm["llm"],
811+
prompt_tokens=chunk.usage.prompt_tokens,
812+
completion_tokens=chunk.usage.completion_tokens
813+
)
814+
cost = prompt_cost + completion_cost_value
815+
else:
816+
pass
817+
# logger.info(f"!!!!DEBUG: chunk attributes: {dir(chunk)}")
818+
choice0 = chunk.choices[0]
790819
if choice0.finish_reason == "stop":
791820
logger.debug(f"Streaming got stop. Chunk {chunk}")
792-
break
821+
# break
793822
n_chunks += 1
794823
content = choice0["delta"].get("content", "")
795824
logger.debug(f"Got streaming content: {content}")
796-
answer += content
825+
if isinstance(content, str):
826+
answer += content
797827
if return_response:
798828
ret["response"] = response
799829
ret["answer"] = answer
800830
ret["n_chunks"] = n_chunks
801831
ret["elapsed_time"] = time.time() - start
802832
ret["ok"] = True
803833
ret["error"] = ""
804-
# TODO: for now return 0, may perhaps be possible to do better?
805-
ret["cost"] = 0
806-
ret["n_prompt_tokens"] = 0
807-
ret["n_completion_tokens"] = 0
834+
# get the cost from the callback
835+
if not callback_complete.is_set():
836+
callback_complete.wait(timeout=2.0)
837+
ret["cost"] = callback_data.get("cost")
838+
ret["n_prompt_tokens"] = callback_data.get("prompt_tokens")
839+
ret["n_completion_tokens"] = callback_data.get("completion_tokens")
808840
return ret
809841
except Exception as e:
810842
tb = traceback.extract_tb(e.__traceback__)
@@ -827,12 +859,14 @@ def chunk_generator(model_generator, retobj):
827859
except Exception as e:
828860
yield dict(error=str(e), answer="", ok=False)
829861
finally:
830-
# TODO: add cost and elapsed time information into retobj
831-
# litellm does not support cost on streaming responses
832-
# response.__hidden_params["response_cost"] is 0.0
833-
ret["cost"] = None
862+
# get the cost from the callback
863+
if not callback_complete.is_set():
864+
logger.info(f"Waiting for callback completion")
865+
callback_complete.wait(timeout=2.0)
866+
ret["cost"] = callback_data.get("cost")
867+
ret["n_prompt_tokens"] = callback_data.get("prompt_tokens")
868+
ret["n_completion_tokens"] = callback_data.get("completion_tokens")
834869
ret["elapsed_time"] = time.time() - start
835-
pass
836870
if return_response:
837871
ret["response"] = response
838872
ret["chunks"] = chunk_generator(response, ret)
@@ -851,15 +885,20 @@ def chunk_generator(model_generator, retobj):
851885
del completion_kwargs["api_key"]
852886
ret["kwargs"] = completion_kwargs
853887
if return_cost:
854-
# TODO: replace with response._hidden_params["response_cost"] ?
855-
# but what if cost not supported for the model?
856-
857888
try:
858-
ret["cost"] = completion_cost(
859-
completion_response=response,
860-
model=llm["llm"],
861-
messages=messages,
862-
)
889+
cost = response._hidden_params.get("response_cost", None)
890+
logger.info(f"DEBUG: cost from hidden parms: {cost}")
891+
# if cost is None:
892+
if True:
893+
cost = completion_cost(
894+
completion_response=response,
895+
model=llm["llm"],
896+
messages=messages,
897+
)
898+
ret["cost"] = cost
899+
logger.info(f"DEBUG: cost from completion_cost: {cost}")
900+
else:
901+
ret["cost"] = cost
863902
if debug:
864903
logger.debug(f"Cost for this call {ret['cost']}")
865904
except Exception as e:

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ dependencies = [
2626
"loguru",
2727
"docstring_parser",
2828
"python-dotenv",
29+
"arize-phoenix-otel",
30+
"openinference-instrumentation-litellm",
2931
]
3032

3133
[project.optional-dependencies]

uv.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)