|
|
|
|
|
#pragma once |
|
|
|
#include <ATen/mps/MPSStream.h> |
|
#include <ctime> |
|
#include <stack> |
|
|
|
namespace at::mps { |
|
|
|
|
|
|
|
class MPSEvent { |
|
public: |
|
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing); |
|
~MPSEvent(); |
|
|
|
|
|
void record(bool needsLock, bool syncEvent = false); |
|
|
|
bool wait(bool needsLock, bool syncEvent = false); |
|
|
|
bool notify(bool needsLock, MTLSharedEventNotificationBlock block); |
|
|
|
bool query() const; |
|
|
|
|
|
bool synchronize(); |
|
|
|
void reset(MPSStream* stream, bool enable_timing); |
|
|
|
id_t getID() const { return m_id; } |
|
|
|
uint64_t getCompletionTime() const { return m_completion_time; } |
|
|
|
void waitForCpuSync(); |
|
|
|
private: |
|
id_t m_id; |
|
|
|
bool m_enable_timing; |
|
uint64_t m_signalCounter = 0; |
|
MPSStream* m_stream = nullptr; |
|
MTLSharedEvent_t m_event = nullptr; |
|
MTLSharedEventListener* m_listener = nullptr; |
|
|
|
std::mutex m_cpu_sync_mutex{}; |
|
std::condition_variable m_cpu_sync_cv{}; |
|
|
|
bool m_cpu_sync_completed = false; |
|
|
|
uint64_t m_completion_time = 0; |
|
|
|
void recordLocked(bool syncEvent); |
|
bool waitLocked(bool syncEvent); |
|
bool notifyLocked(MTLSharedEventNotificationBlock block); |
|
void notifyCpuSync(); |
|
static uint64_t getTime() { |
|
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); |
|
} |
|
}; |
|
|
|
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr; |
|
|
|
class MPSEventPool { |
|
public: |
|
explicit MPSEventPool(MPSStream* default_stream); |
|
~MPSEventPool(); |
|
|
|
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream); |
|
void emptyCache(); |
|
|
|
|
|
id_t acquireEvent(bool enable_timing); |
|
void releaseEvent(id_t event_id); |
|
void recordEvent(id_t event_id, bool syncEvent); |
|
void waitForEvent(id_t event_id, bool syncEvent); |
|
void synchronizeEvent(id_t event_id); |
|
bool queryEvent(id_t event_id); |
|
|
|
double elapsedTime(id_t start_event_id, id_t end_event_id); |
|
|
|
private: |
|
MPSStream* m_default_stream = nullptr; |
|
std::recursive_mutex m_mutex; |
|
std::stack<std::unique_ptr<MPSEvent>> m_pool{}; |
|
|
|
|
|
|
|
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{}; |
|
uint64_t m_event_counter = 0; |
|
std::function<void(MPSEvent*)> m_default_deleter; |
|
|
|
MPSEvent* getInUseEvent(id_t event_id, bool locked = true); |
|
}; |
|
|
|
|
|
std::shared_ptr<MPSEventPool> getMPSEventPool(); |
|
|
|
} |
|
|