File size: 5,336 Bytes
9e908c5
9e1ef69
 
 
9e908c5
 
df17f8f
a0010c7
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
 
9e908c5
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
 
9e908c5
83acbfc
a0010c7
9e1ef69
 
 
a0010c7
9e1ef69
 
a0010c7
 
 
 
83acbfc
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df17f8f
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
9e908c5
9e1ef69
a0010c7
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
a0010c7
9e1ef69
a0010c7
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import numpy as np
import os
import tempfile
import moviepy.editor as mpe
import soundfile as sf
import nltk
from pydub import AudioSegment
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Ensure NLTK data is downloaded
nltk.download('punkt')

# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

# Story generator
story_generator = pipeline('text-generation', model='gpt2-large', device=0 if device=='cuda' else -1)

# Stable Diffusion model
sd_model_id = "runwayml/stable-diffusion-v1-5"
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id, torch_dtype=torch_dtype)
sd_pipe = sd_pipe.to(device)

# Text-to-Speech model
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts", torch_dtype=torch_dtype)
tts_model = tts_model.to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", torch_dtype=torch_dtype)
vocoder = vocoder.to(device)

def text2speech(text):
    try:
        inputs = tts_processor(text=text, return_tensors="pt").to(device)
        speaker_embeddings = torch.zeros((1, 512), device=device)
        speech = tts_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
        
        output_path = os.path.join(tempfile.gettempdir(), "speech_output.wav")
        sf.write(output_path, speech.cpu().numpy(), samplerate=16000)
        return output_path
    except Exception as e:
        print(f"Error in text2speech: {str(e)}")
        raise

def generate_story(prompt):
    generated = story_generator(prompt, max_length=500, num_return_sequences=1)
    story = generated[0]['generated_text']
    return story

def split_story_into_sentences(story):
    sentences = nltk.sent_tokenize(story)
    return sentences

def generate_images(sentences):
    images = []
    for idx, sentence in enumerate(sentences):
        image = sd_pipe(sentence).images[0]
        # Save image to temporary file
        temp_file = tempfile.NamedTemporaryFile(suffix=f"_{idx}.png", delete=False)
        image.save(temp_file.name)
        images.append(temp_file.name)
    return images

def generate_audio(story_text):
    audio_path = text2speech(story_text)
    audio = AudioSegment.from_file(audio_path)
    total_duration = len(audio) / 1000  # duration in seconds
    return audio_path, total_duration

def compute_sentence_durations(sentences, total_duration):
    total_words = sum(len(sentence.split()) for sentence in sentences)
    sentence_durations = []
    for sentence in sentences:
        num_words = len(sentence.split())
        duration = total_duration * (num_words / total_words)
        sentence_durations.append(duration)
    return sentence_durations

def create_video(images, durations, audio_path):
    clips = []
    for image_path, duration in zip(images, durations):
        clip = mpe.ImageClip(image_path).set_duration(duration)
        clips.append(clip)
    video = mpe.concatenate_videoclips(clips, method='compose')
    audio = mpe.AudioFileClip(audio_path)
    video = video.set_audio(audio)
    # Save video
    output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4")
    video.write_videofile(output_path, fps=1, codec='libx264')
    return output_path

def process_pipeline(prompt, progress=gr.Progress(track_tqdm=True)):
    try:
        with gr.Progress(track_tqdm=True, desc="Generating Story"):
            story = generate_story(prompt)
        with gr.Progress(track_tqdm=True, desc="Splitting Story into Sentences"):
            sentences = split_story_into_sentences(story)
        with gr.Progress(track_tqdm=True, desc="Generating Images for Sentences"):
            images = generate_images(sentences)
        with gr.Progress(track_tqdm=True, desc="Generating Audio"):
            audio_path, total_duration = generate_audio(story)
        with gr.Progress(track_tqdm=True, desc="Computing Durations"):
            durations = compute_sentence_durations(sentences, total_duration)
        with gr.Progress(track_tqdm=True, desc="Creating Video"):
            video_path = create_video(images, durations, audio_path)
        return video_path
    except Exception as e:
        print(f"Error in process_pipeline: {str(e)}")
        raise gr.Error(f"An error occurred: {str(e)}")

title = """<h1 align="center">AI Story Video Generator ๐ŸŽฅ</h1>
<p align="center">
Generate a story from a prompt, create images for each sentence, and produce a video with narration!
</p>
"""

with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
    gr.HTML(title)
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Enter a Prompt", lines=2)
            generate_button = gr.Button("Generate Video")
            progress_bar = gr.Markdown("")
        with gr.Column():
            video_output = gr.Video(label="Generated Video")
    
    generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)

demo.launch(debug=True)