Spaces:
Running
Running
File size: 4,151 Bytes
eb46e50 c9406a3 8b1dbc7 c9406a3 8b1dbc7 c9406a3 8b1dbc7 c9406a3 e8777df c9406a3 8b1dbc7 c9406a3 eb46e50 c9406a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import spaces
# Flag to use GPU (set to False by default)
USE_GPU = False
# Load the processor and model
device = torch.device("cuda" if USE_GPU and torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(
'allenai/MolmoE-1B-0924',
trust_remote_code=True,
torch_dtype='auto',
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/MolmoE-1B-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto' if USE_GPU else None
)
if not USE_GPU:
model.to(device)
# Predefined prompts
prompts = [
"Describe this image in detail",
"What objects can you see in this image?",
"What's the main subject of this image?",
"Describe the colors in this image",
"What emotions does this image evoke?"
]
def process_image_and_text(image, text, max_new_tokens, temperature, top_p):
# Process the image and text
inputs = processor.process(
images=[Image.fromarray(image)],
text=text
)
# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_strings="<|endoftext|>"
),
tokenizer=processor.tokenizer
)
# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
def chatbot(image, text, history, max_new_tokens, temperature, top_p):
if image is None:
return history + [("Please upload an image first.", None)]
response = process_image_and_text(image, text, max_new_tokens, temperature, top_p)
history.append((text, response))
return history
def update_textbox(prompt):
return gr.update(value=prompt)
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Chatbot with MolmoE-1B-0924")
with gr.Row():
image_input = gr.Image(type="numpy")
chatbot_output = gr.Chatbot()
with gr.Row():
text_input = gr.Textbox(placeholder="Ask a question about the image...")
prompt_dropdown = gr.Dropdown(choices=[""] + prompts, label="Select a premade prompt", value="")
submit_button = gr.Button("Submit")
clear_button = gr.ClearButton([text_input, chatbot_output])
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(minimum=1, maximum=500, value=200, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
state = gr.State([])
# Add copy button for raw output
with gr.Row():
raw_output = gr.Textbox(label="Raw Output", interactive=False)
copy_button = gr.Button("Copy Raw Output")
def update_raw_output(history):
if history:
return history[-1][1]
return ""
submit_button.click(
chatbot,
inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
outputs=[chatbot_output]
).then(
update_raw_output,
inputs=[chatbot_output],
outputs=[raw_output]
)
text_input.submit(
chatbot,
inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
outputs=[chatbot_output]
).then(
update_raw_output,
inputs=[chatbot_output],
outputs=[raw_output]
)
prompt_dropdown.change(update_textbox, inputs=[prompt_dropdown], outputs=[text_input])
copy_button.click(lambda x: gr.update(value=x), inputs=[raw_output], outputs=[gr.Textbox(visible=False)])
demo.launch() |