fredaddy commited on
Commit
e76d7b2
·
verified ·
1 Parent(s): 830d07f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -15
handler.py CHANGED
@@ -1,27 +1,59 @@
1
  import torch
2
  from PIL import Image
 
 
3
  from transformers import AutoModel, AutoTokenizer
4
 
5
  class EndpointHandler:
6
- def __init__(self, path):
 
 
 
7
  self.model = AutoModel.from_pretrained(
8
- path,
 
 
 
 
 
 
 
 
 
9
  trust_remote_code=True,
10
- attn_implementation='sdpa', # Using sdpa instead of flash_attention_2
11
- torch_dtype=torch.float16
12
  )
13
- self.model = self.model.eval().cuda()
14
- self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
15
 
16
  def __call__(self, data):
17
- image = Image.open(data['inputs']['image'].file).convert('RGB')
18
- question = data['inputs'].get("question", "Extract all data in the image. Be extremely careful to ensure that you don't miss anything. It's imperative that you extract and digitize everything on that page.")
19
- msgs = [{'role': 'user', 'content': [image, question]}]
20
 
21
- res = self.model.chat(
22
- image=None,
23
- msgs=msgs,
24
- tokenizer=self.tokenizer
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- return {"generated_text": res}
 
1
  import torch
2
  from PIL import Image
3
+ import base64
4
+ from io import BytesIO
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Load the model
12
  self.model = AutoModel.from_pretrained(
13
+ path,
14
+ trust_remote_code=True,
15
+ attn_implementation='sdpa',
16
+ torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32,
17
+ ).to(self.device)
18
+ self.model.eval()
19
+
20
+ # Load the tokenizer
21
+ self.tokenizer = AutoTokenizer.from_pretrained(
22
+ path,
23
  trust_remote_code=True,
 
 
24
  )
 
 
25
 
26
  def __call__(self, data):
27
+ # Extract image and text from the input data
28
+ image_data = data.get("inputs", {}).get("image", "")
29
+ text_prompt = data.get("inputs", {}).get("text", "")
30
 
31
+ if not image_data or not text_prompt:
32
+ return {"error": "Both 'image' and 'text' must be provided in the input data."}
33
+
34
+ # Process the image data
35
+ try:
36
+ image_bytes = base64.b64decode(image_data)
37
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
38
+ except Exception as e:
39
+ return {"error": f"Failed to process image data: {e}"}
40
+
41
+ # Prepare the messages for the model
42
+ msgs = [{'role': 'user', 'content': [image, text_prompt]}]
43
+
44
+ # Generate output
45
+ with torch.no_grad():
46
+ res = self.model.chat(
47
+ image=None,
48
+ msgs=msgs,
49
+ tokenizer=self.tokenizer,
50
+ sampling=True,
51
+ temperature=0.7,
52
+ top_p=0.95,
53
+ max_length=2000,
54
+ )
55
+
56
+ # The result is the generated text
57
+ output_text = res
58
 
59
+ return {"generated_text": output_text}