Chris4K commited on
Commit
457a598
·
verified ·
1 Parent(s): 9891585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +579 -258
app.py CHANGED
@@ -1,24 +1,7 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from typing import List, Tuple, Optional, Dict, Any, Union
4
- from dataclasses import dataclass
5
- from enum import Enum
6
- import logging
7
-
8
- from typing import List, Tuple, Optional, Dict, Any, Union, AsyncGenerator
9
- from dataclasses import dataclass
10
  from enum import Enum
11
- import logging
12
- import torch
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
- from llama_cpp import Llama
15
-
16
- from huggingface_hub import hf_hub_download
17
-
18
- prm_model_path = hf_hub_download(
19
- repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF",
20
- filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf"
21
- )
22
 
23
  class GenerationStrategy(str, Enum):
24
  DEFAULT = "default"
@@ -26,189 +9,236 @@ class GenerationStrategy(str, Enum):
26
  BEST_OF_N = "best_of_n"
27
  BEAM_SEARCH = "beam_search"
28
  DVTS = "dvts"
 
 
 
 
 
 
 
 
29
 
30
  @dataclass
31
  class GenerationConfig:
32
  num_samples: int = 5
33
  depth: int = 3
34
  breadth: int = 2
35
- max_history_turns: int = 3
36
  max_new_tokens: int = 50
37
  temperature: float = 0.7
38
  top_p: float = 0.9
 
 
 
 
39
  strategy: GenerationStrategy = GenerationStrategy.DEFAULT
40
 
41
- class LlamaGenerator:
42
- def __init__(
43
- self,
44
- llama_model_name: str,
45
- prm_model_path: str,
46
- device: str = None,
47
- default_generation_config: Optional[GenerationConfig] = None
48
- ):
49
- """Initialize the LlamaGenerator with specified models."""
 
50
  self.logger = logging.getLogger(__name__)
51
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
52
- self.default_config = default_generation_config or GenerationConfig()
53
-
54
- self.logger.info(f"Initializing LlamaGenerator on device: {self.device}")
55
-
 
56
  try:
