davanstrien HF staff commited on
Commit
bf4ec84
·
verified ·
1 Parent(s): d9af75a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -23
handler.py CHANGED
@@ -4,44 +4,60 @@ from PIL import Image
4
  import requests
5
  import torch
6
 
7
-
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  self.processor = AutoProcessor.from_pretrained(
11
- path, trust_remote_code=True, torch_dtype="auto", device_map="auto"
 
 
 
12
  )
13
  self.model = AutoModelForCausalLM.from_pretrained(
14
- path, trust_remote_code=True, torch_dtype="auto", device_map="auto"
 
 
 
15
  )
16
 
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
  # Extract inputs from the request data
19
- image_url = data.get("image_url")
20
- text_prompt = data.get("text_prompt", "Describe this image.")
 
 
 
 
21
 
22
  # Download and process the image
23
- image = Image.open(requests.get(image_url, stream=True).raw)
24
- if image.mode != "RGB":
25
- image = image.convert("RGB")
 
 
 
26
 
27
  # Process the image and text
28
- inputs = self.processor.process(images=[image], text=text_prompt)
 
 
 
29
 
30
  # Move inputs to the correct device and make a batch of size 1
31
  inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
32
 
33
  # Generate output
34
- with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
35
- output = self.model.generate_from_batch(
36
- inputs,
37
- GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
38
- tokenizer=self.processor.tokenizer,
39
- )
40
-
41
- # Decode the generated tokens
42
- generated_tokens = output[0, inputs["input_ids"].size(1) :]
43
- generated_text = self.processor.tokenizer.decode(
44
- generated_tokens, skip_special_tokens=True
45
- )
46
-
47
- return [{"generated_text": generated_text}]
 
 
4
  import requests
5
  import torch
6
 
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  self.processor = AutoProcessor.from_pretrained(
10
+ path,
11
+ trust_remote_code=True,
12
+ torch_dtype='auto',
13
+ device_map='auto'
14
  )
15
  self.model = AutoModelForCausalLM.from_pretrained(
16
+ path,
17
+ trust_remote_code=True,
18
+ torch_dtype='auto',
19
+ device_map='auto'
20
  )
21
 
22
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
23
  # Extract inputs from the request data
24
+ inputs = data.get("inputs", {})
25
+ image_url = inputs.get("image_url")
26
+ text_prompt = inputs.get("text_prompt", "Describe this image.")
27
+
28
+ if not image_url:
29
+ return [{"error": "No image_url provided in inputs"}]
30
 
31
  # Download and process the image
32
+ try:
33
+ image = Image.open(requests.get(image_url, stream=True).raw)
34
+ if image.mode != "RGB":
35
+ image = image.convert("RGB")
36
+ except Exception as e:
37
+ return [{"error": f"Failed to load image: {str(e)}"}]
38
 
39
  # Process the image and text
40
+ inputs = self.processor.process(
41
+ images=[image],
42
+ text=text_prompt
43
+ )
44
 
45
  # Move inputs to the correct device and make a batch of size 1
46
  inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
47
 
48
  # Generate output
49
+ try:
50
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
51
+ output = self.model.generate_from_batch(
52
+ inputs,
53
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
54
+ tokenizer=self.processor.tokenizer
55
+ )
56
+
57
+ # Decode the generated tokens
58
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
59
+ generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
60
+
61
+ return [{"generated_text": generated_text}]
62
+ except Exception as e:
63
+ return [{"error": f"Error during generation: {str(e)}"}]