|
import gradio as gr |
|
import os |
|
import random |
|
import numpy as np |
|
import pandas as pd |
|
import gdown |
|
import base64 |
|
from time import gmtime, strftime |
|
from csv import writer |
|
import json |
|
import zipfile |
|
from os import listdir |
|
from os.path import isfile, join, isdir |
|
from datasets import load_dataset |
|
from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver |
|
|
|
ENVS = ['ShadowHand', 'ShadowHandCatchAbreast', 'ShadowHandOver', 'ShadowHandBlockStack', 'ShadowHandCatchUnderarm', |
|
'ShadowHandCatchOver2Underarm', 'ShadowHandBottleCap', 'ShadowHandLiftUnderarm', 'ShadowHandTwoCatchUnderarm', |
|
'ShadowHandDoorOpenInward', 'ShadowHandDoorOpenOutward', 'ShadowHandDoorCloseInward', 'ShadowHandDoorCloseOutward', |
|
'ShadowHandPushBlock', 'ShadowHandKettle', |
|
'ShadowHandScissors', 'ShadowHandPen', 'ShadowHandSwingCup', 'ShadowHandGraspAndPlace', 'ShadowHandSwitch'] |
|
|
|
|
|
|
|
|
|
LOAD_DATA_GOOGLE_DRIVE = False |
|
|
|
if LOAD_DATA_GOOGLE_DRIVE: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
urls = [ |
|
'https://drive.google.com/drive/folders/1SF5jQ7HakO3lFXBon57VP83-AwfnrM3F?usp=share_link', |
|
'https://drive.google.com/drive/folders/13WuS6ow6sm7ws7A5xzCEhR-2XX_YiIu5?usp=share_link', |
|
'https://drive.google.com/drive/folders/1GWLffJDOyLkubF2C03UFcB7iFpzy1aDy?usp=share_link', |
|
'https://drive.google.com/drive/folders/1UKAntA7WliD84AUhRN224PkW4vt9agZW?usp=share_link', |
|
'https://drive.google.com/drive/folders/11cCQw3qb1vJbviVPfBnOVWVzD_VzHdWs?usp=share_link', |
|
'https://drive.google.com/drive/folders/1Wvy604wCxEdXAwE7r3sE0L0ieXvM__u8?usp=share_link', |
|
'https://drive.google.com/drive/folders/1BTv_pMTNGm7m3hD65IgBrX880v-rLIaf?usp=share_link', |
|
'https://drive.google.com/drive/folders/12x7F11ln2VQkqi8-Mu3kng74eLgifM0N?usp=share_link', |
|
'https://drive.google.com/drive/folders/1OWkOul2CCrqynqpt44Fu1CBxzNNfOFE2?usp=share_link', |
|
'https://drive.google.com/drive/folders/1ukwsfrbSEqCBNmRSuAYvYBHijWCQh2OU?usp=share_link', |
|
'https://drive.google.com/drive/folders/1EO7zumR6sVfsWQWCS6zfNs5WuO2Se6WX?usp=share_link', |
|
'https://drive.google.com/drive/folders/1aw0iBWvvZiSKng0ejRK8xbNoHLVUFCFu?usp=share_link', |
|
'https://drive.google.com/drive/folders/1szIcxlVyT5WJtzpqYWYlue0n82A6-xtk?usp=share_link', |
|
] |
|
|
|
output = './' |
|
|
|
|
|
VIDEO_PATH = 'split_processed_zip' |
|
for i, url in enumerate(urls): |
|
id = url.split('/')[-1] |
|
os.system(f"gdown --id {id} -O {output} --folder --no-cookies --remaining-ok") |
|
|
|
|
|
path_to_zip_file = str(i+1) |
|
zip_files = [join(path_to_zip_file, f) for f in listdir(path_to_zip_file)] |
|
for f in zip_files: |
|
if f.endswith(".zip"): |
|
directory_to_extract_to = VIDEO_PATH |
|
print(f'extract data {f} to {directory_to_extract_to}') |
|
with zipfile.ZipFile(f, 'r') as zip_ref: |
|
zip_ref.extractall(directory_to_extract_to) |
|
os.remove(f) |
|
|
|
else: |
|
VIDEO_PATH = 'processed-data' |
|
path_to_zip_file = VIDEO_PATH |
|
zip_files = [join(path_to_zip_file, f) for f in listdir(path_to_zip_file)] |
|
for f in zip_files: |
|
if f.endswith(".zip"): |
|
directory_to_extract_to = path_to_zip_file |
|
print(f'extract data {f} to {directory_to_extract_to}') |
|
with zipfile.ZipFile(f, 'r') as zip_ref: |
|
zip_ref.extractall(directory_to_extract_to) |
|
os.remove(f) |
|
|
|
|
|
|
|
|
|
|
|
VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json') |
|
|
|
def inference(video_path): |
|
|
|
with open(video_path, "rb") as f: |
|
data = f.read() |
|
b64 = base64.b64encode(data).decode() |
|
html = ( |
|
f""" |
|
<video controls autoplay muted loop> |
|
<source src="data:video/mp4;base64,{b64}" type="video/mp4"> |
|
</video> |
|
""" |
|
) |
|
return html |
|
|
|
def video_identity(video): |
|
return video |
|
|
|
def nan(): |
|
return None |
|
|
|
FORMAT = ['mp4', 'gif'][0] |
|
|
|
def get_huggingface_dataset(): |
|
try: |
|
import huggingface_hub |
|
except (ImportError, ModuleNotFoundError): |
|
raise ImportError( |
|
"Package `huggingface_hub` not found is needed " |
|
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'." |
|
) |
|
HF_TOKEN = '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' |
|
DATASET_NAME = 'crowdsourced-robotinder-demo' |
|
FLAGGING_DIR = 'flag/' |
|
path_to_dataset_repo = huggingface_hub.create_repo( |
|
repo_id=DATASET_NAME, |
|
token=HF_TOKEN, |
|
private=False, |
|
repo_type="dataset", |
|
exist_ok=True, |
|
) |
|
dataset_dir = os.path.join(DATASET_NAME, FLAGGING_DIR) |
|
repo = huggingface_hub.Repository( |
|
local_dir=dataset_dir, |
|
clone_from=path_to_dataset_repo, |
|
use_auth_token=HF_TOKEN, |
|
) |
|
repo.git_pull(lfs=True) |
|
log_file = os.path.join(dataset_dir, "flag_data.csv") |
|
return repo, log_file |
|
|
|
def update(user_choice, left, right, choose_env, data_folder=VIDEO_PATH, flag_to_huggingface=False): |
|
global last_left_video_path |
|
global last_right_video_path |
|
global last_infer_left_video_path |
|
global last_infer_right_video_path |
|
|
|
if flag_to_huggingface: |
|
env_name = str(last_left_video_path).split('/')[1] |
|
current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) |
|
info = [env_name, user_choice, last_left_video_path, last_right_video_path, current_time] |
|
print(info) |
|
repo, log_file = get_huggingface_dataset() |
|
with open(log_file, 'a') as file: |
|
writer_object = writer(file) |
|
writer_object.writerow(info) |
|
file.close() |
|
if int(current_time.split('-')[-2]) % 5 == 0: |
|
try: |
|
repo.push_to_hub(commit_message=f"Flagged sample at {current_time}") |
|
except: |
|
repo.git_pull(lfs=True) |
|
repo.push_to_hub(commit_message=f"Flagged sample at {current_time}") |
|
if choose_env == 'Random' or choose_env == '': |
|
envs = get_env_names() |
|
env_name = envs[random.randint(0, len(envs)-1)] |
|
else: |
|
env_name = choose_env |
|
|
|
left, right = randomly_select_videos(env_name) |
|
|
|
last_left_video_path = left |
|
last_right_video_path = right |
|
last_infer_left_video_path = inference(left) |
|
last_infer_right_video_path = inference(right) |
|
|
|
return last_infer_left_video_path, last_infer_right_video_path, env_name |
|
|
|
def replay(left, right): |
|
return left, right |
|
|
|
def parse_envs(folder=VIDEO_PATH, filter=True, MAX_ITER=20000, DEFAULT_ITER=20000): |
|
""" |
|
return a dict of env_name: video_paths |
|
""" |
|
files = {} |
|
if filter: |
|
df = pd.read_csv('Bidexhands_Video.csv') |
|
|
|
for env_name in os.listdir(folder): |
|
env_path = os.path.join(folder, env_name) |
|
if os.path.isdir(env_path): |
|
videos = os.listdir(env_path) |
|
video_files = [] |
|
for video in videos: |
|
if video.endswith(f'.{FORMAT}'): |
|
if filter: |
|
if len(video.split('_')) < 6: |
|
print(f'{video} is wrongly named.') |
|
seed = video.split('_')[2] |
|
checkpoint = video.split('_')[4] |
|
try: |
|
succeed_iteration = df.loc[(df['seed'] == int(seed)) & (df['env_name'] == str(env_name))]['succeed_iteration'].iloc[0] |
|
except: |
|
print(f'Env {env_name} with seed {seed} not found in Bidexhands_Video.csv') |
|
|
|
if 'unsolved' in succeed_iteration: |
|
continue |
|
elif pd.isnull(succeed_iteration): |
|
min_iter = DEFAULT_ITER |
|
max_iter = MAX_ITER |
|
elif '-' in succeed_iteration: |
|
[min_iter, max_iter] = succeed_iteration.split('-') |
|
else: |
|
min_iter = succeed_iteration |
|
max_iter = MAX_ITER |
|
|
|
|
|
valid_checkpoints = np.arange(int(min_iter), int(max_iter)+1000, 1000) |
|
if int(checkpoint) not in valid_checkpoints: |
|
continue |
|
|
|
video_path = os.path.join(folder, env_name, video) |
|
video_files.append(video_path) |
|
|
|
|
|
files[env_name] = video_files |
|
|
|
with open(VIDEO_INFO, 'w') as fp: |
|
json.dump(files, fp) |
|
|
|
return files |
|
|
|
def get_env_names(): |
|
with open(VIDEO_INFO, 'r') as fp: |
|
files = json.load(fp) |
|
return list(files.keys()) |
|
|
|
def randomly_select_videos(env_name): |
|
|
|
with open(VIDEO_INFO, 'r') as fp: |
|
files = json.load(fp) |
|
env_files = files[env_name] |
|
|
|
selected_video_ids = np.random.choice(len(env_files), 2, replace=False) |
|
left_video_path = env_files[selected_video_ids[0]] |
|
right_video_path = env_files[selected_video_ids[1]] |
|
return left_video_path, right_video_path |
|
|
|
def build_interface(iter=3, data_folder=VIDEO_PATH): |
|
import sys |
|
import csv |
|
csv.field_size_limit(sys.maxsize) |
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
print(HF_TOKEN) |
|
HF_TOKEN = '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' |
|
|
|
|
|
|
|
|
|
|
|
|
|
files = parse_envs() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!") |
|
gr.Markdown("### Select the best robot behaviour in your choice!") |
|
|
|
env_name = list(files.keys())[random.randint(0, len(files)-1)] |
|
with gr.Row(): |
|
str_env_name = gr.Markdown(f"{env_name}") |
|
|
|
|
|
left_video_path, right_video_path = randomly_select_videos(env_name) |
|
|
|
with gr.Row(): |
|
if FORMAT == 'mp4': |
|
|
|
|
|
|
|
infer_left_video_path = inference(left_video_path) |
|
infer_right_video_path = inference(right_video_path) |
|
right = gr.HTML(infer_right_video_path, label="right_video") |
|
left = gr.HTML(infer_left_video_path, label="left_video") |
|
else: |
|
left = gr.Image(left_video_path, shape=(1024, 768), label="left_video") |
|
|
|
right = gr.Image(right_video_path, label="right_video") |
|
|
|
global last_left_video_path |
|
last_left_video_path = left_video_path |
|
global last_right_video_path |
|
last_right_video_path = right_video_path |
|
|
|
global last_infer_left_video_path |
|
last_infer_left_video_path = infer_left_video_path |
|
global last_infer_right_video_path |
|
last_infer_right_video_path = infer_right_video_path |
|
|
|
|
|
user_choice = gr.Radio(["Left", "Right", "Not Sure", "Both Good", "Both Bad"], label="Which one is your favorite?") |
|
choose_env = gr.Radio(["Random"]+ENVS, label="Choose the next task:") |
|
btn2 = gr.Button("Next") |
|
|
|
|
|
|
|
|
|
|
|
btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name]) |
|
|
|
|
|
|
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
last_left_video_path = None |
|
last_right_video_path = None |
|
|
|
demo = build_interface() |
|
|
|
demo.launch(share=False) |
|
|