diabolic6045's picture
Update app.py
bd69ebb verified
raw
history blame
3.13 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor
from PIL import Image
import spaces
import os
from huggingface_hub import login
login(os.environ["HF_KEY"])
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", trust_remote_code=True, device_map='auto')
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto')
# Define the helper function to build prompts
TASK2INSTRUCTION = {
"caption": "画像を詳細に述べてください。",
"tag": "与えられた単語を使って、画像を詳細に述べてください。",
"vqa": "与えられた画像を下に、質問に答えてください。",
}
def build_prompt(task="caption", input=None, sep="\n\n### "):
assert task in TASK2INSTRUCTION, f"Please choose from {list(TASK2INSTRUCTION.keys())}"
if task in ["tag", "vqa"]:
assert input is not None, "Please fill in `input`!"
if task == "tag" and isinstance(input, list):
input = "、".join(input)
else:
assert input is None, f"`{task}` mode doesn't support to input questions"
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
p = sys_msg
roles = ["指示", "応答"]
instruction = TASK2INSTRUCTION[task]
msgs = [": \n" + instruction, ": \n"]
if input:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + input)
for role, msg in zip(roles, msgs):
p += sep + role + msg
return p
# Define the function to generate text from the image and prompt
def generate_text(image, task, input_text=None):
prompt = build_prompt(task=task, input=input_text)
inputs = processor(images=image, return_tensors="pt")
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs.update(text_encoding)
outputs = model.generate(
**inputs.to(device=device, dtype=model.dtype),
do_sample=False,
num_beams=5,
max_new_tokens=128,
min_length=1,
repetition_penalty=1.5,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
return generated_text
# Define the Gradio interface
image_input = gr.Image(label="Upload an image")
task_input = gr.Radio(choices=["caption", "tag", "vqa"], value="caption", label="Select a task")
text_input = gr.Textbox(label="Enter text (for tag or vqa tasks)")
output = gr.Textbox(label="Generated text")
interface = gr.Interface(
fn=generate_text,
inputs=[image_input, task_input, text_input],
outputs=output,
examples=[
["examples/example_1.jpg", "caption", None],
["examples/example_2.jpg", "tag", "河津桜、青空"],
["examples/example_3.jpg", "vqa", "OCRはできますか?"],
],
)
interface.launch()