|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from PIL import Image |
|
import requests |
|
import gradio as gr |
|
import spaces |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model_name = 'scb10x/llama-3-typhoon-v1.5-8b-instruct-vision-preview' |
|
|
|
@spaces.GPU(duration=60) |
|
def load_model(): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
revision='main', |
|
torch_dtype=torch.float16 if device == 'cuda' else torch.float32, |
|
device_map='auto', |
|
trust_remote_code=True |
|
) |
|
|
|
return model |
|
|
|
model = load_model() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
def prepare_inputs(text, image, device='cuda'): |
|
messages = [ |
|
{"role": "system", "content": "You are a helpful vision-capable assistant who eagerly converses with the user in their language."}, |
|
] |
|
messages.append({"role": "user", "content": "<|image|>\n" + text}) |
|
|
|
inputs_formatted = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=False |
|
) |
|
|
|
text_chunks = [tokenizer(chunk).input_ids for chunk in inputs_formatted.split('<|image|>')] |
|
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device) |
|
attention_mask = torch.ones_like(input_ids).to(device) |
|
|
|
return input_ids, attention_mask |
|
|
|
@spaces.GPU(duration=60) |
|
def predict(prompt, img_url): |
|
try: |
|
image = Image.open(requests.get(img_url, stream=True).raw) |
|
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device) |
|
|
|
input_ids, attention_mask = prepare_inputs(prompt, image, device=device) |
|
|
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
max_new_tokens=100, |
|
use_cache=True, |
|
temperature=0.2, |
|
top_p=0.2, |
|
repetition_penalty=1.0 |
|
)[0] |
|
|
|
result = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
return result |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
inputs = [ |
|
gr.Textbox(label="Prompt", placeholder="Ask about the food in the image"), |
|
gr.Textbox(label="Image URL", placeholder="Enter an image URL") |
|
] |
|
|
|
outputs = gr.Textbox(label="Generated Output") |
|
|
|
gr.Interface( |
|
fn=predict, inputs=inputs, outputs=outputs, title="Food Image AI Assistant", |
|
description="This model can analyze food images and answer questions about them." |
|
).launch() |
|
|