File size: 5,178 Bytes
027e8a9
 
b345f1b
 
027e8a9
 
 
 
 
 
 
b345f1b
 
 
027e8a9
b345f1b
 
027e8a9
b345f1b
027e8a9
b345f1b
 
720f703
b345f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4675b5
 
 
 
b345f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b6a87c
b345f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
027e8a9
 
b345f1b
 
027e8a9
b345f1b
 
 
 
 
 
 
 
 
 
 
 
720f703
b345f1b
 
293b33a
b345f1b
 
293b33a
b345f1b
 
 
 
 
 
293b33a
 
027e8a9
 
 
 
 
 
 
b345f1b
293b33a
027e8a9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
from gradio_client import Client
import cv2
from moviepy.editor import *

# 1. extract and store 1 image every 5 images from video input
# 2. extract audio
# 3. for each image from extracted_images, get caption from caption model and concatenate into list
# 4. for audio, ask audio questioning model to describe sound/scene
# 5. give all to LLM, and ask it to resume, according to image caption list combined to audio caption

import re
import torch
from transformers import pipeline

zephyr_model = "HuggingFaceH4/zephyr-7b-beta"
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto")

standard_sys = f"""

"""

def extract_frames(video_in, interval=24, output_format='.jpg'):
    """Extract frames from a video at a specified interval and store them in a list.

    Args:
    - video_in: string or path-like object pointing to the video file
    - interval: integer specifying how many frames apart to extract images (default: 5)
    - output_format: string indicating desired format for saved images (default: '.jpg')

    Returns:
    A list of strings containing paths to saved images.
    """

    # Initialize variables
    vidcap = cv2.VideoCapture(video_in)
    frames = []
    count = 0

    # Loop through frames until there are no more
    while True:
        success, image = vidcap.read()

        # Check if successful read and not past end of video
        if success:
            print('Read a new frame:', success)

            # Save current frame if it meets criteria
            if count % interval == 0:
                filename = f'frame_{count // interval}{output_format}'
                frames.append(filename)
                cv2.imwrite(filename, image)
                print(f'Saved {filename}')

            # Increment counter
            count += 1

        # Break out of loop when done reading frames
        else:
            break

    # Close video capture
    vidcap.release()
    print('Done extracting frames!')

    return frames

def process_image(image_in):
    client = Client("vikhyatk/moondream2")
    result = client.predict(
		image_in,	# filepath  in 'image' Image component
		"Describe precisely the image in one sentence.",	# str  in 'Question' Textbox component
		api_name="/answer_question"
        #api_name="/predict"
    )
    print(result)
    return result

def extract_audio(video_path):
    video_clip = VideoFileClip(video_path)
    audio_clip = video_clip.audio
    audio_clip.write_audiofile("output_audio.mp3")
    return "output_audio.mp3"

def get_salmonn(audio_in):
    salmonn_prompt = "Please list each event in the audio in order."
    client = Client("fffiloni/SALMONN-7B-gradio")
    result = client.predict(
    		audio_in,	# filepath  in 'Audio' Audio component
    		salmonn_prompt,	# str  in 'User question' Textbox component
    		4,	# float (numeric value between 1 and 10) in 'beam search numbers' Slider component
    		1,	# float (numeric value between 0.8 and 2.0) in 'temperature' Slider component
    		0.9,	# float (numeric value between 0.1 and 1.0) in 'top p' Slider component
    		api_name="/gradio_answer"
    )
    print(result)
    return result

def llm_process(user_prompt):
    agent_maker_sys = standard_sys
    
    instruction = f"""
<|system|>
{agent_maker_sys}</s>
<|user|>
"""
    
    prompt = f"{instruction.strip()}\n{user_prompt}</s>"    
    outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
    pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
    cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
    
    print(f"SUGGESTED video description: {cleaned_text}")
    return cleaned_text.lstrip("\n")

def infer(video_in):
    # Extract frames from a video
    frame_files = extract_frames(video_in)
    
    # Process each extracted frame and collect results in a list
    processed_texts = []
    for frame_file in frame_files:
        text = process_image(frame_file)
        processed_texts.append(text)
    print(processed_texts)

    # Convert processed_texts list to a string list with line breaks
    string_list = '\n'.join(processed_texts)

    # Extract audio from video
    extracted_audio = extract_audio(video_in)
    print(extracted_audio)

    # Get description of audio content
    #audio_content_described = get_salmonn(extracted_audio)

    # Assemble captions
    '''formatted_captions = f"""
### Visual events:\n{string_list}\n ### Audio events:\n{audio_content_described}
"""
    print(formatted_captions)

    # Send formatted captions to LLM
    #video_description_from_llm = llm_process(formatted_captions)
    '''
    return extracted_audio

with gr.Blocks() as demo :
    with gr.Column(elem_id="col-container"):
        gr.HTML("""
        <h2 style="text-align: center;">Video description</h2>
        """)
        video_in = gr.Video(label="Video input")
        submit_btn = gr.Button("Submit")
        video_description = gr.Audio(label="Video description")
    submit_btn.click(
        fn = infer,
        inputs = [video_in],
        outputs = [video_description]
    )
demo.queue().launch()