File size: 8,611 Bytes
5f0b0b8
 
 
 
 
 
 
 
 
 
 
 
4226ecb
85ec17d
5f0b0b8
 
 
 
 
14984e1
 
 
 
 
 
 
 
 
 
 
 
5f0b0b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e19cf90
 
5f0b0b8
 
 
 
 
4226ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e19cf90
 
 
4226ecb
e19cf90
4226ecb
 
 
 
 
 
 
 
 
 
e19cf90
 
 
4226ecb
e19cf90
4226ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Import libraries
import os
import gradio as gr
import torch
import soundfile as sf
import numpy as np
from PIL import Image
import torch.nn.functional as F
import logging
from scipy.io.wavfile import write as write_wav
from scipy import signal
from moviepy.editor import VideoFileClip, AudioFileClip
import requests
from audiocraft.models import AudioGen, MusicGen  # Use audiocraft for AudioGen and MusicGen

# Set up logging for better debug tracking
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()

# Download Places365 class labels
try:
    logging.info("Downloading Places365 class labels...")
    url = "http://places2.csail.mit.edu/models_places365/categories_places365.txt"
    response = requests.get(url)
    with open("categories_places365.txt", "wb") as f:
        f.write(response.content)
    logging.info("Places365 class labels downloaded successfully.")
except Exception as e:
    logging.error(f"Error downloading Places365 class labels: {e}")
    raise

# Load Places365 model for scene detection (on CPU to save GPU memory)
try:
    logging.info("Loading Places365 model for scene detection...")
    places365 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
    places365.eval()
    places365.to("cpu")  # Move to CPU
    logging.info("Places365 model loaded successfully.")
except Exception as e:
    logging.error(f"Error loading Places365 model: {e}")
    raise

# Load Places365 class labels
with open("categories_places365.txt", "r") as f:
    SCENE_CLASSES = [line.strip().split(" ")[0][3:] for line in f.readlines()]

# Load AudioGen Medium and MusicGen Medium models
try:
    logging.info("Loading AudioGen Medium and MusicGen Medium models...")
    audiogen_model = AudioGen.get_pretrained("facebook/audiogen-medium")
    musicgen_model = MusicGen.get_pretrained("facebook/musicgen-medium")
    logging.info("AudioGen Medium and MusicGen Medium models loaded successfully.")
except Exception as e:
    logging.error(f"Error loading AudioGen/MusicGen models: {e}")
    raise

