pvanand commited on
Commit
84d79ad
·
verified ·
1 Parent(s): 6ac25d0

Update observability.py

Browse files
Files changed (1) hide show
  1. observability.py +175 -175
observability.py CHANGED
@@ -1,176 +1,176 @@
1
- # File: llm_observability.py
2
- import sqlite3
3
- import json
4
- from datetime import datetime
5
- from typing import Dict, Any, List, Optional, Callable
6
- import logging
7
- import functools
8
-
9
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
- logger = logging.getLogger(__name__)
11
-
12
- def log_execution(func: Callable) -> Callable:
13
- @functools.wraps(func)
14
- def wrapper(*args: Any, **kwargs: Any) -> Any:
15
- logger.info(f"Executing {func.__name__}")
16
- try:
17
- result = func(*args, **kwargs)
18
- logger.info(f"{func.__name__} completed successfully")
19
- return result
20
- except Exception as e:
21
- logger.error(f"Error in {func.__name__}: {e}")
22
- raise
23
- return wrapper
24
-
25
-
26
- class LLMObservabilityManager:
27
- def __init__(self, db_path: str = "llm_observability_v2.db"):
28
- self.db_path = db_path
29
- self.create_table()
30
-
31
- def create_table(self):
32
- with sqlite3.connect(self.db_path) as conn:
33
- cursor = conn.cursor()
34
- cursor.execute('''
35
- CREATE TABLE IF NOT EXISTS llm_observations (
36
- id INTEGER PRIMARY KEY AUTOINCREMENT,
37
- conversation_id TEXT,
38
- created_at DATETIME,
39
- status TEXT,
40
- request TEXT,
41
- response TEXT,
42
- model TEXT,
43
- prompt_tokens INTEGER,
44
- completion_tokens INTEGER,
45
- total_tokens INTEGER,
46
- cost FLOAT,
47
- latency FLOAT,
48
- user TEXT
49
- )
50
- ''')
51
-
52
- def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
53
- created_at = datetime.now()
54
-
55
- with sqlite3.connect(self.db_path) as conn:
56
- cursor = conn.cursor()
57
- cursor.execute('''
58
- INSERT INTO llm_observations
59
- (conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
60
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
61
- ''', (
62
- conversation_id,
63
- created_at,
64
- status,
65
- request,
66
- response,
67
- model,
68
- prompt_tokens,
69
- completion_tokens,
70
- total_tokens,
71
- cost,
72
- latency,
73
- user
74
- ))
75
-
76
- def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
77
- with sqlite3.connect(self.db_path) as conn:
78
- cursor = conn.cursor()
79
- if conversation_id:
80
- cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
81
- else:
82
- cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
83
- rows = cursor.fetchall()
84
-
85
- column_names = [description[0] for description in cursor.description]
86
- return [dict(zip(column_names, row)) for row in rows]
87
-
88
- def get_all_observations(self) -> List[Dict[str, Any]]:
89
- return self.get_observations()
90
-
91
- def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
92
- with sqlite3.connect(self.db_path) as conn:
93
- cursor = conn.cursor()
94
- # Get the latest observation for each unique conversation_id
95
- query = '''
96
- SELECT * FROM llm_observations o1
97
- WHERE created_at = (
98
- SELECT MAX(created_at)
99
- FROM llm_observations o2
100
- WHERE o2.conversation_id = o1.conversation_id
101
- )
102
- ORDER BY created_at DESC
103
- '''
104
- if limit is not None:
105
- query += f' LIMIT {limit}'
106
-
107
- cursor.execute(query)
108
- rows = cursor.fetchall()
109
-
110
- column_names = [description[0] for description in cursor.description]
111
- return [dict(zip(column_names, row)) for row in rows]
112
-
113
- ## OBSERVABILITY
114
- from uuid import uuid4
115
- import csv
116
- from io import StringIO
117
- from fastapi import APIRouter, HTTPException
118
- from pydantic import BaseModel
119
- from starlette.responses import StreamingResponse
120
-
121
-
122
-
123
- router = APIRouter(
124
- prefix="/observability",
125
- tags=["observability"]
126
- )
127
-
128
- class ObservationResponse(BaseModel):
129
- observations: List[Dict]
130
-
131
- def create_csv_response(observations: List[Dict]) -> StreamingResponse:
132
- def iter_csv(data):
133
- output = StringIO()
134
- writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else [])
135
- writer.writeheader()
136
- for row in data:
137
- writer.writerow(row)
138
- output.seek(0)
139
- yield output.read()
140
-
141
- headers = {
142
- 'Content-Disposition': 'attachment; filename="observations.csv"'
143
- }
144
- return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers)
145
-
146
-
147
- @router.get("/last-observations/{limit}")
148
- async def get_last_observations(limit: int = 10, format: str = "json"):
149
- observability_manager = LLMObservabilityManager()
150
-
151
- try:
152
- # Get all observations, sorted by created_at in descending order
153
- all_observations = observability_manager.get_observations()
154
- all_observations.sort(key=lambda x: x['created_at'], reverse=True)
155
-
156
- # Get the last conversation_id
157
- if all_observations:
158
- last_conversation_id = all_observations[0]['conversation_id']
159
-
160
- # Filter observations for the last conversation
161
- last_conversation_observations = [
162
- obs for obs in all_observations
163
- if obs['conversation_id'] == last_conversation_id
164
- ][:limit]
165
-
166
- if format.lower() == "csv":
167
- return create_csv_response(last_conversation_observations)
168
- else:
169
- return ObservationResponse(observations=last_conversation_observations)
170
- else:
171
- if format.lower() == "csv":
172
- return create_csv_response([])
173
- else:
174
- return ObservationResponse(observations=[])
175
- except Exception as e:
176
  raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}")
 
