File size: 1,003 Bytes
0843a80
ef62727
cae4936
 
0843a80
 
 
 
cae4936
0843a80
 
 
 
 
cae4936
0843a80
 
 
 
cf0876f
0843a80
 
 
 
ef62727
0843a80
 
cae4936
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import pipeline, CLIPProcessor, CLIPModel
import torch


def generate_story(image_caption, image, genre):
    clip_ranker_checkpoint = "openai/clip-vit-base-patch32"
    clip_ranker_processor = CLIPProcessor.from_pretrained(clip_ranker_checkpoint)
    clip_ranker_model = CLIPModel.from_pretrained(clip_ranker_checkpoint)

    story_gen = pipeline(
        "text-generation", 
        "pranavpsv/genre-story-generator-v2"
        )
    
    input = f"<BOS> <{genre}> {image_caption}"
    stories = [story_gen(input)[0]['generated_text'].strip(input) for i in range(3)]
    clip_ranker_inputs = clip_ranker_processor(
        text=stories, 
        images=image, 
        truncation=True,
        return_tensors='pt', 
        padding=True
        )
    clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
    logits_per_image = clip_ranker_outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)
    story = stories[torch.argmax(probs).item()]

    return story