Spaces:
Running
Running
File size: 3,470 Bytes
6ff1f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import AsyncGenerator, Dict, Any
import importlib
import os
import logging
from core.exceptions import ModelError, ProviderError
from core.key_manager import key_manager
logger = logging.getLogger(__name__)
class TextGenerator:
"""Handles text generation across different model providers."""
@staticmethod
async def generate_stream(model: str, prompt: str) -> AsyncGenerator[str, None]:
"""
Generate streaming text responses using the specified model.
Args:
model (str): The name of the model to use
prompt (str): The input prompt for text generation
Yields:
str: Generated text chunks
Raises:
ModelError: If the model is not found or invalid
ProviderError: If there's an error with the provider
"""
try:
# Get provider for the selected model
provider = key_manager.get_provider_for_model(model)
logger.info(f"Using provider {provider} for model {model}")
# Get API key
api_key = key_manager.get_next_key(provider)
logger.info(f"Retrieved API key for provider {provider}")
# Import provider module
try:
# Convert provider name to valid module name
module_name = provider.lower().split()[0] # Get first word in lowercase
logger.info(f"Importing module: models.{module_name}")
# Check if module exists
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
module_path = os.path.join(backend_dir, 'models', f'{module_name}.py')
if not os.path.exists(module_path):
raise ProviderError(f"Provider module not found at: {module_path}")
provider_module = importlib.import_module(f"models.{module_name}")
# Verify required functions exist
if not hasattr(provider_module, 'run_model_stream'):
raise ProviderError(f"Provider module {module_name} missing required function: run_model_stream")
except ImportError as e:
raise ProviderError(f"Failed to import provider module: {str(e)}")
except Exception as e:
raise ProviderError(f"Error loading provider module: {str(e)}")
# Generate text stream
async for chunk in provider_module.run_model_stream(api_key, model, prompt):
yield chunk
except Exception as e:
logger.error(f"Error in generate_stream: {str(e)}")
raise
@staticmethod
def get_available_models() -> Dict[str, Any]:
"""
Get all available models and their providers.
Returns:
Dict[str, Any]: Dictionary containing model information
"""
try:
models = key_manager.get_available_models()
# Return just the models array directly since that's what frontend expects
return models
except Exception as e:
logger.error(f"Error getting available models: {str(e)}")
raise ModelError(f"Failed to get available models: {str(e)}")
# Global instance
text_generator = TextGenerator() |