Gabriel commited on
Commit
eacc8d0
·
verified ·
1 Parent(s): 2830cf6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -45
handler.py CHANGED
@@ -1,40 +1,18 @@
1
  from typing import Dict, Any
2
  import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from PIL import Image
5
  import io
6
  import base64
7
  import requests
8
- from qwen_vl_utils import process_vision_info
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
- path,
15
- torch_dtype="auto",
16
- device_map="auto"
17
- ).to(self.device)
18
-
19
  self.processor = AutoProcessor.from_pretrained(path)
20
-
21
- # Optionally, adjust min_pixels and max_pixels if needed
22
- # min_pixels = 256*28*28
23
- # max_pixels = 1280*28*28
24
- # self.processor = AutoProcessor.from_pretrained(path, min_pixels=min_pixels, max_pixels=max_pixels)
25
 
26
  def __call__(self, data: Any) -> Dict[str, Any]:
27
- """
28
- Args:
29
- data (Any): The input data, which can be:
30
- - Binary image data in the request body.
31
- - A dictionary with 'image' and 'text' keys:
32
- - 'image': Base64-encoded image string or image URL.
33
- - 'text': The text prompt.
34
-
35
- Returns:
36
- Dict[str, Any]: The generated text output from the model.
37
- """
38
  default_prompt = "Describe this image."
39
 
40
  if isinstance(data, (bytes, bytearray)):
@@ -46,8 +24,7 @@ class EndpointHandler():
46
  if image_input is None:
47
  return {"error": "No image provided."}
48
  if image_input.startswith('http'):
49
- response = requests.get(image_input)
50
- image = Image.open(io.BytesIO(response.content)).convert('RGB')
51
  else:
52
  image_data = base64.b64decode(image_input)
53
  image = Image.open(io.BytesIO(image_data)).convert('RGB')
@@ -58,34 +35,24 @@ class EndpointHandler():
58
  {
59
  "role": "user",
60
  "content": [
61
- {
62
- "type": "image",
63
- "image": image,
64
- },
65
  {"type": "text", "text": text_input},
66
  ],
67
  }
68
  ]
69
 
70
- text = self.processor.apply_chat_template(
71
- messages, tokenize=False, add_generation_prompt=True
72
- )
73
- image_inputs, video_inputs = process_vision_info(messages)
74
  inputs = self.processor(
75
  text=[text],
76
- images=image_inputs,
77
- videos=video_inputs,
78
  padding=True,
79
  return_tensors="pt",
80
- )
81
- inputs = inputs.to(self.device)
82
 
83
- generated_ids = self.model.generate(**inputs, max_new_tokens=128)
84
- generated_ids_trimmed = [
85
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
86
- ]
87
  output_text = self.processor.batch_decode(
88
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
89
- )
90
 
91
- return {"generated_text": output_text[0]}
 
 
1
  from typing import Dict, Any
2
  import torch
3
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
4
  from PIL import Image
5
  import io
6
  import base64
7
  import requests
 
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(path).to(self.device)
 
 
 
 
 
13
  self.processor = AutoProcessor.from_pretrained(path)
 
 
 
 
 
14
 
15
  def __call__(self, data: Any) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
16
  default_prompt = "Describe this image."
17
 
18
  if isinstance(data, (bytes, bytearray)):
 
24
  if image_input is None:
25
  return {"error": "No image provided."}
26
  if image_input.startswith('http'):
27
+ image = Image.open(requests.get(image_input, stream=True).raw).convert('RGB')
 
28
  else:
29
  image_data = base64.b64decode(image_input)
30
  image = Image.open(io.BytesIO(image_data)).convert('RGB')
 
35
  {
36
  "role": "user",
37
  "content": [
38
+ {"type": "image", "image": image},
 
 
 
39
  {"type": "text", "text": text_input},
40
  ],
41
  }
42
  ]
43
 
44
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
45
  inputs = self.processor(
46
  text=[text],
47
+ images=[image],
 
48
  padding=True,
49
  return_tensors="pt",
50
+ ).to(self.device)
 
51
 
52
+ generate_ids = self.model.generate(inputs.input_ids, max_length=30)
 
 
 
53
  output_text = self.processor.batch_decode(
54
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
55
+ )[0]
56
 
57
+ return {"generated_text": output_text}
58
+