Spaces:
Runtime error
Runtime error
File size: 3,740 Bytes
57e3690 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_GEMM_HPP
#define GGML_SYCL_GEMM_HPP
#include <fstream>
#include <iostream>
#include "ggml-sycl.h"
#if GGML_SYCL_DNNL
#include "dnnl.hpp"
#include "dnnl_sycl.hpp"
class DnnlGemmWrapper {
public:
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
template<typename T>
static constexpr dt to_dt() {
if constexpr (std::is_same_v<T, float>) return dt::f32;
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
else static_assert(0);
}
static inline void row_gemm(sycl::queue& q, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
// Get the device associated with the queue
sycl::device dev = q.get_device();
// Get the context associated with the queue
sycl::context ctx = q.get_context();
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_prim.execute(stream, matmul_args);
}
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
auto const eng = stream.get_engine();
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_prim.execute(stream, matmul_args);
}
};
#endif
#endif // GGML_SYCL_GEMM_HPP
|