File size: 4,151 Bytes
eb46e50
c9406a3
 
 
 
 
8b1dbc7
 
c9406a3
 
8b1dbc7
c9406a3
8b1dbc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9406a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8777df
c9406a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b1dbc7
c9406a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb46e50
c9406a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import spaces

# Flag to use GPU (set to False by default)
USE_GPU = False

# Load the processor and model
device = torch.device("cuda" if USE_GPU and torch.cuda.is_available() else "cpu")

processor = AutoProcessor.from_pretrained(
    'allenai/MolmoE-1B-0924',
    trust_remote_code=True,
    torch_dtype='auto',
)

model = AutoModelForCausalLM.from_pretrained(
    'allenai/MolmoE-1B-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto' if USE_GPU else None
)

if not USE_GPU:
    model.to(device)

# Predefined prompts
prompts = [
    "Describe this image in detail",
    "What objects can you see in this image?",
    "What's the main subject of this image?",
    "Describe the colors in this image",
    "What emotions does this image evoke?"
]

def process_image_and_text(image, text, max_new_tokens, temperature, top_p):
    # Process the image and text
    inputs = processor.process(
        images=[Image.fromarray(image)],
        text=text
    )

    # Move inputs to the correct device and make a batch of size 1
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}

    # Generate output
    output = model.generate_from_batch(
        inputs,
        GenerationConfig(
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            stop_strings="<|endoftext|>"
        ),
        tokenizer=processor.tokenizer
    )

    # Only get generated tokens; decode them to text
    generated_tokens = output[0, inputs['input_ids'].size(1):]
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return generated_text

def chatbot(image, text, history, max_new_tokens, temperature, top_p):
    if image is None:
        return history + [("Please upload an image first.", None)]

    response = process_image_and_text(image, text, max_new_tokens, temperature, top_p)
    history.append((text, response))
    return history

def update_textbox(prompt):
    return gr.update(value=prompt)

# Define the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Chatbot with MolmoE-1B-0924")
    
    with gr.Row():
        image_input = gr.Image(type="numpy")
        chatbot_output = gr.Chatbot()
    
    with gr.Row():
        text_input = gr.Textbox(placeholder="Ask a question about the image...")
        prompt_dropdown = gr.Dropdown(choices=[""] + prompts, label="Select a premade prompt", value="")
    
    submit_button = gr.Button("Submit")
    clear_button = gr.ClearButton([text_input, chatbot_output])

    with gr.Accordion("Advanced options", open=False):
        max_new_tokens = gr.Slider(minimum=1, maximum=500, value=200, step=1, label="Max new tokens")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")

    state = gr.State([])

    # Add copy button for raw output
    with gr.Row():
        raw_output = gr.Textbox(label="Raw Output", interactive=False)
        copy_button = gr.Button("Copy Raw Output")

    def update_raw_output(history):
        if history:
            return history[-1][1]
        return ""

    submit_button.click(
        chatbot,
        inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
        outputs=[chatbot_output]
    ).then(
        update_raw_output,
        inputs=[chatbot_output],
        outputs=[raw_output]
    )

    text_input.submit(
        chatbot,
        inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
        outputs=[chatbot_output]
    ).then(
        update_raw_output,
        inputs=[chatbot_output],
        outputs=[raw_output]
    )

    prompt_dropdown.change(update_textbox, inputs=[prompt_dropdown], outputs=[text_input])
    
    copy_button.click(lambda x: gr.update(value=x), inputs=[raw_output], outputs=[gr.Textbox(visible=False)])

demo.launch()