1
+ # File: llm_observability.py
2
+ import sqlite3
3
+ import json
4
+ from datetime import datetime
5
+ from typing import Dict, Any, List, Optional, Callable
6
+ import logging
7
+ import functools
8
+
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def log_execution(func: Callable) -> Callable:
13
+ @functools.wraps(func)
14
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
15
+ logger.info(f"Executing {func.__name__}")
16
+ try:
17
+ result = func(*args, **kwargs)
18
+ logger.info(f"{func.__name__} completed successfully")
19
+ return result
20
+ except Exception as e:
21
+ logger.error(f"Error in {func.__name__}: {e}")
22
+ raise
23
+ return wrapper
24
+
25
+
26
+ class LLMObservabilityManager:
27
+ def __init__(self, db_path: str = "/data/llm_observability_v2.db"):
28
+ self.db_path = db_path
29
+ self.create_table()
30
+
31
+ def create_table(self):
32
+ with sqlite3.connect(self.db_path) as conn:
33
+ cursor = conn.cursor()
34
+ cursor.execute('''
35
+ CREATE TABLE IF NOT EXISTS llm_observations (
36
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
37
+ conversation_id TEXT,
38
+ created_at DATETIME,
39
+ status TEXT,
40
+ request TEXT,
41
+ response TEXT,
42
+ model TEXT,
43
+ prompt_tokens INTEGER,
44
+ completion_tokens INTEGER,
45
+ total_tokens INTEGER,
46
+ cost FLOAT,
47
+ latency FLOAT,
48
+ user TEXT
49
+ )
50
+ ''')
51
+
52
+ def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
53
+ created_at = datetime.now()
54
+
55
+ with sqlite3.connect(self.db_path) as conn:
56
+ cursor = conn.cursor()
57
+ cursor.execute('''
58
+ INSERT INTO llm_observations
59
+ (conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
60
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
61
+ ''', (
62
+ conversation_id,
63
+ created_at,
64
+ status,
65
+ request,
66
+ response,
67
+ model,
68
+ prompt_tokens,
69
+ completion_tokens,
70
+ total_tokens,
71
+ cost,
72
+ latency,
73
+ user
74
+ ))
75
+
76
+ def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
77
+ with sqlite3.connect(self.db_path) as conn:
78
+ cursor = conn.cursor()
79
+ if conversation_id:
80
+ cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
81
+ else:
82
+ cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
83
+ rows = cursor.fetchall()
84
+
85
+ column_names = [description[0] for description in cursor.description]
86
+ return [dict(zip(column_names, row)) for row in rows]
87
+
88
+ def get_all_observations(self) -> List[Dict[str, Any]]:
89
+ return self.get_observations()
90
+
91
+ def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
92
+ with sqlite3.connect(self.db_path) as conn:
93
+ cursor = conn.cursor()
94
+ # Get the latest observation for each unique conversation_id
95
+ query = '''
96
+ SELECT * FROM llm_observations o1
97
+ WHERE created_at = (
98
+ SELECT MAX(created_at)
99
+ FROM llm_observations o2
100
+ WHERE o2.conversation_id = o1.conversation_id
101
+ )
102
+ ORDER BY created_at DESC
103
+ '''
104
+ if limit is not None:
105
+ query += f' LIMIT {limit}'
106
+
107
+ cursor.execute(query)
108
+ rows = cursor.fetchall()
109
+
110
+ column_names = [description[0] for description in cursor.description]
111
+ return [dict(zip(column_names, row)) for row in rows]
112
+
113
+ ## OBSERVABILITY
114
+ from uuid import uuid4
115
+ import csv
116
+ from io import StringIO
117
+ from fastapi import APIRouter, HTTPException
118
+ from pydantic import BaseModel
119
+ from starlette.responses import StreamingResponse
120
+
121
+
122
+
123
+ router = APIRouter(
124
+ prefix="/observability",
125
+ tags=["observability"]
126
+ )
127
+
128
+ class ObservationResponse(BaseModel):
129
+ observations: List[Dict]
130
+
131
+ def create_csv_response(observations: List[Dict]) -> StreamingResponse:
132
+ def iter_csv(data):
133
+ output = StringIO()
134
+ writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else [])
135
+ writer.writeheader()
136
+ for row in data:
137
+ writer.writerow(row)
138
+ output.seek(0)
139
+ yield output.read()
140
+
141
+ headers = {
142
+ 'Content-Disposition': 'attachment; filename="observations.csv"'
143
+ }
144
+ return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers)
145
+
146
+
147
+ @router.get("/last-observations/{limit}")
148
+ async def get_last_observations(limit: int = 10, format: str = "json"):
149
+ observability_manager = LLMObservabilityManager()
150
+
151
+ try:
152
+ # Get all observations, sorted by created_at in descending order
153
+ all_observations = observability_manager.get_observations()
154
+ all_observations.sort(key=lambda x: x['created_at'], reverse=True)
155
+
156
+ # Get the last conversation_id
157
+ if all_observations:
158
+ last_conversation_id = all_observations[0]['conversation_id']
159
+
160
+ # Filter observations for the last conversation
161
+ last_conversation_observations = [
162
+ obs for obs in all_observations
163
+ if obs['conversation_id'] == last_conversation_id
164
+ ][:limit]
165
+
166
+ if format.lower() == "csv":
167
+ return create_csv_response(last_conversation_observations)
168
+ else:
169
+ return ObservationResponse(observations=last_conversation_observations)
170
+ else:
171
+ if format.lower() == "csv":
172
+ return create_csv_response([])
173
+ else:
174
+ return ObservationResponse(observations=[])
175
+ except Exception as e:
176
  raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}")