Create base_generator.py
Browse files- services/base_generator.py +52 -0
services/base_generator.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# base_generator.py
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from logging import getLogger
|
6 |
+
|
7 |
+
|
8 |
+
from config.config import GenerationConfig, ModelConfig
|
9 |
+
|
10 |
+
class BaseGenerator(ABC):
|
11 |
+
"""Base class for all generator implementations."""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
model_name: str,
|
16 |
+
device: Optional[str] = None,
|
17 |
+
default_generation_config: Optional[GenerationConfig] = None,
|
18 |
+
model_config: Optional[ModelConfig] = None,
|
19 |
+
cache_size: int = 1000,
|
20 |
+
max_batch_size: int = 32
|
21 |
+
):
|
22 |
+
self.logger = getLogger(__name__)
|
23 |
+
self.model_manager = ModelManager(device)
|
24 |
+
self.cache = ResponseCache(cache_size)
|
25 |
+
self.batch_processor = BatchProcessor(max_batch_size)
|
26 |
+
self.health_check = HealthCheck()
|
27 |
+
# self.tokenizer = self.model_manager.tokenizers[model_name]
|
28 |
+
#self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
|
29 |
+
self.default_config = default_generation_config or GenerationConfig()
|
30 |
+
self.model_config = model_config or ModelConfig()
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
async def generate_stream(
|
34 |
+
self,
|
35 |
+
prompt: str,
|
36 |
+
config: Optional[GenerationConfig] = None
|
37 |
+
) -> AsyncGenerator[str, None]:
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def generate(
|
46 |
+
self,
|
47 |
+
prompt: str,
|
48 |
+
model_kwargs: Dict[str, Any],
|
49 |
+
strategy: str = "default",
|
50 |
+
**kwargs
|
51 |
+
) -> str:
|
52 |
+
pass
|