Spaces:
Building
Building
""" | |
OpenAI GPT Implementation | |
""" | |
import os | |
import openai | |
from typing import Dict, List, Any | |
from llm_interface import LLMInterface | |
from logger import log_info, log_error, log_warning, log_debug, LogTimer | |
DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60")) | |
MAX_RESPONSE_LENGTH = 4096 # Max response length | |
class OpenAILLM(LLMInterface): | |
"""OpenAI GPT integration with improved error handling""" | |
def __init__(self, api_key: str, model: str = "gpt-4", settings: Dict[str, Any] = None): | |
super().__init__(settings) | |
self.api_key = api_key | |
self.model = model | |
self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT) | |
openai.api_key = api_key | |
log_info(f"π OpenAI LLM initialized", model=self.model, timeout=self.timeout) | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
"""Generate response with consistent error handling""" | |
# Build messages | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add context | |
for msg in context[-10:]: # Last 10 messages | |
role = "assistant" if msg.get("role") == "assistant" else "user" | |
messages.append({"role": role, "content": msg.get("content", "")}) | |
# Add current input | |
messages.append({"role": "user", "content": user_input}) | |
try: | |
with LogTimer(f"OpenAI {self.model} request"): | |
# Use async client | |
client = openai.AsyncOpenAI( | |
api_key=self.api_key, | |
timeout=self.timeout | |
) | |
response = await client.chat.completions.create( | |
model=self.model, | |
messages=messages, | |
max_tokens=self.settings.get("max_tokens", 2048), | |
temperature=self.settings.get("temperature", 0.7), | |
stream=False | |
) | |
# Extract content | |
content = response.choices[0].message.content | |
# Check length | |
if len(content) > MAX_RESPONSE_LENGTH: | |
log_warning(f"Response exceeded max length, truncating", | |
original_length=len(content), | |
max_length=MAX_RESPONSE_LENGTH) | |
content = content[:MAX_RESPONSE_LENGTH] + "..." | |
# Log token usage | |
if response.usage: | |
log_info(f"Token usage", | |
prompt_tokens=response.usage.prompt_tokens, | |
completion_tokens=response.usage.completion_tokens, | |
total_tokens=response.usage.total_tokens) | |
return content | |
except openai.RateLimitError as e: | |
log_warning("OpenAI rate limit", error=str(e)) | |
raise | |
except openai.APITimeoutError as e: | |
log_error("OpenAI timeout", error=str(e)) | |
raise | |
except openai.APIError as e: | |
log_error("OpenAI API error", | |
status_code=e.status_code if hasattr(e, 'status_code') else None, | |
error=str(e)) | |
raise | |
except Exception as e: | |
log_error("OpenAI unexpected error", error=str(e)) | |
raise | |
async def startup(self, project_config: Dict) -> bool: | |
"""OpenAI doesn't need startup""" | |
log_info("OpenAI startup called (no-op)") | |
return True | |
def get_provider_name(self) -> str: | |
return f"openai-{self.model}" | |
def get_model_info(self) -> Dict[str, Any]: | |
return { | |
"provider": "openai", | |
"model": self.model, | |
"max_tokens": self.settings.get("max_tokens", 2048), | |
"temperature": self.settings.get("temperature", 0.7) | |
} |