File size: 6,743 Bytes
1a6d961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# File: llm_observability.py
import sqlite3
import json
from datetime import datetime
from typing import Dict, Any, List, Optional, Callable
import logging
import functools

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def log_execution(func: Callable) -> Callable:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        logger.info(f"Executing {func.__name__}")
        try:
            result = func(*args, **kwargs)
            logger.info(f"{func.__name__} completed successfully")
            return result
        except Exception as e:
            logger.error(f"Error in {func.__name__}: {e}")
            raise
    return wrapper


class LLMObservabilityManager:
    def __init__(self, db_path: str = "llm_observability_v2.db"):
        self.db_path = db_path
        self.create_table()

    def create_table(self):
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''

                CREATE TABLE IF NOT EXISTS llm_observations (

                    id INTEGER PRIMARY KEY AUTOINCREMENT,

                    conversation_id TEXT,

                    created_at DATETIME,

                    status TEXT,

                    request TEXT,

                    response TEXT,

                    model TEXT,

                    prompt_tokens INTEGER,

                    completion_tokens INTEGER,

                    total_tokens INTEGER,

                    cost FLOAT,

                    latency FLOAT,

                    user TEXT

                )

            ''')

    def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
        created_at = datetime.now()
        
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''

                INSERT INTO llm_observations 

                (conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)

                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)

            ''', (
                conversation_id,
                created_at,
                status,
                request,
                response,
                model,
                prompt_tokens,
                completion_tokens,
                total_tokens,
                cost,
                latency,
                user
            ))

    def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            if conversation_id:
                cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
            else:
                cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
            rows = cursor.fetchall()

            column_names = [description[0] for description in cursor.description]
            return [dict(zip(column_names, row)) for row in rows]

    def get_all_observations(self) -> List[Dict[str, Any]]:
        return self.get_observations()
    
    def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            # Get the latest observation for each unique conversation_id
            query = '''

                SELECT * FROM llm_observations o1

                WHERE created_at = (

                    SELECT MAX(created_at) 

                    FROM llm_observations o2 

                    WHERE o2.conversation_id = o1.conversation_id

                )

                ORDER BY created_at DESC

            '''
            if limit is not None:
                query += f' LIMIT {limit}'
                
            cursor.execute(query)
            rows = cursor.fetchall()
            
            column_names = [description[0] for description in cursor.description]
            return [dict(zip(column_names, row)) for row in rows]

    ## OBSERVABILITY
from uuid import uuid4
import csv
from io import StringIO
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from starlette.responses import StreamingResponse



router = APIRouter( 
    prefix="/observability",
    tags=["observability"]
)

class ObservationResponse(BaseModel):
    observations: List[Dict]
    
def create_csv_response(observations: List[Dict]) -> StreamingResponse:
    def iter_csv(data):
        output = StringIO()
        writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else [])
        writer.writeheader()
        for row in data:
            writer.writerow(row)
        output.seek(0)
        yield output.read()

    headers = {
        'Content-Disposition': 'attachment; filename="observations.csv"'
    }
    return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers)
    

@router.get("/last-observations/{limit}")
async def get_last_observations(limit: int = 10, format: str = "json"):
    observability_manager = LLMObservabilityManager()
    
    try:
        # Get all observations, sorted by created_at in descending order
        all_observations = observability_manager.get_observations()
        all_observations.sort(key=lambda x: x['created_at'], reverse=True)
        
        # Get the last conversation_id
        if all_observations:
            last_conversation_id = all_observations[0]['conversation_id']
            
            # Filter observations for the last conversation
            last_conversation_observations = [
                obs for obs in all_observations
                if obs['conversation_id'] == last_conversation_id
            ][:limit]
            
            if format.lower() == "csv":
                return create_csv_response(last_conversation_observations)
            else:
                return ObservationResponse(observations=last_conversation_observations)
        else:
            if format.lower() == "csv":
                return create_csv_response([])
            else:
                return ObservationResponse(observations=[])
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}")