File size: 2,569 Bytes
4bdb245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// SPDX-License-Identifier: Apache-2.0

#include "kompute/operations/OpTensorCopy.hpp"
#include "kompute/Tensor.hpp"

namespace kp {

OpTensorCopy::OpTensorCopy(const std::vector<std::shared_ptr<Tensor>>& tensors)
{
    KP_LOG_DEBUG("Kompute OpTensorCopy constructor with params");

    this->mTensors = tensors;

    if (this->mTensors.size() < 2) {
        throw std::runtime_error(
          "Kompute OpTensorCopy called with less than 2 tensor");
    }

    kp::Tensor::TensorDataTypes dataType = this->mTensors[0]->dataType();
    uint32_t size = this->mTensors[0]->size();
    for (const std::shared_ptr<Tensor>& tensor : tensors) {
        if (tensor->dataType() != dataType) {
            throw std::runtime_error(fmt::format(
              "Attempting to copy tensors of different types from {} to {}",
              Tensor::toString(dataType),
              Tensor::toString(tensor->dataType())));
        }
        if (tensor->size() != size) {
            throw std::runtime_error(fmt::format(
              "Attempting to copy tensors of different sizes from {} to {}",
              size,
              tensor->size()));
        }
    }
}

OpTensorCopy::~OpTensorCopy()
{
    KP_LOG_DEBUG("Kompute OpTensorCopy destructor started");
}

void
OpTensorCopy::record(const vk::CommandBuffer& commandBuffer)
{
    KP_LOG_DEBUG("Kompute OpTensorCopy record called");

    // We iterate from the second tensor onwards and record a copy to all
    for (size_t i = 1; i < this->mTensors.size(); i++) {
        this->mTensors[i]->recordCopyFrom(commandBuffer, this->mTensors[0]);
    }
}

void
OpTensorCopy::preEval(const vk::CommandBuffer& /*commandBuffer*/)
{
    KP_LOG_DEBUG("Kompute OpTensorCopy preEval called");
}

void
OpTensorCopy::postEval(const vk::CommandBuffer& /*commandBuffer*/)
{
    KP_LOG_DEBUG("Kompute OpTensorCopy postEval called");

    // Do not copy on CPU side if source is storage tensor
    if (this->mTensors[0]->tensorType() == kp::Tensor::TensorTypes::eStorage)
    {
        KP_LOG_DEBUG("Kompute OpTensorCopy not copying tensor source given it's of eStorage type");
        return;
    }
    void* data = this->mTensors[0]->rawData();

    // Copy the data from the first tensor into all the tensors
    for (size_t i = 1; i < this->mTensors.size(); i++) {
        if (this->mTensors[i]->tensorType() == kp::Tensor::TensorTypes::eStorage) {
            KP_LOG_DEBUG("Kompute OpTensorCopy not copying to tensor dest given it's of eStorage type");
            continue;
        }
        this->mTensors[i]->setRawData(data);
    }
}

}