sachin commited on
Commit
a0887d0
·
1 Parent(s): 9e8036b

add qunatizat

Browse files
Files changed (1) hide show
  1. src/server/gemma_llm.py +94 -71
src/server/gemma_llm.py CHANGED
@@ -1,83 +1,101 @@
1
  import torch
2
  from logging_config import logger
3
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration
4
  from PIL import Image
5
  from fastapi import HTTPException
6
  from io import BytesIO
7
 
 
 
 
 
 
 
 
8
 
9
  class LLMManager:
10
  def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
11
  self.model_name = model_name
12
  self.device = torch.device(device)
13
- self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
14
  self.model = None
15
  self.is_loaded = False
16
  self.processor = None
 
17
 
18
  def unload(self):
19
  if self.is_loaded:
20
- # Delete the model and processor to free memory
21
  del self.model
22
  del self.processor
23
- # If using CUDA, clear the cache to free GPU memory
24
  if self.device.type == "cuda":
25
  torch.cuda.empty_cache()
 
26
  self.is_loaded = False
27
  logger.info(f"LLM {self.model_name} unloaded from {self.device}")
 
28
  def load(self):
29
  if not self.is_loaded:
30
-
31
- #self.model_name = "google/gemma-3-4b-it"
32
-
33
- self.model = Gemma3ForConditionalGeneration.from_pretrained(
34
- self.model_name, device_map="auto"
 
35
  ).eval()
36
-
37
- self.processor = AutoProcessor.from_pretrained(self.model_name)
38
-
39
- self.is_loaded = True
40
- logger.info(f"LLM {self.model_name} loaded on {self.device}")
41
-
42
- async def generate(self, prompt: str, max_tokens: int = 2048, temperature: float = 0.7) -> str:
 
43
  if not self.is_loaded:
44
  self.load()
45
-
46
  messages_vlm = [
47
  {
48
  "role": "system",
49
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state, Provide a concise response in one sentence maximum."}]
50
  },
51
  {
52
  "role": "user",
53
- "content": []
54
  }
55
  ]
56
 
57
- # Add text prompt to user content
58
- messages_vlm[1]["content"].append({"type": "text", "text": prompt})
59
-
60
- # Process the chat template with the processor
61
- inputs_vlm = self.processor.apply_chat_template(
62
- messages_vlm,
63
- add_generation_prompt=True,
64
- tokenize=True,
65
- return_dict=True,
66
- return_tensors="pt"
67
- ).to(self.model.device, dtype=torch.bfloat16)
 
 
 
68
 
69
  input_len = inputs_vlm["input_ids"].shape[-1]
70
 
71
- # Generate response
72
  with torch.inference_mode():
73
- generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
 
 
 
 
 
74
  generation = generation[0][input_len:]
75
 
76
  # Decode the output
77
  response = self.processor.decode(generation, skip_special_tokens=True)
78
-
79
  return response
80
-
81
  async def vision_query(self, image: Image.Image, query: str) -> str:
82
  if not self.is_loaded:
83
  self.load()
@@ -85,7 +103,7 @@ class LLMManager:
85
  messages_vlm = [
86
  {
87
  "role": "system",
88
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarise your answer in max 2 lines."}]
89
  },
90
  {
91
  "role": "user",
@@ -93,18 +111,17 @@ class LLMManager:
93
  }
94
  ]
95
 
96
- # Add text prompt to user content
97
  messages_vlm[1]["content"].append({"type": "text", "text": query})
98
 
99
- # Handle image if provided and valid
100
- if image and image.size[0] > 0 and image.size[1] > 0: # Check for valid dimensions
101
- # Image is already a PIL Image, no need to read or reopen
102
  messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
103
  logger.info(f"Received valid image for processing")
104
  else:
105
  logger.info("No valid image provided, processing text only")
106
 
107
- # Process the chat template with the processor
108
  try:
109
  inputs_vlm = self.processor.apply_chat_template(
110
  messages_vlm,
@@ -112,7 +129,8 @@ class LLMManager:
112
  tokenize=True,
113
  return_dict=True,
114
  return_tensors="pt"
115
- ).to(self.model.device, dtype=torch.bfloat16)
 
116
  except Exception as e:
117
  logger.error(f"Error in apply_chat_template: {str(e)}")
118
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
@@ -121,23 +139,27 @@ class LLMManager:
121
 
122
  # Generate response
123
  with torch.inference_mode():
