Razzaqi3143 commited on
Commit
61a445a
·
verified ·
1 Parent(s): e8e26cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import whisper
4
+ import spacy
5
+ from PIL import Image
6
+ from diffusers import StableDiffusionPipeline #stable model to be updated if used
7
+ import torch
8
+ import logging
9
+ import os
10
+ import io
11
+
12
+ # Disable WANDB logging and configure logging level
13
+ logging.disable(logging.WARNING)
14
+ os.environ["WANDB_DISABLED"] = "true"
15
+
16
+ # Load models
17
+ whisper_model = whisper.load_model("base")
18
+ spacy.prefer_gpu()
19
+ spacy_nlp = spacy.load("en_core_web_sm")
20
+
21
+ #Initialize the model
22
+ stable_diffusion_pipeline = StableDiffusionPipeline.from_pretrained(
23
+ "CompVis/stable-diffusion-v1-2",
24
+ torch_dtype=torch.float16
25
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+
28
+
29
+ def extract_keyframes(video_path, frame_interval=30, num_frames=5):
30
+ try:
31
+ cap = cv2.VideoCapture(video_path)
32
+ frames = []
33
+ success, frame = cap.read()
34
+ count = 0
35
+ while success and count < num_frames:
36
+ if count % frame_interval == 0:
37
+ frames.append(frame)
38
+ success, frame = cap.read()
39
+ count += 1
40
+ cap.release()
41
+ return frames
42
+ except Exception as e:
43
+ logging.error("Error extracting keyframes:", exc_info=e)
44
+ return None
45
+ def test_extract_keyframes():
46
+ video_path = "video.mp4"
47
+ frames = extract_keyframes(video_path)
48
+
49
+ assert frames is not None, "Keyframe extraction failed"
50
+ assert len(frames) > 0, "No keyframes extracted"
51
+ print("Keyframe extraction test passed")
52
+
53
+ test_extract_keyframes()
54
+ def transcribe_audio(video_path):
55
+ try:
56
+ result = whisper_model.transcribe(video_path)
57
+ return result['text']
58
+ except Exception as e:
59
+ logging.error("Error transcribing audio:", exc_info=e)
60
+ return None
61
+ def test_transcribe_audio():
62
+ video_path = "video.mp4"
63
+ transcription = transcribe_audio(video_path)
64
+
65
+ assert transcription is not None, "Transcription failed"
66
+ assert len(transcription) > 0, "Empty transcription"
67
+ print("Transcription test passed")
68
+
69
+ test_transcribe_audio()
70
+ def extract_keywords(text):
71
+ try:
72
+ if not text or not text.strip():
73
+ logging.warning("Empty or whitespace-only text: No keywords extracted")
74
+ return []
75
+
76
+ doc = spacy_nlp(text)
77
+ keywords = [chunk.text for chunk in doc.noun_chunks]
78
+
79
+ if not keywords:
80
+ logging.warning("No keywords extracted from the text")
81
+
82
+ return keywords
83
+ except Exception as e:
84
+ logging.error("Error extracting keywords:", exc_info=e)
85
+ return []
86
+ def test_extract_keywords():
87
+ text = "This is a test text for keyword extraction."
88
+ keywords = extract_keywords(text)
89
+
90
+ assert keywords is not None, "Keyword extraction failed"
91
+ assert len(keywords) > 0, "No keywords extracted"
92
+ print("Keyword extraction test passed")
93
+
94
+ test_extract_keywords()
95
+ def generate_thumbnails(frames, keywords, num_thumbnails=3):
96
+ try:
97
+ thumbnails = []
98
+ for frame in frames:
99
+ for _ in range(num_thumbnails):
100
+ prompt = "A visually striking image of " + ", ".join(keywords)
101
+ generated_image = stable_diffusion_pipeline(prompt, init_image=frame).images[0]
102
+ thumbnails.append(generated_image)
103
+ return thumbnails
104
+ except Exception as e:
105
+ logging.exception("Error generating thumbnails:", exc_info=e)
106
+ return None
107
+ def process_video(video):
108
+ try:
109
+ # Determine the video path based on the type of input
110
+ video_path = video.name if hasattr(video, 'name') else video
111
+
112
+ # Extract Keyframes
113
+ frames = extract_keyframes(video_path)
114
+ if frames is None:
115
+ return handle_error("Error extracting keyframes. Please check the video file.")
116
+
117
+ # Transcribe Audio
118
+ transcription = transcribe_audio(video_path)
119
+ if transcription is None:
120
+ return handle_error("Error transcribing audio. Please check the audio quality.")
121
+
122
+ # Extract Keywords
123
+ keywords = extract_keywords(transcription)
124
+ if not keywords:
125
+ return handle_error("Error extracting keywords. Please check the transcription.")
126
+
127
+ # Use the first keyword as title, the full transcription as text, and a generic text placement description
128
+ title = keywords[0] if keywords else "Thumbnail"
129
+ text = transcription
130
+ text_placement = "white letter center at bottom, modern and dynamic"
131
+
132
+ # Generate Thumbnails
133
+ thumbnail_images = generate_thumbnails(frames, keywords)
134
+ if not thumbnail_images:
135
+ return handle_error("Error generating thumbnails. Please try again later.")
136
+
137
+ return thumbnail_images, "Thumbnails generated successfully."
138
+ except Exception as e:
139
+ logging.exception("Unexpected error:", exc_info=e)
140
+ return handle_error("An unexpected error occurred. Please try again later.")
141
+ def handle_error(error_message):
142
+ # Return a placeholder image and the error message
143
+ placeholder = Image.new('RGB', (512, 512), color = (255, 0, 0)) # Placeholder image (red square)
144
+ return [placeholder], error_message
145
+ # Gradio interface
146
+ interface = gr.Interface(
147
+ fn=process_video,
148
+ inputs=gr.Video(label="Upload Video"),
149
+ outputs=[
150
+ gr.Gallery(label="Generated Thumbnails"),
151
+ gr.Textbox(label="Status", lines=2, placeholder="Status message will appear here...")
152
+ ],
153
+ title="YouTube Thumbnail Generator",
154
+ description="Upload a video and generate multiple thumbnails using the video content and transcription.",
155
+ live=True
156
+ )
157
+
158
+ interface.launch()