Spaces:
Running
Running
import streamlit as st | |
from langchain import PromptTemplate, LLMChain | |
from langchain_cohere import ChatCohere | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from PIL import Image | |
from diffusers import DiffusionPipeline | |
import os | |
# Load models | |
model_name = "Salesforce/blip-image-captioning-base" | |
processor = BlipProcessor.from_pretrained(model_name) | |
model = BlipForConditionalGeneration.from_pretrained(model_name) | |
chat = ChatCohere(model="command", cohere_api_key="Pkl2kMU326tZ4aJhgE2up6iD9qJeuVDgS6EdWgaJ") | |
# Define templates and chains | |
story_template = PromptTemplate( | |
input_variables=["prompt"], | |
template=""" | |
Create a Chapter wise detailed and sequential story based on the following prompt: {prompt}. It must only have 5 chapters. here's an example of the output: | |
Chapter 0: | |
Chapter 1: | |
Chapter 2: | |
Chapter 3: | |
Chapter 4: | |
Chapter 5: | |
after chapter 5, there should not be any other text. | |
""", | |
) | |
story_chain = LLMChain(llm=chat, prompt=story_template) | |
def generate_story(prompt): | |
response = story_chain.invoke(prompt) | |
return response['text'] | |
def split_story(story): | |
chapters = story.split("\n\nChapter ") | |
chapters = [chapters[0]] + [f"Chapter {chap}" for chap in chapters[1:]] | |
return chapters | |
def generate_images(story_parts, save_dir="story_images"): | |
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo") | |
os.makedirs(save_dir, exist_ok=True) | |
image_paths = [] | |
for i, part in enumerate(story_parts): | |
image = pipe(part).images[0] | |
image_path = os.path.join(save_dir, f"story_part_{i+1}.png") | |
image.save(image_path) | |
image_paths.append(image_path) | |
return image_paths | |
def generate_captions(image_paths): | |
captions_dict = {} | |
for image_path in image_paths: | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
out = model.generate(**inputs) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
captions_dict[image_path] = caption | |
return captions_dict | |
story_prompt = PromptTemplate( | |
input_variables=[ | |
"context", | |
"image_caption_1", | |
"image_caption_2", | |
"image_caption_3", | |
"image_caption_4", | |
"image_caption_5", | |
"image_caption_6" | |
], | |
template=""" | |
You are given a sequence of 6 images. Your task is to create a coherent and engaging story based on these images. Use the provided context to set the scene and ensure that each part of the story corresponds to the captions of the images in the order they are provided. | |
Instructions: | |
1. Read the context to understand the background and setting of the story. | |
2. Carefully look at the image captions to capture the essence of each image. | |
3. Create a story that weaves together the elements from the context and the image captions. | |
4. Make sure the transitions between the images are smooth and the story maintains a logical flow. | |
Begin the story below: | |
{context} | |
1. Image 1: {image_caption_1} | |
2. Image 2: {image_caption_2} | |
3. Image 3: {image_caption_3} | |
4. Image 4: {image_caption_4} | |
5. Image 5: {image_caption_5} | |
6. Image 6: {image_caption_6} | |
Story: | |
""" | |
) | |
llm_chain = LLMChain( | |
llm=chat, | |
prompt=story_prompt | |
) | |
def final_story(image_captions, context): | |
result = llm_chain.run({ | |
"context": context, | |
"image_caption_1": image_captions[0], | |
"image_caption_2": image_captions[1], | |
"image_caption_3": image_captions[2], | |
"image_caption_4": image_captions[3], | |
"image_caption_5": image_captions[4], | |
"image_caption_6": image_captions[5] | |
}) | |
return result | |
# Streamlit UI | |
st.title("Story Board") | |
prompt = st.text_input("Enter your story prompt:", "A heroic journey of a young adventurer in a mystical land.") | |
if st.button("Generate Story"): | |
with st.spinner("Generating story..."): | |
story_context = generate_story(prompt) | |
story_parts = split_story(story_context) | |
image_paths = generate_images(story_parts) | |
captions_dict = generate_captions(image_paths) | |
captions_list = list(captions_dict.values()) | |
final_story_text = final_story(image_captions=captions_list, context=story_context) | |
for i in range(6): | |
st.image(image_paths[i]) | |
st.write(story_parts[i + 1]) # Display chapters below corresponding images |