|
import gradio as gr |
|
from train import TrainingLoop |
|
from scipy.special import softmax |
|
import numpy as np |
|
|
|
|
|
train = None |
|
|
|
frames, attributions = None, None |
|
|
|
lunar_lander_spec_conversion = { |
|
0: "X-coordinate", |
|
1: "Y-coordinate", |
|
2: "Linear velocity in the X-axis", |
|
3: "Linear velocity in the Y-axis", |
|
4: "Angle", |
|
5: "Angular velocity", |
|
6: "Left leg touched the floor", |
|
7: "Right leg touched the floor" |
|
} |
|
|
|
def create_training_loop(env_spec): |
|
global train |
|
train = TrainingLoop(env_spec=env_spec) |
|
train.create_agent() |
|
|
|
return train.env.spec |
|
|
|
def display_softmax(inputs): |
|
inputs = np.array(inputs) |
|
probabilities = softmax(inputs) |
|
|
|
softmax_dict = {name: float(prob) for name, prob in zip(lunar_lander_spec_conversion.values(), probabilities)} |
|
|
|
return softmax_dict |
|
|
|
def generate_output(num_iterations, option): |
|
global frames, attributions |
|
frames, attributions = train.explain_trained(num_iterations=num_iterations, option=option) |
|
slider.maximum = len(frames) |
|
|
|
def get_frame_and_attribution(slider_value): |
|
global frames, attributions |
|
frame = frames[slider_value] |
|
attribution = display_softmax(attributions[slider_value]) |
|
|
|
return frame, attribution |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Introspection in Deep Reinforcement Learning") |
|
|
|
with gr.Tab(label="Attribute"): |
|
env_spec = gr.Textbox(label="Environment Specification (e.g.: LunarLander-v2)", lines=1) |
|
env = gr.Interface(title="Create the Environment", allow_flagging="never", inputs=env_spec, fn=create_training_loop, outputs=gr.JSON()) |
|
|
|
with gr.Row(): |
|
option = gr.Dropdown(choices=["Torch Tensor of 0's", "Running Average"], type="index") |
|
baselines = gr.Slider(label="Number of Baseline Iterations", interactive=True, minimum=0, maximum=100, value=10, step=5, info="Baseline inputs to collect for the average", render=True) |
|
gr.Button("ATTRIBUTE").click(fn=generate_output, inputs=[baselines, option]) |
|
slider = gr.Slider(label="Key Frame", minimum=0, maximum=20000, step=1, value=0) |
|
|
|
gr.Interface(fn=get_frame_and_attribution, inputs=slider, live=True, outputs=[gr.Image(), gr.Label()]) |
|
|
|
|
|
demo.launch() |