File size: 6,624 Bytes
55c19e3
9e908c5
55c19e3
 
9e1ef69
9e908c5
 
df17f8f
a0010c7
9e1ef69
 
 
 
02a76d9
 
55c19e3
 
9e1ef69
 
 
 
 
 
55c19e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e1ef69
 
 
9e908c5
9e1ef69
02a76d9
 
 
 
 
9e1ef69
 
02a76d9
55c19e3
02a76d9
55c19e3
9e1ef69
02a76d9
55c19e3
 
 
 
 
 
 
83acbfc
a0010c7
02a76d9
a0010c7
 
 
 
83acbfc
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
02a76d9
9e1ef69
 
 
 
 
 
 
55c19e3
9e1ef69
 
 
 
55c19e3
df17f8f
9e1ef69
55c19e3
9e1ef69
 
 
 
 
 
55c19e3
bd79c84
a0010c7
dafecc5
 
 
 
 
 
9e1ef69
a0010c7
9e1ef69
a0010c7
55c19e3
 
9e1ef69
55c19e3
9e1ef69
 
 
 
 
 
 
 
 
 
 
 
 
55c19e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# app.py
import gradio as gr
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import numpy as np
import os
import tempfile
import moviepy.editor as mpe
import nltk
from pydub import AudioSegment
import warnings
import asyncio
import edge_tts
import random
from openai import OpenAI

warnings.filterwarnings("ignore", category=UserWarning)

# Ensure NLTK data is downloaded
nltk.download('punkt')

# LLM Inference Class
class LLMInferenceNode:
    def __init__(self):
        self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
        self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY")
        
        self.huggingface_client = OpenAI(
            base_url="https://api-inference.huggingface.co/v1/",
            api_key=self.huggingface_token,
        )
        self.sambanova_client = OpenAI(
            api_key=self.sambanova_api_key,
            base_url="https://api.sambanova.ai/v1",
        )

    def generate(self, input_text, long_talk=True, compress=False, 
                compression_level="medium", poster=False, prompt_type="Short",
                provider="Hugging Face", model=None):
        try:
            # Define system message
            system_message = "You are a helpful assistant. Try your best to give the best response possible to the user."
            
            # Define base prompts based on type
            prompts = {
                "Short": """Create a brief, straightforward caption for this description, suitable for a text-to-image AI system. 
                        Focus on the main elements, key characters, and overall scene without elaborate details.""",
                "Long": """Create a detailed visually descriptive caption of this description for a text-to-image AI system. 
                        Include detailed visual descriptions, cinematography, and lighting setup."""
            }
            
            base_prompt = prompts.get(prompt_type, prompts["Short"])
            user_message = f"{base_prompt}\nDescription: {input_text}"

            # Generate with selected provider
            if provider == "Hugging Face":
                client = self.huggingface_client
            else:
                client = self.sambanova_client

            response = client.chat.completions.create(
                model=model or "meta-llama/Meta-Llama-3.1-70B-Instruct",
                max_tokens=1024,
                temperature=1.0,
                messages=[
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": user_message},
                ]
            )
            
            return response.choices[0].message.content.strip()

        except Exception as e:
            print(f"An error occurred: {e}")
            return f"Error occurred while processing the request: {str(e)}"

# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

# Story generator
story_generator = pipeline(
    'text-generation',
    model='gpt2-large',
    device=0 if device == 'cuda' else -1
)

# Stable Diffusion model
sd_pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch_dtype
).to(device)

# Text-to-Speech function using edge_tts
async def _text2speech_async(text):
    communicate = edge_tts.Communicate(text, voice="en-US-AriaNeural")
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
        tmp_path = tmp_file.name
        await communicate.save(tmp_path)
    return tmp_path

def text2speech(text):
    try:
        output_path = asyncio.run(_text2speech_async(text))
        return output_path
    except Exception as e:
        print(f"Error in text2speech: {str(e)}")
        raise

def generate_story(prompt):
    generated = story_generator(prompt, max_length=500, num_return_sequences=1)
    story = generated[0]['generated_text']
    return story

def split_story_into_sentences(story):
    sentences = nltk.sent_tokenize(story)
    return sentences

def generate_images(sentences):
    images = []
    for idx, sentence in enumerate(sentences):
        image = sd_pipe(sentence).images[0]
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{idx}.png")
        image.save(temp_file.name)
        images.append(temp_file.name)
    return images

def generate_audio(story_text):
    audio_path = text2speech(story_text)
    audio = AudioSegment.from_file(audio_path)
    total_duration = len(audio) / 1000
    return audio_path, total_duration

def compute_sentence_durations(sentences, total_duration):
    total_words = sum(len(sentence.split()) for sentence in sentences)
    return [total_duration * (len(sentence.split()) / total_words) for sentence in sentences]

def create_video(images, durations, audio_path):
    clips = [mpe.ImageClip(img).set_duration(dur) for img, dur in zip(images, durations)]
    video = mpe.concatenate_videoclips(clips, method='compose')
    audio = mpe.AudioFileClip(audio_path)
    video = video.set_audio(audio)
    output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4")
    video.write_videofile(output_path, fps=1, codec='libx264')
    return output_path

def process_pipeline(prompt, progress=gr.Progress()):
    try:
        story = generate_story(prompt)
        sentences = split_story_into_sentences(story)
        images = generate_images(sentences)
        audio_path, total_duration = generate_audio(story)
        durations = compute_sentence_durations(sentences, total_duration)
        video_path = create_video(images, durations, audio_path)
        return video_path
    except Exception as e:
        print(f"Error in process_pipeline: {str(e)}")
        raise gr.Error(f"An error occurred: {str(e)}")

# Gradio Interface
title = """<h1 align="center">AI Story Video Generator ๐ŸŽฅ</h1>
<p align="center">Generate a story from a prompt, create images for each sentence, and produce a video with narration!</p>"""

with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
    gr.HTML(title)
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Enter a Prompt", lines=2)
            generate_button = gr.Button("Generate Video")
        with gr.Column():
            video_output = gr.Video(label="Generated Video")
    
    generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)

demo.launch(debug=True)