image2story / app.py
bipin
added files
cae4936
raw
history blame
1.2 kB
import gradio as gr
from prefix_clip import download_pretrained_model, generate_caption
from gpt2_story_gen import generate_story
def main(pil_image, genre, model="Conceptual", use_beam_search=True):
model_file = "pretrained_weights.pt"
download_pretrained_model(model.lower(), file_to_save=model_file)
image_caption = generate_caption(
model_path=model_file,
pil_image=pil_image,
use_beam_search=use_beam_search,
)
story = generate_story(image_caption, genre.lower())
return story
if __name__ == "__main__":
interface = gr.Interface(
main,
title="image2story",
inputs=[
gr.inputs.Image(type="pil", source="upload", label="Input"),
gr.inputs.Dropdown(
type="value",
label="Story genre",
choices=[
"superhero",
"action",
"drama",
"horror",
"thriller",
"sci_fi",
],
),
],
outputs=gr.outputs.Textbox(label="Generated story"),
enable_queue=True,
)
interface.launch()