Spaces:
Building
Building
File size: 4,038 Bytes
8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
"""
OpenAI GPT Implementation
"""
import os
import openai
from typing import Dict, List, Any
from llm_interface import LLMInterface
from utils.logger import log_info, log_error, log_warning, log_debug, LogTimer
DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
MAX_RESPONSE_LENGTH = 4096 # Max response length
class OpenAILLM(LLMInterface):
"""OpenAI GPT integration with improved error handling"""
def __init__(self, api_key: str, model: str = "gpt-4", settings: Dict[str, Any] = None):
super().__init__(settings)
self.api_key = api_key
self.model = model
self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
openai.api_key = api_key
log_info(f"π OpenAI LLM initialized", model=self.model, timeout=self.timeout)
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response with consistent error handling"""
# Build messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Add context
for msg in context[-10:]: # Last 10 messages
role = "assistant" if msg.get("role") == "assistant" else "user"
messages.append({"role": role, "content": msg.get("content", "")})
# Add current input
messages.append({"role": "user", "content": user_input})
try:
with LogTimer(f"OpenAI {self.model} request"):
# Use async client
client = openai.AsyncOpenAI(
api_key=self.api_key,
timeout=self.timeout
)
response = await client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=self.settings.get("max_tokens", 2048),
temperature=self.settings.get("temperature", 0.7),
stream=False
)
# Extract content
content = response.choices[0].message.content
# Check length
if len(content) > MAX_RESPONSE_LENGTH:
log_warning(f"Response exceeded max length, truncating",
original_length=len(content),
max_length=MAX_RESPONSE_LENGTH)
content = content[:MAX_RESPONSE_LENGTH] + "..."
# Log token usage
if response.usage:
log_info(f"Token usage",
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens)
return content
except openai.RateLimitError as e:
log_warning("OpenAI rate limit", error=str(e))
raise
except openai.APITimeoutError as e:
log_error("OpenAI timeout", error=str(e))
raise
except openai.APIError as e:
log_error("OpenAI API error",
status_code=e.status_code if hasattr(e, 'status_code') else None,
error=str(e))
raise
except Exception as e:
log_error("OpenAI unexpected error", error=str(e))
raise
async def startup(self, project_config: Dict) -> bool:
"""OpenAI doesn't need startup"""
log_info("OpenAI startup called (no-op)")
return True
def get_provider_name(self) -> str:
return f"openai-{self.model}"
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": "openai",
"model": self.model,
"max_tokens": self.settings.get("max_tokens", 2048),
"temperature": self.settings.get("temperature", 0.7)
} |