File size: 7,473 Bytes
7e50900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#pragma once

#include <ATen/core/Tensor.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/ScalarOps.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/result_type.h>
#endif

namespace at {
namespace native {

// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
// match, will change them to be a common super type so comparisons are done between the same types.
// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
// corresponding raw_* version should be used since it was already contiguous of the right type.
inline void searchsorted_maybe_trim_input_tensors(
    Tensor& trimmed_input,
    Tensor& trimmed_boundaries,
    Tensor& trimmed_sorter,
    const Tensor& raw_input,
    const Tensor& raw_boundaries,
    const Tensor& raw_sorter) {
  bool in_is_contiguous = raw_input.is_contiguous();
  bool bd_is_contiguous = raw_boundaries.is_contiguous();
  bool sort_is_contiguous = raw_sorter.is_contiguous();

  if (!in_is_contiguous) {
    TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
      "tensor if possible. This message will only appear once per program.");
    trimmed_input = raw_input.contiguous();
  }
  if (!bd_is_contiguous) {
    TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
      "tensor if possible. This message will only appear once per program.");
    trimmed_boundaries = raw_boundaries.contiguous();
  }
  if (!sort_is_contiguous) {
    TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
      "tensor if possible. This message will only appear once per program.");
    trimmed_sorter = raw_sorter.contiguous();
  }
  if (raw_input.dtype() != raw_boundaries.dtype()) {
    at::native::ResultTypeState state = {};
    state = at::native::update_result_type_state(raw_boundaries, state);
    state = at::native::update_result_type_state(raw_input, state);
    ScalarType common_stype = at::native::result_type(state);

    TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
    if (common_stype != raw_input.scalar_type()) {
      trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
    }
    if (common_stype != raw_boundaries.scalar_type()) {
      trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
    }
  }
}

/* unused but needed for internal jagged tensor class */
inline void searchsorted_maybe_trim_input_tensors(
    Tensor& trimmed_input,
    Tensor& trimmed_boundaries,
    const Tensor& raw_input,
    const Tensor& raw_boundaries) {
  Tensor trimmed_sorter;
  Tensor raw_sorter;
  return searchsorted_maybe_trim_input_tensors(
      trimmed_input,
      trimmed_boundaries,
      trimmed_sorter,
      raw_input,
      raw_boundaries,
      raw_sorter);
}

inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
  if (boundaries.dim() != input.dim()) {
    return false;
  }
  const auto& dims_bd = boundaries.sizes();
  const auto& dims_in = input.sizes();
  for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
    if (dims_bd[dim] != dims_in[dim]) {
      return false;
    }
  }
  return true;
}

inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
  auto tensor = c10::scalar_to_tensor(scalar, device);
  // This is to adopt the scalar promotion rules defined in native/TypeProperties.h
  // So we have the same type promotion rules as binary operations.
  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  return tensor;
}

inline void searchsorted_pre_check(
    const Tensor& boundaries,
    const Tensor& input,
    const Tensor& output,
    const bool out_int32,
    const bool right,
    const c10::optional<c10::string_view> side_opt,
    const Tensor& sorter) {
  if (side_opt) {
    const c10::string_view side = *side_opt;
    TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
      "got ", side);

    // assume the user has not explicitly set (right=False, side="right")
    TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
    "of ", side, " while right was True");
  }

  TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
    "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
    "tensor device type ", input.device());

  if (sorter.defined()) {
    TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
      "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
      "device type ", boundaries.device());

    TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
      "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());

    TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
      "dtype but got dtype ", sorter.scalar_type());
  }

  TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
    "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
    "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
    input.numel(), ")");

  TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
    "got 0 dimension");

  TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
    "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
    "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
    input.sizes());

  ScalarType output_dtype = output.scalar_type();
  TORCH_CHECK(
      (output_dtype == ScalarType::Long && !out_int32) ||
          (output_dtype == ScalarType::Int && out_int32),
      "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
      "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
      " and out_int32 flag is ", (out_int32 ? "True" : "False"));

  if (out_int32) {
    TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
      "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
      boundaries.sizes().back());
  }
}

}}