rag_chat_with_analytics_aws / aws_aiclient.py
pvanand's picture
Create aws_aiclient.py
e577d93 verified
raw
history blame
5.24 kB
# aws_aiclient.py
import os
import time
import json
from typing import List, Dict, Optional, Union, AsyncGenerator
import boto3
from starlette.responses import StreamingResponse
from observability import log_execution, LLMObservabilityManager
import psycopg2
import logging
from langchain_aws import ChatBedrockConverse
logger = logging.getLogger(__name__)
text_models = {
'Claude 3 Sonnet': {
'model': 'anthropic.claude-3-sonnet-20240229-v1:0',
'input_cost': 0.000003, # $3 per million tokens = $0.000003 per token
'output_cost': 0.000015 # $15 per million tokens = $0.000015 per token
},
'Claude 3 Haiku': {
'model': 'anthropic.claude-3-haiku-20240307-v1:0',
'input_cost': 0.00000025, # $0.25 per million tokens
'output_cost': 0.00000125 # $1.25 per million tokens
},
'Llama 3 8B': {
'model': 'meta.llama3-8b-instruct-v1:0',
'input_cost': 0.00000019, # $0.19 per million tokens
'output_cost': 0.00000019 # $0.19 per million tokens
},
'Llama 3 70B': {
'model': 'meta.llama3-70b-instruct-v1:0',
'input_cost': 0.00000143, # $1.43 per million tokens
'output_cost': 0.00000143 # $1.43 per million tokens
}
}
class AIClient:
def __init__(self):
self.client = ChatBedrockConverse(
model='meta.llama3-70b-instruct-v1:0', # default model
region_name="ap-south-1",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
)
self.observability_manager = LLMObservabilityManager()
self.models = text_models
async def generate_response(
self,
messages: List[Dict[str, str]],
model: str = "meta.llama3-70b-instruct-v1:0",
max_tokens: int = 32000,
conversation_id: str = "default",
user: str = "anonymous"
) -> AsyncGenerator[str, None]:
if not messages:
return
start_time = time.time()
full_response = ""
usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}
status = "success"
try:
# Update the client's model if different from current
if model != self.client.model:
self.client.model = model
# Stream the response
async for chunk in self.client.astream(messages):
if chunk.content and chunk.content[0].get("text"):
content = chunk.content[0].get("text")
yield content
full_response += content
if chunk.usage_metadata:
usage["prompt_tokens"] = chunk.usage_metadata.get("input_tokens", 0)
usage["completion_tokens"] = chunk.usage_metadata.get("output_tokens", 0)
usage["total_tokens"] = chunk.usage_metadata.get("total_tokens", 0)
except Exception as e:
status = "error"
full_response = str(e)
print(f"Error in generate_response: {e}")
finally:
latency = time.time() - start_time
# Calculate cost based on the model being used
model_name = next((name for name, info in text_models.items()
if info['model'] == model), None)
if model_name:
model_info = text_models[model_name]
cost = (usage["prompt_tokens"] * model_info["input_cost"] +
usage["completion_tokens"] * model_info["output_cost"])
else:
cost = 0 # Default if model not found
try:
self.observability_manager.insert_observation(
response=full_response,
model=model,
completion_tokens=usage["completion_tokens"],
prompt_tokens=usage["prompt_tokens"],
total_tokens=usage["total_tokens"],
cost=cost,
conversation_id=conversation_id,
status=status,
request=json.dumps([msg for msg in messages if msg.get('role') != 'system']),
latency=latency,
user=user
)
except Exception as obs_error:
print(f"Error logging observation: {obs_error}")
class DatabaseManager:
"""Manages database operations."""
def __init__(self):
self.db_params = {
"dbname": "postgres",
"user": os.environ['SUPABASE_USER'],
"password": os.environ['SUPABASE_PASSWORD'],
"host": "aws-0-us-west-1.pooler.supabase.com",
"port": "5432"
}
@log_execution
def update_database(self, user_id: str, user_query: str, response: str) -> None:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
insert_query = """
INSERT INTO ai_document_generator (user_id, user_query, response)
VALUES (%s, %s, %s);
"""
cur.execute(insert_query, (user_id, user_query, response))