124
- generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
 
 
 
 
 
125
  generation = generation[0][input_len:]
126
 
127
  # Decode the output
128
  decoded = self.processor.decode(generation, skip_special_tokens=True)
129
- logger.info(f"Chat Response: {decoded}")
130
-
131
  return decoded
132
-
133
  async def chat_v2(self, image: Image.Image, query: str) -> str:
134
  if not self.is_loaded:
135
  self.load()
136
- # Construct the message structure
137
  messages_vlm = [
138
  {
139
  "role": "system",
140
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state"}]
141
  },
142
  {
143
  "role": "user",
@@ -145,42 +167,43 @@ class LLMManager:
145
  }
146
  ]
147
 
148
- # Add text prompt to user content
149
  messages_vlm[1]["content"].append({"type": "text", "text": query})
150
 
151
- # Handle image only if provided and valid
152
- if image and image.file and image.size > 0: # Check for valid file with content
153
- # Read the image file
154
- image_data = await image.read()
155
- if not image_data:
156
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
157
- # Open image with PIL for processing
158
- img = Image.open(BytesIO(image_data))
159
- # Add image to content (assuming processor accepts PIL images)
160
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": img})
161
- logger.info(f"Received image: {image.filename}")
162
  else:
163
- if image and (not image.file or image.size == 0):
164
- logger.warning("Received invalid or empty image parameter, treating as text-only")
165
  logger.info("No valid image provided, processing text only")
166
 
167
- # Process the chat template with the processor
168
- inputs_vlm = self.processor.apply_chat_template(
169
- messages_vlm,
170
- add_generation_prompt=True,
171
- tokenize=True,
172
- return_dict=True,
173
- return_tensors="pt"
174
- ).to(self.model.device, dtype=torch.bfloat16)
 
 
 
 
 
175
 
176
  input_len = inputs_vlm["input_ids"].shape[-1]
177
 
178
  # Generate response
179
  with torch.inference_mode():
180
- generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
 
 
 
 
 
181
  generation = generation[0][input_len:]
182
 
183
  # Decode the output
184
  decoded = self.processor.decode(generation, skip_special_tokens=True)
185
- logger.info(f"Chat Response: {decoded}")
186
  return decoded
 
1
  import torch
2
  from logging_config import logger
3
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
4
  from PIL import Image
5
  from fastapi import HTTPException
6
  from io import BytesIO
7
 
8
+ # Define 4-bit quantization config for better precision and performance
9
+ quantization_config = BitsAndBytesConfig(
10
+ load_in_4bit=True,
11
+ bnb_4bit_quant_type="nf4", # Normalized float 4-bit
12
+ bnb_4bit_use_double_quant=True, # Double quantization for better accuracy
13
+ bnb_4bit_compute_dtype=torch.bfloat16 # Consistent compute dtype
14
+ )
15
 
16
  class LLMManager:
17
  def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
18
  self.model_name = model_name
19
  self.device = torch.device(device)
20
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32 # Align dtype with quantization
21
  self.model = None
22
  self.is_loaded = False
23
  self.processor = None
24
+ logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
25
 
26
  def unload(self):
27
  if self.is_loaded:
 
28
  del self.model
29
  del self.processor
 
30
  if self.device.type == "cuda":
31
  torch.cuda.empty_cache()
32
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
33
  self.is_loaded = False
34
  logger.info(f"LLM {self.model_name} unloaded from {self.device}")
35
+
36
  def load(self):
37
  if not self.is_loaded:
38
+ try:
39
+ self.model = Gemma3ForConditionalGeneration.from_pretrained(
40
+ self.model_name,
41
+ device_map="auto",
42
+ quantization_config=quantization_config,
43
+ torch_dtype=self.torch_dtype
44
  ).eval()
45
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
46
+ self.is_loaded = True
47
+ logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization")
48
+ except Exception as e:
49
+ logger.error(f"Failed to load model: {str(e)}")
50
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
51
+
52
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
53
  if not self.is_loaded:
54
  self.load()
55
+
56
  messages_vlm = [
57
  {
58
  "role": "system",
59
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
60
  },
61
  {
62
  "role": "user",
63
+ "content": [{"type": "text", "text": prompt}]
64
  }
65
  ]
66
 
