Rathapoom commited on
Commit
b7450f3
·
verified ·
1 Parent(s): e99fc14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -4,21 +4,25 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from PIL import Image
5
  import requests
6
  import gradio as gr
 
7
 
8
  # Load model and tokenizer
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
- model = AutoModelForCausalLM.from_pretrained(
11
- 'scb10x/llama-3-typhoon-v1.5-8b-instruct-vision-preview',
12
- revision='main', # or specify a commit hash if needed
13
- torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
14
- device_map='auto',
15
- trust_remote_code=True
16
- ).to(device)
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- 'scb10x/llama-3-typhoon-v1.5-8b-instruct-vision-preview',
20
- trust_remote_code=True
21
- )
 
 
 
 
 
 
 
 
 
22
 
23
  def prepare_inputs(text, image, device='cuda'):
24
  messages = [
@@ -38,7 +42,7 @@ def prepare_inputs(text, image, device='cuda'):
38
 
39
  return input_ids, attention_mask
40
 
41
- # Inference function
42
  def predict(prompt, img_url):
43
  try:
44
  image = Image.open(requests.get(img_url, stream=True).raw)
 
4
  from PIL import Image
5
  import requests
6
  import gradio as gr
7
+ import spaces # Import Hugging Face Spaces package
8
 
9
  # Load model and tokenizer
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ model_name = 'scb10x/llama-3-typhoon-v1.5-8b-instruct-vision-preview'
 
 
 
 
 
 
12
 
13
+ @spaces.GPU(duration=60) # Decorate the function to dynamically request and release GPU
14
+ def load_model():
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_name,
17
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
18
+ device_map='auto',
19
+ trust_remote_code=True
20
+ )
21
+ return model
22
+
23
+ model = load_model()
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
26
 
27
  def prepare_inputs(text, image, device='cuda'):
28
  messages = [
 
42
 
43
  return input_ids, attention_mask
44
 
45
+ @spaces.GPU(duration=60) # Decorate the function for GPU use
46
  def predict(prompt, img_url):
47
  try:
48
  image = Image.open(requests.get(img_url, stream=True).raw)