File size: 312 Bytes
cae4936
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from transformers import pipeline


def generate_story(image_caption, genre):
    story_gen = pipeline("text-generation", "pranavpsv/genre-story-generator-v2")

    input = f"<BOS> <{genre}> {image_caption}"
    story = story_gen(input)[0]["generated_text"]
    story = f"{story.strip(input)}"

    return story