File size: 1,745 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#ifndef moses_Normalizer_h
#define moses_Normalizer_h

#include <vector>
#include <algorithm>
#include "Util.h"

namespace Discriminative
{

class Normalizer
{
public:
  virtual void operator()(std::vector<float> &losses) const = 0;
  virtual ~Normalizer() {}
};

class SquaredLossNormalizer : public Normalizer
{
public:
  virtual void operator()(std::vector<float> &losses) const {
    // This is (?) a good choice for sqrt loss (default loss function in VW)

    float sum = 0;

    // clip to [0,1] and take 1-Z as non-normalized prob
    std::vector<float>::iterator it;
    for (it = losses.begin(); it != losses.end(); it++) {
      if (*it <= 0.0) *it = 1.0;
      else if (*it >= 1.0) *it = 0.0;
      else *it = 1.0 - *it;
      sum += *it;
    }

    if (! Moses::Equals(sum, 0)) {
      // normalize
      for (it = losses.begin(); it != losses.end(); it++)
        *it /= sum;
    } else {
      // sum of non-normalized probs is 0, then take uniform probs
      for (it = losses.begin(); it != losses.end(); it++)
        *it = 1.0 / losses.size();
    }
  }

  virtual ~SquaredLossNormalizer() {}
};

// safe softmax
class LogisticLossNormalizer : public Normalizer
{
public:
  virtual void operator()(std::vector<float> &losses) const {
    std::vector<float>::iterator it;

    float sum = 0;
    float max = 0;
    for (it = losses.begin(); it != losses.end(); it++) {
      *it = -*it;
      max = std::max(max, *it);
    }

    for (it = losses.begin(); it != losses.end(); it++) {
      *it = exp(*it - max);
      sum += *it;
    }

    for (it = losses.begin(); it != losses.end(); it++) {
      *it /= sum;
    }
  }

  virtual ~LogisticLossNormalizer() {}
};

} // namespace Discriminative

#endif // moses_Normalizer_h