Spaces:
Runtime error
Runtime error
// Copyright (c) Meta Platforms, Inc. and affiliates. | |
// All rights reserved. | |
// | |
// This source code is licensed under the license found in the | |
// LICENSE file in the root directory of this source tree. | |
struct PrimSamplerDataBase { | |
typedef PrimSamplerDataBase base; | |
}; | |
template< | |
bool dowarp, | |
template<typename> class GridSamplerT=GridSamplerChlast> | |
struct PrimSamplerTW { | |
struct Data : public PrimSamplerDataBase { | |
float fadescale, fadeexp; | |
int tplate_nstride; | |
int TD, TH, TW; | |
float * tplate; | |
float * grad_tplate; | |
int warp_nstride; | |
int WD, WH, WW; | |
float * warp; | |
float * grad_warp; | |
__forceinline__ __device__ void n_stride(int n) { | |
tplate += n * tplate_nstride; | |
grad_tplate += n * tplate_nstride; | |
warp += n * warp_nstride; | |
grad_warp += n * warp_nstride; | |
} | |
}; | |
float fade; | |
float * tplate_ptr; | |
float * warp_ptr; | |
float3 yy1; | |
__forceinline__ __device__ float4 forward( | |
const Data & data, | |
int k, | |
float3 y0) { | |
fade = __expf(-data.fadescale * ( | |
__powf(abs(y0.x), data.fadeexp) + | |
__powf(abs(y0.y), data.fadeexp) + | |
__powf(abs(y0.z), data.fadeexp))); | |
if (dowarp) { | |
warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW); | |
yy1 = GridSamplerT<float3>::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false); | |
} else { | |
yy1 = y0; | |
} | |
tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW); | |
float4 sample = GridSamplerT<float4>::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false); | |
sample.w *= fade; | |
return sample; | |
} | |
__forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0, | |
float4 sample, float4 dL_sample, bool validthread) { | |
float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3( | |
__powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f), | |
__powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f), | |
__powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f)); | |
float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w; | |
dL_sample.w *= fade; | |
float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW); | |
float3 dL_y1 = GridSamplerT<float4>::backward(4, data.TD, data.TH, data.TW, | |
tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false); | |
if (dowarp) { | |
float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW); | |
dL_y0 += GridSamplerT<float3>::backward(3, data.WD, data.WH, data.WW, | |
warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false); | |
} else { | |
dL_y0 += dL_y1; | |
} | |
return dL_y0; | |
} | |
}; | |