57
- self._initialize_models(llama_model_name, prm_model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
- self.logger.error(f"Failed to initialize models: {str(e)}")
60
  raise
61
 
62
- def _initialize_models(self, llama_model_name: str, prm_model_path: str):
63
- """Initialize models with error handling and logging."""
64
- # Initialize LLaMA model and tokenizer
65
- self.llama_tokenizer = AutoTokenizer.from_pretrained(
66
- llama_model_name,
67
- padding_side='left',
68
- trust_remote_code=True
69
- )
70
- if self.llama_tokenizer.pad_token is None:
71
- self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
72
-
73
- self.llama_model = AutoModelForCausalLM.from_pretrained(
74
- llama_model_name,
75
- device_map="auto",
76
- trust_remote_code=True
77
- )
78
-
79
- # Initialize PRM model
80
- self.prm_model = self._load_quantized_model(prm_model_path)
81
-
82
- # Enable token streaming
83
- self.supports_streaming = hasattr(self.llama_model, "streamer")
84
-
85
- async def generate_stream(
86
- self,
87
- prompt: str,
88
- config: Optional[GenerationConfig] = None
89
- ) -> AsyncGenerator[str, None]:
90
- """Stream tokens as they're generated."""
91
- if not self.supports_streaming:
92
- raise NotImplementedError("This model doesn't support streaming")
93
-
94
- config = config or self.default_config
95
- input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
96
-
97
- async for token in self.llama_model.streamer(input_ids, **self._get_generation_kwargs(config)):
98
- yield self.llama_tokenizer.decode([token])
99
 
100
- def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
101
- """Get generation kwargs based on config."""
102
- return {
103
- "max_new_tokens": config.max_new_tokens,
104
- "temperature": config.temperature,
105
- "top_p": config.top_p,
106
- "do_sample": config.temperature > 0,
107
- }
108
-
109
- def _load_quantized_model(self, model_path: str) -> Llama:
110
- """Load a quantized GGUF model using llama-cpp-python.
111
-
112
- Args:
113
- model_path (str): Path to the GGUF model file
114
-
115
- Returns:
116
- Llama: Loaded model instance
117
- """
118
  try:
119
- # Configure GPU layers if CUDA is available
120
  n_gpu_layers = -1 if torch.cuda.is_available() else 0
121
-
122
- # Load the model
123
  model = Llama(
124
  model_path=model_path,
125
- n_ctx=2048, # Context window
126
- n_batch=512, # Batch size for prompt processing
127
- n_gpu_layers=n_gpu_layers, # Number of layers to offload to GPU
128
- verbose=False
129
  )
130
-
131
- self.logger.info(f"Successfully loaded GGUF model from {model_path}")
132
  return model
133
-
134
  except Exception as e:
135
  self.logger.error(f"Failed to load GGUF model: {str(e)}")
136
  raise
137
 
138
- def _score_with_prm(self, text: str) -> float:
139
- """Score text using the PRM model.
140
-
141
- Args:
142
- text (str): Text to score
143
-
144
- Returns:
145
- float: Model score
146
- """
147
- try:
148
- # For GGUF models, we need to use the proper scoring interface
149
- result = self.prm_model.eval(text)
150
- return result['logprobs'] # Or another appropriate scoring metric
151
-
152
- except Exception as e:
153
- self.logger.error(f"Error scoring text with PRM: {str(e)}")
154
- return float('-inf') # Return very low score on error
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- def _construct_prompt(
158
- self,
159
- context: str,
160
- user_input: str,
161
- chat_history: List[Tuple[str, str]],
162
- max_history_turns: int = 3
163
- ) -> str:
164
- """Construct a formatted prompt from the input components."""
165
- system_message = f"Please assist based on the following context: {context}"
166
- prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
167
 
168
- for user_msg, assistant_msg in chat_history[-max_history_turns:]:
169
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
170
- prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
171
 
172
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
173
- prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
174
- return prompt
175
 
176
- def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  self,
178
  prompt: str,
179
- model_kwargs: Dict[str, Any],
180
- strategy: str = "default",
181
- num_samples: int = 5,
182
- depth: int = 3,
183
- breadth: int = 2
184
- ) -> str:
185
- """Generate a response using the specified strategy.
186
 
187
- Args:
188
- prompt (str): The input prompt
189
- model_kwargs (dict): Additional arguments for model.generate()
190
- strategy (str): Generation strategy ('default', 'majority_voting', 'best_of_n', 'beam_search', 'dvts')
191
- num_samples (int): Number of samples for applicable strategies
192
- depth (int): Depth for DVTS strategy
193
- breadth (int): Breadth for DVTS strategy
194
-
195
- Returns:
196
- str: Generated response
197
- """
198
- if strategy == "default":
199
- input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
200
- output = self.llama_model.generate(input_ids, **model_kwargs)
201
- return self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
202
-
203
- elif strategy == "majority_voting":
204
- outputs = []
205
- for _ in range(num_samples):
206
- input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
207
- output = self.llama_model.generate(input_ids, **model_kwargs)
208
- outputs.append(self.llama_tokenizer.decode(output[0], skip_special_tokens=True))
209
- return max(set(outputs), key=outputs.count)
210
-
211
- elif strategy == "best_of_n":
 
 
 
 
 
 
 
 
 
 
 
 
212
  scored_outputs = []
213
  for _ in range(num_samples):
214
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
@@ -217,8 +247,9 @@ class LlamaGenerator:
217
  score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
218
  scored_outputs.append((response, score))
219
  return max(scored_outputs, key=lambda x: x[1])[0]
220
-
221
- elif strategy == "beam_search":
 
222
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
223
  outputs = self.llama_model.generate(
224
  input_ids,
@@ -227,8 +258,9 @@ class LlamaGenerator:
227
  **model_kwargs
228
  )
229
  return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
230
-
231
- elif strategy == "dvts":
 
232
  results = []
233
  for _ in range(breadth):
234
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
@@ -247,97 +279,322 @@ class LlamaGenerator:
247
  results.append((extended_response, score))
248
  return max(results, key=lambda x: x[1])[0]
249
 
250
- else:
251
- raise ValueError(f"Unknown strategy: {strategy}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- def generate_with_context(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  self,
255
- context: str,
256
- user_input: str,
257
- chat_history: List[Tuple[str, str]],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  model_kwargs: Dict[str, Any],
259
- max_history_turns: int = 3,
260
  strategy: str = "default",
261
- num_samples: int = 5,
262
- depth: int = 3,
263
- breadth: int = 2
264
  ) -> str:
265
- """Generate a response using context and chat history.
266
-
267
- Args:
268
- context (str): Context for the conversation
269
- user_input (str): Current user input
270
- chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
271
- model_kwargs (dict): Additional arguments for model.generate()
272
- max_history_turns (int): Maximum number of history turns to include
273
- strategy (str): Generation strategy
274
- num_samples (int): Number of samples for applicable strategies
275
- depth (int): Depth for DVTS strategy
276
- breadth (int): Breadth for DVTS strategy
277
 
278
- Returns:
279
- str: Generated response
280
- """
281
- prompt = self._construct_prompt(
282
- context,
283
- user_input,
284
- chat_history,
285
- max_history_turns
286
- )
287
- return self.generate(
288
  prompt,
289
  model_kwargs,
290
- strategy,
291
- num_samples,
292
- depth,
293
- breadth
294
  )
