|
#pragma once |
|
#include <ATen/xpu/XPUContext.h> |
|
|
|
#include <optional> |
|
|
|
namespace at::xpu { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_XPU_API XPUEvent { |
|
|
|
XPUEvent(bool enable_timing = false) noexcept |
|
: enable_timing_{enable_timing} {} |
|
|
|
~XPUEvent() { |
|
if (isCreated()) { |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_deletion( |
|
at::kXPU, reinterpret_cast<uintptr_t>(event_.get())); |
|
} |
|
} |
|
} |
|
|
|
XPUEvent(const XPUEvent&) = delete; |
|
XPUEvent& operator=(const XPUEvent&) = delete; |
|
|
|
XPUEvent(XPUEvent&& other) = default; |
|
XPUEvent& operator=(XPUEvent&& other) = default; |
|
|
|
operator sycl::event&() const { |
|
return event(); |
|
} |
|
|
|
std::optional<at::Device> device() const { |
|
if (isCreated()) { |
|
return at::Device(at::kXPU, device_index_); |
|
} else { |
|
return std::nullopt; |
|
} |
|
} |
|
|
|
inline bool isCreated() const { |
|
return (event_.get() != nullptr); |
|
} |
|
|
|
DeviceIndex device_index() const { |
|
return device_index_; |
|
} |
|
|
|
sycl::event& event() const { |
|
return *event_; |
|
} |
|
|
|
bool query() const { |
|
using namespace sycl::info; |
|
if (!isCreated()) { |
|
return true; |
|
} |
|
|
|
return event().get_info<event::command_execution_status>() == |
|
event_command_status::complete; |
|
} |
|
|
|
void record() { |
|
record(getCurrentXPUStream()); |
|
} |
|
|
|
void recordOnce(const XPUStream& stream) { |
|
if (!isCreated()) { |
|
record(stream); |
|
} |
|
} |
|
|
|
void record(const XPUStream& stream) { |
|
if (!isCreated()) { |
|
device_index_ = stream.device_index(); |
|
event_ = std::make_unique<sycl::event>( |
|
stream.queue().ext_oneapi_submit_barrier()); |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_creation( |
|
at::kXPU, reinterpret_cast<uintptr_t>(event_.get())); |
|
} |
|
} else { |
|
TORCH_CHECK( |
|
device_index_ == stream.device_index(), |
|
"Event device ", |
|
device_index_, |
|
" does not match recording stream's device ", |
|
stream.device_index(), |
|
"."); |
|
event_.reset(); |
|
event_ = std::make_unique<sycl::event>( |
|
stream.queue().ext_oneapi_submit_barrier()); |
|
} |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_record( |
|
at::kXPU, |
|
reinterpret_cast<uintptr_t>(event_.get()), |
|
reinterpret_cast<uintptr_t>(&stream.queue())); |
|
} |
|
} |
|
|
|
void block(const XPUStream& stream) { |
|
if (isCreated()) { |
|
std::vector<sycl::event> event_list{event()}; |
|
|
|
stream.queue().ext_oneapi_submit_barrier(event_list); |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_wait( |
|
at::kXPU, |
|
reinterpret_cast<uintptr_t>(event_.get()), |
|
reinterpret_cast<uintptr_t>(&stream.queue())); |
|
} |
|
} |
|
} |
|
|
|
float elapsed_time(const XPUEvent& other) const { |
|
TORCH_CHECK( |
|
isCreated() && other.isCreated(), |
|
"Both events must be recorded before calculating elapsed time."); |
|
TORCH_CHECK( |
|
query() && other.query(), |
|
"Both events must be completed before calculating elapsed time."); |
|
TORCH_CHECK( |
|
enable_timing_ && other.enable_timing_, |
|
"Both events must be created with argument 'enable_timing=True'."); |
|
|
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED( |
|
false, "elapsed_time is not supported by XPUEvent."); |
|
} |
|
|
|
void synchronize() const { |
|
if (isCreated()) { |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_synchronization( |
|
at::kXPU, reinterpret_cast<uintptr_t>(event_.get())); |
|
} |
|
event().wait_and_throw(); |
|
} |
|
} |
|
|
|
private: |
|
bool enable_timing_ = false; |
|
DeviceIndex device_index_ = -1; |
|
|
|
|
|
std::unique_ptr<sycl::event> event_; |
|
}; |
|
|
|
} |
|
|