Gen1 / misc /pyBulletSimRecorder.py
gensim2's picture
init
ff66cf3
import pybullet as p
import PySimpleGUI as sg
import pickle
from os import getcwd
from urdfpy import URDF
from os.path import abspath, dirname, basename, splitext
from transforms3d.affines import decompose
from transforms3d.quaternions import mat2quat
import numpy as np
class PyBulletRecorder:
class LinkTracker:
def __init__(self,
name,
body_id,
link_id,
link_origin,
mesh_path,
mesh_scale,
mesh_material=None):
self.body_id = body_id
self.link_id = link_id
self.mesh_path = mesh_path
self.mesh_scale = mesh_scale
self.mesh_material = mesh_material
decomposed_origin = decompose(link_origin)
orn = mat2quat(decomposed_origin[1])
orn = [orn[1], orn[2], orn[3], orn[0]]
self.link_pose = [decomposed_origin[0],
orn]
self.name = name
def transform(self, position, orientation):
return p.multiplyTransforms(
position, orientation,
self.link_pose[0], self.link_pose[1],
)
def get_keyframe(self):
if self.link_id == -1:
position, orientation = p.getBasePositionAndOrientation(
self.body_id)
position, orientation = self.transform(
position=position, orientation=orientation)
else:
link_state = p.getLinkState(self.body_id,
self.link_id,
computeForwardKinematics=True)
position, orientation = self.transform(
position=link_state[4],
orientation=link_state[5])
return {
'position': list(position),
'orientation': list(orientation)
}
def __init__(self):
self.states = []
self.links = []
def register_object(self, body_id, urdf_path, global_scaling=1, color=None):
link_id_map = dict()
n = p.getNumJoints(body_id)
link_id_map[str(p.getBodyInfo(body_id)[0].decode('gb2312'))] = -1
for link_id in range(0, n):
link_id_map[str(p.getJointInfo(body_id, link_id)[
12].decode('gb2312'))] = link_id
dir_path = dirname(abspath(urdf_path))
file_name = splitext(basename(urdf_path))[0]
robot = URDF.load(urdf_path)
for link in robot.links:
# print("robot link:", body_id, link.name, link_id_map.keys())
if link.name not in link_id_map:
print("skip links !! ", link.name, link_id_map, len(robot.links), p.getBodyInfo(body_id)[0].decode('gb2312'))
continue
link_id = link_id_map[link.name]
if len(link.visuals) > 0:
for i, link_visual in enumerate(link.visuals):
mesh_material = None
if link_visual.material is not None:
mesh_material = link_visual.material
if color is not None:
mesh_material.name = mesh_material.name + f"_{np.random.randint(100)}" # mark it
mesh_material.color = color
if link_visual.geometry.mesh is not None:
print("use mesh", i, link_id_map.keys())
mesh_scale = [global_scaling,
global_scaling, global_scaling]\
if link_visual.geometry.mesh.scale is None \
else link_visual.geometry.mesh.scale * global_scaling
self.links.append(('mesh',
PyBulletRecorder.LinkTracker(
name=file_name + f'_{body_id}_{link.name}_{i}',
body_id=body_id,
link_id=link_id,
link_origin= # If link_id == -1 then is base link,
# PyBullet will return
# inertial_origin @ visual_origin,
# so need to undo that transform
(np.linalg.inv(link.inertial.origin)
if link_id == -1
else np.identity(4)) @
link_visual.origin * global_scaling,
mesh_path=dir_path + '/' +
link_visual.geometry.mesh.filename,
mesh_scale=mesh_scale,
mesh_material=mesh_material)))
if link_visual.geometry.box is not None:
print("use box", i, link_id_map.keys(), link_visual.geometry.box.__dict__)
# import IPython; IPython.embed()
mesh_scale = link_visual.geometry.box.size / 2
self.links.append(('box',
PyBulletRecorder.LinkTracker(
name=file_name + f'_{body_id}_{link.name}_{i}',
body_id=body_id,
link_id=link_id,
link_origin= (np.linalg.inv(link.inertial.origin)
if link_id == -1
else np.identity(4)) @
link_visual.origin * global_scaling,
mesh_path='box',
mesh_scale=mesh_scale,
mesh_material=mesh_material)))
if link_visual.geometry.cylinder is not None:
print("use cylinder", i, link_id_map.keys(), link_visual.geometry.cylinder.__dict__)
mesh_scale = [link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.length]
self.links.append(('cylinder',
PyBulletRecorder.LinkTracker(
name=file_name + f'_{body_id}_{link.name}_{i}',
body_id=body_id,
link_id=link_id,
link_origin= (np.linalg.inv(link.inertial.origin)
if link_id == -1
else np.identity(4)) @
link_visual.origin * global_scaling,
mesh_path='cylinder',
mesh_scale=mesh_scale,
mesh_material=mesh_material)))
if link_visual.geometry.sphere is not None:
print("use sphere", i, link_id_map.keys(), link_visual.geometry.sphere.__dict__)
mesh_scale = [link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius]
self.links.append(('sphere',
PyBulletRecorder.LinkTracker(
name=file_name + f'_{body_id}_{link.name}_{i}',
body_id=body_id,
link_id=link_id,
link_origin= (np.linalg.inv(link.inertial.origin)
if link_id == -1
else np.identity(4)) @
link_visual.origin * global_scaling,
mesh_path='sphere',
mesh_scale=mesh_scale,
mesh_material=mesh_material)))
def add_keyframe(self):
# Ideally, call every p.stepSimulation()
current_state = {}
for name, link in self.links:
current_state[link.name] = link.get_keyframe()
self.states.append(current_state)
def prompt_save(self):
layout = [[sg.Text('Do you want to save previous episode?')],
[sg.Button('Yes'), sg.Button('No')]]
window = sg.Window('PyBullet Recorder', layout)
save = False
while True:
event, values = window.read()
if event in (None, 'No'):
break
elif event == 'Yes':
save = True
break
window.close()
if save:
layout = [[sg.Text('Where do you want to save it?')],
[sg.Text('Path'), sg.InputText(getcwd())],
[sg.Button('OK')]]
window = sg.Window('PyBullet Recorder', layout)
event, values = window.read()
window.close()
self.save(values[0])
self.reset()
def reset(self):
self.states = []
def get_formatted_output(self):
retval = {}
for geo_name, link in self.links:
if geo_name == 'mesh':
retval[link.name] = {
'type': 'mesh',
'mesh_path': link.mesh_path,
'mesh_scale': link.mesh_scale,
'frames': [state[link.name] for state in self.states]
}
if geo_name == 'box':
# print("retval: box!")
retval[link.name] = {
'type': 'cube',
'name': link.name,
'mesh_scale': link.mesh_scale,
'frames': [state[link.name] for state in self.states]
}
if geo_name == 'cylinder':
retval[link.name] = {
'type': 'cylinder',
'name': link.name,
'mesh_scale': link.mesh_scale,
'frames': [state[link.name] for state in self.states]
}
if geo_name == 'sphere':
retval[link.name] = {
'type': 'sphere',
'name': link.name,
'mesh_scale': link.mesh_scale,
'frames': [state[link.name] for state in self.states]
}
if link.mesh_material is not None:
retval[link.name]['mesh_material_name'] = link.mesh_material.name
retval[link.name] ['mesh_material_color'] = link.mesh_material.color
return retval
def save(self, path):
if path is None:
print("[Recorder] Path is None.. not saving")
else:
print("[Recorder] Saving state to {}".format(path))
pickle.dump(self.get_formatted_output(), open(path, 'wb'))