Spaces:
Sleeping
Sleeping
File size: 6,743 Bytes
1a6d961 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# File: llm_observability.py
import sqlite3
import json
from datetime import datetime
from typing import Dict, Any, List, Optional, Callable
import logging
import functools
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def log_execution(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger.info(f"Executing {func.__name__}")
try:
result = func(*args, **kwargs)
logger.info(f"{func.__name__} completed successfully")
return result
except Exception as e:
logger.error(f"Error in {func.__name__}: {e}")
raise
return wrapper
class LLMObservabilityManager:
def __init__(self, db_path: str = "llm_observability_v2.db"):
self.db_path = db_path
self.create_table()
def create_table(self):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS llm_observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT,
created_at DATETIME,
status TEXT,
request TEXT,
response TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
total_tokens INTEGER,
cost FLOAT,
latency FLOAT,
user TEXT
)
''')
def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
created_at = datetime.now()
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO llm_observations
(conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
conversation_id,
created_at,
status,
request,
response,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
latency,
user
))
def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
if conversation_id:
cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
else:
cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
return [dict(zip(column_names, row)) for row in rows]
def get_all_observations(self) -> List[Dict[str, Any]]:
return self.get_observations()
def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get the latest observation for each unique conversation_id
query = '''
SELECT * FROM llm_observations o1
WHERE created_at = (
SELECT MAX(created_at)
FROM llm_observations o2
WHERE o2.conversation_id = o1.conversation_id
)
ORDER BY created_at DESC
'''
if limit is not None:
query += f' LIMIT {limit}'
cursor.execute(query)
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
return [dict(zip(column_names, row)) for row in rows]
## OBSERVABILITY
from uuid import uuid4
import csv
from io import StringIO
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from starlette.responses import StreamingResponse
router = APIRouter(
prefix="/observability",
tags=["observability"]
)
class ObservationResponse(BaseModel):
observations: List[Dict]
def create_csv_response(observations: List[Dict]) -> StreamingResponse:
def iter_csv(data):
output = StringIO()
writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else [])
writer.writeheader()
for row in data:
writer.writerow(row)
output.seek(0)
yield output.read()
headers = {
'Content-Disposition': 'attachment; filename="observations.csv"'
}
return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers)
@router.get("/last-observations/{limit}")
async def get_last_observations(limit: int = 10, format: str = "json"):
observability_manager = LLMObservabilityManager()
try:
# Get all observations, sorted by created_at in descending order
all_observations = observability_manager.get_observations()
all_observations.sort(key=lambda x: x['created_at'], reverse=True)
# Get the last conversation_id
if all_observations:
last_conversation_id = all_observations[0]['conversation_id']
# Filter observations for the last conversation
last_conversation_observations = [
obs for obs in all_observations
if obs['conversation_id'] == last_conversation_id
][:limit]
if format.lower() == "csv":
return create_csv_response(last_conversation_observations)
else:
return ObservationResponse(observations=last_conversation_observations)
else:
if format.lower() == "csv":
return create_csv_response([])
else:
return ObservationResponse(observations=[])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}") |