Spaces:
Runtime error
Runtime error
File size: 2,835 Bytes
382191a 998bf2a 382191a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
"""The main entry point for performing comparison on analysis_gpt_mts."""
from __future__ import annotations
import argparse
import os
import pandas as pd
from zeno_build.experiments import search_space
from zeno_build.experiments.experiment_run import ExperimentRun
from zeno_build.reporting.visualize import visualize
import config
from modeling import (
GptMtInstance,
process_data,
process_output,
)
def analysis_gpt_mt_main(
input_dir: str,
results_dir: str,
) -> None:
"""Run the analysis of GPT-MT experiment."""
# Get the dataset configuration
lang_pair_preset = config.main_space.dimensions["lang_pairs"]
if not isinstance(lang_pair_preset, search_space.Constant):
raise ValueError(
"All experiments must be run on a single set of language pairs."
)
lang_pairs = config.lang_pairs[lang_pair_preset.value]
# Load and exhaustiveize the format of the necessary data.
test_data: list[GptMtInstance] = process_data(
input_dir=input_dir,
lang_pairs=lang_pairs,
)
results: list[ExperimentRun] = []
model_presets = config.main_space.dimensions["model_preset"]
if not isinstance(model_presets, search_space.Categorical):
raise ValueError("The model presets must be a categorical parameter.")
for model_preset in model_presets.choices:
output = process_output(
input_dir=input_dir,
lang_pairs=lang_pairs,
model_preset=model_preset,
)
results.append(
ExperimentRun(model_preset, {"model_preset": model_preset}, output)
)
# Perform the visualization
df = pd.DataFrame(
{
"data": [x.data for x in test_data],
"label": [x.label for x in test_data],
"lang_pair": [x.lang_pair for x in test_data],
"doc_id": [x.doc_id for x in test_data],
}
)
labels = [x.label for x in test_data]
visualize(
df,
labels,
results,
"text-classification",
"data",
config.zeno_distill_and_metric_functions,
zeno_config={
"cache_path": os.path.join(results_dir, "zeno_cache"),
"port": 7860,
"host": "0.0.0.0",
"editable": False,
},
)
if __name__ == "__main__":
# Parse the command line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
type=str,
help="The directory of the GPT-MT repo.",
)
parser.add_argument(
"--results-dir",
type=str,
default="results",
help="The directory to store the results in.",
)
args = parser.parse_args()
analysis_gpt_mt_main(
input_dir=args.input_dir,
results_dir=args.results_dir,
)
|