Skip to content

Commit e5417f6

Browse files
committed
Implements #32
1 parent 7ea50ae commit e5417f6

File tree

2 files changed

+141
-7
lines changed

2 files changed

+141
-7
lines changed

llms_wrapper/cost_logger.py

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,39 @@
77
from typing import Any, Optional
88
from pathlib import Path
99
import json
10+
import socket
11+
import getpass
12+
import datetime
13+
14+
15+
def get_username() -> str:
16+
"""
17+
Get the current username in a portable way.
18+
19+
Returns:
20+
Username as a string
21+
"""
22+
try:
23+
# getpass.getuser() works on Windows, Linux, and macOS
24+
# It checks environment variables in order: LOGNAME, USER, LNAME, USERNAME
25+
return getpass.getuser()
26+
except Exception:
27+
# Fallback to environment variables
28+
return os.environ.get('USER') or os.environ.get('USERNAME') or 'unknown'
29+
30+
31+
def get_hostname() -> str:
32+
"""
33+
Get the current hostname in a portable way.
34+
35+
Returns:
36+
Hostname as a string
37+
"""
38+
try:
39+
# socket.gethostname() works on Windows, Linux, and macOS
40+
return socket.gethostname()
41+
except Exception:
42+
return 'unknown'
1043

1144

1245
class Log2Sqlite:
@@ -39,6 +72,13 @@ def __init__(self, db_path: str, **defaults):
3972
Exception: If database initialization fails
4073
"""
4174
self.db_path = Path(db_path)
75+
76+
if defaults.get("user") is None:
77+
defaults["user"] = get_username()
78+
if defaults.get("hostname") is None:
79+
defaults["hostname"] = get_hostname()
80+
81+
4282
self.defaults = defaults
4383

4484
# Validate that defaults only contain known fields
@@ -120,6 +160,10 @@ def log(self, row: dict):
120160
if invalid_fields:
121161
raise Exception(f"Invalid fields in row: {invalid_fields}")
122162

163+
# the datetime is always set fixed here!
164+
row["datetime"] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
165+
166+
123167
# Merge defaults with provided row (row takes precedence)
124168
merged_row = {**self.defaults, **row}
125169

@@ -245,9 +289,89 @@ def import_file(self, file: str):
245289
raise Exception(f"Failed to import from {file}: {e}") from e
246290
raise
247291

292+
def rows(self, model=None, modelalias=None, hostname=None, user=None,
293+
project=None, task=None, note=None, apikey_name=None,
294+
date_from=None, date_to=None) -> list[dict]:
295+
"""
296+
Get all rows matching specified criteria.
297+
298+
Args:
299+
model: Filter by model name
300+
modelalias: Filter by model alias
301+
hostname: Filter by hostname
302+
user: Filter by user
303+
project: Filter by project
304+
task: Filter by task
305+
note: Filter by note
306+
apikey_name: Filter by API key name
307+
date_from: Filter by datetime >= this value (inclusive)
308+
date_to: Filter by datetime <= this value (inclusive)
309+
310+
Returns:
311+
List of dictionaries, each containing a matching row's data
312+
(excluding the auto-increment id field)
313+
314+
Raises:
315+
Exception: If query fails
316+
"""
317+
conditions = []
318+
values = []
319+
320+
# Add field filters
321+
field_filters = {
322+
'model': model,
323+
'modelalias': modelalias,
324+
'hostname': hostname,
325+
'user': user,
326+
'project': project,
327+
'task': task,
328+
'note': note,
329+
'apikey_name': apikey_name
330+
}
331+
332+
for field, value in field_filters.items():
333+
if value is not None:
334+
conditions.append(f'{field} = ?')
335+
values.append(value)
336+
337+
# Add date range filters
338+
if date_from is not None:
339+
conditions.append('datetime >= ?')
340+
values.append(date_from)
341+
342+
if date_to is not None:
343+
conditions.append('datetime <= ?')
344+
values.append(date_to)
345+
346+
# Build WHERE clause
347+
where = ''
348+
if conditions:
349+
where = 'WHERE ' + ' AND '.join(conditions)
350+
351+
try:
352+
conn = sqlite3.connect(self.db_path, timeout=5.0)
353+
try:
354+
conn.row_factory = sqlite3.Row # Access columns by name
355+
cursor = conn.execute(
356+
f'SELECT * FROM logs {where} ORDER BY datetime DESC, id DESC',
357+
values
358+
)
359+
360+
# Convert rows to list of dicts, excluding id field
361+
result = []
362+
for row in cursor:
363+
row_dict = {key: row[key] for key in row.keys() if key != 'id'}
364+
result.append(row_dict)
365+
366+
return result
367+
finally:
368+
conn.close()
369+
except Exception as e:
370+
raise Exception(f"Failed to get rows: {e}") from e
371+
248372
def get(self, model=None, modelalias=None, hostname=None, user=None,
249373
project=None, task=None, note=None, apikey_name=None,
250-
date_from=None, date_to=None) -> tuple[float, int, int]:
374+
date_from=None, date_to=None) -> tuple[float, int, int, int]:
251375
"""
252376
Get aggregated cost and token sums for rows matching specified criteria.
253377
@@ -264,7 +388,7 @@ def get(self, model=None, modelalias=None, hostname=None, user=None,
264388
date_to: Filter by datetime <= this value (inclusive)
265389
266390
Returns:
267-
Tuple of (total_cost, total_input_tokens, total_output_tokens)
391+
Tuple of (total_cost, total_input_tokens, total_output_tokens, row_count)
268392
269393
Raises:
270394
Exception: If query fails
@@ -310,12 +434,13 @@ def get(self, model=None, modelalias=None, hostname=None, user=None,
310434
f'''SELECT
311435
COALESCE(SUM(cost), 0.0),
312436
COALESCE(SUM(input_tokens), 0),
313-
COALESCE(SUM(output_tokens), 0)
437+
COALESCE(SUM(output_tokens), 0),
438+
COUNT(*)
314439
FROM logs {where}''',
315440
values
316441
)
317442
result = cursor.fetchone()
318-
return (float(result[0]), int(result[1]), int(result[2]))
443+
return (float(result[0]), int(result[1]), int(result[2]), int(result[3]))
319444
finally:
320445
conn.close()
321446
except Exception as e:

llms_wrapper/llms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,8 @@ def query(
687687
tools: a list of tool dictionaries, each dictionary describing a tool.
688688
See https://docs.litellm.ai/docs/completion/function_call for the format.
689689
However, this can be created using the `make_tooling` function.
690-
return_cost: whether or not LLM invocation costs should get returned
690+
return_cost: whether or not LLM invocation costs should get returned. Gets automatically enabled if
691+
cost logging is enabled.
691692
return_response: whether or not the complete reponse should get returned
692693
debug: if True, emits debug messages to aid development and debugging
693694
litellm_debug: if True, litellm debug logging is enabled, if False, disabled, if None, use debug setting
@@ -712,6 +713,8 @@ def cleaned_args(args: dict):
712713
return args
713714
if self.debug:
714715
debug = True
716+
if self.cost_logger:
717+
return_cost = True
715718
if litellm_debug is None and debug or litellm_debug:
716719
# litellm.set_verbose = True ## deprecated!
717720
os.environ['LITELLM_LOG'] = 'DEBUG'
@@ -863,7 +866,10 @@ def cost_callback(kwargs, completion_response, start_time, end_time):
863866
ret["cost"] = callback_data.get("cost")
864867
ret["n_prompt_tokens"] = callback_data.get("prompt_tokens")
865868
ret["n_completion_tokens"] = callback_data.get("completion_tokens")
866-
self.cost_logger.log(dict(cost=ret["cost"], input_tokens=ret["input_tokens"], output_tokens=ret["output_tokens"]))
869+
self.cost_logger.log(
870+
dict(
871+
model=llm["llm"], modelalias=llm["alias"],
872+
cost=ret["cost"], input_tokens=ret["input_tokens"], output_tokens=ret["output_tokens"]))
867873
return ret
868874
except Exception as e:
869875
tb = traceback.extract_tb(e.__traceback__)
@@ -937,7 +943,10 @@ def chunk_generator(model_generator, retobj):
937943
ret["n_completion_tokens"] = usage.completion_tokens
938944
ret["n_prompt_tokens"] = usage.prompt_tokens
939945
ret["n_total_tokens"] = usage.total_tokens
940-
self.cost_logger.log(dict(cost=ret["cost"], input_tokens=ret["n_prompt_tokens"], output_tokens=ret["n_completion_tokens"]))
946+
self.cost_logger.log(
947+
dict(
948+
model=llm["llm"], modelalias=llm["alias"],
949+
cost=ret["cost"], input_tokens=ret["n_prompt_tokens"], output_tokens=ret["n_completion_tokens"]))
941950
# add the cost and tokens from the recursive call info, if available
942951
if recursive_call_info.get("cost") is not None:
943952
ret["cost"] += recursive_call_info["cost"]

0 commit comments

Comments
 (0)