import gradio as gr from prefix_clip import download_pretrained_model, generate_caption from gpt2_story_gen import generate_story def main(pil_image, genre, model="Conceptual", use_beam_search=True): model_file = "pretrained_weights.pt" download_pretrained_model(model.lower(), file_to_save=model_file) image_caption = generate_caption( model_path=model_file, pil_image=pil_image, use_beam_search=use_beam_search, ) story = generate_story(image_caption, genre.lower()) return story if __name__ == "__main__": interface = gr.Interface( main, title="image2story", inputs=[ gr.inputs.Image(type="pil", source="upload", label="Input"), gr.inputs.Dropdown( type="value", label="Story genre", choices=[ "superhero", "action", "drama", "horror", "thriller", "sci_fi", ], ), ], outputs=gr.outputs.Textbox(label="Generated story"), enable_queue=True, ) interface.launch()