fredaddy commited on
Commit
b1cc8b6
·
verified ·
1 Parent(s): acc0847

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +45 -1
handler.py CHANGED
@@ -1 +1,45 @@
1
- #Handler.py file needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Handler.py file needed
2
+
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import AutoProcessor, AutoModelForVision2Seq
6
+
7
+ class ModelHandler:
8
+ def __init__(self):
9
+ self.model = None
10
+ self.processor = None
11
+
12
+ def initialize(self, model_dir):
13
+ # Load the processor and model
14
+ self.processor = AutoProcessor.from_pretrained(model_dir)
15
+ self.model = AutoModelForVision2Seq.from_pretrained(model_dir)
16
+
17
+ def preprocess(self, inputs):
18
+ # Process the input image
19
+ image = Image.open(inputs["image"].file)
20
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
21
+
22
+ # Process the text context (if provided)
23
+ text_context = inputs.get("text_context", "")
24
+ if text_context:
25
+ context_inputs = self.processor(text=text_context, return_tensors="pt").input_ids
26
+ else:
27
+ context_inputs = None
28
+
29
+ return pixel_values, context_inputs
30
+
31
+ def inference(self, pixel_values, context_inputs=None):
32
+ # Run inference on the image with or without text context
33
+ with torch.no_grad():
34
+ if context_inputs is not None:
35
+ outputs = self.model.generate(pixel_values, input_ids=context_inputs)
36
+ else:
37
+ outputs = self.model.generate(pixel_values)
38
+ return outputs
39
+
40
+ def postprocess(self, outputs):
41
+ # Decode the output to text
42
+ decoded_text = self.processor.batch_decode(outputs, skip_special_tokens=True)
43
+ return {"digitized_text": decoded_text[0]}
44
+
45
+ service = ModelHandler()