Ahmadkhan12 commited on
Commit
1a9c8fd
·
verified ·
1 Parent(s): e809fa2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import libraries
2
+ import os
3
+ import gradio as gr
4
+ import torch
5
+ import soundfile as sf
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch.nn.functional as F
9
+ import logging
10
+ from scipy.io.wavfile import write as write_wav
11
+ from scipy import signal
12
+ from moviepy.editor import VideoFileClip, AudioFileClip
13
+ import requests
14
+ from audiocraft.models import AudioGen, MusicGen # Use audiocraft for AudioGen and MusicGen
15
+
16
+ # Set up logging for better debug tracking
17
+ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
18
+ logger = logging.getLogger()
19
+
20
+ # Download Places365 class labels
21
+ try:
22
+ logging.info("Downloading Places365 class labels...")
23
+ url = "http://places2.csail.mit.edu/models_places365/categories_places365.txt"
24
+ response = requests.get(url)
25
+ with open("categories_places365.txt", "wb") as f:
26
+ f.write(response.content)
27
+ logging.info("Places365 class labels downloaded successfully.")
28
+ except Exception as e:
29
+ logging.error(f"Error downloading Places365 class labels: {e}")
30
+ raise
31
+
32
+ # Load Places365 model for scene detection (on CPU to save GPU memory)
33
+ try:
34
+ logging.info("Loading Places365 model for scene detection...")
35
+ places365 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
36
+ places365.eval()
37
+ places365.to("cpu") # Move to CPU
38
+ logging.info("Places365 model loaded successfully.")
39
+ except Exception as e:
40
+ logging.error(f"Error loading Places365 model: {e}")
41
+ raise
42
+
43
+ # Load Places365 class labels
44
+ with open("categories_places365.txt", "r") as f:
45
+ SCENE_CLASSES = [line.strip().split(" ")[0][3:] for line in f.readlines()]
46
+
47
+ # Load AudioGen Medium and MusicGen Medium models
48
+ try:
49
+ logging.info("Loading AudioGen Medium and MusicGen Medium models...")
50
+ audiogen_model = AudioGen.get_pretrained("facebook/audiogen-medium")
51
+ musicgen_model = MusicGen.get_pretrained("facebook/musicgen-medium")
52
+ logging.info("AudioGen Medium and MusicGen Medium models loaded successfully.")
53
+ except Exception as e:
54
+ logging.error(f"Error loading AudioGen/MusicGen models: {e}")
55
+ raise
56
+
57
+ # Function to classify a frame using Places365
58
+ def classify_frame(frame):
59
+ try:
60
+ preprocess = transforms.Compose([
61
+ transforms.Resize(128), # Smaller resolution
62
+ transforms.CenterCrop(128),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
65
+ ])
66
+ img = Image.fromarray(frame)
67
+ img = preprocess(img).unsqueeze(0)
68
+ with torch.no_grad():
69
+ output = places365(img.to("cpu")) # Ensure inference on CPU
70
+ probabilities = F.softmax(output, dim=1)
71
+ _, predicted = torch.max(probabilities, 1)
72
+ predicted_index = predicted.item()
73
+
74
+ # Ensure the predicted index is within the range of SCENE_CLASSES
75
+ if predicted_index >= len(SCENE_CLASSES) or predicted_index < 0:
76
+ logging.warning(f"Predicted class index {predicted_index} is out of range. Defaulting to 'nature'.")
77
+ return "nature" # Default scene type
78
+
79
+ scene_type = SCENE_CLASSES[predicted_index]
80
+ logging.info(f"Predicted scene: {scene_type}")
81
+ return scene_type
82
+ except Exception as e:
83
+ logging.error(f"Error classifying frame: {e}")
84
+ raise
85
+
86
+ # Function to analyze video content and return the scene type using Places365
87
+ def analyze_video(video_path):
88
+ try:
89
+ logging.info(f"Analyzing video: {video_path}")
90
+ clip = VideoFileClip(video_path)
91
+ frame = clip.get_frame(0) # Get the first frame
92
+ frame = Image.fromarray(frame) # Convert to PIL image
93
+ frame = np.array(frame.resize((128, 128))) # Resize to reduce memory usage
94
+
95
+ # Classify the frame using Places365
96
+ scene_type = classify_frame(frame)
97
+ logging.info(f"Scene type detected: {scene_type}")
98
+ return scene_type
99
+ except Exception as e:
100
+ logging.error(f"Error analyzing video: {e}")
101
+ raise
102
+
103
+ # Function to generate audio using AudioGen Medium
104
+ def generate_audio_audiogen(scene, duration=10):
105
+ try:
106
+ logging.info(f"Generating audio for scene: {scene} using AudioGen Medium...")
107
+ audiogen_model.set_generation_params(duration=duration)
108
+ descriptions = [f"Ambient sounds of {scene}"]
109
+ wav = audiogen_model.generate(descriptions) # Generate audio
110
+ audio_path = "generated_audio_audiogen.wav"
111
+ sf.write(audio_path, wav.squeeze().cpu().numpy(), 32000) # Save as WAV file
112
+ logging.info(f"Audio generated and saved to: {audio_path}")
113
+ return audio_path
114
+ except Exception as e:
115
+ logging.error(f"Error generating audio with AudioGen Medium: {e}")
116
+ raise
117
+
118
+ # Function to generate music using MusicGen Medium
119
+ def generate_music_musicgen(scene, duration=10):
120
+ try:
121
+ logging.info(f"Generating music for scene: {scene} using MusicGen Medium...")
122
+ musicgen_model.set_generation_params(duration=duration)
123
+ descriptions = [f"Calm music for {scene}"]
124
+ wav = musicgen_model.generate(descriptions) # Generate music
125
+ music_path = "generated_music_musicgen.wav"
126
+ sf.write(music_path, wav.squeeze().cpu().numpy(), 32000) # Save as WAV file
127
+ logging.info(f"Music generated and saved to: {music_path}")
128
+ return music_path
129
+ except Exception as e:
130
+ logging.error(f"Error generating music with MusicGen Medium: {e}")
131
+ raise
132
+
133
+ # Function to merge audio and video into a final video file using moviepy
134
+ def merge_audio_video(video_path, audio_path, output_path="output.mp4"):
135
+ try:
136
+ logging.info("Merging audio and video using moviepy...")
137
+ video_clip = VideoFileClip(video_path)
138
+ audio_clip = AudioFileClip(audio_path)
139
+ final_clip = video_clip.set_audio(audio_clip)
140
+ final_clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
141
+ logging.info(f"Final video saved to: {output_path}")
142
+ return output_path
143
+ except Exception as e:
144
+ logging.error(f"Error merging audio and video: {e}")
145
+ return None
146
+
147
+ # Main processing function to handle video upload, scene analysis, and video output
148
+ def process_video(video_path, progress=gr.Progress()):
149
+ try:
150
+ progress(0.1, desc="Starting video processing...")
151
+ logging.info("Starting video processing...")
152
+
153
+ # Analyze the video to determine the scene type
154
+ progress(0.3, desc="Analyzing video...")
155
+ scene_type = analyze_video(video_path)
156
+
157
+ # Generate audio using AudioGen Medium
158
+ progress(0.5, desc="Generating audio...")
159
+ audio_path = generate_audio_audiogen(scene_type, duration=10)
160
+
161
+ # Generate music using MusicGen Medium
162
+ progress(0.7, desc="Generating music...")
163
+ music_path = generate_music_musicgen(scene_type, duration=10)
164
+
165
+ # Merge the generated audio with the video and output the final video
166
+ progress(0.9, desc="Merging audio and video...")
167
+ output_path = merge_audio_video(video_path, music_path)
168
+ if not output_path:
169
+ return "Error: Failed to merge audio and video.", "Logs: Merge failed."
170
+
171
+ logging.info("Video processing completed successfully.")
172
+ return output_path, "Logs: Processing completed."
173
+ except Exception as e:
174
+ logging.error(f"Error in process_video: {e}")
175
+ return f"An error occurred during processing: {e}", f"Logs: {e}"
176
+
177
+ # Gradio UI for video upload
178
+ def gradio_interface(video_file, progress=gr.Progress()):
179
+ try:
180
+ progress(0.1, desc="Starting video processing...")
181
+ logging.info("Gradio interface triggered.")
182
+ output_video, logs = process_video(video_file, progress)
183
+ return output_video, logs
184
+ except Exception as e:
185
+ logging.error(f"Error in Gradio interface: {e}")
186
+ return f"An error occurred: {e}", f"Logs: {e}"
187
+
188
+ # Launch Gradio app
189
+ try:
190
+ logging.info("Launching Gradio app...")
191
+ interface = gr.Interface(
192
+ fn=gradio_interface,
193
+ inputs=[gr.Video(label="Upload Video")],
194
+ outputs=[gr.Video(label="Output Video with Generated Audio"), gr.Textbox(label="Logs", lines=10)],
195
+ title="Video to Video with Generated Audio and Music",
196
+ description="Upload a video, and this app will analyze it and generate matching audio and music using AudioGen Medium and MusicGen Medium."
197
+ )
198
+ interface.queue() # Enable queue for long-running tasks
199
+ interface.launch(share=True) # Launch the app
200
+ except Exception as e:
201
+ logging.error(f"Error launching Gradio app: {e}")
202
+ raise