Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import imageio
|
9 |
+
import tempfile
|
10 |
+
import os
|
11 |
+
|
12 |
+
# Load ASR model (open-source, not Whisper)
|
13 |
+
asr_processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
14 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
15 |
+
|
16 |
+
# Load image captioning model
|
17 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
18 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
19 |
+
|
20 |
+
# Load Zeroscope video generation pipeline
|
21 |
+
video_pipe = StableDiffusionPipeline.from_pretrained(
|
22 |
+
"cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
23 |
+
).to("cuda")
|
24 |
+
video_pipe.scheduler = DDIMScheduler.from_config(video_pipe.scheduler.config)
|
25 |
+
|
26 |
+
# --- Helper functions ---
|
27 |
+
def transcribe_audio(audio_path):
|
28 |
+
waveform, rate = torchaudio.load(audio_path)
|
29 |
+
input_values = asr_processor(waveform[0], sampling_rate=rate, return_tensors="pt").input_values
|
30 |
+
logits = asr_model(input_values).logits
|
31 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
32 |
+
transcription = asr_processor.decode(predicted_ids[0])
|
33 |
+
return transcription
|
34 |
+
|
35 |
+
def describe_image(image):
|
36 |
+
inputs = blip_processor(image, return_tensors="pt")
|
37 |
+
out = blip_model.generate(**inputs)
|
38 |
+
description = blip_processor.decode(out[0], skip_special_tokens=True)
|
39 |
+
return description
|
40 |
+
|
41 |
+
def build_prompt(image_desc, voice_text, influencer_task):
|
42 |
+
return f"A cinematic video of {image_desc}. They are speaking about '{voice_text}'. Their daily routine: {influencer_task}."
|
43 |
+
|
44 |
+
def generate_video(prompt, job_id):
|
45 |
+
frames = video_pipe(prompt, num_inference_steps=25, height=512, width=768, num_frames=24).frames
|
46 |
+
temp_video_path = os.path.join(tempfile.gettempdir(), f"{job_id}_output.mp4")
|
47 |
+
imageio.mimsave(temp_video_path, [np.array(f) for f in frames], fps=8)
|
48 |
+
return temp_video_path
|
49 |
+
|
50 |
+
# --- Gradio interface function ---
|
51 |
+
def process_inputs(user_image, voice, influencer_tasks, job_id):
|
52 |
+
image_desc = describe_image(user_image)
|
53 |
+
voice_text = transcribe_audio(voice)
|
54 |
+
final_prompt = build_prompt(image_desc, voice_text, influencer_tasks)
|
55 |
+
video_path = generate_video(final_prompt, job_id)
|
56 |
+
return video_path, final_prompt
|
57 |
+
|
58 |
+
# --- Gradio UI ---
|
59 |
+
with gr.Blocks() as demo:
|
60 |
+
gr.Markdown("# π§βπ€ Influencer Video Generator")
|
61 |
+
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Column():
|
64 |
+
user_image = gr.Image(label="Upload Your Image", type="pil")
|
65 |
+
voice = gr.Audio(source="upload", label="Upload Your Voice (WAV/MP3)", type="filepath")
|
66 |
+
influencer_tasks = gr.Textbox(label="What does the influencer do daily?", placeholder="e.g., go to gym, film reels, drink coffee")
|
67 |
+
job_id = gr.Textbox(label="Job ID", placeholder="e.g., JOB-12345")
|
68 |
+
generate_btn = gr.Button("π₯ Generate Video")
|
69 |
+
with gr.Column():
|
70 |
+
output_video = gr.Video(label="Generated Video")
|
71 |
+
prompt_display = gr.Textbox(label="Generated Prompt")
|
72 |
+
|
73 |
+
generate_btn.click(fn=process_inputs,
|
74 |
+
inputs=[user_image, voice, influencer_tasks, job_id],
|
75 |
+
outputs=[output_video, prompt_display])
|
76 |
+
|
77 |
+
demo.launch()
|