|
#include <algorithm> |
|
#include "Optimiser.h" |
|
#include "Hildreth.h" |
|
#include "moses/StaticData.h" |
|
|
|
using namespace Moses; |
|
using namespace std; |
|
|
|
namespace Mira |
|
{ |
|
|
|
size_t MiraOptimiser::updateWeights( |
|
ScoreComponentCollection& weightUpdate, |
|
const vector<vector<ScoreComponentCollection> >& featureValues, |
|
const vector<vector<float> >& losses, |
|
const vector<vector<float> >& bleuScores, |
|
const vector<vector<float> >& modelScores, |
|
const vector<ScoreComponentCollection>& oracleFeatureValues, |
|
const vector<float> oracleBleuScores, |
|
const vector<float> oracleModelScores, |
|
float learning_rate, |
|
size_t rank, |
|
size_t epoch) |
|
{ |
|
|
|
|
|
vector<ScoreComponentCollection> featureValueDiffs; |
|
vector<float> lossMinusModelScoreDiffs; |
|
vector<float> all_losses; |
|
|
|
|
|
ScoreComponentCollection max_batch_featureValueDiff; |
|
|
|
|
|
float epsilon = 0.0001; |
|
int violatedConstraintsBefore = 0; |
|
float oldDistanceFromOptimum = 0; |
|
|
|
for (size_t i = 0; i < featureValues.size(); ++i) { |
|
|
|
|
|
for (size_t j = 0; j < featureValues[i].size(); ++j) { |
|
ScoreComponentCollection featureValueDiff = oracleFeatureValues[i]; |
|
featureValueDiff.MinusEquals(featureValues[i][j]); |
|
|
|
|
|
if (featureValueDiff.GetL1Norm() == 0) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", features equal --> skip" << endl; |
|
continue; |
|
} |
|
|
|
float loss = losses[i][j]; |
|
|
|
|
|
bool violated = false; |
|
|
|
float modelScoreDiff = oracleModelScores[i] - modelScores[i][j]; |
|
float diff = 0; |
|
|
|
if (loss > modelScoreDiff) |
|
diff = loss - modelScoreDiff; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", constraint: " << modelScoreDiff << " >= " << loss << " (current violation: " << diff << ")" << endl; |
|
if (diff > epsilon) |
|
violated = true; |
|
|
|
if (m_normaliseMargin) { |
|
modelScoreDiff = (2*m_sigmoidParam/(1 + exp(-modelScoreDiff))) - m_sigmoidParam; |
|
loss = (2*m_sigmoidParam/(1 + exp(-loss))) - m_sigmoidParam; |
|
diff = 0; |
|
if (loss > modelScoreDiff) { |
|
diff = loss - modelScoreDiff; |
|
} |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", normalised constraint: " << modelScoreDiff << " >= " << loss << " (current violation: " << diff << ")" << endl; |
|
} |
|
|
|
if (m_scale_margin) { |
|
diff *= oracleBleuScores[i]; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", scaling margin with oracle bleu score " << oracleBleuScores[i] << endl; |
|
} |
|
|
|
featureValueDiffs.push_back(featureValueDiff); |
|
lossMinusModelScoreDiffs.push_back(diff); |
|
all_losses.push_back(loss); |
|
if (violated) { |
|
++violatedConstraintsBefore; |
|
oldDistanceFromOptimum += diff; |
|
} |
|
} |
|
} |
|
|
|
|
|
vector<float> alphas; |
|
ScoreComponentCollection summedUpdate; |
|
if (violatedConstraintsBefore > 0) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", number of constraints passed to optimizer: " << |
|
featureValueDiffs.size() << " (of which violated: " << violatedConstraintsBefore << ")" << endl; |
|
if (m_slack != 0) { |
|
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs, m_slack); |
|
} else { |
|
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs); |
|
} |
|
|
|
|
|
|
|
for (size_t k = 0; k < featureValueDiffs.size(); ++k) { |
|
float alpha = alphas[k]; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", alpha: " << alpha << endl; |
|
ScoreComponentCollection update(featureValueDiffs[k]); |
|
update.MultiplyEquals(alpha); |
|
|
|
|
|
summedUpdate.PlusEquals(update); |
|
} |
|
} else { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", no constraint violated for this batch" << endl; |
|
|
|
return 1; |
|
} |
|
|
|
|
|
if (learning_rate != 1) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", apply learning rate " << learning_rate << " to update." << endl; |
|
summedUpdate.MultiplyEquals(learning_rate); |
|
} |
|
|
|
|
|
if (oracleBleuScores.size() == 1) { |
|
if (m_scale_update) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", scaling summed update with oracle bleu score " << oracleBleuScores[0] << endl; |
|
summedUpdate.MultiplyEquals(oracleBleuScores[0]); |
|
} |
|
} |
|
|
|
|
|
weightUpdate.PlusEquals(summedUpdate); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0; |
|
} |
|
|
|
size_t MiraOptimiser::updateWeightsHopeFear( |
|
Moses::ScoreComponentCollection& weightUpdate, |
|
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValuesHope, |
|
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValuesFear, |
|
const std::vector<std::vector<float> >& bleuScoresHope, |
|
const std::vector<std::vector<float> >& bleuScoresFear, |
|
const std::vector<std::vector<float> >& modelScoresHope, |
|
const std::vector<std::vector<float> >& modelScoresFear, |
|
float learning_rate, |
|
size_t rank, |
|
size_t epoch, |
|
int updatePosition) |
|
{ |
|
|
|
|
|
vector<ScoreComponentCollection> featureValueDiffs; |
|
vector<float> lossMinusModelScoreDiffs; |
|
vector<float> modelScoreDiffs; |
|
vector<float> all_losses; |
|
|
|
|
|
ScoreComponentCollection max_batch_featureValueDiff; |
|
|
|
|
|
float epsilon = 0.0001; |
|
int violatedConstraintsBefore = 0; |
|
float oldDistanceFromOptimum = 0; |
|
|
|
|
|
for (size_t i = 0; i < featureValuesHope.size(); ++i) { |
|
if (updatePosition != -1) { |
|
if (int(i) < updatePosition) |
|
continue; |
|
else if (int(i) > updatePosition) |
|
break; |
|
} |
|
|
|
|
|
for (size_t j = 0; j < featureValuesHope[i].size(); ++j) { |
|
ScoreComponentCollection featureValueDiff = featureValuesHope[i][j]; |
|
featureValueDiff.MinusEquals(featureValuesFear[i][j]); |
|
|
|
if (featureValueDiff.GetL1Norm() == 0) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", features equal --> skip" << endl; |
|
continue; |
|
} |
|
|
|
float loss = bleuScoresHope[i][j] - bleuScoresFear[i][j]; |
|
|
|
|
|
bool violated = false; |
|
|
|
float modelScoreDiff = modelScoresHope[i][j] - modelScoresFear[i][j]; |
|
float diff = 0; |
|
if (loss > modelScoreDiff) |
|
diff = loss - modelScoreDiff; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", constraint: " << modelScoreDiff << " >= " << loss << " (current violation: " << diff << ")" << endl; |
|
|
|
if (diff > epsilon) |
|
violated = true; |
|
|
|
if (m_normaliseMargin) { |
|
modelScoreDiff = (2*m_sigmoidParam/(1 + exp(-modelScoreDiff))) - m_sigmoidParam; |
|
loss = (2*m_sigmoidParam/(1 + exp(-loss))) - m_sigmoidParam; |
|
diff = 0; |
|
if (loss > modelScoreDiff) { |
|
diff = loss - modelScoreDiff; |
|
} |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", normalised constraint: " << modelScoreDiff << " >= " << loss << " (current violation: " << diff << ")" << endl; |
|
} |
|
|
|
if (m_scale_margin) { |
|
diff *= bleuScoresHope[i][j]; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", scaling margin with oracle bleu score " << bleuScoresHope[i][j] << endl; |
|
} |
|
|
|
featureValueDiffs.push_back(featureValueDiff); |
|
lossMinusModelScoreDiffs.push_back(diff); |
|
modelScoreDiffs.push_back(modelScoreDiff); |
|
all_losses.push_back(loss); |
|
if (violated) { |
|
++violatedConstraintsBefore; |
|
oldDistanceFromOptimum += diff; |
|
} |
|
} |
|
} |
|
|
|
|
|
vector<float> alphas; |
|
ScoreComponentCollection summedUpdate; |
|
if (violatedConstraintsBefore > 0) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", number of constraints passed to optimizer: " << |
|
featureValueDiffs.size() << " (of which violated: " << violatedConstraintsBefore << ")" << endl; |
|
if (m_slack != 0) { |
|
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs, m_slack); |
|
} else { |
|
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs); |
|
} |
|
|
|
|
|
|
|
for (size_t k = 0; k < featureValueDiffs.size(); ++k) { |
|
float alpha = alphas[k]; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", alpha: " << alpha << endl; |
|
if (alpha != 0) { |
|
|
|
if (m_boost && modelScoreDiffs[k] <= 0) { |
|
|
|
float factor = std::min(1.5f, (float) log2(bleuScoresHope[0][0])); |
|
factor = min(3.0f, factor); |
|
alpha = alpha * factor; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", apply boosting factor " << factor << " to update." << endl; |
|
} |
|
|
|
ScoreComponentCollection update(featureValueDiffs[k]); |
|
update.MultiplyEquals(alpha); |
|
|
|
|
|
summedUpdate.PlusEquals(update); |
|
} |
|
} |
|
} else { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", no constraint violated for this batch" << endl; |
|
|
|
return 1; |
|
} |
|
|
|
|
|
if (learning_rate != 1) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", apply learning rate " << learning_rate << " to update." << endl; |
|
summedUpdate.MultiplyEquals(learning_rate); |
|
} |
|
|
|
|
|
if (featureValuesHope.size() == 1) { |
|
if (m_scale_update) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", scaling summed update with oracle bleu score " << bleuScoresHope[0][0] << endl; |
|
summedUpdate.MultiplyEquals(bleuScoresHope[0][0]); |
|
} |
|
} |
|
|
|
|
|
weightUpdate.PlusEquals(summedUpdate); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0; |
|
} |
|
|
|
size_t MiraOptimiser::updateWeightsAnalytically( |
|
ScoreComponentCollection& weightUpdate, |
|
ScoreComponentCollection& featureValuesHope, |
|
ScoreComponentCollection& featureValuesFear, |
|
float bleuScoreHope, |
|
float bleuScoreFear, |
|
float modelScoreHope, |
|
float modelScoreFear, |
|
float learning_rate, |
|
size_t rank, |
|
size_t epoch) |
|
{ |
|
|
|
float epsilon = 0.0001; |
|
float oldDistanceFromOptimum = 0; |
|
bool constraintViolatedBefore = false; |
|
|
|
|
|
|
|
ScoreComponentCollection featureValueDiff = featureValuesHope; |
|
featureValueDiff.MinusEquals(featureValuesFear); |
|
if (featureValueDiff.GetL1Norm() == 0) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", features equal --> skip" << endl; |
|
return 1; |
|
} |
|
|
|
|
|
|
|
float modelScoreDiff = modelScoreHope - modelScoreFear; |
|
float loss = bleuScoreHope - bleuScoreFear; |
|
float diff = 0; |
|
if (loss > modelScoreDiff) |
|
diff = loss - modelScoreDiff; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", constraint: " << modelScoreDiff << " >= " << loss << " (current violation: " << diff << ")" << endl; |
|
|
|
if (m_normaliseMargin) { |
|
modelScoreDiff = (2*m_sigmoidParam/(1 + exp(-modelScoreDiff))) - m_sigmoidParam; |
|
loss = (2*m_sigmoidParam/(1 + exp(-loss))) - m_sigmoidParam; |
|
if (loss > modelScoreDiff) |
|
diff = loss - modelScoreDiff; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", normalised constraint: " << modelScoreDiff << " >= " << loss << " (current violation: " << diff << ")" << endl; |
|
} |
|
|
|
if (diff > epsilon) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
oldDistanceFromOptimum += diff; |
|
constraintViolatedBefore = true; |
|
|
|
|
|
|
|
|
|
float squaredNorm = featureValueDiff.GetL2Norm() * featureValueDiff.GetL2Norm(); |
|
|
|
float alpha = diff / squaredNorm; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", unclipped alpha: " << alpha << endl; |
|
if (m_slack > 0 ) { |
|
if (alpha > m_slack) { |
|
alpha = m_slack; |
|
} else if (alpha < m_slack*(-1)) { |
|
alpha = m_slack*(-1); |
|
} |
|
} |
|
|
|
|
|
if (learning_rate != 1) |
|
alpha = alpha * learning_rate; |
|
|
|
if (m_scale_update) { |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", scaling update with oracle bleu score " << bleuScoreHope << endl; |
|
alpha *= bleuScoreHope; |
|
} |
|
|
|
cerr << "Rank " << rank << ", epoch " << epoch << ", clipped/scaled alpha: " << alpha << endl; |
|
|
|
|
|
if (m_boost && modelScoreDiff <= 0) { |
|
|
|
float factor = min(1.5f, (float) log2(bleuScoreHope)); |
|
factor = min(3.0f, factor); |
|
alpha = alpha * factor; |
|
cerr << "Rank " << rank << ", epoch " << epoch << ", boosted alpha: " << alpha << endl; |
|
} |
|
|
|
featureValueDiff.MultiplyEquals(alpha); |
|
weightUpdate.PlusEquals(featureValueDiff); |
|
|
|
} |
|
|
|
if (!constraintViolatedBefore) { |
|
|
|
cerr << "Rank " << rank << ", epoch " << epoch << ", constraint already satisfied" << endl; |
|
return 1; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0; |
|
} |
|
|
|
} |
|
|
|
|