File size: 1,683 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef _MPS_LIBRARY_H_
#define _MPS_LIBRARY_H_

#include <string>
#include <unordered_map>

#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
typedef id<MTLLibrary> MTLLibrary_t;
#else
typedef void* MTLComputePipelineState;
typedef void* MTLComputePipelineState_t;
typedef void* MTLLibrary;
typedef void* MTLLibrary_t;
#endif

class MPSLibrary {
 public:
  // disable constructor for singleton
  static MPSLibrary* createFromUrl(const std::string& library_url);
  static MPSLibrary* createFromSource(const std::string& source);
  ~MPSLibrary();

  MTLLibrary_t library() { return _library; }

  MTLComputePipelineState_t getComputePipelineState(
      const std::string& function_name);

 private:
  MTLLibrary_t _library;
  std::unordered_map<std::string, MTLComputePipelineState_t> _pso_map;
};

class MPSLibraryManager {
 public:
  // disable constructor for singleton
  MPSLibraryManager(const MPSLibraryManager&) = delete;
  MPSLibraryManager& operator=(const MPSLibraryManager&) = delete;
  MPSLibraryManager(MPSLibraryManager&&) = delete;
  MPSLibraryManager& operator=(MPSLibraryManager&&) = delete;

  static MPSLibraryManager* getInstance();

  bool hasLibrary(const std::string& name);

  MPSLibrary* getLibrary(const std::string& library_url);

  MPSLibrary* createLibraryFromSouce(const std::string& name,
                                     const std::string& sources);

  ~MPSLibraryManager();

 private:
  MPSLibraryManager();
  std::unordered_map<std::string, std::unique_ptr<MPSLibrary>> _library_map;
};
#endif