video-gen / app.py
ans123's picture
Create app.py
413e121 verified
raw
history blame
3.48 kB
import gradio as gr
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torchaudio
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
import numpy as np
import imageio
import tempfile
import os
# Load ASR model (open-source, not Whisper)
asr_processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
asr_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
# Load image captioning model
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Load Zeroscope video generation pipeline
video_pipe = StableDiffusionPipeline.from_pretrained(
"cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
).to("cuda")
video_pipe.scheduler = DDIMScheduler.from_config(video_pipe.scheduler.config)
# --- Helper functions ---
def transcribe_audio(audio_path):
waveform, rate = torchaudio.load(audio_path)
input_values = asr_processor(waveform[0], sampling_rate=rate, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.decode(predicted_ids[0])
return transcription
def describe_image(image):
inputs = blip_processor(image, return_tensors="pt")
out = blip_model.generate(**inputs)
description = blip_processor.decode(out[0], skip_special_tokens=True)
return description
def build_prompt(image_desc, voice_text, influencer_task):
return f"A cinematic video of {image_desc}. They are speaking about '{voice_text}'. Their daily routine: {influencer_task}."
def generate_video(prompt, job_id):
frames = video_pipe(prompt, num_inference_steps=25, height=512, width=768, num_frames=24).frames
temp_video_path = os.path.join(tempfile.gettempdir(), f"{job_id}_output.mp4")
imageio.mimsave(temp_video_path, [np.array(f) for f in frames], fps=8)
return temp_video_path
# --- Gradio interface function ---
def process_inputs(user_image, voice, influencer_tasks, job_id):
image_desc = describe_image(user_image)
voice_text = transcribe_audio(voice)
final_prompt = build_prompt(image_desc, voice_text, influencer_tasks)
video_path = generate_video(final_prompt, job_id)
return video_path, final_prompt
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# πŸ§‘β€πŸŽ€ Influencer Video Generator")
with gr.Row():
with gr.Column():
user_image = gr.Image(label="Upload Your Image", type="pil")
voice = gr.Audio(source="upload", label="Upload Your Voice (WAV/MP3)", type="filepath")
influencer_tasks = gr.Textbox(label="What does the influencer do daily?", placeholder="e.g., go to gym, film reels, drink coffee")
job_id = gr.Textbox(label="Job ID", placeholder="e.g., JOB-12345")
generate_btn = gr.Button("πŸŽ₯ Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
prompt_display = gr.Textbox(label="Generated Prompt")
generate_btn.click(fn=process_inputs,
inputs=[user_image, voice, influencer_tasks, job_id],
outputs=[output_video, prompt_display])
demo.launch()