Spaces:
Sleeping
Sleeping
# base_generator.py | |
from abc import ABC, abstractmethod | |
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple | |
from dataclasses import dataclass | |
from logging import getLogger | |
from services.model_manager import ModelManager | |
from services.cache import ResponseCache | |
from services.batch_processor import BatchProcessor | |
from services.health_check import HealthCheck | |
from config.config import GenerationConfig, ModelConfig | |
class BaseGenerator(ABC): | |
"""Base class for all generator implementations.""" | |
def __init__( | |
self, | |
model_name: str, | |
device: Optional[str] = None, | |
default_generation_config: Optional[GenerationConfig] = None, | |
model_config: Optional[ModelConfig] = None, | |
cache_size: int = 1000, | |
max_batch_size: int = 32 | |
): | |
self.logger = getLogger(__name__) | |
self.model_manager = ModelManager(device) | |
self.cache = ResponseCache(cache_size) | |
self.batch_processor = BatchProcessor(max_batch_size) | |
self.health_check = HealthCheck() | |
# self.tokenizer = self.model_manager.tokenizers[model_name] | |
#self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer | |
self.default_config = default_generation_config or GenerationConfig() | |
self.model_config = model_config or ModelConfig() | |
async def generate_stream( | |
self, | |
prompt: str, | |
config: Optional[GenerationConfig] = None | |
) -> AsyncGenerator[str, None]: | |
pass | |
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: | |
pass | |
def generate( | |
self, | |
prompt: str, | |
model_kwargs: Dict[str, Any], | |
strategy: str = "default", | |
**kwargs | |
) -> str: | |
pass |