Spaces:
Runtime error
Runtime error
from transformers import pipeline, CLIPProcessor, CLIPModel | |
def generate_story(image_caption, image, genre): | |
clip_ranker_checkpoint = "openai/clip-vit-base-patch32" | |
clip_ranker_processor = CLIPProcessor.from_pretrained(clip_ranker_checkpoint) | |
clip_ranker_model = CLIPModel.from_pretrained(clip_ranker_checkpoint) | |
story_gen = pipeline( | |
"text-generation", | |
"pranavpsv/genre-story-generator-v2" | |
) | |
input = f"<BOS> <{genre}> {image_caption}" | |
stories = [story_gen(input)[0]['generated_text'].strip(input) for i in range(3)] | |
clip_ranker_inputs = clip_ranker_processor( | |
text=stories, | |
images=image, | |
return_tensors='pt', | |
padding=True | |
) | |
clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) | |
story = stories[torch.argmax(probs).item()] | |
return story | |