# Function to classify a frame using Places365
def classify_frame(frame):
    try:
        preprocess = transforms.Compose([
            transforms.Resize(128),  # Smaller resolution
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        img = Image.fromarray(frame)
        img = preprocess(img).unsqueeze(0)
        with torch.no_grad():
            output = places365(img.to("cpu"))  # Ensure inference on CPU
        probabilities = F.softmax(output, dim=1)
        _, predicted = torch.max(probabilities, 1)
        predicted_index = predicted.item()

        # Ensure the predicted index is within the range of SCENE_CLASSES
        if predicted_index >= len(SCENE_CLASSES) or predicted_index < 0:
            logging.warning(f"Predicted class index {predicted_index} is out of range. Defaulting to 'nature'.")
            return "nature"  # Default scene type

        scene_type = SCENE_CLASSES[predicted_index]
        logging.info(f"Predicted scene: {scene_type}")
        return scene_type
    except Exception as e:
        logging.error(f"Error classifying frame: {e}")
        raise

# Function to analyze video content and return the scene type using Places365
def analyze_video(video_path):
    try:
        logging.info(f"Analyzing video: {video_path}")
        clip = VideoFileClip(video_path)
        frame = clip.get_frame(0)  # Get the first frame
        frame = Image.fromarray(frame)  # Convert to PIL image
        frame = np.array(frame.resize((128, 128)))  # Resize to reduce memory usage

        # Classify the frame using Places365
        scene_type = classify_frame(frame)
        logging.info(f"Scene type detected: {scene_type}")
        return scene_type
    except Exception as e:
        logging.error(f"Error analyzing video: {e}")
        raise

# Function to generate audio using AudioGen Medium
def generate_audio_audiogen(scene, duration=10):
    try:
        logging.info(f"Generating audio for scene: {scene} using AudioGen Medium...")
        audiogen_model.set_generation_params(duration=duration)
        descriptions = [f"Ambient sounds of {scene}"]
        wav = audiogen_model.generate(descriptions)  # Generate audio
        audio_path = "generated_audio_audiogen.wav"
        sf.write(audio_path, wav.squeeze().cpu().numpy(), 32000)  # Save as WAV file
        logging.info(f"Audio generated and saved to: {audio_path}")
        return audio_path
    except Exception as e:
        logging.error(f"Error generating audio with AudioGen Medium: {e}")
        raise

# Function to generate music using MusicGen Medium
def generate_music_musicgen(scene, duration=10):
    try:
        logging.info(f"Generating music for scene: {scene} using MusicGen Medium...")
        musicgen_model.set_generation_params(duration=duration)
        descriptions = [f"Calm music for {scene}"]
        wav = musicgen_model.generate(descriptions)  # Generate music
        music_path = "generated_music_musicgen.wav"
        sf.write(music_path, wav.squeeze().cpu().numpy(), 32000)  # Save as WAV file
        logging.info(f"Music generated and saved to: {music_path}")
        return music_path
    except Exception as e:
        logging.error(f"Error generating music with MusicGen Medium: {e}")
        raise

# Function to merge audio and video into a final video file using moviepy
def merge_audio_video(video_path, audio_path, output_path="output.mp4"):
    try:
        logging.info("Merging audio and video using moviepy...")
        video_clip = VideoFileClip(video_path)
        audio_clip = AudioFileClip(audio_path)
        final_clip = video_clip.set_audio(audio_clip)
        final_clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
        logging.info(f"Final video saved to: {output_path}")
        return output_path
    except Exception as e:
        logging.error(f"Error merging audio and video: {e}")
        return None

# Main processing function to handle video upload, scene analysis, and video output
def process_video(video_path, progress=gr.Progress()):
    try:
        progress(0.1, desc="Starting video processing...")
        logging.info("Starting video processing...")

        # Analyze the video to determine the scene type
        progress(0.3, desc="Analyzing video...")
        scene_type = analyze_video(video_path)

        # Generate audio using AudioGen Medium
        progress(0.5, desc="Generating audio...")
        audio_path = generate_audio_audiogen(scene_type, duration=10)

        # Generate music using MusicGen Medium
        progress(0.7, desc="Generating music...")
        music_path = generate_music_musicgen(scene_type, duration=10)

        # Merge the generated audio with the video and output the final video
        progress(0.9, desc="Merging audio and video...")
        output_path = merge_audio_video(video_path, music_path)
        if not output_path:
            return "Error: Failed to merge audio and video.", "Logs: Merge failed."

        logging.info("Video processing completed successfully.")
        return output_path, "Logs: Processing completed."
    except Exception as e:
        logging.error(f"Error in process_video: {e}")
        return f"An error occurred during processing: {e}", f"Logs: {e}"

# Gradio UI for video upload
def gradio_interface(video_file, progress=gr.Progress()):
    try:
        progress(0.1, desc="Starting video processing...")
        logging.info("Gradio interface triggered.")
        output_video, logs = process_video(video_file, progress)
        return output_video, logs
    except Exception as e:
        logging.error(f"Error in Gradio interface: {e}")
        return f"An error occurred: {e}", f"Logs: {e}"

# Launch Gradio app
try:
    logging.info("Launching Gradio app...")
    interface = gr.Interface(
        fn=gradio_interface,
        inputs=[gr.Video(label="Upload Video")],
        outputs=[gr.Video(label="Output Video with Generated Audio"), gr.Textbox(label="Logs", lines=10)],
        title="Video to Video with Generated Audio and Music",
        description="Upload a video, and this app will analyze it and generate matching audio and music using AudioGen Medium and MusicGen Medium."
    )
    interface.queue()  # Enable queue for long-running tasks
    interface.launch(share=True)  # Launch the app
except Exception as e:
    logging.error(f"Error launching Gradio app: {e}")
    raise