Rathapoom's picture
Update app.py
b7450f3 verified
raw
history blame
2.73 kB
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import requests
import gradio as gr
import spaces # Import Hugging Face Spaces package
# Load model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'scb10x/llama-3-typhoon-v1.5-8b-instruct-vision-preview'
@spaces.GPU(duration=60) # Decorate the function to dynamically request and release GPU
def load_model():
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
device_map='auto',
trust_remote_code=True
)
return model
model = load_model()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
def prepare_inputs(text, image, device='cuda'):
messages = [
{"role": "system", "content": "You are a helpful vision-capable assistant who eagerly converses with the user in their language."},
]
messages.append({"role": "user", "content": "<|image|>\n" + text})
inputs_formatted = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
text_chunks = [tokenizer(chunk).input_ids for chunk in inputs_formatted.split('<|image|>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
attention_mask = torch.ones_like(input_ids).to(device)
return input_ids, attention_mask
@spaces.GPU(duration=60) # Decorate the function for GPU use
def predict(prompt, img_url):
try:
image = Image.open(requests.get(img_url, stream=True).raw)
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
input_ids, attention_mask = prepare_inputs(prompt, image, device=device)
output_ids = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=100,
use_cache=True,
temperature=0.2,
top_p=0.2,
repetition_penalty=1.0
)[0]
result = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
return result
except Exception as e:
return str(e)
# Gradio Interface
inputs = [
gr.Textbox(label="Prompt", placeholder="Ask about the food in the image"),
gr.Textbox(label="Image URL", placeholder="Enter an image URL")
]
outputs = gr.Textbox(label="Generated Output")
gr.Interface(
fn=predict, inputs=inputs, outputs=outputs, title="Food Image AI Assistant",
description="This model can analyze food images and answer questions about them."
).launch()