Spaces:
Runtime error
Runtime error
import gradio as gr | |
from prefix_clip import download_pretrained_model, generate_caption | |
from gpt2_story_gen import generate_story | |
def main(pil_image, genre, model="Conceptual", use_beam_search=True): | |
model_file = "pretrained_weights.pt" | |
download_pretrained_model(model.lower(), file_to_save=model_file) | |
image_caption = generate_caption( | |
model_path=model_file, | |
pil_image=pil_image, | |
use_beam_search=use_beam_search, | |
) | |
story = generate_story(image_caption, genre.lower()) | |
return story | |
if __name__ == "__main__": | |
interface = gr.Interface( | |
main, | |
title="image2story", | |
inputs=[ | |
gr.inputs.Image(type="pil", source="upload", label="Input"), | |
gr.inputs.Dropdown( | |
type="value", | |
label="Story genre", | |
choices=[ | |
"superhero", | |
"action", | |
"drama", | |
"horror", | |
"thriller", | |
"sci_fi", | |
], | |
), | |
], | |
outputs=gr.outputs.Textbox(label="Generated story"), | |
enable_queue=True, | |
) | |
interface.launch() | |