Spaces:
Running
Running
import gradio as gr | |
import asyncio | |
import os | |
import traceback | |
import numpy as np | |
import re | |
from functools import partial | |
# Import all required libraries | |
import torch | |
import imageio | |
import cv2 | |
from PIL import Image | |
import edge_tts | |
from transformers import AutoTokenizer, pipeline | |
from moviepy.editor import VideoFileClip, AudioFileClip | |
# Initialize the Qwen model | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") | |
text_pipe = pipeline( | |
"text-generation", | |
model="Qwen/Qwen2.5-1.5B-Instruct", | |
tokenizer=tokenizer | |
) | |
# Initialize the sentiment analyzer | |
sentiment_analyzer = pipeline("sentiment-analysis") | |
# Load diffusers libraries after tokenizer to avoid GPU memory conflicts | |
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
# Initialize video generation components | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
step = 8 | |
repo = "ByteDance/AnimateDiff-Lightning" | |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" | |
base = "emilianJR/epiCRealism" | |
print(f"Using device: {device} with dtype: {dtype}") | |
# Load motion adapter and pipeline in a function to handle errors gracefully | |
def load_models(): | |
try: | |
print("Loading motion adapter...") | |
adapter = MotionAdapter().to(device, dtype) | |
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) | |
print("Loading diffusion pipeline...") | |
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear") | |
return adapter, pipe | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
traceback.print_exc() | |
return None, None | |
# We'll load the models on first use to avoid startup errors | |
adapter, pipe = None, None | |
# Define all required functions | |
def summarize(text): | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are an expert summarizer focused on efficiency and clarity. " | |
"Create concise narrative summaries that: " | |
"1. Capture all key points and main ideas " | |
"2. Omit examples, repetitions, and secondary details " | |
"3. Maintain logical flow and coherence " | |
"4. Use clear, direct language without markdown formatting" | |
) | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
"Please summarize the following text in 10-15 sentences. " | |
"Focus on essential information, exclude non-critical details, " | |
f"and maintain natural storytelling flow:\n\n{text}" | |
) | |
} | |
] | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
response = text_pipe( | |
prompt, | |
max_new_tokens=512, | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=3, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True | |
) | |
result = response[0]['generated_text'] | |
summary = result.split("assistant\n")[-1].strip() | |
return summary | |
def generate_story(prompt): | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are a skilled storyteller specializing in tight, impactful narratives. " | |
"Create engaging stories that:\n" | |
"1. Contain exactly 15-20 sentences\n" | |
"2. Keep each sentence under 77 tokens\n" | |
"3. Maintain strong narrative flow and pacing\n" | |
"4. Focus on vivid imagery and concrete details\n" | |
"5. Avoid filler words and redundant phrases\n" | |
"6. Use simple, direct language without markdown" | |
) | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
f"Craft a compelling short story based on this premise: {prompt}\n" | |
"Structure requirements:\n" | |
"- Strict 15-20 sentence count\n" | |
"- Maximum 77 tokens per sentence\n" | |
"- Clear beginning-middle-end structure\n" | |
"- Emphasis on showing rather than telling\n" | |
"Output plain text only, no markdown formatting." | |
) | |
} | |
] | |
chat_prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# First attempt to generate story | |
generated = text_pipe( | |
chat_prompt, | |
max_new_tokens=1024, | |
num_beams=5, | |
early_stopping=True, | |
no_repeat_ngram_size=4, | |
temperature=0.65, | |
top_k=30, | |
top_p=0.90, | |
do_sample=True, | |
length_penalty=0.9 | |
) | |
full_output = generated[0]['generated_text'] | |
story = full_output.split("assistant\n")[-1].strip() | |
# Process sentences and check constraints | |
sentences = [] | |
for s in story.split('.'): | |
if s.strip(): | |
sentences.append(s.strip()) | |
# Check sentence count constraint | |
sentence_count = len(sentences) | |
if sentence_count < 15 or sentence_count > 20: | |
# Regenerate with stricter parameters if constraints not met | |
enhanced_prompt = f"{prompt} (IMPORTANT: Story MUST have EXACTLY 15-20 sentences, and each sentence MUST be under 77 tokens. Current attempt had {sentence_count} sentences.)" | |
messages[1]["content"] = ( | |
f"Craft a compelling short story based on this premise: {enhanced_prompt}\n" | |
"Structure requirements:\n" | |
"- CRITICAL: Output EXACTLY 15-20 sentences, not more, not less\n" | |
"- CRITICAL: Maximum 77 tokens per sentence\n" | |
"- Clear beginning-middle-end structure\n" | |
"- Emphasis on showing rather than telling\n" | |
"Output plain text only, no markdown formatting." | |
) | |
chat_prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Try with more strict parameters | |
generated = text_pipe( | |
chat_prompt, | |
max_new_tokens=1024, | |
num_beams=7, | |
early_stopping=True, | |
no_repeat_ngram_size=4, | |
temperature=0.5, | |
top_k=20, | |
top_p=0.85, | |
do_sample=True, | |
length_penalty=1.0 | |
) | |
full_output = generated[0]['generated_text'] | |
story = full_output.split("assistant\n")[-1].strip() | |
sentences = [] | |
for s in story.split('.'): | |
if s.strip(): | |
sentences.append(s.strip()) | |
word_to_token_ratio = 1.3 | |
constrained_sentences = [] | |
for sentence in sentences: | |
words = sentence.split() | |
estimated_tokens = len(words) * word_to_token_ratio | |
if estimated_tokens > 77: | |
max_words = int(75 / word_to_token_ratio) | |
truncated = ' '.join(words[:max_words]) | |
constrained_sentences.append(truncated) | |
else: | |
constrained_sentences.append(sentence) | |
while len(constrained_sentences) < 15: | |
constrained_sentences.append("The story continued with unexpected twists and turns.") | |
constrained_sentences = constrained_sentences[:20] | |
formatted_sentences = [] | |
for s in constrained_sentences: | |
if not s.endswith(('.', '!', '?')): | |
s += '.' | |
formatted_sentences.append(s) | |
final_story = '\n'.join(formatted_sentences) | |
return final_story | |
def generate_video(summary): | |
global adapter, pipe | |
# Load models if not already loaded | |
if adapter is None or pipe is None: | |
adapter, pipe = load_models() | |
if adapter is None or pipe is None: | |
raise Exception("Failed to load models. Please check the logs for errors.") | |
def crossfade_transition(frames1, frames2, transition_length=10): | |
blended_frames = [] | |
frames1_np = [np.array(frame) for frame in frames1[-transition_length:]] | |
frames2_np = [np.array(frame) for frame in frames2[:transition_length]] | |
for i in range(transition_length): | |
alpha = i / transition_length | |
beta = 1.0 - alpha | |
blended = cv2.addWeighted(frames1_np[i], beta, frames2_np[i], alpha, 0) | |
blended_frames.append(Image.fromarray(blended)) | |
return blended_frames | |
# Sentence splitting | |
sentences = [] | |
current_sentence = "" | |
for char in summary: | |
current_sentence += char | |
if char in {'.', '!', '?'}: | |
sentences.append(current_sentence.strip()) | |
current_sentence = "" | |
sentences = [s.strip() for s in sentences if s.strip()] | |
print(f"Total scenes: {len(sentences)}") | |
# For development/testing purposes, limit the number of sentences | |
max_sentences = 5 | |
if len(sentences) > max_sentences: | |
print(f"Limiting to first {max_sentences} sentences for faster testing") | |
sentences = sentences[:max_sentences] | |
# Output config | |
output_dir = "generated_frames" | |
video_path = "generated_video.mp4" | |
os.makedirs(output_dir, exist_ok=True) | |
# Generate animation | |
all_frames = [] | |
previous_frames = None | |
transition_frames = 10 | |
batch_size = 1 | |
for i in range(0, len(sentences), batch_size): | |
batch_prompts = sentences[i : i + batch_size] | |
for idx, prompt in enumerate(batch_prompts): | |
print(f"Generating animation for prompt {i+idx+1}/{len(sentences)}: {prompt}") | |
try: | |
output = pipe( | |
prompt=prompt, | |
guidance_scale=1.0, | |
num_inference_steps=step, | |
width=256, | |
height=256, | |
) | |
frames = output.frames[0] | |
if previous_frames is not None: | |
transition = crossfade_transition(previous_frames, frames, transition_frames) | |
all_frames.extend(transition) | |
all_frames.extend(frames) | |
previous_frames = frames | |
except Exception as e: | |
print(f"Error generating frames for prompt: {prompt}") | |
print(f"Error details: {str(e)}") | |
# Continue with next prompt if one fails | |
# Save video | |
if not all_frames: | |
raise Exception("No frames were generated. Video creation failed.") | |
print(f"Saving video with {len(all_frames)} frames") | |
imageio.mimsave(video_path, all_frames, fps=8) | |
print(f"Video saved at {video_path}") | |
return video_path | |
def estimate_voiceover_words(video_path): | |
try: | |
# Get video duration in seconds | |
video = VideoFileClip(video_path) | |
duration_minutes = video.duration / 60 | |
# Estimate word count based on average speaking rate (150 words per minute) | |
estimated_words = int(duration_minutes * 150) | |
# Ensure a minimum word count | |
return max(estimated_words, 30) | |
except Exception as e: | |
print(f"Error estimating voiceover words: {str(e)}") | |
return 50 # Default fallback | |
def summary_of_summary(text, video_path): | |
target_word_count = estimate_voiceover_words(video_path) | |
messages_2 = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are an expert summarizer focused on brevity and clarity. " | |
f"Create a summary that is exactly around {target_word_count} words: " | |
"1. Capture the most essential information\n" | |
"2. Omit unnecessary details and examples\n" | |
"3. Maintain logical flow and coherence\n" | |
"4. Use clear, direct language" | |
) | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
f"Please summarize the following text in approximately {target_word_count} words:\n\n{text}" | |
) | |
} | |
] | |
# Generate prompt | |
prompt_for_resummarization = tokenizer.apply_chat_template( | |
messages_2, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Generate response | |
response = text_pipe( | |
prompt_for_resummarization, | |
max_new_tokens=target_word_count + 20, | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=3, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True | |
) | |
# Extract result | |
summary = response[0]['generated_text'].split("assistant\n")[-1].strip() | |
return summary | |
async def generate_audio_with_sentiment(text, sentiment_analyzer): | |
# Perform sentiment analysis on the text | |
sentiment = sentiment_analyzer(text)[0] | |
label = sentiment['label'] | |
confidence = sentiment['score'] | |
print(f"Sentiment: {label} with confidence {confidence:.2f}") | |
# Set voice parameters based on sentiment | |
if label == "POSITIVE": | |
voice = "en-US-AriaNeural" # Cheerful and energetic tone for positive sentiment | |
rate = "1.2" # Faster speech | |
pitch = "+2Hz" # Slightly higher pitch for a more positive tone | |
else: | |
voice = "en-US-GuyNeural" # Neutral tone for negative sentiment | |
rate = "0.9" # Slower speech | |
pitch = "-2Hz" # Lower pitch for a more somber tone | |
# Generate speech with EdgeTTS | |
communicate = edge_tts.Communicate(text, voice) | |
# Save the audio to a file | |
await communicate.save("output.mp3") | |
# Play the generated audio | |
return "output.mp3" | |
def combine_video_with_audio(video_path, audio_path, output_path): | |
# Load video and audio | |
video = VideoFileClip(video_path) | |
audio = AudioFileClip(audio_path) | |
# Set the audio to the video | |
video = video.set_audio(audio) | |
# Save the final video | |
video.write_videofile(output_path, codec='libx264', audio_codec='aac') | |
print("Video with audio saved successfully!") | |
# Main processing function | |
def create_story_video(prompt, progress=gr.Progress()): | |
if not prompt or len(prompt.strip()) < 5: | |
return "Please enter a longer prompt (at least 5 characters).", None, None | |
try: | |
print("Step 1: Generating story...") | |
progress(0, desc="Starting story generation...") | |
story = generate_story(prompt) | |
print("Story generation complete.") | |
progress(20, desc="Story generated successfully!") | |
print("Step 2: Generating video...") | |
progress(25, desc="Creating video animation (this may take several minutes)...") | |
video_path = generate_video(story) | |
print("Video generation complete.") | |
progress(60, desc="Video created successfully!") | |
print("Step 3: Summarizing for audio...") | |
progress(65, desc="Creating audio summary...") | |
audio_summary = summary_of_summary(story, video_path) | |
print("Audio summary complete.") | |
progress(80, desc="Creating audio narration...") | |
print("Step 4: Generating audio...") | |
try: | |
try: | |
loop = asyncio.get_event_loop() | |
except RuntimeError: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
audio_file = loop.run_until_complete( | |
generate_audio_with_sentiment(audio_summary, sentiment_analyzer) | |
) | |
print(f"Audio generated at: {audio_file}") | |
progress(90, desc="Audio created successfully!") | |
except Exception as e: | |
print(f"Audio generation error: {str(e)}") | |
return story, None, f"Audio generation failed: {str(e)}" | |
print("Step 5: Combining video and audio...") | |
progress(95, desc="Combining video and audio...") | |
output_path = 'final_video_with_audio.mp4' | |
combine_video_with_audio(video_path, audio_file, output_path) | |
print("Combination complete.") | |
progress(100, desc="Process complete!") | |
return story, output_path, audio_file # Return audio file path instead of summary | |
except Exception as e: | |
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return f"An error occurred: {str(e)}", None, None | |
# Sample prompt examples based on realistic scenarios | |
EXAMPLE_PROMPTS = [ | |
"A nurse discovers an unusual pattern in patient symptoms that leads to an important medical breakthrough.", | |
"During a home renovation, a family uncovers a time capsule from the previous owners.", | |
"A struggling local restaurant owner finds an innovative way to save their business during an economic downturn.", | |
"An environmental scientist tracks mysterious wildlife behavior that reveals concerning climate changes.", | |
"A community comes together to rebuild after a devastating natural disaster.", | |
] | |
# Create the Gradio interface | |
with gr.Blocks(title="AI Story Video Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🎬 AI Story Video Generator") | |
gr.Markdown("Enter a one-sentence prompt to generate a complete story with video and narration.") | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Your Story Idea", | |
placeholder="Enter a one-sentence prompt (e.g., 'A detective discovers a hidden room in an abandoned mansion')", | |
lines=2 | |
) | |
gr.Markdown("### Try these example prompts:") | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=[[prompt] for prompt in EXAMPLE_PROMPTS], | |
inputs=prompt_input, | |
label="Click any example to load it" | |
) | |
with gr.Row(): | |
generate_button = gr.Button("Generate Story Video", variant="primary") | |
clear_button = gr.Button("Clear", variant="secondary") | |
status_indicator = gr.Markdown("Ready to generate your story video...") | |
with gr.Tabs(): | |
with gr.TabItem("Results"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
video_output = gr.Video(label="Generated Video with Narration") | |
with gr.Column(scale=1): | |
story_output = gr.TextArea(label="Generated Story", lines=15, max_lines=30) | |
audio_output = gr.Audio(label="Audio Narration") # Changed to Audio | |
with gr.TabItem("Help & Information"): | |
gr.Markdown(""" | |
## How to use this tool | |
1. Enter a creative one-sentence story idea in the input box | |
2. Click "Generate Story Video" and wait for processing to complete | |
3. View your story, narration audio, and final video | |
## Processing Steps | |
- Story Generation: Expands your idea into a 15-20 sentence story | |
- Video Creation: Visualizes sentences with AI animation | |
- Audio Narration: Creates a voiceover with sentiment analysis | |
- Final Compilation: Combines video and audio | |
""") | |
def clear_outputs(): | |
return "", None, None | |
generate_button.click( | |
fn=create_story_video, | |
inputs=prompt_input, | |
outputs=[story_output, video_output, audio_output], # Updated to audio_output | |
api_name="generate" | |
) | |
clear_button.click( | |
fn=clear_outputs, | |
inputs=None, | |
outputs=[story_output, video_output, audio_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |