awacke1's picture
Update app.py
55c19e3 verified
raw
history blame
6.62 kB
# app.py
import gradio as gr
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import numpy as np
import os
import tempfile
import moviepy.editor as mpe
import nltk
from pydub import AudioSegment
import warnings
import asyncio
import edge_tts
import random
from openai import OpenAI
warnings.filterwarnings("ignore", category=UserWarning)
# Ensure NLTK data is downloaded
nltk.download('punkt')
# LLM Inference Class
class LLMInferenceNode:
def __init__(self):
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY")
self.huggingface_client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=self.huggingface_token,
)
self.sambanova_client = OpenAI(
api_key=self.sambanova_api_key,
base_url="https://api.sambanova.ai/v1",
)
def generate(self, input_text, long_talk=True, compress=False,
compression_level="medium", poster=False, prompt_type="Short",
provider="Hugging Face", model=None):
try:
# Define system message
system_message = "You are a helpful assistant. Try your best to give the best response possible to the user."
# Define base prompts based on type
prompts = {
"Short": """Create a brief, straightforward caption for this description, suitable for a text-to-image AI system.
Focus on the main elements, key characters, and overall scene without elaborate details.""",
"Long": """Create a detailed visually descriptive caption of this description for a text-to-image AI system.
Include detailed visual descriptions, cinematography, and lighting setup."""
}
base_prompt = prompts.get(prompt_type, prompts["Short"])
user_message = f"{base_prompt}\nDescription: {input_text}"
# Generate with selected provider
if provider == "Hugging Face":
client = self.huggingface_client
else:
client = self.sambanova_client
response = client.chat.completions.create(
model=model or "meta-llama/Meta-Llama-3.1-70B-Instruct",
max_tokens=1024,
temperature=1.0,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"An error occurred: {e}")
return f"Error occurred while processing the request: {str(e)}"
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
# Story generator
story_generator = pipeline(
'text-generation',
model='gpt2-large',
device=0 if device == 'cuda' else -1
)
# Stable Diffusion model
sd_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch_dtype
).to(device)
# Text-to-Speech function using edge_tts
async def _text2speech_async(text):
communicate = edge_tts.Communicate(text, voice="en-US-AriaNeural")
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_path = tmp_file.name
await communicate.save(tmp_path)
return tmp_path
def text2speech(text):
try:
output_path = asyncio.run(_text2speech_async(text))
return output_path
except Exception as e:
print(f"Error in text2speech: {str(e)}")
raise
def generate_story(prompt):
generated = story_generator(prompt, max_length=500, num_return_sequences=1)
story = generated[0]['generated_text']
return story
def split_story_into_sentences(story):
sentences = nltk.sent_tokenize(story)
return sentences
def generate_images(sentences):
images = []
for idx, sentence in enumerate(sentences):
image = sd_pipe(sentence).images[0]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{idx}.png")
image.save(temp_file.name)
images.append(temp_file.name)
return images
def generate_audio(story_text):
audio_path = text2speech(story_text)
audio = AudioSegment.from_file(audio_path)
total_duration = len(audio) / 1000
return audio_path, total_duration
def compute_sentence_durations(sentences, total_duration):
total_words = sum(len(sentence.split()) for sentence in sentences)
return [total_duration * (len(sentence.split()) / total_words) for sentence in sentences]
def create_video(images, durations, audio_path):
clips = [mpe.ImageClip(img).set_duration(dur) for img, dur in zip(images, durations)]
video = mpe.concatenate_videoclips(clips, method='compose')
audio = mpe.AudioFileClip(audio_path)
video = video.set_audio(audio)
output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4")
video.write_videofile(output_path, fps=1, codec='libx264')
return output_path
def process_pipeline(prompt, progress=gr.Progress()):
try:
story = generate_story(prompt)
sentences = split_story_into_sentences(story)
images = generate_images(sentences)
audio_path, total_duration = generate_audio(story)
durations = compute_sentence_durations(sentences, total_duration)
video_path = create_video(images, durations, audio_path)
return video_path
except Exception as e:
print(f"Error in process_pipeline: {str(e)}")
raise gr.Error(f"An error occurred: {str(e)}")
# Gradio Interface
title = """<h1 align="center">AI Story Video Generator ๐ŸŽฅ</h1>
<p align="center">Generate a story from a prompt, create images for each sentence, and produce a video with narration!</p>"""
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Enter a Prompt", lines=2)
generate_button = gr.Button("Generate Video")
with gr.Column():
video_output = gr.Video(label="Generated Video")
generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)
demo.launch(debug=True)