davanstrien HF staff commited on
Commit
6c4cdba
·
verified ·
1 Parent(s): d880504

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -4
handler.py CHANGED
@@ -3,23 +3,29 @@ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
  from PIL import Image
4
  import requests
5
  import torch
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  self.processor = AutoProcessor.from_pretrained(
10
  path,
11
  trust_remote_code=True,
12
- torch_dtype='auto',
13
  device_map='auto'
14
  )
15
  self.model = AutoModelForCausalLM.from_pretrained(
16
  path,
17
  trust_remote_code=True,
18
- torch_dtype='auto',
19
- device_map='auto'
 
20
  )
21
 
22
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
23
  # Extract inputs from the request data
24
  inputs = data.get("inputs", {})
25
  image_url = inputs.get("image_url")
@@ -38,7 +44,7 @@ class EndpointHandler:
38
 
39
  # Process the image and text
40
  try:
41
- with torch.cuda.amp.autocast(enabled=True):
42
  inputs = self.processor.process(
43
  images=[image],
44
  text=text_prompt
@@ -58,6 +64,10 @@ class EndpointHandler:
58
  generated_tokens = output[0, inputs['input_ids'].size(1):]
59
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
60
 
 
 
 
 
61
  return [{"generated_text": generated_text}]
62
  except Exception as e:
63
  return [{"error": f"Error during generation: {str(e)}"}]
 
3
  from PIL import Image
4
  import requests
5
  import torch
6
+ import gc
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  self.processor = AutoProcessor.from_pretrained(
11
  path,
12
  trust_remote_code=True,
13
+ torch_dtype=torch.bfloat16,
14
  device_map='auto'
15
  )
16
  self.model = AutoModelForCausalLM.from_pretrained(
17
  path,
18
  trust_remote_code=True,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map='auto',
21
+ low_cpu_mem_usage=True
22
  )
23
 
24
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
25
+ # Clear CUDA cache
26
+ torch.cuda.empty_cache()
27
+ gc.collect()
28
+
29
  # Extract inputs from the request data
30
  inputs = data.get("inputs", {})
31
  image_url = inputs.get("image_url")
 
44
 
45
  # Process the image and text
46
  try:
47
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
48
  inputs = self.processor.process(
49
  images=[image],
50
  text=text_prompt
 
64
  generated_tokens = output[0, inputs['input_ids'].size(1):]
65
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
66
 
67
+ # Clear CUDA cache again
68
+ torch.cuda.empty_cache()
69
+ gc.collect()
70
+
71
  return [{"generated_text": generated_text}]
72
  except Exception as e:
73
  return [{"error": f"Error during generation: {str(e)}"}]