|
"""Script to prepare the ReDial data for CRB-CRS model.""" |
|
|
|
import argparse |
|
import json |
|
import logging |
|
import os |
|
from typing import Any, Dict, List, Tuple |
|
|
|
from tqdm import tqdm |
|
|
|
from src.model.crb_crs.retriever.retriever import ( |
|
CONV_PREFIX, |
|
CRS_PREFIX, |
|
USER_PREFIX, |
|
) |
|
from src.model.crb_crs.utils_preprocessing import preprocess_utterance |
|
|
|
ParsedDialogue = List[str] |
|
|
|
|
|
def read_jsonl_data(path: str) -> List[Dict[str, Any]]: |
|
"""Reads data from a jsonl file. |
|
|
|
Args: |
|
path: Path to the jsonl file. |
|
|
|
Returns: |
|
List of dictionaries. |
|
""" |
|
data = [] |
|
with open(path, "r", encoding="utf-8") as f: |
|
for line in f.readlines(): |
|
data.append(json.loads(line)) |
|
return data |
|
|
|
|
|
def parse_dialogue( |
|
dialogue: Dict[str, Any], idx: int |
|
) -> Tuple[ParsedDialogue, ParsedDialogue, ParsedDialogue]: |
|
"""Parses a dialogue. |
|
|
|
Produce three types of parsed dialogues: |
|
1. Parsed dialogue with original utterances. |
|
2. Parsed dialogue with preprocessed utterances but incluiding stopwords. |
|
3. Parsed dialogue with preprocessed utterances without stopwords. |
|
|
|
Args: |
|
dialogue: Dialogue. |
|
idx: Index of the dialogue. |
|
|
|
Returns: |
|
List of dialogue utterances with participant prefix. |
|
""" |
|
parsed_dialogue_original = [f"{CONV_PREFIX} {idx}"] |
|
parsed_dialogue_preprocessed = [f"{CONV_PREFIX} {idx}"] |
|
parsed_dialogue_preprocessed_no_stopwords = [f"{CONV_PREFIX} {idx}"] |
|
|
|
user_id = dialogue.get("initiatorWorkerId") |
|
system_id = dialogue.get("respondentWorkerId") |
|
|
|
for message in dialogue.get("messages", []): |
|
sender_id = message.get("senderWorkerId") |
|
utterance = message.get("text") |
|
preprocessed_utterance = preprocess_utterance( |
|
{"text": utterance}, "redial", no_stopwords=False |
|
) |
|
preprocessed_utterance_no_stopwords = preprocess_utterance( |
|
{"text": utterance}, "redial", no_stopwords=True |
|
) |
|
|
|
if sender_id == user_id: |
|
parsed_dialogue_original.append(f"{USER_PREFIX} {utterance}") |
|
elif sender_id == system_id: |
|
parsed_dialogue_original.append(f"{CRS_PREFIX} {utterance}") |
|
|
|
parsed_dialogue_preprocessed.append(preprocessed_utterance) |
|
parsed_dialogue_preprocessed_no_stopwords.append( |
|
preprocessed_utterance_no_stopwords |
|
) |
|
return ( |
|
parsed_dialogue_original, |
|
parsed_dialogue_preprocessed, |
|
parsed_dialogue_preprocessed_no_stopwords, |
|
) |
|
|
|
|
|
def parse_dialogues( |
|
dialogues: List[Dict[str, Any]] |
|
) -> Tuple[ParsedDialogue, ParsedDialogue, ParsedDialogue]: |
|
"""Parses dialogues. |
|
|
|
Args: |
|
dialogues: List of dialogues. |
|
|
|
Returns: |
|
List of parsed dialogues. |
|
""" |
|
parsed_dialogues_original = [] |
|
parsed_dialogues_preprocessed = [] |
|
parsed_dialogues_preprocessed_no_stopwords = [] |
|
|
|
for i, dialogue in enumerate(tqdm(dialogues)): |
|
( |
|
parsed_dialogue_original, |
|
parsed_dialogue_preprocessed, |
|
parsed_dialogue_preprocessed_no_stopwords, |
|
) = parse_dialogue(dialogue, i) |
|
|
|
parsed_dialogues_original.extend(parsed_dialogue_original) |
|
parsed_dialogues_preprocessed.extend(parsed_dialogue_preprocessed) |
|
parsed_dialogues_preprocessed_no_stopwords.extend( |
|
parsed_dialogue_preprocessed_no_stopwords |
|
) |
|
|
|
return ( |
|
parsed_dialogues_original, |
|
parsed_dialogues_preprocessed, |
|
parsed_dialogues_preprocessed_no_stopwords, |
|
) |
|
|
|
|
|
def save_parsed_dialogues(parsed_dialogues: List[str], path: str) -> None: |
|
"""Saves parsed dialogues to a file. |
|
|
|
Args: |
|
parsed_dialogues: List of parsed dialogues. |
|
path: Path to the output file. |
|
""" |
|
with open(path, "w", encoding="utf-8") as f: |
|
for utterance in parsed_dialogues: |
|
f.write(f"{utterance}\n") |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""Parses command-line arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description="Prepare ReDial data for CRB-CRS model." |
|
) |
|
parser.add_argument( |
|
"--redial_folder", |
|
type=str, |
|
default="data/redial", |
|
help="Path to folder with ReDial dialogues.", |
|
) |
|
parser.add_argument( |
|
"--output_folder", |
|
type=str, |
|
default="data/redial/corpus/", |
|
help="Path to output folder.", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
raw_dialogues = [] |
|
for file in ["train_data.jsonl", "valid_data.jsonl", "test_data.jsonl"]: |
|
if os.path.exists(os.path.join(args.redial_folder, file)): |
|
raw_dialogues.extend( |
|
read_jsonl_data(os.path.join(args.redial_folder, file)) |
|
) |
|
|
|
logging.info(f"Loaded {len(raw_dialogues)} dialogues.") |
|
|
|
( |
|
parsed_dialogue_original, |
|
parsed_dialogue_preprocessed, |
|
parsed_dialogue_preprocessed_no_stopwords, |
|
) = parse_dialogues(raw_dialogues) |
|
|
|
logging.info("Finished parsing dialogues.") |
|
|
|
os.makedirs(args.output_folder, exist_ok=True) |
|
save_parsed_dialogues( |
|
parsed_dialogue_original, |
|
os.path.join(args.output_folder, "original_corpus.txt"), |
|
) |
|
save_parsed_dialogues( |
|
parsed_dialogue_preprocessed, |
|
os.path.join(args.output_folder, "preprocessed_corpus.txt"), |
|
) |
|
save_parsed_dialogues( |
|
parsed_dialogue_preprocessed_no_stopwords, |
|
os.path.join( |
|
args.output_folder, |
|
"preprocessed_corpus_no_stopwords.txt", |
|
), |
|
) |
|
|
|
logging.info(f"Saved parsed dialogues to {args.output_folder}.") |
|
|