295
-
296
- ######################
297
- #########
298
- #################
299
- from fastapi import FastAPI, HTTPException, BackgroundTasks
 
 
 
 
 
300
  from fastapi.middleware.cors import CORSMiddleware
301
- from pydantic import BaseModel, Field
302
- from typing import List, Optional, Dict
 
303
  import asyncio
304
  import uuid
305
  from datetime import datetime
306
  import json
 
 
 
 
307
 
308
  class ChatMessage(BaseModel):
309
- role: str = Field(..., description="Role of the message sender (user/assistant)")
 
 
 
 
 
310
  content: str = Field(..., description="Content of the message")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  class GenerationRequest(BaseModel):
313
- context: Optional[str] = Field(None, description="Context for the conversation")
314
- messages: List[ChatMessage] = Field(..., description="Chat history")
315
- config: Optional[Dict] = Field(None, description="Generation configuration")
316
- stream: bool = Field(False, description="Whether to stream the response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  class GenerationResponse(BaseModel):
319
- id: str = Field(..., description="Generation ID")
320
- content: str = Field(..., description="Generated content")
321
- created_at: datetime = Field(default_factory=datetime.now)
322
-
323
- app = FastAPI(title="LLaMA Generation Service")
324
-
325
- # Add CORS middleware
326
- app.add_middleware(
327
- CORSMiddleware,
328
- allow_origins=["*"],
329
- allow_credentials=True,
330
- allow_methods=["*"],
331
- allow_headers=["*"],
332
- )
 
 
 
333
 
334
- # Store generator instance
335
- generator = None
336
-
337
- @app.on_event("startup")
338
- async def startup_event():
339
  global generator
340
  try:
 
341
  generator = LlamaGenerator(
342
  llama_model_name="meta-llama/Llama-3.2-1B-Instruct",
343
  prm_model_path=prm_model_path,
@@ -346,30 +603,87 @@ async def startup_event():
346
  temperature=0.7
347
  )
348
  )
349
- except Exception as e:
350
- print(f"Failed to initialize generator: {str(e)}")
351
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
- @app.post("/generate", response_model=GenerationResponse)
354
- async def generate(request: GenerationRequest):
355
  if not generator:
356
- raise HTTPException(status_code=503, detail="Generator not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
 
 
 
 
 
 
 
358
  try:
359
- # Format chat history
360
  chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
361
  user_input = request.messages[-1].content
362
 
363
- # Create generation config
364
- config = GenerationConfig(**request.config) if request.config else None
365
 
366
- # Generate response
367
  response = await asyncio.to_thread(
368
  generator.generate_with_context,
369
  context=request.context or "",
370
  user_input=user_input,
371
  chat_history=chat_history,
372
- model_kwargs={}, # Add any model-specific kwargs here
373
  config=config
374
  )
375
 
@@ -381,23 +695,30 @@ async def generate(request: GenerationRequest):
381
  raise HTTPException(status_code=500, detail=str(e))
382
 
383
  @app.websocket("/generate/stream")
384
- async def generate_stream(websocket):
 
 
 
 
 
 
 
 
 
 
 
385
  await websocket.accept()
386
 
387
  try:
388
  while True:
389
- # Receive and parse request
390
  request_data = await websocket.receive_text()
391
  request = GenerationRequest.parse_raw(request_data)
392
 
393
- # Format chat history
394
  chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
395
  user_input = request.messages[-1].content
396
 
397
- # Create generation config
398
- config = GenerationConfig(**request.config) if request.config else None
399
 
400
- # Stream response
401
  async for token in generator.generate_stream(
402
  prompt=generator._construct_prompt(
403
  context=request.context or "",
@@ -410,8 +731,7 @@ async def generate_stream(websocket):
410
  "token": token,
411
  "finished": False
412
  }))
