Story_Board / app.py
zahidpichen's picture
Update app.py
5395770 verified
import streamlit as st
from langchain import PromptTemplate, LLMChain
from langchain_cohere import ChatCohere
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from diffusers import DiffusionPipeline
import os
# Load models
model_name = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name)
chat = ChatCohere(model="command", cohere_api_key="Pkl2kMU326tZ4aJhgE2up6iD9qJeuVDgS6EdWgaJ")
# Define templates and chains
story_template = PromptTemplate(
input_variables=["prompt"],
template="""
Create a Chapter wise detailed and sequential story based on the following prompt: {prompt}. It must only have 5 chapters. here's an example of the output:
Chapter 0:
Chapter 1:
Chapter 2:
Chapter 3:
Chapter 4:
Chapter 5:
after chapter 5, there should not be any other text.
""",
)
story_chain = LLMChain(llm=chat, prompt=story_template)
def generate_story(prompt):
response = story_chain.invoke(prompt)
return response['text']
def split_story(story):
chapters = story.split("\n\nChapter ")
chapters = [chapters[0]] + [f"Chapter {chap}" for chap in chapters[1:]]
return chapters
def generate_images(story_parts, save_dir="story_images"):
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
os.makedirs(save_dir, exist_ok=True)
image_paths = []
for i, part in enumerate(story_parts):
image = pipe(part).images[0]
image_path = os.path.join(save_dir, f"story_part_{i+1}.png")
image.save(image_path)
image_paths.append(image_path)
return image_paths
def generate_captions(image_paths):
captions_dict = {}
for image_path in image_paths:
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
out = model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
captions_dict[image_path] = caption
return captions_dict
story_prompt = PromptTemplate(
input_variables=[
"context",
"image_caption_1",
"image_caption_2",
"image_caption_3",
"image_caption_4",
"image_caption_5",
"image_caption_6"
],
template="""
You are given a sequence of 6 images. Your task is to create a coherent and engaging story based on these images. Use the provided context to set the scene and ensure that each part of the story corresponds to the captions of the images in the order they are provided.
Instructions:
1. Read the context to understand the background and setting of the story.
2. Carefully look at the image captions to capture the essence of each image.
3. Create a story that weaves together the elements from the context and the image captions.
4. Make sure the transitions between the images are smooth and the story maintains a logical flow.
Begin the story below:
{context}
1. Image 1: {image_caption_1}
2. Image 2: {image_caption_2}
3. Image 3: {image_caption_3}
4. Image 4: {image_caption_4}
5. Image 5: {image_caption_5}
6. Image 6: {image_caption_6}
Story:
"""
)
llm_chain = LLMChain(
llm=chat,
prompt=story_prompt
)
def final_story(image_captions, context):
result = llm_chain.run({
"context": context,
"image_caption_1": image_captions[0],
"image_caption_2": image_captions[1],
"image_caption_3": image_captions[2],
"image_caption_4": image_captions[3],
"image_caption_5": image_captions[4],
"image_caption_6": image_captions[5]
})
return result
# Streamlit UI
st.title("Story Board")
prompt = st.text_input("Enter your story prompt:", "A heroic journey of a young adventurer in a mystical land.")
if st.button("Generate Story"):
with st.spinner("Generating story..."):
story_context = generate_story(prompt)
story_parts = split_story(story_context)
image_paths = generate_images(story_parts)
captions_dict = generate_captions(image_paths)
captions_list = list(captions_dict.values())
final_story_text = final_story(image_captions=captions_list, context=story_context)
for i in range(6):
st.image(image_paths[i])
st.write(story_parts[i + 1]) # Display chapters below corresponding images