|
import argparse |
|
import json |
|
import logging |
|
import random |
|
import time |
|
import numpy as np |
|
from functools import partial |
|
from pprint import pformat |
|
|
|
from datasets import load_dataset |
|
from datasets.utils.logging import set_verbosity_info |
|
|
|
from manual_sharding import save_manual_shards |
|
from utils import get_replacements, redact_pii_batch |
|
|
|
|
|
REPONAME_TOKEN = "<reponame>" |
|
FILENAME_TOKEN = "<filename>" |
|
STARS_TOKEN = "<gh_stars>" |
|
|
|
|
|
def get_num_stars_bucket(num_stars: int) -> str: |
|
if num_stars is None or num_stars == 0: |
|
return "0" |
|
elif num_stars <= 10: |
|
return "1-10" |
|
elif num_stars <= 100: |
|
return "10-100" |
|
elif num_stars <= 1000: |
|
return "100-1000" |
|
else: |
|
return "1000+" |
|
|
|
|
|
def content_with_meta(example): |
|
res = "" |
|
if np.random.binomial(n=1, p=0.2): |
|
res += f"{REPONAME_TOKEN}{example['max_stars_repo_name']}" |
|
if np.random.binomial(n=1, p=0.2): |
|
res += f"{FILENAME_TOKEN}{example['max_stars_repo_path']}" |
|
if np.random.binomial(n=1, p=0.2): |
|
num_stars = get_num_stars_bucket(example["max_stars_count"]) |
|
res += f"{STARS_TOKEN}{num_stars}" |
|
if len(res) > 0: |
|
res += "\n" |
|
res += example["content"] |
|
|
|
return {"content_with_meta": res} |
|
|
|
|
|
def parseArgs(): |
|
parser = argparse.ArgumentParser(description="PII detection and redaction") |
|
parser.add_argument( |
|
"--dataset_name", |
|
default="bigcode/pii-for-code", |
|
type=str, |
|
help="HF repo name/path of the dataset.", |
|
) |
|
|
|
parser.add_argument( |
|
"--add_metadata", |
|
action="store_true", |
|
help="If set, we add metadata to the text", |
|
) |
|
parser.add_argument( |
|
"--num_load_proc", |
|
default=64, |
|
type=int, |
|
help="Number of processes to use for loading the dataset", |
|
) |
|
parser.add_argument( |
|
"--text_column", |
|
default="content", |
|
type=str, |
|
help="Text column to use, if will be renamed to content", |
|
) |
|
parser.add_argument( |
|
"--split", |
|
default="train", |
|
type=str, |
|
help="Dataset split to process", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
default=100, |
|
type=int, |
|
help="Batch size for the PII detection/redaction", |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
default=0, |
|
type=int, |
|
help="Seed for random", |
|
) |
|
parser.add_argument( |
|
"--num_proc", |
|
default=96, |
|
type=int, |
|
help="Number of processes to use for the PII detection/redaction", |
|
) |
|
parser.add_argument( |
|
"--no_redaction", |
|
action="store_true", |
|
help="If set, we don't perform redaction", |
|
) |
|
parser.add_argument( |
|
"--load_replacements", |
|
default=True, |
|
help="If set, we load the replacements from file replacements.json", |
|
) |
|
parser.add_argument( |
|
"--add_reference_text", |
|
default=True, |
|
type=bool, |
|
help="If True we add the reference text with PII between delimiters \ |
|
in the redacted text -used for visualization-", |
|
) |
|
parser.add_argument( |
|
"--check_all_files", |
|
action="store_true", |
|
help="If set, we check all files, not only the ones that contain PII", |
|
) |
|
parser.add_argument( |
|
"--check_sampling_size", |
|
default=0, |
|
type=int, |
|
help="Number of samples to check for PII", |
|
) |
|
|
|
parser.add_argument( |
|
"--save_mode", |
|
default="manual_shards", |
|
type=str, |
|
choices=["hub", "local", "manual_shards"], |
|
help="How to save the dataset", |
|
) |
|
parser.add_argument( |
|
"--save_mode_checks", |
|
default="hub", |
|
type=str, |
|
choices=["hub", "local", "manual_shards"], |
|
help="How to save the checks dataset", |
|
) |
|
|
|
parser.add_argument( |
|
"--target_dataset", |
|
default="bigcode-pii2", |
|
type=str, |
|
help="HF repo name of the target dataset in save_mode=hub.", |
|
) |
|
parser.add_argument( |
|
"--hub_username", |
|
default="loubnabnl", |
|
type=str, |
|
help="Username for the hub", |
|
) |
|
parser.add_argument( |
|
"--save_path_disk", |
|
default="/fsx/loubna/data/the-stack-march-no-pii", |
|
type=str, |
|
help="Path to save the dataset on disk in save_mode=local.", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def get_check_ds(ds, args): |
|
if not args.check_all_files: |
|
ds_checks = ds.filter( |
|
lambda exs: exs["modified"], |
|
batched=True, |
|
batch_size=args.batch_size, |
|
num_proc=args.num_proc, |
|
) |
|
else: |
|
ds_checks = ds |
|
if not args.check_sampling_size: |
|
sampling_size = len(ds_checks) |
|
idx_samples = random.sample( |
|
range(len(ds_checks)), min(len(ds_checks), sampling_size) |
|
) |
|
ds_checks = ds_checks.select(idx_samples) |
|
|
|
return ds_checks |
|
|
|
|
|
def main(): |
|
set_verbosity_info() |
|
args = parseArgs() |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
handlers=[ |
|
logging.FileHandler(f"logs/pii-{args.dataset_name.split('/')[-1]}.log"), |
|
logging.StreamHandler(), |
|
], |
|
) |
|
logger.info( |
|
f"** The job is running with the following arguments: **\n{args}\n **** " |
|
) |
|
|
|
logger.info(f" ===== Loading {args.dataset_name} =====") |
|
ds = load_dataset( |
|
args.dataset_name, |
|
split=args.split, |
|
use_auth_token=True, |
|
num_proc=args.num_load_proc, |
|
) |
|
if args.text_column != "content": |
|
ds = ds.rename_column(args.text_column, "content") |
|
|
|
|
|
logger.info(f" ===== Applying PII redaction =====") |
|
random.seed(args.seed) |
|
|
|
replacements = get_replacements() |
|
with open("replacements.json", "w") as f: |
|
json.dump(replacements, f) |
|
logging.info(f"Using the following replacements:\n{pformat(replacements)}") |
|
ds = ds.map( |
|
partial( |
|
redact_pii_batch, |
|
replacements=replacements, |
|
add_references=args.add_reference_text, |
|
), |
|
batched=True, |
|
batch_size=args.batch_size, |
|
num_proc=args.num_proc, |
|
) |
|
logging.info(f"Dataset info after PII redaction:\n{ds}") |
|
|
|
|
|
logger.info( |
|
f" ===== Checking {args.check_sampling_size} samples from those modified in the dataset =====" |
|
) |
|
ds_checks = get_check_ds(ds, args) |
|
|
|
|
|
if len(ds_checks) == 0: |
|
logger.info("Dataset was empty. Not saving anything.") |
|
else: |
|
logger.info(f"Checks dataset info {ds_checks}") |
|
if args.save_mode_checks == "hub": |
|
logger.info( |
|
f"Pushing the checks dataset to the Hub as {args.target_dataset}_checks" |
|
) |
|
ds_checks.push_to_hub(args.target_dataset + "_checks", private=True) |
|
|
|
elif args.save_mode_checks == "local": |
|
logger.info(f"Saving the checks dataset to disk") |
|
ds_checks.save_to_disk(args.save_path_disk + "_checks") |
|
|
|
elif args.save_mode_checks == "manual_shards": |
|
logger.info(f"Saving the checks dataset in manual shards") |
|
save_manual_shards( |
|
ds_checks, |
|
user=args.hub_username, |
|
remote_dataset_repo=args.target_dataset + "_checks", |
|
local_dir="/fsx/loubna/data/the-stack-march-no-pii_checks", |
|
) |
|
|
|
logger.info("Removing columns that are not needed for the final dataset") |
|
columns = ["content", "modified", "entities"] |
|
if args.add_reference_text: |
|
columns.append("references") |
|
ds = ds.remove_columns(columns) |
|
ds = ds.rename_column("new_content", "content") |
|
logger.info(f"Dataset info after removing columns:\n{ds}") |
|
|
|
if args.add_metadata: |
|
logger.info(f" ===== Adding metadata =====") |
|
ds = ds.map( |
|
content_with_meta, remove_columns=["content"], num_proc=args.num_proc |
|
) |
|
ds = ds.rename_column("content_with_meta", "content") |
|
|
|
|
|
if args.save_mode == "hub": |
|
logger.info( |
|
f" ===== Pushing the dataset to the Hub as: {args.target_dataset} =====" |
|
) |
|
ds.push_to_hub(args.target_dataset, private=True) |
|
|
|
elif args.save_mode == "local": |
|
logger.info(f" ===== Saving the dataset to disk =====") |
|
ds.save_to_disk(args.save_path_disk) |
|
|
|
elif args.save_mode == "manual_shards": |
|
logger.info( |
|
f" ===== Saving the dataset in manual shards to {args.save_path_disk} =====" |
|
) |
|
save_manual_shards( |
|
ds, |
|
user=args.hub_username, |
|
remote_dataset_repo="the-stack-no-pii-march", |
|
local_dir=args.save_path_disk, |
|
) |
|
|
|
logger.info(f" ===== Dataset saved successfully =====") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|