413
-
414
- # Send finished message
415
  await websocket.send_text(json.dumps({
416
  "token": "",
417
  "finished": True
@@ -426,4 +746,5 @@ async def generate_stream(websocket):
426
 
427
  if __name__ == "__main__":
428
  import uvicorn
429
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
+ # config.py
2
+ from dataclasses import dataclass, field
 
 
 
 
 
 
 
3
  from enum import Enum
4
+ from typing import Dict, Any, Optional
 
 
 
 
 
 
 
 
 
 
5
 
6
  class GenerationStrategy(str, Enum):
7
  DEFAULT = "default"
 
9
  BEST_OF_N = "best_of_n"
10
  BEAM_SEARCH = "beam_search"
11
  DVTS = "dvts"
12
+ COT = "chain_of_thought"
13
+ REACT = "react"
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ model_kwargs: Dict[str, Any] = field(default_factory=dict)
18
+ tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
19
+ quantization_kwargs: Dict[str, Any] = field(default_factory=dict)
20
 
21
  @dataclass
22
  class GenerationConfig:
23
  num_samples: int = 5
24
  depth: int = 3
25
  breadth: int = 2
26
+ max_history_turns: int = 1
27
  max_new_tokens: int = 50
28
  temperature: float = 0.7
29
  top_p: float = 0.9
30
+ top_k: int = 50
31
+ repetition_penalty: float = 1.1
32
+ length_penalty: float = 1.0
33
+ do_sample: bool = True
34
  strategy: GenerationStrategy = GenerationStrategy.DEFAULT
35
 
36
+ # model_manager.py
37
+ import torch
38
+ from transformers import AutoModelForCausalLM, AutoTokenizer
39
+ from llama_cpp import Llama
40
+ from typing import Optional, Dict
41
+ import logging
42
+ from functools import lru_cache
43
+
44
+ class ModelManager:
45
+ def __init__(self, device: Optional[str] = None):
46
  self.logger = logging.getLogger(__name__)
47
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
48
+ self.models: Dict[str, Any] = {}
49
+ self.tokenizers: Dict[str, Any] = {}
50
+
51
+ def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None:
52
+ """Load a model with specified configuration."""
53
  try:
54
+ ##could be differnt models, so we can use a factory pattern to load the correct model - textgen, llama, gguf, text2video, text2image etc.
55
+ if model_type == "llama":
56
+ self.tokenizers[model_id] = AutoTokenizer.from_pretrained(
57
+ model_path,
58
+ padding_side='left',
59
+ trust_remote_code=True,
60
+ **config.tokenizer_kwargs
61
+ )
62
+ if self.tokenizers[model_id].pad_token is None:
63
+ self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token
64
+
65
+ self.models[model_id] = AutoModelForCausalLM.from_pretrained(
66
+ model_path,
67
+ device_map="auto",
68
+ trust_remote_code=True,
69
+ **config.model_kwargs
70
+ )
71
+ elif model_type == "gguf":
72
+ #TODO load the model first from the cache, if not found load the model and save it in the cache
73
+ #from huggingface_hub import hf_hub_download
74
+ #prm_model_path = hf_hub_download(
75
+ # repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF",
76
+ # filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf"
77
+ #)
78
+
79
+
80
+ self.models[model_id] = self._load_quantized_model(
81
+ model_path,
82
+ **config.quantization_kwargs
83
+ )
84
  except Exception as e:
85
+ self.logger.error(f"Failed to load model {model_id}: {str(e)}")
86
  raise
87
 
88
+ def unload_model(self, model_id: str) -> None:
89
+ """Unload a model and free resources."""
90
+ if model_id in self.models:
91
+ del self.models[model_id]
92
+ if model_id in self.tokenizers:
93
+ del self.tokenizers[model_id]
94
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ def _load_quantized_model(self, model_path: str, **kwargs) -> Llama:
97
+ """Load a quantized GGUF model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  try:
 
99
  n_gpu_layers = -1 if torch.cuda.is_available() else 0
 
 
100
  model = Llama(
101
  model_path=model_path,
102
+ n_ctx=kwargs.get('n_ctx', 2048),
103
+ n_batch=kwargs.get('n_batch', 512),
104
+ n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers),
105
+ verbose=kwargs.get('verbose', False)
106
  )
 
 
107
  return model
 
108
  except Exception as e:
109
  self.logger.error(f"Failed to load GGUF model: {str(e)}")
110
  raise
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # cache.py
114
+ from functools import lru_cache
115
+ from typing import Tuple, Any
116
+
117
+ # TODO explain howto use the cache
118
+ class ResponseCache:
119
+ def __init__(self, cache_size: int = 1000):
120
+ self.cache_size = cache_size
121
+ self._initialize_cache()
122
+
123
+ def _initialize_cache(self):
124
+ @lru_cache(maxsize=self.cache_size)
125
+ def cached_response(prompt: str, config_hash: str) -> Tuple[str, float]:
126
+ pass
127
+ self.get_cached_response = cached_response
128
+
129
+ def cache_response(self, prompt: str, config: GenerationConfig, response: str, score: float) -> None:
130
+ config_hash = hash(str(config.__dict__))
131
+ self.get_cached_response(prompt, str(config_hash))
132
 
133
+ def get_response(self, prompt: str, config: GenerationConfig) -> Optional[Tuple[str, float]]:
134
+ config_hash = hash(str(config.__dict__))
135
+ return self.get_cached_response(prompt, str(config_hash))
 
 
 
 
 
 
 
136
 
 
 
 
137
 
138
+ # batch_processor.py
139
+ from typing import List, Dict
140
+ import asyncio
141
 
142
+ #TODO explain how to use the batch processor
143
+ class BatchProcessor:
144
+ def __init__(self, max_batch_size: int = 32, max_wait_time: float = 0.1):
145
+ self.max_batch_size = max_batch_size
146
+ self.max_wait_time = max_wait_time
147
+ self.pending_requests: List[Dict] = []
148
+ self.lock = asyncio.Lock()
149
+
150
+ async def add_request(self, request: Dict) -> Any:
151
+ async with self.lock:
152
+ self.pending_requests.append(request)
153
+ if len(self.pending_requests) >= self.max_batch_size:
154
+ return await self._process_batch()
155
+ else:
156
+ await asyncio.sleep(self.max_wait_time)
157
+ if self.pending_requests:
158
+ return await self._process_batch()
159
+
160
+ async def _process_batch(self) -> List[Any]:
161
+ batch = self.pending_requests[:self.max_batch_size]
162
+ self.pending_requests = self.pending_requests[self.max_batch_size:]
163
+ # TODO implement the batch processing logic
164
+ return batch
165
+
166
+
167
+
168
+ # base_generator.py
169
+ from abc import ABC, abstractmethod
170
+ from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
171
+ from dataclasses import dataclass
172
+ from logging import getLogger
173
+
174
+ from .config import GenerationConfig, ModelConfig
175
+
176
+ class BaseGenerator(ABC):
177
+ """Base class for all generator implementations."""
178
+
179
+ def __init__(
180
+ self,
181
+ model_name: str,
182
+ device: Optional[str] = None,
183
+ default_generation_config: Optional[GenerationConfig] = None,
184
+ model_config: Optional[ModelConfig] = None,
185
+ cache_size: int = 1000,
186
+ max_batch_size: int = 32
187
+ ):
188
+ self.logger = getLogger(__name__)
189
+ self.model_manager = ModelManager(device)
190
+ self.cache = ResponseCache(cache_size)
191
+ self.batch_processor = BatchProcessor(max_batch_size)
192
+ self.health_check = HealthCheck()
193
+
194
+ self.default_config = default_generation_config or GenerationConfig()
195
+ self.model_config = model_config or ModelConfig()
196
+
197
+ @abstractmethod
198
+ async def generate_stream(
199
  self,
200
  prompt: str,
201
+ config: Optional[GenerationConfig] = None
202
+ ) -> AsyncGenerator[str, None]:
203
+ pass
 
 
 
 
204
 
205
+ @abstractmethod
206
+ def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
207
+ pass
208
+
209
+ @abstractmethod
210
+ def generate(self, prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
211
+ pass
212
+
213
+ # strategy.py
214
+ #TODO UPDATE Paths
215
+ from abc import ABC, abstractmethod
216
+ from typing import List, Tuple
217
+
218
+ class GenerationStrategy(ABC):
219
+ """Base class for generation strategies."""
220
+
221
+ @abstractmethod
222
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
223
+ pass
224
+
225
+ class DefaultStrategy(GenerationStrategy):
226
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
227
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
228
+ output = generator.model.generate(input_ids, **model_kwargs)
229
+ return generator.tokenizer.decode(output[0], skip_special_tokens=True)
230
+
231
+ class MajorityVotingStrategy(GenerationStrategy):
232
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
233
+ outputs = []
234
+ for _ in range(num_samples):
235
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
236
+ output = generator.model.generate(input_ids, **model_kwargs)
237
+ outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
238
+ return max(set(outputs), key=outputs.count)
239
+
240
+ class BestOfN(GenerationStrategy):
241
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
242
  scored_outputs = []
243
  for _ in range(num_samples):
244
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
 
247
  score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
248
  scored_outputs.append((response, score))
249
  return max(scored_outputs, key=lambda x: x[1])[0]
250
+
251
+ class BeamSearch(GenerationStrategy):
252
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
253
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
254
  outputs = self.llama_model.generate(
255
  input_ids,
 
258
  **model_kwargs
259
  )
260
  return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
261
+
262
+ class DVT(GenerationStrategy):
263
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
264
  results = []
265
  for _ in range(breadth):
266
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
 
279
  results.append((extended_response, score))
280
  return max(results, key=lambda x: x[1])[0]
281
 
282
+ class COT(GenerationStrategy):
283
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
284
+ #TODO implement the chain of thought strategy
285
+
286
+ return "Not implemented yet"
287
+
288
+ class ReAct(GenerationStrategy):
289
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
290
+ #TODO implement the ReAct framework
291
+ return "Not implemented yet"
292
+ # Add other strategy implementations...
293
+
294
+ # prompt_builder.py
295
+ from typing import Protocol, List, Tuple
296
+ from transformers import AutoTokenizer
297
+
298
+ class PromptTemplate(Protocol):
299
+ """Protocol for prompt templates."""
300
+ def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
301
+ pass
302
+
303
+ class LlamaPromptTemplate:
304
+ def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
305
+ system_message = f"Please assist based on the following context: {context}"
306
+ prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
307
+
308
+ for user_msg, assistant_msg in chat_history[-max_history_turns:]:
309
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
310
+ prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
311
+
312
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
313
+ prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
314
+ return prompt
315
 
316
+ class TransformersPromptTemplate:
317
+ def __init__(self, model_path: str):
318
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
319
+
320
+ def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
321
+ messages = [
322
+ {
323
+ "role": "system",
324
+ "content": f"Please assist based on the following context: {context}",
325
+ }
326
+ ]
327
+
328
+ for user_msg, assistant_msg in chat_history:
329
+ messages.extend([
330
+ {"role": "user", "content": user_msg},
331
+ {"role": "assistant", "content": assistant_msg}
332
+ ])
333
+
334
+ messages.append({"role": "user", "content": user_input})
335
+
336
+ tokenized_chat = self.tokenizer.apply_chat_template(
337
+ messages,
338
+ tokenize=False,
339
+ add_generation_prompt=True
340
+ )
341
+ return tokenized_chat
342
+
343
+ # health_check.py
344
+ import psutil
345
+ from dataclasses import dataclass
346
+ from typing import Dict, Any
347
+
348
+ @dataclass
349
+ class HealthStatus:
350
+ status: str
351
+ gpu_memory: Dict[str, float]
352
+ cpu_usage: float
353
+ ram_usage: float
354
+ model_status: Dict[str, str]
355
+
356
+ class HealthCheck:
357
+ @staticmethod
358
+ def check_gpu_memory() -> Dict[str, float]:
359
+ if torch.cuda.is_available():
360
+ return {
361
+ f"gpu_{i}": torch.cuda.memory_allocated(i) / 1024**3
362
+ for i in range(torch.cuda.device_count())
363
+ }
364
+ return {}
365
+
366
+ @staticmethod
367
+ def check_system_resources() -> HealthStatus:
368
+ return HealthStatus(
369
+ status="healthy",
370
+ gpu_memory=HealthCheck.check_gpu_memory(),
371
+ cpu_usage=psutil.cpu_percent(),
372
+ ram_usage=psutil.virtual_memory().percent,
373
+ #TODO add more system resources like disk, network, etc.
374
+ model_status={} # To be filled by the model manager
375
+ )
376
+
377
+
378
+ # llama_generator.py
379
+ class LlamaGenerator(BaseGenerator):
380
+ def __init__(
381
  self,
382
+ llama_model_name: str,
383
+ prm_model_path: str,
384
+ device: Optional[str] = None,
385
+ default_generation_config: Optional[GenerationConfig] = None,
386
+ model_config: Optional[ModelConfig] = None,
387
+ cache_size: int = 1000,
388
+ max_batch_size: int = 32
389
+ ):
390
+ super().__init__(
391
+ llama_model_name,
392
+ device,
393
+ default_generation_config,
394
+ model_config,
395
+ cache_size,
396
+ max_batch_size
397
+ )
398
+
399
+ # Initialize models
400
+ self.model_manager.load_model(
401
+ "llama",
402
+ llama_model_name,
403
+ "llama",
404
+ self.model_config
405
+ )
406
+ self.model_manager.load_model(
407
+ "prm",
408
+ prm_model_path,
409
+ "gguf",
410
+ self.model_config
411
+ )
412
+
413
+ self.prompt_builder = LlamaPromptTemplate()
414
+ self._init_strategies()
415
+
416
+ def _init_strategies(self):
417
+ self.strategies = {
418
+ "default": DefaultStrategy(),
419
+ "majority_voting": MajorityVotingStrategy(),
420
+ "best_of_n": BestOfN(),
421
+ "beam_search": BeamSearch(),
422
+ "dvts": DVT(),
423
+ }
424
+
425
+ def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
426
+ """Get generation kwargs based on config."""
427
+ return {
428
+ key: getattr(config, key)
429
+ for key in [
430
+ "max_new_tokens",
431
+ "temperature",
432
+ "top_p",
433
+ "top_k",
434
+ "repetition_penalty",
435
+ "length_penalty",
436
+ "do_sample"
437
+ ]
438
+ if hasattr(config, key)
439
+ }
440
+
441
+ def generate(
442
+ self,
443
+ prompt: str,
444
  model_kwargs: Dict[str, Any],
 
445
  strategy: str = "default",
446
+ **kwargs
 
 
447
  ) -> str:
448
+ if strategy not in self.strategies:
449
+ raise ValueError(f"Unknown strategy: {strategy}")
 
 
 
 
 
 
 
 
 
 
450
 
451
+ return self.strategies[strategy].generate(
452
+ self,
 
 
 
 
 
 
 
 
453
  prompt,
454
  model_kwargs,
455
+ **kwargs
 
 
 
456
  )
457
+
458
+ def check_health(self) -> HealthStatus:
459
+ """Check the health status of the generator."""
460
+ return self.health_check.check_system_resources() # TODO add model status
461
+
462
+
463
+ ###################
464
+ #################
465
+
466
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends
467
  from fastapi.middleware.cors import CORSMiddleware
468
+ from fastapi.responses import StreamingResponse
469
+ from pydantic import BaseModel, Field, ConfigDict
470
+ from typing import List, Optional, Dict, Any, AsyncGenerator
471
  import asyncio
472
  import uuid
473
  from datetime import datetime
474
  import json
475
+ from huggingface_hub import hf_hub_download
476
+ from contextlib import asynccontextmanager
477
+
478
+
479
 
480
  class ChatMessage(BaseModel):
481
+ """A single message in the chat history."""
482
+ role: str = Field(
483
+ ...,
484
+ description="Role of the message sender",
485
+ examples=["user", "assistant"]
486
+ )
487
  content: str = Field(..., description="Content of the message")
488
+
489
+ model_config = ConfigDict(
490
+ json_schema_extra={
491
+ "example": {
492
+ "role": "user",
493
+ "content": "What is the capital of France?"
494
+ }
495
+ }
496
+ )
497
+
498
+ class GenerationConfig(BaseModel):
499
+ """Configuration for text generation."""
500
+ temperature: float = Field(
501
+ 0.7,
502
+ ge=0.0,
503
+ le=2.0,
504
+ description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic."
505
+ )
506
+ max_new_tokens: int = Field(
507
+ 100,
508
+ ge=1,
509
+ le=2048,
510
+ description="Maximum number of tokens to generate"
511
+ )
512
+ top_p: float = Field(
513
+ 0.9,
514
+ ge=0.0,
515
+ le=1.0,
516
+ description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered."
517
+ )
518
+ top_k: int = Field(
519
+ 50,
520
+ ge=0,
521
+ description="Only consider the top k tokens for text generation"
522
+ )
523
+ strategy: str = Field(
524
+ "default",
525
+ description="Generation strategy to use",
526
+ examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"]
527
+ )
528
+ num_samples: int = Field(
529
+ 5,
530
+ ge=1,
531
+ le=10,
532
+ description="Number of samples to generate (used in majority_voting and best_of_n strategies)"
533
+ )
534
 
535
  class GenerationRequest(BaseModel):
536
+ """Request model for text generation."""
537
+ context: Optional[str] = Field(
538
+ None,
539
+ description="Additional context to guide the generation",
540
+ examples=["You are a helpful assistant skilled in Python programming"]
541
+ )
542
+ messages: List[ChatMessage] = Field(
543
+ ...,
544
+ description="Chat history including the current message",
545
+ min_items=1
546
+ )
547
+ config: Optional[GenerationConfig] = Field(
548
+ None,
549
+ description="Generation configuration parameters"
550
+ )
551
+ stream: bool = Field(
552
+ False,
553
+ description="Whether to stream the response token by token"
554
+ )
555
+
556
+ model_config = ConfigDict(
557
+ json_schema_extra={
558
+ "example": {
559
+ "context": "You are a helpful assistant",
560
+ "messages": [
561
+ {"role": "user", "content": "What is the capital of France?"}
562
+ ],
563
+ "config": {
564
+ "temperature": 0.7,
565
+ "max_new_tokens": 100
566
+ },
567
+ "stream": False
568
+ }
569
+ }
570
+ )
571
 
572
  class GenerationResponse(BaseModel):
573
+ """Response model for text generation."""
574
+ id: str = Field(..., description="Unique generation ID")
575
+ content: str = Field(..., description="Generated text content")
576
+ created_at: datetime = Field(
577
+ default_factory=datetime.now,
578
+ description="Timestamp of generation"
579
+ )
580
+
581
+
582
+ # Model and cache management
583
+ async def get_prm_model_path():
584
+ """Download and cache the PRM model."""
585
+ return await asyncio.to_thread(
586
+ hf_hub_download,
587
+ repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF",
588
+ filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf"
589
+ )
590
 
591
+ @asynccontextmanager
592
+ async def lifespan(app: FastAPI):
593
+ """Lifecycle management for the FastAPI application."""
594
+ # Startup: Initialize generator
 
595
  global generator
596
  try:
597
+ prm_model_path = await get_prm_model_path()
598
  generator = LlamaGenerator(
599
  llama_model_name="meta-llama/Llama-3.2-1B-Instruct",
600
  prm_model_path=prm_model_path,
 
603
  temperature=0.7
604
  )
605
  )
606
+ yield
607
+ finally:
608
+ # Shutdown: Clean up resources
609
+ if generator:
610
+ await asyncio.to_thread(generator.cleanup)
611
+
612
+ # FastAPI application
613
+ app = FastAPI(
614
+ title="Inference Deluxe Service",
615
+ description="""
616
+ A service for generating text using LLaMA models with various generation strategies.
617
+
618
+ Generation Strategies:
619
+ - default: Standard autoregressive generation
620
+ - majority_voting: Generates multiple responses and selects the most common one
621
+ - best_of_n: Generates multiple responses and selects the best based on a scoring metric
622
+ - beam_search: Uses beam search for more coherent text generation
623
+ - dvts: Dynamic vocabulary tree search for efficient generation
624
+ """,
625
+ version="1.0.0",
626
+ lifespan=lifespan
627
+ )
628
+
629
+ # CORS middleware
630
+ app.add_middleware(
631
+ CORSMiddleware,
632
+ allow_origins=["*"],
633
+ allow_credentials=True,
634
+ allow_methods=["*"],
635
+ allow_headers=["*"],
636
+ )
637
 
638
+ async def get_generator():
639
+ """Dependency to get the generator instance."""
640
  if not generator:
641
+ raise HTTPException(
642
+ status_code=503,
643
+ detail="Generator not initialized"
644
+ )
645
+ return generator
646
+
647
+ @app.post(
648
+ "/generate",
649
+ response_model=GenerationResponse,
650
+ tags=["generation"],
651
+ summary="Generate text response",
652
+ response_description="Generated text with unique identifier"
653
+ )
654
+ async def generate(
655
+ request: GenerationRequest,
656
+ generator: Any = Depends(get_generator)
657
+ ):
658
+ """
659
+ Generate a text response based on the provided context and chat history.
660
+
661
+ The generation process can be customized using various parameters in the config:
662
+ - temperature: Controls randomness (0.0 to 2.0)
663
+ - max_new_tokens: Maximum length of generated text
664
+ - top_p: Nucleus sampling parameter
665
+ - top_k: Top-k sampling parameter
666
+ - strategy: Generation strategy to use
667
+ - num_samples: Number of samples for applicable strategies
668
 
669
+ Generation Strategies:
670
+ - default: Standard generation
671
+ - majority_voting: Generates multiple responses and uses the most common one
672
+ - best_of_n: Generates multiple responses and picks the best
673
+ - beam_search: Uses beam search for coherent generation
674
+ - dvts: Dynamic vocabulary tree search
675
+ """
676
  try:
 
677
  chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
678
  user_input = request.messages[-1].content
679
 
680
+ config = request.config or GenerationConfig()
 
681
 
 
682
  response = await asyncio.to_thread(
683
  generator.generate_with_context,
684
  context=request.context or "",
685
  user_input=user_input,
686
  chat_history=chat_history,
 
687
  config=config
688
  )
689
 
 
695
  raise HTTPException(status_code=500, detail=str(e))
696
 
697
  @app.websocket("/generate/stream")
698
+ async def generate_stream(
699
+ websocket: WebSocket,
700
+ generator: Any = Depends(get_generator)
701
+ ):
702
+ """
703
+ Stream generated text tokens over a WebSocket connection.
704
+
705
+ The stream sends JSON messages with the following format:
706
+ - During generation: {"token": "generated_token", "finished": false}
707
+ - End of generation: {"token": "", "finished": true}
708
+ - Error: {"error": "error_message"}
709
+ """
710
  await websocket.accept()
711
 
712
  try:
713
  while True:
 
714
  request_data = await websocket.receive_text()
715
  request = GenerationRequest.parse_raw(request_data)
716
 
 
717
  chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
718
  user_input = request.messages[-1].content
719
 
720
+ config = request.config or GenerationConfig()
 
721
 
 
722
  async for token in generator.generate_stream(
723
  prompt=generator._construct_prompt(
724
  context=request.context or "",
 
731
  "token": token,
732
  "finished": False
733
  }))
734
+
 
735
  await websocket.send_text(json.dumps({
736
  "token": "",
737
  "finished": True
 
746
 
747
  if __name__ == "__main__":
748
  import uvicorn
749
+ uvicorn.run(app, host="0.0.0.0", port=8000)
750
+