|
#pragma once |
|
|
|
#include <ATen/quantized/Quantizer.h> |
|
#include <c10/core/TensorImpl.h> |
|
#include <c10/util/Exception.h> |
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API QTensorImpl : public c10::TensorImpl { |
|
public: |
|
QTensorImpl( |
|
Storage&& storage, |
|
DispatchKeySet key_set, |
|
const caffe2::TypeMeta data_type, |
|
QuantizerPtr quantizer); |
|
|
|
|
|
QTensorImpl( |
|
ImplType type, |
|
Storage&& storage, |
|
DispatchKeySet key_set, |
|
const caffe2::TypeMeta data_type, |
|
QuantizerPtr quantizer); |
|
|
|
|
|
|
|
QuantizerPtr quantizer() { |
|
return quantizer_; |
|
} |
|
|
|
void set_quantizer_(QuantizerPtr quantizer) { |
|
quantizer_ = quantizer; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
|
const c10::VariableVersion& version_counter, |
|
bool allow_tensor_metadata_change) const override { |
|
auto impl = c10::make_intrusive<QTensorImpl>( |
|
Storage(storage()), key_set(), data_type_, quantizer_); |
|
copy_tensor_metadata( |
|
this, |
|
impl.get(), |
|
version_counter, |
|
allow_tensor_metadata_change); |
|
impl->refresh_numel(); |
|
impl->refresh_contiguous(); |
|
return impl; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
|
c10::VariableVersion&& version_counter, |
|
bool allow_tensor_metadata_change) const override { |
|
auto impl = c10::make_intrusive<QTensorImpl>( |
|
Storage(storage()), key_set(), data_type_, quantizer_); |
|
copy_tensor_metadata( |
|
this, |
|
impl.get(), |
|
std::move(version_counter), |
|
allow_tensor_metadata_change); |
|
impl->refresh_numel(); |
|
impl->refresh_contiguous(); |
|
return impl; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { |
|
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); |
|
auto q_impl = static_cast<const QTensorImpl*>(impl.get()); |
|
copy_tensor_metadata( |
|
q_impl, |
|
this, |
|
version_counter(), |
|
allow_tensor_metadata_change()); |
|
refresh_numel(); |
|
refresh_contiguous(); |
|
} |
|
|
|
private: |
|
QuantizerPtr quantizer_; |
|
|
|
const char* tensorimpl_type_name() const override; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void copy_tensor_metadata( |
|
const QTensorImpl* src_q_impl, |
|
QTensorImpl* dest_q_impl, |
|
const c10::VariableVersion& version_counter, |
|
bool allow_tensor_metadata_change) { |
|
TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change); |
|
|
|
|
|
dest_q_impl->quantizer_ = src_q_impl->quantizer_; |
|
} |
|
}; |
|
|
|
} |
|
|