image2story / gpt2_story_gen.py
bipin
added files
cae4936
raw
history blame
312 Bytes
from transformers import pipeline
def generate_story(image_caption, genre):
story_gen = pipeline("text-generation", "pranavpsv/genre-story-generator-v2")
input = f"<BOS> <{genre}> {image_caption}"
story = story_gen(input)[0]["generated_text"]
story = f"{story.strip(input)}"
return story