# File: llm_observability.py
import sqlite3
import json
from datetime import datetime
from typing import Dict, Any, List, Optional, Callable
import logging
import functools

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def log_execution(func: Callable) -> Callable:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        logger.info(f"Executing {func.__name__}")
        try:
            result = func(*args, **kwargs)
            logger.info(f"{func.__name__} completed successfully")
            return result
        except Exception as e:
            logger.error(f"Error in {func.__name__}: {e}")
            raise
    return wrapper


class LLMObservabilityManager:
    def __init__(self, db_path: str = "/data/llm_observability_v2.db"):
        self.db_path = db_path
        self.create_table()

    def create_table(self):
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS llm_observations (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    conversation_id TEXT,
                    created_at DATETIME,
                    status TEXT,
                    request TEXT,
                    response TEXT,
                    model TEXT,
                    prompt_tokens INTEGER,
                    completion_tokens INTEGER,
                    total_tokens INTEGER,
                    cost FLOAT,
                    latency FLOAT,
                    user TEXT
                )
            ''')

    def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
        created_at = datetime.now()
        
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                INSERT INTO llm_observations 
                (conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            ''', (
                conversation_id,
                created_at,
                status,
                request,
                response,
                model,
                prompt_tokens,
                completion_tokens,
                total_tokens,
                cost,
                latency,
                user
            ))

    def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            if conversation_id:
                cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
            else:
                cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
            rows = cursor.fetchall()

            column_names = [description[0] for description in cursor.description]
            return [dict(zip(column_names, row)) for row in rows]

    def get_all_observations(self) -> List[Dict[str, Any]]:
        return self.get_observations()
    
    def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            # Get the latest observation for each unique conversation_id
            query = '''
                SELECT * FROM llm_observations o1
                WHERE created_at = (
                    SELECT MAX(created_at) 
                    FROM llm_observations o2 
                    WHERE o2.conversation_id = o1.conversation_id
                )
                ORDER BY created_at DESC
            '''
            if limit is not None:
                query += f' LIMIT {limit}'
                
            cursor.execute(query)
            rows = cursor.fetchall()
            
            column_names = [description[0] for description in cursor.description]
            return [dict(zip(column_names, row)) for row in rows]

    ## OBSERVABILITY
from uuid import uuid4
import csv
from io import StringIO
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from starlette.responses import StreamingResponse



router = APIRouter( 
    prefix="/observability",
    tags=["observability"]
)

class ObservationResponse(BaseModel):
    observations: List[Dict]
    
def create_csv_response(observations: List[Dict]) -> StreamingResponse:
    def iter_csv(data):
        output = StringIO()
        writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else [])
        writer.writeheader()
        for row in data:
            writer.writerow(row)
        output.seek(0)
        yield output.read()

    headers = {
        'Content-Disposition': 'attachment; filename="observations.csv"'
    }
    return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers)
    

@router.get("/last-observations/{limit}")
async def get_last_observations(limit: int = 10, format: str = "json"):
    observability_manager = LLMObservabilityManager()
    
    try:
        # Get all observations, sorted by created_at in descending order
        all_observations = observability_manager.get_observations()
        all_observations.sort(key=lambda x: x['created_at'], reverse=True)
        
        # Get the last conversation_id
        if all_observations:
            last_conversation_id = all_observations[0]['conversation_id']
            
            # Filter observations for the last conversation
            last_conversation_observations = [
                obs for obs in all_observations
                if obs['conversation_id'] == last_conversation_id
            ][:limit]
            
            if format.lower() == "csv":
                return create_csv_response(last_conversation_observations)
            else:
                return ObservationResponse(observations=last_conversation_observations)
        else:
            if format.lower() == "csv":
                return create_csv_response([])
            else:
                return ObservationResponse(observations=[])
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}")