Spaces:
Running
Running
# app.py | |
import gradio as gr | |
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM | |
from diffusers import StableDiffusionPipeline, DiffusionPipeline | |
import torch | |
from PIL import Image | |
import numpy as np | |
import os | |
import tempfile | |
import moviepy.editor as mpe | |
import nltk | |
from pydub import AudioSegment | |
import warnings | |
import asyncio | |
import edge_tts | |
import random | |
from openai import OpenAI | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Ensure NLTK data is downloaded | |
nltk.download('punkt') | |
# LLM Inference Class | |
class LLMInferenceNode: | |
def __init__(self): | |
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY") | |
self.huggingface_client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1/", | |
api_key=self.huggingface_token, | |
) | |
self.sambanova_client = OpenAI( | |
api_key=self.sambanova_api_key, | |
base_url="https://api.sambanova.ai/v1", | |
) | |
def generate(self, input_text, long_talk=True, compress=False, | |
compression_level="medium", poster=False, prompt_type="Short", | |
provider="Hugging Face", model=None): | |
try: | |
# Define system message | |
system_message = "You are a helpful assistant. Try your best to give the best response possible to the user." | |
# Define base prompts based on type | |
prompts = { | |
"Short": """Create a brief, straightforward caption for this description, suitable for a text-to-image AI system. | |
Focus on the main elements, key characters, and overall scene without elaborate details.""", | |
"Long": """Create a detailed visually descriptive caption of this description for a text-to-image AI system. | |
Include detailed visual descriptions, cinematography, and lighting setup.""" | |
} | |
base_prompt = prompts.get(prompt_type, prompts["Short"]) | |
user_message = f"{base_prompt}\nDescription: {input_text}" | |
# Generate with selected provider | |
if provider == "Hugging Face": | |
client = self.huggingface_client | |
else: | |
client = self.sambanova_client | |
response = client.chat.completions.create( | |
model=model or "meta-llama/Meta-Llama-3.1-70B-Instruct", | |
max_tokens=1024, | |
temperature=1.0, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message}, | |
] | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return f"Error occurred while processing the request: {str(e)}" | |
# Initialize models | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
# Story generator | |
story_generator = pipeline( | |
'text-generation', | |
model='gpt2-large', | |
device=0 if device == 'cuda' else -1 | |
) | |
# Stable Diffusion model | |
sd_pipe = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch_dtype | |
).to(device) | |
# Text-to-Speech function using edge_tts | |
async def _text2speech_async(text): | |
communicate = edge_tts.Communicate(text, voice="en-US-AriaNeural") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: | |
tmp_path = tmp_file.name | |
await communicate.save(tmp_path) | |
return tmp_path | |
def text2speech(text): | |
try: | |
output_path = asyncio.run(_text2speech_async(text)) | |
return output_path | |
except Exception as e: | |
print(f"Error in text2speech: {str(e)}") | |
raise | |
def generate_story(prompt): | |
generated = story_generator(prompt, max_length=500, num_return_sequences=1) | |
story = generated[0]['generated_text'] | |
return story | |
def split_story_into_sentences(story): | |
sentences = nltk.sent_tokenize(story) | |
return sentences | |
def generate_images(sentences): | |
images = [] | |
for idx, sentence in enumerate(sentences): | |
image = sd_pipe(sentence).images[0] | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{idx}.png") | |
image.save(temp_file.name) | |
images.append(temp_file.name) | |
return images | |
def generate_audio(story_text): | |
audio_path = text2speech(story_text) | |
audio = AudioSegment.from_file(audio_path) | |
total_duration = len(audio) / 1000 | |
return audio_path, total_duration | |
def compute_sentence_durations(sentences, total_duration): | |
total_words = sum(len(sentence.split()) for sentence in sentences) | |
return [total_duration * (len(sentence.split()) / total_words) for sentence in sentences] | |
def create_video(images, durations, audio_path): | |
clips = [mpe.ImageClip(img).set_duration(dur) for img, dur in zip(images, durations)] | |
video = mpe.concatenate_videoclips(clips, method='compose') | |
audio = mpe.AudioFileClip(audio_path) | |
video = video.set_audio(audio) | |
output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4") | |
video.write_videofile(output_path, fps=1, codec='libx264') | |
return output_path | |
def process_pipeline(prompt, progress=gr.Progress()): | |
try: | |
story = generate_story(prompt) | |
sentences = split_story_into_sentences(story) | |
images = generate_images(sentences) | |
audio_path, total_duration = generate_audio(story) | |
durations = compute_sentence_durations(sentences, total_duration) | |
video_path = create_video(images, durations, audio_path) | |
return video_path | |
except Exception as e: | |
print(f"Error in process_pipeline: {str(e)}") | |
raise gr.Error(f"An error occurred: {str(e)}") | |
# Gradio Interface | |
title = """<h1 align="center">AI Story Video Generator ๐ฅ</h1> | |
<p align="center">Generate a story from a prompt, create images for each sentence, and produce a video with narration!</p>""" | |
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Enter a Prompt", lines=2) | |
generate_button = gr.Button("Generate Video") | |
with gr.Column(): | |
video_output = gr.Video(label="Generated Video") | |
generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output) | |
demo.launch(debug=True) |