File size: 4,468 Bytes
85c6c13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
from langchain import PromptTemplate, LLMChain
from langchain_cohere import ChatCohere
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from diffusers import DiffusionPipeline
import os

# Load models
model_name = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name)

chat = ChatCohere(model="command", cohere_api_key="Pkl2kMU326tZ4aJhgE2up6iD9qJeuVDgS6EdWgaJ")

# Define templates and chains
story_template = PromptTemplate(
    input_variables=["prompt"],
    template="""
    Create a Chapter wise detailed and sequential story based on the following prompt: {prompt}. It must only have 5 chapters. here's an example of the output:
    Chapter 0:
    Chapter 1:
    Chapter 2:
    Chapter 3:
    Chapter 4:
    Chapter 5:

    after chapter 5, there should not be any other text.
    """,
)
story_chain = LLMChain(llm=chat, prompt=story_template)

def generate_story(prompt):
    response = story_chain.invoke(prompt)
    return response['text']

def split_story(story):
    chapters = story.split("\n\nChapter ")
    chapters = [chapters[0]] + [f"Chapter {chap}" for chap in chapters[1:]]
    return chapters


def generate_images(story_parts, save_dir="story_images"):
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
    os.makedirs(save_dir, exist_ok=True)
    image_paths = []
    for i, part in enumerate(story_parts):
        image = pipe(part).images[0]
        image_path = os.path.join(save_dir, f"story_part_{i+1}.png")
        image.save(image_path)
        image_paths.append(image_path)
    return image_paths

def generate_captions(image_paths):
    captions_dict = {}
    for image_path in image_paths:
        image = Image.open(image_path)
        inputs = processor(images=image, return_tensors="pt")
        out = model.generate(**inputs)
        caption = processor.decode(out[0], skip_special_tokens=True)
        captions_dict[image_path] = caption
    return captions_dict
    
story_prompt = PromptTemplate(
    input_variables=[
        "context",
        "image_caption_1",
        "image_caption_2",
        "image_caption_3",
        "image_caption_4",
        "image_caption_5",
        "image_caption_6"
    ],
    template="""
    You are given a sequence of 6 images. Your task is to create a coherent and engaging story based on these images. Use the provided context to set the scene and ensure that each part of the story corresponds to the captions of the images in the order they are provided.

    Instructions:
    1. Read the context to understand the background and setting of the story.
    2. Carefully look at the image captions to capture the essence of each image.
    3. Create a story that weaves together the elements from the context and the image captions.
    4. Make sure the transitions between the images are smooth and the story maintains a logical flow.

    Begin the story below:

    {context}

    1. Image 1: {image_caption_1}
    2. Image 2: {image_caption_2}
    3. Image 3: {image_caption_3}
    4. Image 4: {image_caption_4}
    5. Image 5: {image_caption_5}
    6. Image 6: {image_caption_6}

    Story:
    """
)
llm_chain = LLMChain(
    llm=chat,
    prompt=story_prompt
)

def final_story(image_captions, context):
    result = llm_chain.run({
        "context": context,
        "image_caption_1": image_captions[0],
        "image_caption_2": image_captions[1],
        "image_caption_3": image_captions[2],
        "image_caption_4": image_captions[3],
        "image_caption_5": image_captions[4],
        "image_caption_6": image_captions[5]
    })
    return result


# Streamlit UI
st.title("Story Board")
prompt = st.text_input("Enter your story prompt:", "A heroic journey of a young adventurer in a mystical land.")
if st.button("Generate Story"):
    with st.spinner("Generating story..."):
        story_context = generate_story(prompt)
        story_parts = split_story(story_context)
        image_paths = generate_images(story_parts)
        captions_dict = generate_captions(image_paths)
        captions_list = list(captions_dict.values())
        final_story_text = final_story(image_captions=captions_list, context=story_context)
        for i in range(6):
            st.image(image_paths[i])
            st.write(story_parts[i + 1])  # Display chapters below corresponding images