File size: 1,771 Bytes
5041f6c
 
 
 
 
 
 
 
1112f1b
 
 
 
 
5041f6c
04cc94e
5041f6c
1112f1b
 
 
5041f6c
 
1112f1b
5041f6c
 
d258d19
0f8e3cc
5041f6c
 
 
 
 
 
8e64bf0
5041f6c
1112f1b
5041f6c
 
 
 
8e64bf0
 
73a5bf0
8e64bf0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
import spaces
import torch
torch.jit.script = lambda f: f  # Avoid script error in lambda

from t2v_metrics import VQAScore, list_all_vqascore_models

# Global model variable, but do not initialize or move to CUDA here
cur_model_name = "clip-flant5-xl"
model_pipe = update_model(cur_model_name)

def update_model(model_name):
    return VQAScore(model=model_name, device="cuda") 

@spaces.GPU(duration = 20)
def generate(model_name, image, text):
    if model_name != cur_model_name:
        model_pipe = update_model(model_name)
    
    print("Image:", image)  # Debug: Print image path
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    # Wrap the model call in a try-except block to capture and debug CUDA errors
    try:
        result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()  # Perform the model inference
        print("Result", result)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return result  # Return the result

demo = gr.Interface(
    fn=generate,  # function to call
    # ['clip-flant5-xxl', 'clip-flant5-xl', 'clip-flant5-xxl-no-system', 'clip-flant5-xxl-no-system-no-user', 'llava-v1.5-13b', 'llava-v1.5-7b', 'sharegpt4v-7b', 'sharegpt4v-13b', 'llava-v1.6-13b', 'instructblip-flant5-xxl', 'instructblip-flant5-xl']
    inputs=[gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), gr.Image(type="filepath"), gr.Textbox(label="Prompt")],  # define the types of inputs
    outputs="number",  # define the type of output
    title="VQAScore",  # title of the app
    description="This model evaluates the similarity between an image and a text prompt."
)

demo.queue()
demo.launch()