import collections import os import time from threading import Lock import glfw import imageio import mujoco import numpy as np def _import_egl(width, height): from mujoco.egl import GLContext return GLContext(width, height) def _import_glfw(width, height): from mujoco.glfw import GLContext return GLContext(width, height) def _import_osmesa(width, height): from mujoco.osmesa import GLContext return GLContext(width, height) _ALL_RENDERERS = collections.OrderedDict( [ ("glfw", _import_glfw), ("egl", _import_egl), ("osmesa", _import_osmesa), ] ) class RenderContext: """Render context superclass for offscreen and window rendering.""" def __init__(self, model, data, offscreen=True): self.model = model self.data = data self.offscreen = offscreen self.offwidth = model.vis.global_.offwidth self.offheight = model.vis.global_.offheight max_geom = 1000 mujoco.mj_forward(self.model, self.data) self.scn = mujoco.MjvScene(self.model, max_geom) self.cam = mujoco.MjvCamera() self.vopt = mujoco.MjvOption() self.pert = mujoco.MjvPerturb() self.con = mujoco.MjrContext(self.model, mujoco.mjtFontScale.mjFONTSCALE_150) self._markers = [] self._overlays = {} self._init_camera() self._set_mujoco_buffers() def _set_mujoco_buffers(self): if self.offscreen: mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN, self.con) if self.con.currentBuffer != mujoco.mjtFramebuffer.mjFB_OFFSCREEN: raise RuntimeError("Offscreen rendering not supported") else: mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_WINDOW, self.con) if self.con.currentBuffer != mujoco.mjtFramebuffer.mjFB_WINDOW: raise RuntimeError("Window rendering not supported") def render(self, camera_id=None, segmentation=False): width, height = self.offwidth, self.offheight rect = mujoco.MjrRect(left=0, bottom=0, width=width, height=height) if camera_id is not None: if camera_id == -1: self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE else: self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED self.cam.fixedcamid = camera_id mujoco.mjv_updateScene( self.model, self.data, self.vopt, self.pert, self.cam, mujoco.mjtCatBit.mjCAT_ALL, self.scn, ) if segmentation: self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 1 self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 1 for marker_params in self._markers: self._add_marker_to_scene(marker_params) mujoco.mjr_render(rect, self.scn, self.con) for gridpos, (text1, text2) in self._overlays.items(): mujoco.mjr_overlay( mujoco.mjtFontScale.mjFONTSCALE_150, gridpos, rect, text1.encode(), text2.encode(), self.con, ) if segmentation: self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 0 self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 0 def read_pixels(self, depth=True, segmentation=False): width, height = self.offwidth, self.offheight rect = mujoco.MjrRect(left=0, bottom=0, width=width, height=height) rgb_arr = np.zeros(3 * rect.width * rect.height, dtype=np.uint8) depth_arr = np.zeros(rect.width * rect.height, dtype=np.float32) mujoco.mjr_readPixels(rgb_arr, depth_arr, rect, self.con) rgb_img = rgb_arr.reshape(rect.height, rect.width, 3) ret_img = rgb_img if segmentation: seg_img = ( rgb_img[:, :, 0] + rgb_img[:, :, 1] * (2**8) + rgb_img[:, :, 2] * (2**16) ) seg_img[seg_img >= (self.scn.ngeom + 1)] = 0 seg_ids = np.full((self.scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32) for i in range(self.scn.ngeom): geom = self.scn.geoms[i] if geom.segid != -1: seg_ids[geom.segid + 1, 0] = geom.objtype seg_ids[geom.segid + 1, 1] = geom.objid ret_img = seg_ids[seg_img] if depth: depth_img = depth_arr.reshape(rect.height, rect.width) return (ret_img, depth_img) else: return ret_img def _init_camera(self): self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE self.cam.fixedcamid = -1 for i in range(3): self.cam.lookat[i] = np.median(self.data.geom_xpos[:, i]) self.cam.distance = self.model.stat.extent def add_overlay(self, gridpos: int, text1: str, text2: str): """Overlays text on the scene.""" if gridpos not in self._overlays: self._overlays[gridpos] = ["", ""] self._overlays[gridpos][0] += text1 + "\n" self._overlays[gridpos][1] += text2 + "\n" def add_marker(self, **marker_params): self._markers.append(marker_params) def _add_marker_to_scene(self, marker): if self.scn.ngeom >= self.scn.maxgeom: raise RuntimeError("Ran out of geoms. maxgeom: %d" % self.scn.maxgeom) g = self.scn.geoms[self.scn.ngeom] # default values. g.dataid = -1 g.objtype = mujoco.mjtObj.mjOBJ_UNKNOWN g.objid = -1 g.category = mujoco.mjtCatBit.mjCAT_DECOR g.texid = -1 g.texuniform = 0 g.texrepeat[0] = 1 g.texrepeat[1] = 1 g.emission = 0 g.specular = 0.5 g.shininess = 0.5 g.reflectance = 0 g.type = mujoco.mjtGeom.mjGEOM_BOX g.size[:] = np.ones(3) * 0.1 g.mat[:] = np.eye(3) g.rgba[:] = np.ones(4) for key, value in marker.items(): if isinstance(value, (int, float, mujoco._enums.mjtGeom)): setattr(g, key, value) elif isinstance(value, (tuple, list, np.ndarray)): attr = getattr(g, key) attr[:] = np.asarray(value).reshape(attr.shape) elif isinstance(value, str): assert key == "label", "Only label is a string in mjtGeom." if value is None: g.label[0] = 0 else: g.label = value elif hasattr(g, key): raise ValueError( "mjtGeom has attr {} but type {} is invalid".format( key, type(value) ) ) else: raise ValueError("mjtGeom doesn't have field %s" % key) self.scn.ngeom += 1 def close(self): """Override close in your rendering subclass to perform any necessary cleanup after env.close() is called. """ pass class RenderContextOffscreen(RenderContext): """Offscreen rendering class with opengl context.""" def __init__(self, model, data): # We must make GLContext before MjrContext width = model.vis.global_.offwidth height = model.vis.global_.offheight self._get_opengl_backend(width, height) self.opengl_context.make_current() super().__init__(model, data, offscreen=True) def _get_opengl_backend(self, width, height): backend = os.environ.get("MUJOCO_GL") if backend is not None: try: self.opengl_context = _ALL_RENDERERS[backend](width, height) except KeyError: raise RuntimeError( "Environment variable {} must be one of {!r}: got {!r}.".format( "MUJOCO_GL", _ALL_RENDERERS.keys(), backend ) ) else: for name, _ in _ALL_RENDERERS.items(): try: self.opengl_context = _ALL_RENDERERS[name](width, height) backend = name break except: # noqa:E722 pass if backend is None: raise RuntimeError( "No OpenGL backend could be imported. Attempting to create a " "rendering context will result in a RuntimeError." ) class Viewer(RenderContext): """Class for window rendering in all MuJoCo environments.""" def __init__(self, model, data): self._gui_lock = Lock() self._button_left_pressed = False self._button_right_pressed = False self._last_mouse_x = 0 self._last_mouse_y = 0 self._paused = False self._transparent = False self._contacts = False self._render_every_frame = True self._image_idx = 0 self._image_path = "/tmp/frame_%07d.png" self._time_per_render = 1 / 60.0 self._run_speed = 1.0 self._loop_count = 0 self._advance_by_one_step = False self._hide_menu = False # glfw init glfw.init() width, height = glfw.get_video_mode(glfw.get_primary_monitor()).size self.window = glfw.create_window(width // 2, height // 2, "mujoco", None, None) glfw.make_context_current(self.window) glfw.swap_interval(1) framebuffer_width, framebuffer_height = glfw.get_framebuffer_size(self.window) window_width, _ = glfw.get_window_size(self.window) self._scale = framebuffer_width * 1.0 / window_width # set callbacks glfw.set_cursor_pos_callback(self.window, self._cursor_pos_callback) glfw.set_mouse_button_callback(self.window, self._mouse_button_callback) glfw.set_scroll_callback(self.window, self._scroll_callback) glfw.set_key_callback(self.window, self._key_callback) # get viewport self.viewport = mujoco.MjrRect(0, 0, framebuffer_width, framebuffer_height) super().__init__(model, data, offscreen=False) def _key_callback(self, window, key, scancode, action, mods): if action != glfw.RELEASE: return # Switch cameras elif key == glfw.KEY_TAB: self.cam.fixedcamid += 1 self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED if self.cam.fixedcamid >= self.model.ncam: self.cam.fixedcamid = -1 self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE # Pause simulation elif key == glfw.KEY_SPACE and self._paused is not None: self._paused = not self._paused # Advances simulation by one step. elif key == glfw.KEY_RIGHT and self._paused is not None: self._advance_by_one_step = True self._paused = True # Slows down simulation elif key == glfw.KEY_S: self._run_speed /= 2.0 # Speeds up simulation elif key == glfw.KEY_F: self._run_speed *= 2.0 # Turn off / turn on rendering every frame. elif key == glfw.KEY_D: self._render_every_frame = not self._render_every_frame # Capture screenshot elif key == glfw.KEY_T: img = np.zeros( ( glfw.get_framebuffer_size(self.window)[1], glfw.get_framebuffer_size(self.window)[0], 3, ), dtype=np.uint8, ) mujoco.mjr_readPixels(img, None, self.viewport, self.con) imageio.imwrite(self._image_path % self._image_idx, np.flipud(img)) self._image_idx += 1 # Display contact forces elif key == glfw.KEY_C: self._contacts = not self._contacts self.vopt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = self._contacts self.vopt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = self._contacts # Display coordinate frames elif key == glfw.KEY_E: self.vopt.frame = 1 - self.vopt.frame # Hide overlay menu elif key == glfw.KEY_H: self._hide_menu = not self._hide_menu # Make transparent elif key == glfw.KEY_R: self._transparent = not self._transparent if self._transparent: self.model.geom_rgba[:, 3] /= 5.0 else: self.model.geom_rgba[:, 3] *= 5.0 # Geom group visibility elif key in (glfw.KEY_0, glfw.KEY_1, glfw.KEY_2, glfw.KEY_3, glfw.KEY_4): self.vopt.geomgroup[key - glfw.KEY_0] ^= 1 # Quit if key == glfw.KEY_ESCAPE: print("Pressed ESC") print("Quitting.") glfw.destroy_window(self.window) glfw.terminate() def _cursor_pos_callback(self, window, xpos, ypos): if not (self._button_left_pressed or self._button_right_pressed): return mod_shift = ( glfw.get_key(window, glfw.KEY_LEFT_SHIFT) == glfw.PRESS or glfw.get_key(window, glfw.KEY_RIGHT_SHIFT) == glfw.PRESS ) if self._button_right_pressed: action = ( mujoco.mjtMouse.mjMOUSE_MOVE_H if mod_shift else mujoco.mjtMouse.mjMOUSE_MOVE_V ) elif self._button_left_pressed: action = ( mujoco.mjtMouse.mjMOUSE_ROTATE_H if mod_shift else mujoco.mjtMouse.mjMOUSE_ROTATE_V ) else: action = mujoco.mjtMouse.mjMOUSE_ZOOM dx = int(self._scale * xpos) - self._last_mouse_x dy = int(self._scale * ypos) - self._last_mouse_y width, height = glfw.get_framebuffer_size(window) with self._gui_lock: mujoco.mjv_moveCamera( self.model, action, dx / height, dy / height, self.scn, self.cam ) self._last_mouse_x = int(self._scale * xpos) self._last_mouse_y = int(self._scale * ypos) def _mouse_button_callback(self, window, button, act, mods): self._button_left_pressed = ( glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS ) self._button_right_pressed = ( glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_RIGHT) == glfw.PRESS ) x, y = glfw.get_cursor_pos(window) self._last_mouse_x = int(self._scale * x) self._last_mouse_y = int(self._scale * y) def _scroll_callback(self, window, x_offset, y_offset): with self._gui_lock: mujoco.mjv_moveCamera( self.model, mujoco.mjtMouse.mjMOUSE_ZOOM, 0, -0.05 * y_offset, self.scn, self.cam, ) def _create_overlay(self): topleft = mujoco.mjtGridPos.mjGRID_TOPLEFT bottomleft = mujoco.mjtGridPos.mjGRID_BOTTOMLEFT if self._render_every_frame: self.add_overlay(topleft, "", "") else: self.add_overlay( topleft, "Run speed = %.3f x real time" % self._run_speed, "[S]lower, [F]aster", ) self.add_overlay( topleft, "Ren[d]er every frame", "On" if self._render_every_frame else "Off" ) self.add_overlay( topleft, "Switch camera (#cams = %d)" % (self.model.ncam + 1), "[Tab] (camera ID = %d)" % self.cam.fixedcamid, ) self.add_overlay(topleft, "[C]ontact forces", "On" if self._contacts else "Off") self.add_overlay(topleft, "T[r]ansparent", "On" if self._transparent else "Off") if self._paused is not None: if not self._paused: self.add_overlay(topleft, "Stop", "[Space]") else: self.add_overlay(topleft, "Start", "[Space]") self.add_overlay( topleft, "Advance simulation by one step", "[right arrow]" ) self.add_overlay( topleft, "Referenc[e] frames", "On" if self.vopt.frame == 1 else "Off" ) self.add_overlay(topleft, "[H]ide Menu", "") if self._image_idx > 0: fname = self._image_path % (self._image_idx - 1) self.add_overlay(topleft, "Cap[t]ure frame", "Saved as %s" % fname) else: self.add_overlay(topleft, "Cap[t]ure frame", "") self.add_overlay(topleft, "Toggle geomgroup visibility", "0-4") self.add_overlay(bottomleft, "FPS", "%d%s" % (1 / self._time_per_render, "")) self.add_overlay( bottomleft, "Solver iterations", str(self.data.solver_iter + 1) ) self.add_overlay( bottomleft, "Step", str(round(self.data.time / self.model.opt.timestep)) ) self.add_overlay(bottomleft, "timestep", "%.5f" % self.model.opt.timestep) def render(self): # mjv_updateScene, mjr_render, mjr_overlay def update(): # fill overlay items self._create_overlay() render_start = time.time() if self.window is None: return elif glfw.window_should_close(self.window): glfw.destroy_window(self.window) glfw.terminate() self.viewport.width, self.viewport.height = glfw.get_framebuffer_size( self.window ) with self._gui_lock: # update scene mujoco.mjv_updateScene( self.model, self.data, self.vopt, mujoco.MjvPerturb(), self.cam, mujoco.mjtCatBit.mjCAT_ALL.value, self.scn, ) # marker items for marker in self._markers: self._add_marker_to_scene(marker) # render mujoco.mjr_render(self.viewport, self.scn, self.con) # overlay items if not self._hide_menu: for gridpos, [t1, t2] in self._overlays.items(): mujoco.mjr_overlay( mujoco.mjtFontScale.mjFONTSCALE_150, gridpos, self.viewport, t1, t2, self.con, ) glfw.swap_buffers(self.window) glfw.poll_events() self._time_per_render = 0.9 * self._time_per_render + 0.1 * ( time.time() - render_start ) # clear overlay self._overlays.clear() if self._paused: while self._paused: update() if self._advance_by_one_step: self._advance_by_one_step = False break else: self._loop_count += self.model.opt.timestep / ( self._time_per_render * self._run_speed ) if self._render_every_frame: self._loop_count = 1 while self._loop_count > 0: update() self._loop_count -= 1 # clear markers self._markers[:] = [] def close(self): glfw.destroy_window(self.window) glfw.terminate()