Spaces:
Sleeping
Sleeping
import time | |
import torch | |
from hpc_rll.origin.td import td_lambda_error, td_lambda_data | |
from hpc_rll.rl_utils.td import TDLambda | |
from testbase import mean_relative_error, times | |
assert torch.cuda.is_available() | |
use_cuda = True | |
T = 1024 | |
B = 64 | |
def td_val(): | |
ori_value = torch.randn(T + 1, B) | |
ori_reward = torch.randn(T, B) | |
ori_weight = torch.randn(T, B) | |
hpc_value = ori_value.clone().detach() | |
hpc_reward = ori_reward.clone().detach() | |
hpc_weight = ori_weight.clone().detach() | |
hpc_td = TDLambda(T, B) | |
if use_cuda: | |
ori_value = ori_value.cuda() | |
ori_reward = ori_reward.cuda() | |
ori_weight = ori_weight.cuda() | |
hpc_value = hpc_value.cuda() | |
hpc_reward = hpc_reward.cuda() | |
hpc_weight = hpc_weight.cuda() | |
hpc_td = hpc_td.cuda() | |
ori_value.requires_grad_(True) | |
ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight)) | |
ori_loss = ori_loss.mean() | |
ori_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
hpc_value.requires_grad_(True) | |
hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight) | |
hpc_loss = hpc_loss.mean() | |
hpc_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
mre = mean_relative_error( | |
torch.flatten(ori_loss).cpu().detach().numpy(), | |
torch.flatten(hpc_loss).cpu().detach().numpy() | |
) | |
print("td fp mean_relative_error: " + str(mre)) | |
mre = mean_relative_error( | |
torch.flatten(ori_value.grad).cpu().detach().numpy(), | |
torch.flatten(hpc_value.grad).cpu().detach().numpy() | |
) | |
print("td bp mean_relative_error: " + str(mre)) | |
def td_perf(): | |
ori_value = torch.randn(T + 1, B) | |
ori_reward = torch.randn(T, B) | |
ori_weight = torch.randn(T, B) | |
hpc_value = ori_value.clone().detach() | |
hpc_reward = ori_reward.clone().detach() | |
hpc_weight = ori_weight.clone().detach() | |
hpc_td = TDLambda(T, B) | |
if use_cuda: | |
ori_value = ori_value.cuda() | |
ori_reward = ori_reward.cuda() | |
ori_weight = ori_weight.cuda() | |
hpc_value = hpc_value.cuda() | |
hpc_reward = hpc_reward.cuda() | |
hpc_weight = hpc_weight.cuda() | |
hpc_td = hpc_td.cuda() | |
ori_value.requires_grad_(True) | |
for i in range(times): | |
t = time.time() | |
ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight)) | |
ori_loss = ori_loss.mean() | |
ori_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
print('epoch: {}, original td cost time: {}'.format(i, time.time() - t)) | |
hpc_value.requires_grad_(True) | |
for i in range(times): | |
t = time.time() | |
hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight) | |
hpc_loss = hpc_loss.mean() | |
hpc_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
print('epoch: {}, hpc td cost time: {}'.format(i, time.time() - t)) | |
if __name__ == '__main__': | |
print("target problem: T = {}, B = {}".format(T, B)) | |
print("================run td validation test================") | |
td_val() | |
print("================run td performance test================") | |
td_perf() | |