OpenBiDexHand / app.py
quantumiracle-git's picture
Update app.py
af9aa0c
raw
history blame
5.35 kB
import gradio as gr
import os
import random
import numpy as np
import gdown
from time import gmtime, strftime
from csv import writer
from datasets import load_dataset
from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
# download data from huggingface dataset
# dataset = load_dataset("quantumiracle-git/robotinder-data")
# download data from google drive
# url = 'https://drive.google.com/drive/folders/10UmNM2YpvNSkdLMgYiIAxk5IbS4dUezw?usp=sharing'
# output = './'
# id = url.split('/')[-1]
# os.system(f"gdown --id {id} -O {output} --folder --no-cookies")
def video_identity(video):
return video
def nan():
return None
# demo = gr.Interface(video_identity,
# gr.Video(),
# "playable_video",
# examples=[
# os.path.join(os.path.dirname(__file__),
# "videos/rl-video-episode-0.mp4")],
# cache_examples=True)
FORMAT = ['mp4', 'gif'][1]
def update(user_choice, data_folder='videos'):
# data_folder='videos'
envs = parse_envs()
env_name = envs[random.randint(0, len(envs)-1)]
# choose video
videos = os.listdir(os.path.join(data_folder, env_name))
video_files = []
for f in videos:
if f.endswith(f'.{FORMAT}'):
video_files.append(os.path.join(data_folder, env_name, f))
# choose two videos
selected_video_ids = np.random.choice(len(video_files), 2, replace=False)
left = video_files[selected_video_ids[0]]
right = video_files[selected_video_ids[1]]
# log
current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
info = [env_name, user_choice, left, right, current_time]
print(info)
with open('data.csv', 'a') as file:
writer_object = writer(file)
writer_object.writerow(info)
file.close()
return left, right
# def update(left, right):
# if FORMAT == 'mp4':
# left = os.path.join(os.path.dirname(__file__),
# "videos/rl-video-episode-2.mp4")
# right = os.path.join(os.path.dirname(__file__),
# "videos/rl-video-episode-3.mp4")
# else:
# left = os.path.join(os.path.dirname(__file__),
# "videos/rl-video-episode-2.gif")
# right = os.path.join(os.path.dirname(__file__),
# "videos/rl-video-episode-3.gif")
# print(left, right)
# return left, right
def replay(left, right):
return left, right
def parse_envs(folder='./videos'):
envs = []
for f in os.listdir(folder):
if os.path.isdir(os.path.join(folder, f)):
envs.append(f)
return envs
def build_interface(iter=3, data_folder='./videos'):
HF_TOKEN = os.getenv('HF_TOKEN')
print(HF_TOKEN)
HF_TOKEN = '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' # my HF token
# hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") # HuggingFace logger instead of local one: https://github.com/gradio-app/gradio/blob/master/gradio/flagging.py
hf_writer = HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo")
# callback = gr.CSVLogger()
callback = hf_writer
# build gradio interface
with gr.Blocks() as demo:
gr.Markdown("Here is RoboTinder!")
gr.Markdown("Select the best robot behaviour in your choice!")
with gr.Row():
# some initial videos
if FORMAT == 'mp4':
left_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-0.mp4")
right_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-1.mp4")
left = gr.PlayableVideo(left_video_path, label="left_video")
right = gr.PlayableVideo(right_video_path, label="right_video")
else:
left_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-0.gif")
right_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-1.gif")
left = gr.Image(left_video_path, shape=(1024, 768), label="left_video")
# right = gr.Image(right_video_path).style(height=768, width=1024)
right = gr.Image(right_video_path, label="right_video")
btn1 = gr.Button("Replay")
user_choice = gr.Radio(["Left", "Right", "Not Sure"], label="Which one is your favorite?")
btn2 = gr.Button("Next")
# This needs to be called at some point prior to the first call to callback.flag()
callback.setup([user_choice, left, right], "flagged_data_points")
btn1.click(fn=replay, inputs=[left, right], outputs=[left, right])
btn2.click(fn=update, inputs=[user_choice], outputs=[left, right])
# We can choose which components to flag -- in this case, we'll flag all of them
btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False)
return demo
if __name__ == "__main__":
demo = build_interface()
# demo.launch(share=True)
demo.launch(share=False)