File size: 2,197 Bytes
7e50900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>

#pragma once

namespace at { namespace native {
namespace {
  static C10_UNUSED void multilabel_margin_loss_shape_check(
    int64_t& nframe,
    int64_t& dim,
    const int64_t& ndims,
    TensorArg& target_arg,
    const Tensor& input,
    const Tensor& target) {
    bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
    TORCH_CHECK(
                valid_inputs,
                "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
                input.sizes());

    if (ndims <= 1) {
      nframe = 1;
      dim = ndims == 0 ? 1 : input.size(0);
      TORCH_CHECK(
                  valid_inputs && target.dim() <= 1 && target.numel() == dim,
                  "inconsistent size ",
                  target.sizes(),
                  " for ",
                  target_arg);
    } else {
      nframe = input.size(0);
      dim = input.size(1);
      TORCH_CHECK(
                  valid_inputs && target.dim() == 2 && target.size(0) == nframe &&
                  target.size(1) == dim,
                  "inconsistent size ",
                  target.sizes(),
                  " for ",
                  target_arg);
    }
  }

  static C10_UNUSED void multi_margin_loss_shape_check(
    int64_t& nframe,
    int64_t& dim,
    const int64_t& ndims,
    TensorArg& target_arg,
    const Tensor& input,
    const Tensor& target) {
    bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
    if (ndims <= 1) {
      nframe = 1;
      dim = ndims == 0 ? 1 : input.size(0);
    } else {
      nframe = input.size(0);
      dim = input.size(1);
    }

    TORCH_CHECK(
                valid_inputs,
                "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
                input.sizes());
    TORCH_CHECK(
                valid_inputs && target.dim() <= 1 && target.numel() == nframe,
                "inconsistent target size, got: ",
                target.sizes());
  }


}  // anonymous namespace
}} // namespace at::native