ans123 commited on
Commit
413e121
Β·
verified Β·
1 Parent(s): 94d08d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC
3
+ import torch
4
+ import torchaudio
5
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
6
+ from PIL import Image
7
+ import numpy as np
8
+ import imageio
9
+ import tempfile
10
+ import os
11
+
12
+ # Load ASR model (open-source, not Whisper)
13
+ asr_processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
14
+ asr_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
15
+
16
+ # Load image captioning model
17
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
18
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
19
+
20
+ # Load Zeroscope video generation pipeline
21
+ video_pipe = StableDiffusionPipeline.from_pretrained(
22
+ "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
23
+ ).to("cuda")
24
+ video_pipe.scheduler = DDIMScheduler.from_config(video_pipe.scheduler.config)
25
+
26
+ # --- Helper functions ---
27
+ def transcribe_audio(audio_path):
28
+ waveform, rate = torchaudio.load(audio_path)
29
+ input_values = asr_processor(waveform[0], sampling_rate=rate, return_tensors="pt").input_values
30
+ logits = asr_model(input_values).logits
31
+ predicted_ids = torch.argmax(logits, dim=-1)
32
+ transcription = asr_processor.decode(predicted_ids[0])
33
+ return transcription
34
+
35
+ def describe_image(image):
36
+ inputs = blip_processor(image, return_tensors="pt")
37
+ out = blip_model.generate(**inputs)
38
+ description = blip_processor.decode(out[0], skip_special_tokens=True)
39
+ return description
40
+
41
+ def build_prompt(image_desc, voice_text, influencer_task):
42
+ return f"A cinematic video of {image_desc}. They are speaking about '{voice_text}'. Their daily routine: {influencer_task}."
43
+
44
+ def generate_video(prompt, job_id):
45
+ frames = video_pipe(prompt, num_inference_steps=25, height=512, width=768, num_frames=24).frames
46
+ temp_video_path = os.path.join(tempfile.gettempdir(), f"{job_id}_output.mp4")
47
+ imageio.mimsave(temp_video_path, [np.array(f) for f in frames], fps=8)
48
+ return temp_video_path
49
+
50
+ # --- Gradio interface function ---
51
+ def process_inputs(user_image, voice, influencer_tasks, job_id):
52
+ image_desc = describe_image(user_image)
53
+ voice_text = transcribe_audio(voice)
54
+ final_prompt = build_prompt(image_desc, voice_text, influencer_tasks)
55
+ video_path = generate_video(final_prompt, job_id)
56
+ return video_path, final_prompt
57
+
58
+ # --- Gradio UI ---
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# πŸ§‘β€πŸŽ€ Influencer Video Generator")
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ user_image = gr.Image(label="Upload Your Image", type="pil")
65
+ voice = gr.Audio(source="upload", label="Upload Your Voice (WAV/MP3)", type="filepath")
66
+ influencer_tasks = gr.Textbox(label="What does the influencer do daily?", placeholder="e.g., go to gym, film reels, drink coffee")
67
+ job_id = gr.Textbox(label="Job ID", placeholder="e.g., JOB-12345")
68
+ generate_btn = gr.Button("πŸŽ₯ Generate Video")
69
+ with gr.Column():
70
+ output_video = gr.Video(label="Generated Video")
71
+ prompt_display = gr.Textbox(label="Generated Prompt")
72
+
73
+ generate_btn.click(fn=process_inputs,
74
+ inputs=[user_image, voice, influencer_tasks, job_id],
75
+ outputs=[output_video, prompt_display])
76
+
77
+ demo.launch()