Spaces:
Sleeping
Sleeping
File size: 2,224 Bytes
ee1c253 f2f8639 ee1c253 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 d116e72 f2f8639 95a5a2a d116e72 f2f8639 d116e72 f2f8639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
from train import TrainingLoop
from scipy.special import softmax
import numpy as np
train = None
frames, attributions = None, None
lunar_lander_spec_conversion = {
0: "X-coordinate",
1: "Y-coordinate",
2: "Linear velocity in the X-axis",
3: "Linear velocity in the Y-axis",
4: "Angle",
5: "Angular velocity",
6: "Left leg touched the floor",
7: "Right leg touched the floor"
}
def create_training_loop(env_spec):
global train
train = TrainingLoop(env_spec=env_spec)
train.create_agent()
return train.env.spec
def display_softmax(inputs):
inputs = np.array(inputs)
probabilities = softmax(inputs)
softmax_dict = {name: float(prob) for name, prob in zip(lunar_lander_spec_conversion.values(), probabilities)}
return softmax_dict
def generate_output(num_iterations, option):
global frames, attributions
frames, attributions = train.explain_trained(num_iterations=num_iterations, option=option)
slider.maximum = len(frames)
def get_frame_and_attribution(slider_value):
global frames, attributions
frame = frames[slider_value]
attribution = display_softmax(attributions[slider_value])
return frame, attribution
with gr.Blocks() as demo:
gr.Markdown("# Introspection in Deep Reinforcement Learning")
with gr.Tab(label="Attribute"):
env_spec = gr.Textbox(label="Environment Specification (e.g.: LunarLander-v2)", lines=1)
env = gr.Interface(title="Create the Environment", allow_flagging="never", inputs=env_spec, fn=create_training_loop, outputs=gr.JSON())
with gr.Row():
option = gr.Dropdown(choices=["Torch Tensor of 0's", "Running Average"], type="index")
baselines = gr.Slider(label="Number of Baseline Iterations", interactive=True, minimum=0, maximum=100, value=10, step=5, info="Baseline inputs to collect for the average", render=True)
gr.Button("ATTRIBUTE").click(fn=generate_output, inputs=[baselines, option])
slider = gr.Slider(label="Key Frame", minimum=0, maximum=300, step=1, value=0)
gr.Interface(fn=get_frame_and_attribution, inputs=slider, live=True, outputs=[gr.Image(), gr.Label()])
demo.launch() |