File size: 1,875 Bytes
5fa6a5e
 
 
 
 
3d8d158
392ec24
4c3ed2c
54329ad
5fa6a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# base_generator.py
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
from logging import getLogger 
from services.model_manager import ModelManager
from services.cache import ResponseCache
from services.batch_processor import BatchProcessor
from services.health_check import HealthCheck

from config.config import GenerationConfig, ModelConfig

class BaseGenerator(ABC):
    """Base class for all generator implementations."""
    
    def __init__(
        self,
        model_name: str,
        device: Optional[str] = None,
        default_generation_config: Optional[GenerationConfig] = None,
        model_config: Optional[ModelConfig] = None,
        cache_size: int = 1000,
        max_batch_size: int = 32
    ):
        self.logger = getLogger(__name__)
        self.model_manager = ModelManager(device)
        self.cache = ResponseCache(cache_size)
        self.batch_processor = BatchProcessor(max_batch_size)
        self.health_check = HealthCheck()
       # self.tokenizer = self.model_manager.tokenizers[model_name]
        #self.tokenizer = self.load_tokenizer(llama_model_name)  # Add this line to initialize the tokenizer
        self.default_config = default_generation_config or GenerationConfig()
        self.model_config = model_config or ModelConfig()
        
    @abstractmethod
    async def generate_stream(
        self,
        prompt: str,
        config: Optional[GenerationConfig] = None
    ) -> AsyncGenerator[str, None]:
        pass
        
    @abstractmethod
    def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
        pass
 
    @abstractmethod
    def generate(
        self,
        prompt: str,
        model_kwargs: Dict[str, Any],
        strategy: str = "default",
        **kwargs
    ) -> str:
        pass