|
|
|
import io |
|
import gradio as gr |
|
import cv2 |
|
import base64 |
|
import openai |
|
import os |
|
import asyncio |
|
import concurrent.futures |
|
from openai import AsyncOpenAI |
|
|
|
from langchain.prompts import PromptTemplate |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.schema import StrOutputParser |
|
from PIL import Image |
|
import ast |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from prompts import VISION_SYSTEM_PROMPT, USER_PROMPT_TEMPLATE, FINAL_EVALUATION_SYSTEM_PROMPT, FINAL_EVALUATION_USER_PROMPT, SUMMARY_AND_TABLE_PROMPT, AUDIO_SYSTEM_PROMPT |
|
from dotenv import load_dotenv |
|
|
|
|
|
global global_dict |
|
global_dict = {} |
|
|
|
|
|
|
|
VIDEO_FRAME_LIMIT = 2000 |
|
|
|
|
|
|
|
def validate_api_key(api_key): |
|
client = openai.OpenAI(api_key=api_key) |
|
|
|
try: |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "user", "content": "Hello world"}, |
|
] |
|
) |
|
global_dict['api_key'] = api_key |
|
|
|
except openai.RateLimitError as e: |
|
|
|
print(f"OpenAI API request exceeded rate limit: {e}") |
|
response = None |
|
error = e |
|
pass |
|
except openai.APIConnectionError as e: |
|
|
|
print(f"Failed to connect to OpenAI API: {e}") |
|
response = None |
|
error = e |
|
pass |
|
except openai.APIError as e: |
|
|
|
print(f"OpenAI API returned an API Error: {e}") |
|
response = None |
|
error = e |
|
pass |
|
|
|
if response: |
|
return True |
|
else: |
|
raise gr.Error(f"OpenAI returned an API Error: {error}") |
|
|
|
|
|
def _process_video(video_file): |
|
|
|
video = cv2.VideoCapture(video_file.name) |
|
|
|
if 'video_file' not in global_dict: |
|
global_dict.setdefault('video_file', video_file.name) |
|
else: |
|
global_dict['video_file'] = video_file.name |
|
|
|
base64Frames = [] |
|
while video.isOpened(): |
|
success, frame = video.read() |
|
if not success: |
|
break |
|
_, buffer = cv2.imencode(".jpg", frame) |
|
base64Frames.append(base64.b64encode(buffer).decode("utf-8")) |
|
video.release() |
|
if len(base64Frames) > VIDEO_FRAME_LIMIT: |
|
raise gr.Warning(f"Video's play time is too long. (>1m)") |
|
print(len(base64Frames), "frames read.") |
|
|
|
if not base64Frames: |
|
raise gr.Error(f"Cannot open the video.") |
|
return base64Frames |
|
|
|
|
|
def _make_video_batch(video_file): |
|
|
|
frames = _process_video(video_file) |
|
|
|
TOTAL_FRAME_COUNT = len(frames) |
|
BATCH_SIZE = int(1) |
|
TOTAL_BATCH_SIZE = int(TOTAL_FRAME_COUNT * 1 / 300) |
|
BATCH_STEP = int(TOTAL_FRAME_COUNT / TOTAL_BATCH_SIZE) |
|
|
|
base64FramesBatch = [] |
|
|
|
for idx in range(0, TOTAL_FRAME_COUNT, BATCH_STEP * BATCH_SIZE): |
|
|
|
temp = [] |
|
for i in range(BATCH_SIZE): |
|
|
|
if (idx + BATCH_STEP * i) < TOTAL_FRAME_COUNT: |
|
temp.append(frames[idx + BATCH_STEP * i]) |
|
else: |
|
continue |
|
base64FramesBatch.append(temp) |
|
|
|
for idx, batch in enumerate(base64FramesBatch): |
|
|
|
print(f'##{idx} - batch_size: {len(batch)}') |
|
|
|
if 'batched_frames' not in global_dict: |
|
global_dict.setdefault('batched_frames', base64FramesBatch) |
|
else: |
|
global_dict['batched_frames'] = base64FramesBatch |
|
|
|
return base64FramesBatch |
|
|
|
|
|
def show_batches(video_file): |
|
|
|
batched_frames = _make_video_batch(video_file) |
|
|
|
images1 = [] |
|
for i, l in enumerate(batched_frames): |
|
print(f"#### Batch_{i+1}") |
|
for j, img in enumerate(l): |
|
print(f'## Image_{j+1}') |
|
image_bytes = base64.b64decode(img.encode("utf-8")) |
|
|
|
image_stream = io.BytesIO(image_bytes) |
|
|
|
image = Image.open(image_stream) |
|
images1.append((image, f"batch {i+1}")) |
|
print("-"*100) |
|
|
|
return images1 |
|
|
|
|
|
def show_audio_transcript(video_file, api_key): |
|
previous_video_file = global_dict.get('video_file') |
|
|
|
if global_dict.get('transcript') and previous_video_file == video_file.name: |
|
return global_dict['transcript'] |
|
else: |
|
audio_file = open(video_file.name, "rb") |
|
|
|
client = openai.OpenAI(api_key=api_key) |
|
transcript = client.audio.transcriptions.create( |
|
model="whisper-1", |
|
file=audio_file, |
|
response_format="text" |
|
) |
|
if 'transcript' not in global_dict: |
|
global_dict.setdefault('transcript', transcript) |
|
else: |
|
global_dict['transcript'] = transcript |
|
|
|
return transcript |
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_rubric_subsets = {'1': '1. want to be ~ λΌλ ννμ νμ©νμ¬ μ₯λν¬λ§μ λ§νλ€.', '2': '(be) good at ~μ΄λΌλ ννμ νμ©νμ¬ μ₯λν¬λ§κ³Ό κ΄λ ¨λ μμ μ΄ μ νλ μΌμ λ§νλ€.', '3': 'μ§μ
μ λνλ΄λ λ¨μ΄λ₯Ό μ νν μ¬μ©νλ€', '4': 'λ§μ€μ΄μ§ μκ³ μ μ°½νκ² λ§νλ€.'} |
|
rubric_subsets = {'5':'5. μμ κ° μλ νλλ‘ μΉ΄λ©λΌλ₯Ό 보며 λ§νλ€.', '6': '6. μ μ ν μ λμμ μ¬μ©νμ¬ λ§νλ€.'} |
|
rubrics_keyword = '"ν΅μ¬νν(want to be) νμ©", "ν΅μ¬νν(be good at) νμ©", "μ§μ
μ λνλ΄λ λ¨μ΄ νμ©", "μ μ°½μ±", "μλλ°© μμ", "μ λμ"' |
|
global_dict['audio_rubric_subsets'] = audio_rubric_subsets |
|
global_dict['rubric_subsets'] = rubric_subsets |
|
global_dict['rubrics_keyword'] = rubrics_keyword |
|
|
|
|
|
|
|
|
|
async def async_call_gpt_vision(client, batch, rubric_subset): |
|
|
|
vision_prompt_messages = [ |
|
{"role": "system", "content": VISION_SYSTEM_PROMPT}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
PromptTemplate.from_template(USER_PROMPT_TEMPLATE).format(rubrics=rubric_subset), |
|
*map(lambda x: {"image": x, "resize": 300}, batch), |
|
], |
|
}, |
|
] |
|
|
|
|
|
params = { |
|
"model": "gpt-4-vision-preview", |
|
"messages": vision_prompt_messages, |
|
"max_tokens": 1024, |
|
} |
|
|
|
|
|
try: |
|
result_raw = await client.chat.completions.create(**params) |
|
result = result_raw.choices[0].message.content |
|
print(result) |
|
return result |
|
except Exception as e: |
|
print(f"Error processing batch with rubric subset {rubric_subset}: {e}") |
|
return None |
|
|
|
|
|
async def process_rubrics_in_batches(client, frames, rubric_subsets): |
|
|
|
results = {} |
|
for key, rubric_subset in rubric_subsets.items(): |
|
|
|
tasks = [async_call_gpt_vision(client, batch, rubric_subset) for batch in frames] |
|
subset_results = await asyncio.gather(*tasks) |
|
results[key] = [result for result in subset_results if result is not None] |
|
|
|
|
|
return results |
|
|
|
def wrapper_call_gpt_vision(): |
|
api_key = global_dict.get('api_key') |
|
frames = global_dict.get('batched_frames') |
|
rubric_subsets = global_dict.get('rubric_subsets') |
|
client = AsyncOpenAI(api_key=api_key) |
|
|
|
async def call_gpt_vision(): |
|
async_full_result_vision = await process_rubrics_in_batches(client, frames, rubric_subsets) |
|
if 'full_result_vision' not in global_dict: |
|
global_dict.setdefault('full_result_vision', async_full_result_vision) |
|
else: |
|
global_dict['full_result_vision'] = async_full_result_vision |
|
return async_full_result_vision |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
loop.run_until_complete(call_gpt_vision()) |
|
|
|
|
|
async def async_get_evaluation_text(client, result_subset): |
|
|
|
result_subset_text = ' \n'.join(result_subset) |
|
print(result_subset_text) |
|
evaluation_text = PromptTemplate.from_template(FINAL_EVALUATION_USER_PROMPT).format(evals = result_subset_text) |
|
|
|
evaluation_text_message = [ |
|
{"role": "system", "content": FINAL_EVALUATION_SYSTEM_PROMPT}, |
|
{ |
|
"role": "user", |
|
"content": evaluation_text, |
|
}, |
|
] |
|
params = { |
|
"model": "gpt-4-vision-preview", |
|
"messages": evaluation_text_message, |
|
"max_tokens": 1024, |
|
} |
|
|
|
|
|
try: |
|
result_raw_2 = await client.chat.completions.create(**params) |
|
result_2 = result_raw_2.choices[0].message.content |
|
return result_2 |
|
except Exception as e: |
|
print(f"Error getting evaluation text {result_subset}: {e}") |
|
return None |
|
|
|
|
|
|
|
async def async_get_full_result(client, full_result_vision): |
|
|
|
|
|
results_2 = {} |
|
|
|
for key, result_subset in full_result_vision.items(): |
|
tasks_2 = [async_get_evaluation_text(client, result_subset)] |
|
text_results = await asyncio.gather(*tasks_2) |
|
results_2[key] = [result_2 for result_2 in text_results if result_2 is not None] |
|
|
|
|
|
results_2_val_list = list(results_2.values()) |
|
results_2_val = "" |
|
for i in range(len(results_2_val_list)): |
|
results_2_val += results_2_val_list[i][0] |
|
results_2_val += "\n" |
|
|
|
return results_2_val |
|
|
|
|
|
|
|
def wrapper_get_full_result(): |
|
api_key = global_dict.get('api_key') |
|
full_result_vision = global_dict.get('full_result_vision') |
|
client = AsyncOpenAI(api_key=api_key) |
|
|
|
|
|
|
|
async def get_full_result(): |
|
full_text = await async_get_full_result(client,full_result_vision) |
|
|
|
if 'full_text' not in global_dict: |
|
global_dict.setdefault('full_text', full_text) |
|
else: |
|
global_dict['full_text'] = full_text |
|
print("full_text: ") |
|
print(full_text) |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
loop.run_until_complete(get_full_result()) |
|
|
|
|
|
|
|
def call_gpt_audio(api_key) -> str: |
|
audio_rubric_subsets = global_dict.get('audio_rubric_subsets') |
|
transcript = global_dict.get('transcript') |
|
openai.api_key = api_key |
|
|
|
full_text_audio = "" |
|
|
|
print(f"RUBRIC_AUDIO: {audio_rubric_subsets}") |
|
|
|
PROMPT_MESSAGES = [ |
|
{ |
|
"role": "system", |
|
"content": AUDIO_SYSTEM_PROMPT, |
|
}, |
|
{ |
|
"role": "user", |
|
"content": PromptTemplate.from_template(USER_PROMPT_TEMPLATE).format(rubrics=audio_rubric_subsets) + "\n\n<TEXT>\n" + transcript |
|
}, |
|
] |
|
params = { |
|
"model": "gpt-4", |
|
"messages": PROMPT_MESSAGES, |
|
"max_tokens": 1024, |
|
} |
|
|
|
try: |
|
result = openai.chat.completions.create(**params) |
|
full_text_audio = result.choices[0].message.content |
|
print(full_text_audio) |
|
except openai.OpenAIError as e: |
|
print(f"Failed to connect to OpenAI: {e}") |
|
pass |
|
|
|
if 'full_text_audio' not in global_dict: |
|
global_dict.setdefault('full_text_audio', full_text_audio) |
|
else: |
|
global_dict['full_text_audio'] = full_text_audio |
|
|
|
return full_text_audio |
|
|
|
|
|
|
|
def get_final_anser(api_key): |
|
rubrics_keyword = global_dict.get('rubrics_keyword') |
|
full_text_audio = global_dict.get('full_text_audio') |
|
full_text = global_dict.get('full_text') |
|
full = full_text_audio + full_text |
|
global_dict['full'] = full |
|
|
|
chain = ChatOpenAI( |
|
api_key=api_key, |
|
model="gpt-4", |
|
max_tokens=1024, |
|
temperature=0, |
|
) |
|
prompt = PromptTemplate.from_template(SUMMARY_AND_TABLE_PROMPT) |
|
|
|
runnable = prompt | chain | StrOutputParser() |
|
final_eval = runnable.invoke({"full": full, "rubrics_keyword":rubrics_keyword}) |
|
|
|
print(final_eval) |
|
|
|
if 'final_eval' not in global_dict: |
|
global_dict.setdefault('final_eval', final_eval) |
|
else: |
|
global_dict['final_eval'] = final_eval |
|
|
|
return final_eval |
|
|
|
|
|
def tablize_final_anser(): |
|
|
|
final_eval = global_dict.get('final_eval') |
|
pos3 = int(final_eval.find("[[")) |
|
pos4 = int(final_eval.find("]]")) |
|
tablize_final_eval = ast.literal_eval(final_eval[(pos3):(pos4+2)]) |
|
|
|
|
|
cat_final_eval, val_final_eval = tablize_final_eval[0], tablize_final_eval[1] |
|
val_final_eval = [int(score) for score in val_final_eval] |
|
|
|
|
|
fig, ax = plt.subplots() |
|
ax.bar(cat_final_eval, val_final_eval) |
|
ax.set_ylabel('Scores') |
|
ax.set_title('Scores by category') |
|
|
|
plt.rc('xtick', labelsize=3) |
|
ax.set_xticks(range(len(cat_final_eval))) |
|
ax.set_yticks([0,2,4,6,8,10]) |
|
|
|
ax.set_xticklabels(cat_final_eval) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close(fig) |
|
buf.seek(0) |
|
|
|
|
|
image = Image.open(buf) |
|
return image |
|
|
|
|
|
def breif_final_anser(): |
|
final_eval = global_dict.get('final_eval') |
|
pos1 = int(final_eval.find("**μ’
ν© μ μ**")) |
|
pos2 = int(final_eval.find("----μμ½ λ----")) |
|
breif_final_eval = final_eval[pos1:pos2] |
|
return breif_final_eval |
|
|
|
def fin_final_anser(): |
|
fin_final_eval = global_dict.get('full') |
|
return fin_final_eval |
|
|
|
|
|
def mainpage(): |
|
with gr.Blocks() as start_page: |
|
gr.Markdown("Title") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
api_key_input = gr.Textbox( |
|
label="Enter your OpenAI API Key", |
|
info="Your API Key must be allowed to use GPT-4 Vision", |
|
placeholder="sk-*********...", |
|
lines=1 |
|
) |
|
|
|
gr.Markdown("λΉλμ€ μ
λ‘λ νμ΄μ§") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
video_upload = gr.File( |
|
label="Upload your video (video under 1 minute is the best..!)", |
|
file_types=["video"], |
|
) |
|
|
|
|
|
"""with gr.Column(scale=1): |
|
weight_shift_button = gr.Button("Weight Shift") |
|
balance_button = gr.Button("Balance") |
|
form_button = gr.Button("Form") |
|
overall_button = gr.Button("Overall") |
|
""" |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
process_button = gr.Button("Process") |
|
|
|
gr.Markdown("κ²°κ³Ό νμ΄μ§") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
output_box_fin_table = gr.Image(type="pil", label="Score Chart") |
|
|
|
with gr.Column(scale=1): |
|
output_box_fin_brief = gr.Textbox( |
|
label="Brief Evaluation", |
|
lines=10, |
|
interactive=True, |
|
show_copy_button=True, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
output_box_fin_fin = gr.Textbox( |
|
label="Detailed Evaluation", |
|
lines=10, |
|
interactive=True, |
|
show_copy_button=True, |
|
) |
|
with gr.Column(scale=1): |
|
gallery = gr.Gallery( |
|
label="Batched Snapshots of Video", |
|
columns=[3], |
|
rows=[10], |
|
object_fit="contain", |
|
height="auto", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
process_button.click(fn=validate_api_key, inputs=api_key_input, outputs=None).success(fn=show_batches, inputs=[video_upload], outputs=[gallery])\ |
|
.success(fn=show_audio_transcript, inputs=[video_upload, api_key_input], outputs=[])\ |
|
.success(fn=call_gpt_audio, inputs=[api_key_input], outputs=[])\ |
|
.success(fn=lambda:wrapper_call_gpt_vision(), inputs=[], outputs=[]) \ |
|
.success(fn=lambda:wrapper_get_full_result(), inputs=[], outputs=[])\ |
|
.success(fn=get_final_anser, inputs=[api_key_input], outputs=[])\ |
|
.success(fn=tablize_final_anser, inputs=[], outputs=[output_box_fin_table])\ |
|
.success(fn=breif_final_anser, inputs=[], outputs=[output_box_fin_brief])\ |
|
.success(fn=fin_final_anser, inputs=[], outputs=[output_box_fin_fin]) |
|
|
|
start_page.launch() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
mainpage() |