67
+ # Process the chat template
68
+ try:
69
+ inputs_vlm = self.processor.apply_chat_template(
70
+ messages_vlm,
71
+ add_generation_prompt=True,
72
+ tokenize=True,
73
+ return_dict=True,
74
+ return_tensors="pt"
75
+ ).to(self.device, dtype=torch.bfloat16)
76
+ logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
77
+ logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
78
+ except Exception as e:
79
+ logger.error(f"Error in tokenization: {str(e)}")
80
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
81
 
82
  input_len = inputs_vlm["input_ids"].shape[-1]
83
 
84
+ # Generate response with improved settings
85
  with torch.inference_mode():
86
+ generation = self.model.generate(
87
+ **inputs_vlm,
88
+ max_new_tokens=max_tokens, # Increased for coherence
89
+ do_sample=True, # Enable sampling for variability
90
+ temperature=temperature # Control creativity
91
+ )
92
  generation = generation[0][input_len:]
93
 
94
  # Decode the output
95
  response = self.processor.decode(generation, skip_special_tokens=True)
96
+ logger.info(f"Generated response: {response}")
97
  return response
98
+
99
  async def vision_query(self, image: Image.Image, query: str) -> str:
100
  if not self.is_loaded:
101
  self.load()
 
103
  messages_vlm = [
104
  {
105
  "role": "system",
106
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in max 2 lines."}]
107
  },
108
  {
109
  "role": "user",
 
111
  }
112
  ]
113
 
114
+ # Add text prompt
115
  messages_vlm[1]["content"].append({"type": "text", "text": query})
116
 
117
+ # Handle image if valid
118
+ if image and image.size[0] > 0 and image.size[1] > 0:
 
119
  messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
120
  logger.info(f"Received valid image for processing")
121
  else:
122
  logger.info("No valid image provided, processing text only")
123
 
124
+ # Process the chat template
125
  try:
126
  inputs_vlm = self.processor.apply_chat_template(
127
  messages_vlm,
 
129
  tokenize=True,
130
  return_dict=True,
131
  return_tensors="pt"
132
+ ).to(self.device, dtype=torch.bfloat16)
133
+ logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
134
  except Exception as e:
135
  logger.error(f"Error in apply_chat_template: {str(e)}")
136
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
139
 
140
  # Generate response
141
  with torch.inference_mode():
142
+ generation = self.model.generate(
143
+ **inputs_vlm,
144
+ max_new_tokens=512, # Increased for coherence
145
+ do_sample=True, # Enable sampling
146
+ temperature=0.7 # Control creativity
147
+ )
148
  generation = generation[0][input_len:]
149
 
150
  # Decode the output
151
  decoded = self.processor.decode(generation, skip_special_tokens=True)
152
+ logger.info(f"Vision query response: {decoded}")
 
153
  return decoded
154
+
155
  async def chat_v2(self, image: Image.Image, query: str) -> str:
156
  if not self.is_loaded:
157
  self.load()
158
+
159
  messages_vlm = [
160
  {
161
  "role": "system",
162
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
163
  },
164
  {
165
  "role": "user",
 
167
  }
168
  ]
169
 
170
+ # Add text prompt
171
  messages_vlm[1]["content"].append({"type": "text", "text": query})
172
 
173
+ # Handle image if valid
174
+ if image and image.size[0] > 0 and image.size[1] > 0:
175
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
176
+ logger.info(f"Received valid image for processing")
 
 
 
 
 
 
 
177
  else:
 
 
178
  logger.info("No valid image provided, processing text only")
179
 
180
+ # Process the chat template
181
+ try:
182
+ inputs_vlm = self.processor.apply_chat_template(
183
+ messages_vlm,
184
+ add_generation_prompt=True,
185
+ tokenize=True,
186
+ return_dict=True,
187
+ return_tensors="pt"
188
+ ).to(self.device, dtype=torch.bfloat16)
189
+ logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
190
+ except Exception as e:
191
+ logger.error(f"Error in apply_chat_template: {str(e)}")
192
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
 
194
  input_len = inputs_vlm["input_ids"].shape[-1]
195
 
196
  # Generate response
197
  with torch.inference_mode():
198
+ generation = self.model.generate(
199
+ **inputs_vlm,
200
+ max_new_tokens=512, # Increased for coherence
201
+ do_sample=True, # Enable sampling
202
+ temperature=0.7 # Control creativity
203
+ )
204
  generation = generation[0][input_len:]
205
 
206
  # Decode the output
207
  decoded = self.processor.decode(generation, skip_special_tokens=True)
208
+ logger.info(f"Chat_v2 response: {decoded}")
209
  return decoded