File size: 1,875 Bytes
5fa6a5e 3d8d158 392ec24 4c3ed2c 54329ad 5fa6a5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# base_generator.py
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
from logging import getLogger
from services.model_manager import ModelManager
from services.cache import ResponseCache
from services.batch_processor import BatchProcessor
from services.health_check import HealthCheck
from config.config import GenerationConfig, ModelConfig
class BaseGenerator(ABC):
"""Base class for all generator implementations."""
def __init__(
self,
model_name: str,
device: Optional[str] = None,
default_generation_config: Optional[GenerationConfig] = None,
model_config: Optional[ModelConfig] = None,
cache_size: int = 1000,
max_batch_size: int = 32
):
self.logger = getLogger(__name__)
self.model_manager = ModelManager(device)
self.cache = ResponseCache(cache_size)
self.batch_processor = BatchProcessor(max_batch_size)
self.health_check = HealthCheck()
# self.tokenizer = self.model_manager.tokenizers[model_name]
#self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
self.default_config = default_generation_config or GenerationConfig()
self.model_config = model_config or ModelConfig()
@abstractmethod
async def generate_stream(
self,
prompt: str,
config: Optional[GenerationConfig] = None
) -> AsyncGenerator[str, None]:
pass
@abstractmethod
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
pass
@abstractmethod
def generate(
self,
prompt: str,
model_kwargs: Dict[str, Any],
strategy: str = "default",
**kwargs
) -> str:
pass |