FT_Llama / app.py
0llheaven's picture
Update app.py
cf8d75d verified
import os
from unsloth import FastVisionModel
import torch
from PIL import Image
from datasets import load_dataset
from transformers import TextStreamer
import matplotlib.pyplot as plt
import gradio as gr
import random
import numpy as np
device = torch.device("cpu")
def set_seed(seed_value=42):
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
#torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
model, tokenizer = FastVisionModel.from_pretrained(
"0llheaven/llama-3.2-11B-Vision-Instruct-Finetune",
load_in_4bit = True,
use_gradient_checkpointing = "unsloth",
)
#FastVisionModel.for_inference(model)
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
def predict_radiology_description(image, temperature, use_top_p, top_p_value, use_min_p, min_p_value):
try:
set_seed(42)
messages = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
inputs = tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to(device)
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
generate_kwargs = {
"max_new_tokens": 512,
"use_cache": True,
"temperature": temperature,
}
if use_top_p:
generate_kwargs["top_p"] = top_p_value
if use_min_p:
generate_kwargs["min_p"] = min_p_value
output_ids = model.generate(
**inputs,
streamer=text_streamer,
**generate_kwargs
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text.replace("assistant", "\n\nassistant").strip()
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks() as interface:
gr.Markdown("<h1><center>Radiology Image Description Generator</center></h1>")
gr.Markdown("Upload a radiology image, adjust temperature and top-p, and the model will describe the findings in the image")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload")
with gr.Column():
output_text = gr.Textbox(label="Generated Description")
with gr.Row():
with gr.Column(scale=0.5):
temperature_slider = gr.Slider(0.1, 2.0, step=0.1, value=1.0, label="temperature")
use_top_p_checkbox = gr.Checkbox(label="Use top-p", value=True)
top_p_slider = gr.Slider(0.1, 1.0, step=0.05, value=0.9, label="top-p")
use_min_p_checkbox = gr.Checkbox(label="Use min-p", value=False)
min_p_slider = gr.Slider(0.0, 1.0, step=0.05, value=0.1, label="min-p", visible=False)
# Update visibility of sliders
use_top_p_checkbox.change(
lambda use_top_p: gr.update(visible=use_top_p),
inputs=use_top_p_checkbox,
outputs=top_p_slider
)
use_min_p_checkbox.change(
lambda use_min_p: gr.update(visible=use_min_p),
inputs=use_min_p_checkbox,
outputs=min_p_slider
)
generate_button = gr.Button("Generate Description")
# Link function to UI
generate_button.click(
predict_radiology_description,
inputs=[image_input, temperature_slider, use_top_p_checkbox, top_p_slider, use_min_p_checkbox, min_p_slider],
outputs=output_text
)
# Gradio
interface.launch(share=True, debug=True)