Spaces:
Sleeping
Sleeping
import hashlib | |
import os | |
from typing import Any, Dict, List | |
import pandas as pd | |
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains | |
from llm_studio.src.datasets.text_utils import get_tokenizer | |
from llm_studio.src.plots.text_causal_language_modeling_plots import ( | |
create_batch_prediction_df, | |
plot_validation_predictions, | |
) | |
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels | |
from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization | |
from llm_studio.src.utils.utils import PatchedAttribute | |
class Plots: | |
def plot_batch(cls, batch, cfg) -> PlotData: | |
tokenizer = get_tokenizer(cfg) | |
df = create_batch_prediction_df( | |
batch, | |
tokenizer, | |
ids_for_tokenized_text="chosen_input_ids", | |
labels_column="chosen_labels", | |
) | |
path = os.path.join(cfg.output_directory, "batch_viz.parquet") | |
df.to_parquet(path) | |
return PlotData(path, encoding="df") | |
def plot_data(cls, cfg) -> PlotData: | |
""" | |
Plots the data in a scrollable table. | |
We limit the number of rows to max 600 to avoid rendering issues in Wave. | |
As the data visualization is instantiated on every page load, we cache the | |
data visualization in a parquet file. | |
""" | |
config_id = ( | |
str(cfg.dataset.train_dataframe) | |
+ str(cfg.dataset.system_column) | |
+ str(cfg.dataset.prompt_column) | |
+ str(cfg.dataset.answer_column) | |
+ str(cfg.dataset.rejected_answer_column) | |
+ str(cfg.dataset.parent_id_column) | |
) | |
config_hash = hashlib.md5(config_id.encode()).hexdigest() | |
path = os.path.join( | |
os.path.dirname(cfg.dataset.train_dataframe), | |
f"__meta_info__{config_hash}_data_viz.parquet", | |
) | |
if os.path.exists(path): | |
return PlotData(path, encoding="df") | |
df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg) | |
conversations_chosen = get_conversation_chains( | |
df, cfg, limit_chained_samples=True | |
) | |
with PatchedAttribute( | |
cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column | |
): | |
conversations_rejected = get_conversation_chains( | |
df, cfg, limit_chained_samples=True | |
) | |
# Limit to max 15 prompt-conversation-answer rounds | |
max_conversation_length = min( | |
max( | |
[len(conversation["prompts"]) for conversation in conversations_chosen] | |
), | |
15, | |
) | |
conversations_to_display: List = [] | |
for conversation_length in range(1, max_conversation_length + 1): | |
conversations_to_display += [ | |
(conversation_chosen, conversations_rejected) | |
for conversation_chosen, conversations_rejected in zip( | |
conversations_chosen, conversations_rejected | |
) | |
if len(conversation_chosen["prompts"]) == conversation_length | |
][:5] | |
# Convert into a scrollable table by transposing the dataframe | |
df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"]) | |
i = 0 | |
for sample_number, (conversation_chosen, conversations_rejected) in enumerate( | |
conversations_to_display | |
): | |
if conversation_chosen["systems"][0] != "": | |
df_transposed.loc[i] = [ | |
sample_number, | |
"System", | |
conversation_chosen["systems"][0], | |
] | |
i += 1 | |
for prompt, answer_chosen, answer_rejected in zip( | |
conversation_chosen["prompts"], | |
conversation_chosen["answers"], | |
conversations_rejected["answers"], # type: ignore | |
): | |
df_transposed.loc[i] = [ | |
sample_number, | |
"Prompt", | |
prompt, | |
] | |
i += 1 | |
if answer_chosen == answer_rejected: | |
df_transposed.loc[i] = [ | |
sample_number, | |
"Answer", | |
answer_chosen, | |
] | |
i += 1 | |
else: | |
df_transposed.loc[i] = [ | |
sample_number, | |
"Answer Chosen", | |
answer_chosen, | |
] | |
i += 1 | |
df_transposed.loc[i] = [ | |
sample_number, | |
"Answer Rejected", | |
answer_rejected, | |
] | |
i += 1 | |
df_transposed["Content"] = df_transposed["Content"].apply( | |
format_for_markdown_visualization | |
) | |
df_transposed.to_parquet(path) | |
return PlotData(path, encoding="df") | |
def plot_validation_predictions( | |
cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str | |
) -> PlotData: | |
return plot_validation_predictions(val_outputs, cfg, val_df, mode) | |