Kano001's picture
Upload 919 files
375a1cf verified
import collections
import os
import time
from threading import Lock
import glfw
import imageio
import mujoco
import numpy as np
def _import_egl(width, height):
from mujoco.egl import GLContext
return GLContext(width, height)
def _import_glfw(width, height):
from mujoco.glfw import GLContext
return GLContext(width, height)
def _import_osmesa(width, height):
from mujoco.osmesa import GLContext
return GLContext(width, height)
_ALL_RENDERERS = collections.OrderedDict(
[
("glfw", _import_glfw),
("egl", _import_egl),
("osmesa", _import_osmesa),
]
)
class RenderContext:
"""Render context superclass for offscreen and window rendering."""
def __init__(self, model, data, offscreen=True):
self.model = model
self.data = data
self.offscreen = offscreen
self.offwidth = model.vis.global_.offwidth
self.offheight = model.vis.global_.offheight
max_geom = 1000
mujoco.mj_forward(self.model, self.data)
self.scn = mujoco.MjvScene(self.model, max_geom)
self.cam = mujoco.MjvCamera()
self.vopt = mujoco.MjvOption()
self.pert = mujoco.MjvPerturb()
self.con = mujoco.MjrContext(self.model, mujoco.mjtFontScale.mjFONTSCALE_150)
self._markers = []
self._overlays = {}
self._init_camera()
self._set_mujoco_buffers()
def _set_mujoco_buffers(self):
if self.offscreen:
mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN, self.con)
if self.con.currentBuffer != mujoco.mjtFramebuffer.mjFB_OFFSCREEN:
raise RuntimeError("Offscreen rendering not supported")
else:
mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_WINDOW, self.con)
if self.con.currentBuffer != mujoco.mjtFramebuffer.mjFB_WINDOW:
raise RuntimeError("Window rendering not supported")
def render(self, camera_id=None, segmentation=False):
width, height = self.offwidth, self.offheight
rect = mujoco.MjrRect(left=0, bottom=0, width=width, height=height)
if camera_id is not None:
if camera_id == -1:
self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
else:
self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
self.cam.fixedcamid = camera_id
mujoco.mjv_updateScene(
self.model,
self.data,
self.vopt,
self.pert,
self.cam,
mujoco.mjtCatBit.mjCAT_ALL,
self.scn,
)
if segmentation:
self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 1
self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 1
for marker_params in self._markers:
self._add_marker_to_scene(marker_params)
mujoco.mjr_render(rect, self.scn, self.con)
for gridpos, (text1, text2) in self._overlays.items():
mujoco.mjr_overlay(
mujoco.mjtFontScale.mjFONTSCALE_150,
gridpos,
rect,
text1.encode(),
text2.encode(),
self.con,
)
if segmentation:
self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 0
self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 0
def read_pixels(self, depth=True, segmentation=False):
width, height = self.offwidth, self.offheight
rect = mujoco.MjrRect(left=0, bottom=0, width=width, height=height)
rgb_arr = np.zeros(3 * rect.width * rect.height, dtype=np.uint8)
depth_arr = np.zeros(rect.width * rect.height, dtype=np.float32)
mujoco.mjr_readPixels(rgb_arr, depth_arr, rect, self.con)
rgb_img = rgb_arr.reshape(rect.height, rect.width, 3)
ret_img = rgb_img
if segmentation:
seg_img = (
rgb_img[:, :, 0]
+ rgb_img[:, :, 1] * (2**8)
+ rgb_img[:, :, 2] * (2**16)
)
seg_img[seg_img >= (self.scn.ngeom + 1)] = 0
seg_ids = np.full((self.scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32)
for i in range(self.scn.ngeom):
geom = self.scn.geoms[i]
if geom.segid != -1:
seg_ids[geom.segid + 1, 0] = geom.objtype
seg_ids[geom.segid + 1, 1] = geom.objid
ret_img = seg_ids[seg_img]
if depth:
depth_img = depth_arr.reshape(rect.height, rect.width)
return (ret_img, depth_img)
else:
return ret_img
def _init_camera(self):
self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
self.cam.fixedcamid = -1
for i in range(3):
self.cam.lookat[i] = np.median(self.data.geom_xpos[:, i])
self.cam.distance = self.model.stat.extent
def add_overlay(self, gridpos: int, text1: str, text2: str):
"""Overlays text on the scene."""
if gridpos not in self._overlays:
self._overlays[gridpos] = ["", ""]
self._overlays[gridpos][0] += text1 + "\n"
self._overlays[gridpos][1] += text2 + "\n"
def add_marker(self, **marker_params):
self._markers.append(marker_params)
def _add_marker_to_scene(self, marker):
if self.scn.ngeom >= self.scn.maxgeom:
raise RuntimeError("Ran out of geoms. maxgeom: %d" % self.scn.maxgeom)
g = self.scn.geoms[self.scn.ngeom]
# default values.
g.dataid = -1
g.objtype = mujoco.mjtObj.mjOBJ_UNKNOWN
g.objid = -1
g.category = mujoco.mjtCatBit.mjCAT_DECOR
g.texid = -1
g.texuniform = 0
g.texrepeat[0] = 1
g.texrepeat[1] = 1
g.emission = 0
g.specular = 0.5
g.shininess = 0.5
g.reflectance = 0
g.type = mujoco.mjtGeom.mjGEOM_BOX
g.size[:] = np.ones(3) * 0.1
g.mat[:] = np.eye(3)
g.rgba[:] = np.ones(4)
for key, value in marker.items():
if isinstance(value, (int, float, mujoco._enums.mjtGeom)):
setattr(g, key, value)
elif isinstance(value, (tuple, list, np.ndarray)):
attr = getattr(g, key)
attr[:] = np.asarray(value).reshape(attr.shape)
elif isinstance(value, str):
assert key == "label", "Only label is a string in mjtGeom."
if value is None:
g.label[0] = 0
else:
g.label = value
elif hasattr(g, key):
raise ValueError(
"mjtGeom has attr {} but type {} is invalid".format(
key, type(value)
)
)
else:
raise ValueError("mjtGeom doesn't have field %s" % key)
self.scn.ngeom += 1
def close(self):
"""Override close in your rendering subclass to perform any necessary cleanup
after env.close() is called.
"""
pass
class RenderContextOffscreen(RenderContext):
"""Offscreen rendering class with opengl context."""
def __init__(self, model, data):
# We must make GLContext before MjrContext
width = model.vis.global_.offwidth
height = model.vis.global_.offheight
self._get_opengl_backend(width, height)
self.opengl_context.make_current()
super().__init__(model, data, offscreen=True)
def _get_opengl_backend(self, width, height):
backend = os.environ.get("MUJOCO_GL")
if backend is not None:
try:
self.opengl_context = _ALL_RENDERERS[backend](width, height)
except KeyError:
raise RuntimeError(
"Environment variable {} must be one of {!r}: got {!r}.".format(
"MUJOCO_GL", _ALL_RENDERERS.keys(), backend
)
)
else:
for name, _ in _ALL_RENDERERS.items():
try:
self.opengl_context = _ALL_RENDERERS[name](width, height)
backend = name
break
except: # noqa:E722
pass
if backend is None:
raise RuntimeError(
"No OpenGL backend could be imported. Attempting to create a "
"rendering context will result in a RuntimeError."
)
class Viewer(RenderContext):
"""Class for window rendering in all MuJoCo environments."""
def __init__(self, model, data):
self._gui_lock = Lock()
self._button_left_pressed = False
self._button_right_pressed = False
self._last_mouse_x = 0
self._last_mouse_y = 0
self._paused = False
self._transparent = False
self._contacts = False
self._render_every_frame = True
self._image_idx = 0
self._image_path = "/tmp/frame_%07d.png"
self._time_per_render = 1 / 60.0
self._run_speed = 1.0
self._loop_count = 0
self._advance_by_one_step = False
self._hide_menu = False
# glfw init
glfw.init()
width, height = glfw.get_video_mode(glfw.get_primary_monitor()).size
self.window = glfw.create_window(width // 2, height // 2, "mujoco", None, None)
glfw.make_context_current(self.window)
glfw.swap_interval(1)
framebuffer_width, framebuffer_height = glfw.get_framebuffer_size(self.window)
window_width, _ = glfw.get_window_size(self.window)
self._scale = framebuffer_width * 1.0 / window_width
# set callbacks
glfw.set_cursor_pos_callback(self.window, self._cursor_pos_callback)
glfw.set_mouse_button_callback(self.window, self._mouse_button_callback)
glfw.set_scroll_callback(self.window, self._scroll_callback)
glfw.set_key_callback(self.window, self._key_callback)
# get viewport
self.viewport = mujoco.MjrRect(0, 0, framebuffer_width, framebuffer_height)
super().__init__(model, data, offscreen=False)
def _key_callback(self, window, key, scancode, action, mods):
if action != glfw.RELEASE:
return
# Switch cameras
elif key == glfw.KEY_TAB:
self.cam.fixedcamid += 1
self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
if self.cam.fixedcamid >= self.model.ncam:
self.cam.fixedcamid = -1
self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
# Pause simulation
elif key == glfw.KEY_SPACE and self._paused is not None:
self._paused = not self._paused
# Advances simulation by one step.
elif key == glfw.KEY_RIGHT and self._paused is not None:
self._advance_by_one_step = True
self._paused = True
# Slows down simulation
elif key == glfw.KEY_S:
self._run_speed /= 2.0
# Speeds up simulation
elif key == glfw.KEY_F:
self._run_speed *= 2.0
# Turn off / turn on rendering every frame.
elif key == glfw.KEY_D:
self._render_every_frame = not self._render_every_frame
# Capture screenshot
elif key == glfw.KEY_T:
img = np.zeros(
(
glfw.get_framebuffer_size(self.window)[1],
glfw.get_framebuffer_size(self.window)[0],
3,
),
dtype=np.uint8,
)
mujoco.mjr_readPixels(img, None, self.viewport, self.con)
imageio.imwrite(self._image_path % self._image_idx, np.flipud(img))
self._image_idx += 1
# Display contact forces
elif key == glfw.KEY_C:
self._contacts = not self._contacts
self.vopt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = self._contacts
self.vopt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = self._contacts
# Display coordinate frames
elif key == glfw.KEY_E:
self.vopt.frame = 1 - self.vopt.frame
# Hide overlay menu
elif key == glfw.KEY_H:
self._hide_menu = not self._hide_menu
# Make transparent
elif key == glfw.KEY_R:
self._transparent = not self._transparent
if self._transparent:
self.model.geom_rgba[:, 3] /= 5.0
else:
self.model.geom_rgba[:, 3] *= 5.0
# Geom group visibility
elif key in (glfw.KEY_0, glfw.KEY_1, glfw.KEY_2, glfw.KEY_3, glfw.KEY_4):
self.vopt.geomgroup[key - glfw.KEY_0] ^= 1
# Quit
if key == glfw.KEY_ESCAPE:
print("Pressed ESC")
print("Quitting.")
glfw.destroy_window(self.window)
glfw.terminate()
def _cursor_pos_callback(self, window, xpos, ypos):
if not (self._button_left_pressed or self._button_right_pressed):
return
mod_shift = (
glfw.get_key(window, glfw.KEY_LEFT_SHIFT) == glfw.PRESS
or glfw.get_key(window, glfw.KEY_RIGHT_SHIFT) == glfw.PRESS
)
if self._button_right_pressed:
action = (
mujoco.mjtMouse.mjMOUSE_MOVE_H
if mod_shift
else mujoco.mjtMouse.mjMOUSE_MOVE_V
)
elif self._button_left_pressed:
action = (
mujoco.mjtMouse.mjMOUSE_ROTATE_H
if mod_shift
else mujoco.mjtMouse.mjMOUSE_ROTATE_V
)
else:
action = mujoco.mjtMouse.mjMOUSE_ZOOM
dx = int(self._scale * xpos) - self._last_mouse_x
dy = int(self._scale * ypos) - self._last_mouse_y
width, height = glfw.get_framebuffer_size(window)
with self._gui_lock:
mujoco.mjv_moveCamera(
self.model, action, dx / height, dy / height, self.scn, self.cam
)
self._last_mouse_x = int(self._scale * xpos)
self._last_mouse_y = int(self._scale * ypos)
def _mouse_button_callback(self, window, button, act, mods):
self._button_left_pressed = (
glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS
)
self._button_right_pressed = (
glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_RIGHT) == glfw.PRESS
)
x, y = glfw.get_cursor_pos(window)
self._last_mouse_x = int(self._scale * x)
self._last_mouse_y = int(self._scale * y)
def _scroll_callback(self, window, x_offset, y_offset):
with self._gui_lock:
mujoco.mjv_moveCamera(
self.model,
mujoco.mjtMouse.mjMOUSE_ZOOM,
0,
-0.05 * y_offset,
self.scn,
self.cam,
)
def _create_overlay(self):
topleft = mujoco.mjtGridPos.mjGRID_TOPLEFT
bottomleft = mujoco.mjtGridPos.mjGRID_BOTTOMLEFT
if self._render_every_frame:
self.add_overlay(topleft, "", "")
else:
self.add_overlay(
topleft,
"Run speed = %.3f x real time" % self._run_speed,
"[S]lower, [F]aster",
)
self.add_overlay(
topleft, "Ren[d]er every frame", "On" if self._render_every_frame else "Off"
)
self.add_overlay(
topleft,
"Switch camera (#cams = %d)" % (self.model.ncam + 1),
"[Tab] (camera ID = %d)" % self.cam.fixedcamid,
)
self.add_overlay(topleft, "[C]ontact forces", "On" if self._contacts else "Off")
self.add_overlay(topleft, "T[r]ansparent", "On" if self._transparent else "Off")
if self._paused is not None:
if not self._paused:
self.add_overlay(topleft, "Stop", "[Space]")
else:
self.add_overlay(topleft, "Start", "[Space]")
self.add_overlay(
topleft, "Advance simulation by one step", "[right arrow]"
)
self.add_overlay(
topleft, "Referenc[e] frames", "On" if self.vopt.frame == 1 else "Off"
)
self.add_overlay(topleft, "[H]ide Menu", "")
if self._image_idx > 0:
fname = self._image_path % (self._image_idx - 1)
self.add_overlay(topleft, "Cap[t]ure frame", "Saved as %s" % fname)
else:
self.add_overlay(topleft, "Cap[t]ure frame", "")
self.add_overlay(topleft, "Toggle geomgroup visibility", "0-4")
self.add_overlay(bottomleft, "FPS", "%d%s" % (1 / self._time_per_render, ""))
self.add_overlay(
bottomleft, "Solver iterations", str(self.data.solver_iter + 1)
)
self.add_overlay(
bottomleft, "Step", str(round(self.data.time / self.model.opt.timestep))
)
self.add_overlay(bottomleft, "timestep", "%.5f" % self.model.opt.timestep)
def render(self):
# mjv_updateScene, mjr_render, mjr_overlay
def update():
# fill overlay items
self._create_overlay()
render_start = time.time()
if self.window is None:
return
elif glfw.window_should_close(self.window):
glfw.destroy_window(self.window)
glfw.terminate()
self.viewport.width, self.viewport.height = glfw.get_framebuffer_size(
self.window
)
with self._gui_lock:
# update scene
mujoco.mjv_updateScene(
self.model,
self.data,
self.vopt,
mujoco.MjvPerturb(),
self.cam,
mujoco.mjtCatBit.mjCAT_ALL.value,
self.scn,
)
# marker items
for marker in self._markers:
self._add_marker_to_scene(marker)
# render
mujoco.mjr_render(self.viewport, self.scn, self.con)
# overlay items
if not self._hide_menu:
for gridpos, [t1, t2] in self._overlays.items():
mujoco.mjr_overlay(
mujoco.mjtFontScale.mjFONTSCALE_150,
gridpos,
self.viewport,
t1,
t2,
self.con,
)
glfw.swap_buffers(self.window)
glfw.poll_events()
self._time_per_render = 0.9 * self._time_per_render + 0.1 * (
time.time() - render_start
)
# clear overlay
self._overlays.clear()
if self._paused:
while self._paused:
update()
if self._advance_by_one_step:
self._advance_by_one_step = False
break
else:
self._loop_count += self.model.opt.timestep / (
self._time_per_render * self._run_speed
)
if self._render_every_frame:
self._loop_count = 1
while self._loop_count > 0:
update()
self._loop_count -= 1
# clear markers
self._markers[:] = []
def close(self):
glfw.destroy_window(self.window)
glfw.terminate()