CRSArena / script /crb_crs /redial_data_preparation.py
Nolwenn
Initial commit
b599481
raw
history blame
5.65 kB
"""Script to prepare the ReDial data for CRB-CRS model."""
import argparse
import json
import logging
import os
from typing import Any, Dict, List, Tuple
from tqdm import tqdm
from src.model.crb_crs.retriever.retriever import (
CONV_PREFIX,
CRS_PREFIX,
USER_PREFIX,
)
from src.model.crb_crs.utils_preprocessing import preprocess_utterance
ParsedDialogue = List[str]
def read_jsonl_data(path: str) -> List[Dict[str, Any]]:
"""Reads data from a jsonl file.
Args:
path: Path to the jsonl file.
Returns:
List of dictionaries.
"""
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f.readlines():
data.append(json.loads(line))
return data
def parse_dialogue(
dialogue: Dict[str, Any], idx: int
) -> Tuple[ParsedDialogue, ParsedDialogue, ParsedDialogue]:
"""Parses a dialogue.
Produce three types of parsed dialogues:
1. Parsed dialogue with original utterances.
2. Parsed dialogue with preprocessed utterances but incluiding stopwords.
3. Parsed dialogue with preprocessed utterances without stopwords.
Args:
dialogue: Dialogue.
idx: Index of the dialogue.
Returns:
List of dialogue utterances with participant prefix.
"""
parsed_dialogue_original = [f"{CONV_PREFIX} {idx}"]
parsed_dialogue_preprocessed = [f"{CONV_PREFIX} {idx}"]
parsed_dialogue_preprocessed_no_stopwords = [f"{CONV_PREFIX} {idx}"]
user_id = dialogue.get("initiatorWorkerId")
system_id = dialogue.get("respondentWorkerId")
for message in dialogue.get("messages", []):
sender_id = message.get("senderWorkerId")
utterance = message.get("text")
preprocessed_utterance = preprocess_utterance(
{"text": utterance}, "redial", no_stopwords=False
)
preprocessed_utterance_no_stopwords = preprocess_utterance(
{"text": utterance}, "redial", no_stopwords=True
)
if sender_id == user_id:
parsed_dialogue_original.append(f"{USER_PREFIX} {utterance}")
elif sender_id == system_id:
parsed_dialogue_original.append(f"{CRS_PREFIX} {utterance}")
parsed_dialogue_preprocessed.append(preprocessed_utterance)
parsed_dialogue_preprocessed_no_stopwords.append(
preprocessed_utterance_no_stopwords
)
return (
parsed_dialogue_original,
parsed_dialogue_preprocessed,
parsed_dialogue_preprocessed_no_stopwords,
)
def parse_dialogues(
dialogues: List[Dict[str, Any]]
) -> Tuple[ParsedDialogue, ParsedDialogue, ParsedDialogue]:
"""Parses dialogues.
Args:
dialogues: List of dialogues.
Returns:
List of parsed dialogues.
"""
parsed_dialogues_original = []
parsed_dialogues_preprocessed = []
parsed_dialogues_preprocessed_no_stopwords = []
for i, dialogue in enumerate(tqdm(dialogues)):
(
parsed_dialogue_original,
parsed_dialogue_preprocessed,
parsed_dialogue_preprocessed_no_stopwords,
) = parse_dialogue(dialogue, i)
parsed_dialogues_original.extend(parsed_dialogue_original)
parsed_dialogues_preprocessed.extend(parsed_dialogue_preprocessed)
parsed_dialogues_preprocessed_no_stopwords.extend(
parsed_dialogue_preprocessed_no_stopwords
)
return (
parsed_dialogues_original,
parsed_dialogues_preprocessed,
parsed_dialogues_preprocessed_no_stopwords,
)
def save_parsed_dialogues(parsed_dialogues: List[str], path: str) -> None:
"""Saves parsed dialogues to a file.
Args:
parsed_dialogues: List of parsed dialogues.
path: Path to the output file.
"""
with open(path, "w", encoding="utf-8") as f:
for utterance in parsed_dialogues:
f.write(f"{utterance}\n")
def parse_args() -> argparse.Namespace:
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Prepare ReDial data for CRB-CRS model."
)
parser.add_argument(
"--redial_folder",
type=str,
default="data/redial",
help="Path to folder with ReDial dialogues.",
)
parser.add_argument(
"--output_folder",
type=str,
default="data/redial/corpus/",
help="Path to output folder.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
raw_dialogues = []
for file in ["train_data.jsonl", "valid_data.jsonl", "test_data.jsonl"]:
if os.path.exists(os.path.join(args.redial_folder, file)):
raw_dialogues.extend(
read_jsonl_data(os.path.join(args.redial_folder, file))
)
logging.info(f"Loaded {len(raw_dialogues)} dialogues.")
(
parsed_dialogue_original,
parsed_dialogue_preprocessed,
parsed_dialogue_preprocessed_no_stopwords,
) = parse_dialogues(raw_dialogues)
logging.info("Finished parsing dialogues.")
os.makedirs(args.output_folder, exist_ok=True)
save_parsed_dialogues(
parsed_dialogue_original,
os.path.join(args.output_folder, "original_corpus.txt"),
)
save_parsed_dialogues(
parsed_dialogue_preprocessed,
os.path.join(args.output_folder, "preprocessed_corpus.txt"),
)
save_parsed_dialogues(
parsed_dialogue_preprocessed_no_stopwords,
os.path.join(
args.output_folder,
"preprocessed_corpus_no_stopwords.txt",
),
)
logging.info(f"Saved parsed dialogues to {args.output_folder}.")