|
import pybullet as p |
|
import PySimpleGUI as sg |
|
import pickle |
|
from os import getcwd |
|
from urdfpy import URDF |
|
from os.path import abspath, dirname, basename, splitext |
|
from transforms3d.affines import decompose |
|
from transforms3d.quaternions import mat2quat |
|
import numpy as np |
|
|
|
|
|
class PyBulletRecorder: |
|
class LinkTracker: |
|
def __init__(self, |
|
name, |
|
body_id, |
|
link_id, |
|
link_origin, |
|
mesh_path, |
|
mesh_scale, |
|
mesh_material=None): |
|
self.body_id = body_id |
|
self.link_id = link_id |
|
self.mesh_path = mesh_path |
|
self.mesh_scale = mesh_scale |
|
self.mesh_material = mesh_material |
|
decomposed_origin = decompose(link_origin) |
|
orn = mat2quat(decomposed_origin[1]) |
|
orn = [orn[1], orn[2], orn[3], orn[0]] |
|
self.link_pose = [decomposed_origin[0], |
|
orn] |
|
self.name = name |
|
|
|
def transform(self, position, orientation): |
|
return p.multiplyTransforms( |
|
position, orientation, |
|
self.link_pose[0], self.link_pose[1], |
|
) |
|
|
|
def get_keyframe(self): |
|
if self.link_id == -1: |
|
position, orientation = p.getBasePositionAndOrientation( |
|
self.body_id) |
|
position, orientation = self.transform( |
|
position=position, orientation=orientation) |
|
else: |
|
link_state = p.getLinkState(self.body_id, |
|
self.link_id, |
|
computeForwardKinematics=True) |
|
position, orientation = self.transform( |
|
position=link_state[4], |
|
orientation=link_state[5]) |
|
return { |
|
'position': list(position), |
|
'orientation': list(orientation) |
|
} |
|
|
|
def __init__(self): |
|
self.states = [] |
|
self.links = [] |
|
|
|
def register_object(self, body_id, urdf_path, global_scaling=1, color=None): |
|
link_id_map = dict() |
|
n = p.getNumJoints(body_id) |
|
link_id_map[str(p.getBodyInfo(body_id)[0].decode('gb2312'))] = -1 |
|
|
|
for link_id in range(0, n): |
|
link_id_map[str(p.getJointInfo(body_id, link_id)[ |
|
12].decode('gb2312'))] = link_id |
|
|
|
dir_path = dirname(abspath(urdf_path)) |
|
file_name = splitext(basename(urdf_path))[0] |
|
robot = URDF.load(urdf_path) |
|
for link in robot.links: |
|
|
|
if link.name not in link_id_map: |
|
print("skip links !! ", link.name, link_id_map, len(robot.links), p.getBodyInfo(body_id)[0].decode('gb2312')) |
|
continue |
|
|
|
link_id = link_id_map[link.name] |
|
|
|
if len(link.visuals) > 0: |
|
for i, link_visual in enumerate(link.visuals): |
|
mesh_material = None |
|
if link_visual.material is not None: |
|
mesh_material = link_visual.material |
|
if color is not None: |
|
mesh_material.name = mesh_material.name + f"_{np.random.randint(100)}" |
|
mesh_material.color = color |
|
|
|
if link_visual.geometry.mesh is not None: |
|
print("use mesh", i, link_id_map.keys()) |
|
|
|
mesh_scale = [global_scaling, |
|
global_scaling, global_scaling]\ |
|
if link_visual.geometry.mesh.scale is None \ |
|
else link_visual.geometry.mesh.scale * global_scaling |
|
|
|
self.links.append(('mesh', |
|
PyBulletRecorder.LinkTracker( |
|
name=file_name + f'_{body_id}_{link.name}_{i}', |
|
body_id=body_id, |
|
link_id=link_id, |
|
link_origin= |
|
|
|
|
|
|
|
(np.linalg.inv(link.inertial.origin) |
|
if link_id == -1 |
|
else np.identity(4)) @ |
|
link_visual.origin * global_scaling, |
|
mesh_path=dir_path + '/' + |
|
link_visual.geometry.mesh.filename, |
|
mesh_scale=mesh_scale, |
|
mesh_material=mesh_material))) |
|
|
|
if link_visual.geometry.box is not None: |
|
print("use box", i, link_id_map.keys(), link_visual.geometry.box.__dict__) |
|
|
|
mesh_scale = link_visual.geometry.box.size / 2 |
|
self.links.append(('box', |
|
PyBulletRecorder.LinkTracker( |
|
name=file_name + f'_{body_id}_{link.name}_{i}', |
|
body_id=body_id, |
|
link_id=link_id, |
|
link_origin= (np.linalg.inv(link.inertial.origin) |
|
if link_id == -1 |
|
else np.identity(4)) @ |
|
link_visual.origin * global_scaling, |
|
mesh_path='box', |
|
mesh_scale=mesh_scale, |
|
mesh_material=mesh_material))) |
|
|
|
|
|
if link_visual.geometry.cylinder is not None: |
|
print("use cylinder", i, link_id_map.keys(), link_visual.geometry.cylinder.__dict__) |
|
mesh_scale = [link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.length] |
|
self.links.append(('cylinder', |
|
PyBulletRecorder.LinkTracker( |
|
name=file_name + f'_{body_id}_{link.name}_{i}', |
|
body_id=body_id, |
|
link_id=link_id, |
|
link_origin= (np.linalg.inv(link.inertial.origin) |
|
if link_id == -1 |
|
else np.identity(4)) @ |
|
link_visual.origin * global_scaling, |
|
mesh_path='cylinder', |
|
mesh_scale=mesh_scale, |
|
mesh_material=mesh_material))) |
|
|
|
|
|
if link_visual.geometry.sphere is not None: |
|
print("use sphere", i, link_id_map.keys(), link_visual.geometry.sphere.__dict__) |
|
mesh_scale = [link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius] |
|
self.links.append(('sphere', |
|
PyBulletRecorder.LinkTracker( |
|
name=file_name + f'_{body_id}_{link.name}_{i}', |
|
body_id=body_id, |
|
link_id=link_id, |
|
link_origin= (np.linalg.inv(link.inertial.origin) |
|
if link_id == -1 |
|
else np.identity(4)) @ |
|
link_visual.origin * global_scaling, |
|
mesh_path='sphere', |
|
mesh_scale=mesh_scale, |
|
mesh_material=mesh_material))) |
|
|
|
def add_keyframe(self): |
|
|
|
current_state = {} |
|
for name, link in self.links: |
|
current_state[link.name] = link.get_keyframe() |
|
self.states.append(current_state) |
|
|
|
def prompt_save(self): |
|
layout = [[sg.Text('Do you want to save previous episode?')], |
|
[sg.Button('Yes'), sg.Button('No')]] |
|
window = sg.Window('PyBullet Recorder', layout) |
|
save = False |
|
while True: |
|
event, values = window.read() |
|
if event in (None, 'No'): |
|
break |
|
elif event == 'Yes': |
|
save = True |
|
break |
|
window.close() |
|
|
|
if save: |
|
layout = [[sg.Text('Where do you want to save it?')], |
|
[sg.Text('Path'), sg.InputText(getcwd())], |
|
[sg.Button('OK')]] |
|
window = sg.Window('PyBullet Recorder', layout) |
|
event, values = window.read() |
|
window.close() |
|
self.save(values[0]) |
|
self.reset() |
|
|
|
def reset(self): |
|
self.states = [] |
|
|
|
def get_formatted_output(self): |
|
retval = {} |
|
for geo_name, link in self.links: |
|
if geo_name == 'mesh': |
|
retval[link.name] = { |
|
'type': 'mesh', |
|
'mesh_path': link.mesh_path, |
|
'mesh_scale': link.mesh_scale, |
|
'frames': [state[link.name] for state in self.states] |
|
} |
|
if geo_name == 'box': |
|
|
|
retval[link.name] = { |
|
'type': 'cube', |
|
'name': link.name, |
|
'mesh_scale': link.mesh_scale, |
|
'frames': [state[link.name] for state in self.states] |
|
} |
|
if geo_name == 'cylinder': |
|
retval[link.name] = { |
|
'type': 'cylinder', |
|
'name': link.name, |
|
'mesh_scale': link.mesh_scale, |
|
'frames': [state[link.name] for state in self.states] |
|
} |
|
if geo_name == 'sphere': |
|
retval[link.name] = { |
|
'type': 'sphere', |
|
'name': link.name, |
|
'mesh_scale': link.mesh_scale, |
|
'frames': [state[link.name] for state in self.states] |
|
} |
|
if link.mesh_material is not None: |
|
retval[link.name]['mesh_material_name'] = link.mesh_material.name |
|
retval[link.name] ['mesh_material_color'] = link.mesh_material.color |
|
|
|
return retval |
|
|
|
def save(self, path): |
|
if path is None: |
|
print("[Recorder] Path is None.. not saving") |
|
else: |
|
print("[Recorder] Saving state to {}".format(path)) |
|
pickle.dump(self.get_formatted_output(), open(path, 'wb')) |
|
|