fredaddy commited on
Commit
cd5795f
·
verified ·
1 Parent(s): b1cc8b6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +22 -34
handler.py CHANGED
@@ -1,45 +1,33 @@
1
- #Handler.py file needed
2
-
3
  from PIL import Image
4
  import torch
5
- from transformers import AutoProcessor, AutoModelForVision2Seq
6
 
7
  class ModelHandler:
8
  def __init__(self):
9
- self.model = None
10
- self.processor = None
11
-
12
- def initialize(self, model_dir):
13
- # Load the processor and model
14
- self.processor = AutoProcessor.from_pretrained(model_dir)
15
- self.model = AutoModelForVision2Seq.from_pretrained(model_dir)
 
 
16
 
17
  def preprocess(self, inputs):
18
- # Process the input image
19
- image = Image.open(inputs["image"].file)
20
- pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
21
-
22
- # Process the text context (if provided)
23
- text_context = inputs.get("text_context", "")
24
- if text_context:
25
- context_inputs = self.processor(text=text_context, return_tensors="pt").input_ids
26
- else:
27
- context_inputs = None
28
-
29
- return pixel_values, context_inputs
30
 
31
- def inference(self, pixel_values, context_inputs=None):
32
- # Run inference on the image with or without text context
33
- with torch.no_grad():
34
- if context_inputs is not None:
35
- outputs = self.model.generate(pixel_values, input_ids=context_inputs)
36
- else:
37
- outputs = self.model.generate(pixel_values)
38
- return outputs
39
 
40
- def postprocess(self, outputs):
41
- # Decode the output to text
42
- decoded_text = self.processor.batch_decode(outputs, skip_special_tokens=True)
43
- return {"digitized_text": decoded_text[0]}
44
 
45
  service = ModelHandler()
 
 
 
1
  from PIL import Image
2
  import torch
3
+ from transformers import AutoModel, AutoTokenizer
4
 
5
  class ModelHandler:
6
  def __init__(self):
7
+ # Load the model and tokenizer with appropriate weights
8
+ self.model = AutoModel.from_pretrained(
9
+ 'openbmb/MiniCPM-V-2_6',
10
+ trust_remote_code=True,
11
+ attn_implementation='sdpa',
12
+ torch_dtype=torch.bfloat16
13
+ ).eval().cuda()
14
+
15
+ self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
16
 
17
  def preprocess(self, inputs):
18
+ # Preprocess image input
19
+ image = Image.open(inputs['image'].file).convert('RGB')
20
+ question = inputs.get("question", "What is in the image?")
21
+ msgs = [{'role': 'user', 'content': [image, question]}]
22
+ return msgs
 
 
 
 
 
 
 
23
 
24
+ def inference(self, msgs):
25
+ # Run inference on the model
26
+ result = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
27
+ return result
 
 
 
 
28
 
29
+ def postprocess(self, result):
30
+ # Postprocess the output from the model
31
+ return {"generated_text": result}
 
32
 
33
  service = ModelHandler()