llm-studio / tests /src /metrics /test_text_causal_language_modeling_metrics.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
from unittest.mock import MagicMock
import numpy as np
import pandas as pd
import pytest
from llm_studio.src.metrics.text_causal_language_modeling_metrics import sacrebleu_score
@pytest.fixture
def mock_val_df():
return pd.DataFrame()
def test_sacrebleu_score_perfect_match(mock_val_df):
cfg = MagicMock()
results = {
"predicted_text": ["Hello world", "Python is great"],
"target_text": ["Hello world", "Python is great"],
}
scores = sacrebleu_score(cfg, results, mock_val_df)
assert np.allclose(scores, np.array([100.0, 100.0]))
def test_sacrebleu_score_partial_match(mock_val_df):
cfg = MagicMock()
results = {
"predicted_text": ["Hello universe", "Python is awesome"],
"target_text": ["Hello world", "Python is great"],
}
scores = sacrebleu_score(cfg, results, mock_val_df)
assert np.allclose(scores, np.array([50.0, 55.03212081]))
def test_sacrebleu_score_no_match(mock_val_df):
cfg = MagicMock()
results = {
"predicted_text": ["Goodbye universe", "What a day"],
"target_text": ["Hello world", "Python is great"],
}
scores = sacrebleu_score(cfg, results, mock_val_df)
assert np.allclose(scores, np.array([0.0, 0.0]))
def test_sacrebleu_score_all_empty_target(mock_val_df):
cfg = MagicMock()
results = {
"predicted_text": ["Hello world", "Python is great"],
"target_text": ["", ""],
}
scores = sacrebleu_score(cfg, results, mock_val_df)
assert np.allclose(scores, np.array([0.0, 0.0]))
def test_sacrebleu_score_one_empty_target(mock_val_df):
cfg = MagicMock()
results = {
"predicted_text": ["Hello world", "Python is great"],
"target_text": ["", "Python is great"],
}
scores = sacrebleu_score(cfg, results, mock_val_df)
assert np.allclose(scores, np.array([0.0, 100.0]))
def test_sacrebleu_score_invalid_input_empty(mock_val_df):
cfg = MagicMock()
results = {"predicted_text": [], "target_text": []}
with pytest.raises(ValueError):
sacrebleu_score(cfg, results, mock_val_df)
def test_sacrebleu_score_invalid_input_different_lengths(mock_val_df):
cfg = MagicMock()
results = {
"predicted_text": ["Hello world", "Python", "is", "great"],
"target_text": ["Hello universe", "Python is awesome"],
}
with pytest.raises(ValueError):
sacrebleu_score(cfg, results, mock_val_df)