"""samples with temperature, grouping by language code. assumes input files is sorted by language group""" import argparse import logging import random import sys def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("corpus_filepath", type=str, help="path to input corpus to sample") parser.add_argument("linecounts_filepath", type=str, help="path to file containing line counts of input corpus (from 'uniq -c')") return parser.parse_args() # def count_lines(file): # def blocks(files, size=65536): # while True: # b = files.read(size) # if not b: break # yield b # with open(file, "r",encoding="utf-8",errors='ignore') as f: # return (sum(bl.count("\n") for bl in blocks(f))) def main(): logging.basicConfig( level=logging.INFO, filename='sampling.log', filemode='w', format='%(asctime)s %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') logger = logging.getLogger(__name__) args = parse_args() logger.info(f"creating counts lookup dict from {args.linecounts_filepath}") with open(args.linecounts_filepath) as f: total_raw_lines = 0 lc_lookup = dict() for line in f: count, lang = line.strip().split(' ') count = int(count) lc_lookup[lang] = {"raw_lines": count} total_raw_lines += count logger.info(f"lookup dict finished ({len(lc_lookup)} entries)") logger.info(f"dataset contains {total_raw_lines} lines") # calculate lines to keep with (((raw_lines_in_lang / total_line_count) ** 0.3) / total_proprotions) * total lines # calculate proportions logger.info("calculating sampling factors") total_sampling_factors = 0 for lang in lc_lookup: # we sample lines proportional to this so smaller langs are upsampled and larger langs are downsampled sampling_factor = (lc_lookup[lang]['raw_lines'] / total_raw_lines) ** 0.3 lc_lookup[lang]["sampling_factor"] = sampling_factor total_sampling_factors += sampling_factor logger.info(f"sampling factor total is {total_sampling_factors}") logger.info(f"calculating number of lines to sample") total_lines_to_sample = 0 for lang in lc_lookup: lines_to_sample = round(lc_lookup[lang]["sampling_factor"]/total_sampling_factors * total_raw_lines) lc_lookup[lang]['lines_to_sample'] = lines_to_sample total_lines_to_sample += lines_to_sample prop_size_difference = abs((total_raw_lines - total_lines_to_sample)/total_lines_to_sample) assert prop_size_difference < 0.01 # sense check that sampled corpus is right size logger.info( f"total raw lines is {total_raw_lines}, total sampled lines is {total_lines_to_sample} ({prop_size_difference:.3%} difference)") # assume input file is sorted by group logger.info(f"sampling from {args.corpus_filepath}") with open(args.corpus_filepath, "r") as f: single_lang_line_store = [] langcode = "" while line := f.readline(): line = line.strip() _, nextlang, _ = line.split('\t') if langcode == nextlang or langcode == "": # same language single_lang_line_store.append(line) else: # language change, time to sample and write out raw_lines_in_lang = len(single_lang_line_store) assert raw_lines_in_lang == lc_lookup[langcode]["raw_lines"] # sanity check it's same data num_lines_to_keep = lc_lookup[langcode]["lines_to_sample"] logger.info(f"finished reading {langcode}: read in {raw_lines_in_lang}, writing {num_lines_to_keep}") if raw_lines_in_lang > num_lines_to_keep: sampled_lines_gc = (x for x in random.sample(single_lang_line_store, num_lines_to_keep)) else: # need to oversample, so now use sampling with replacement sampled_lines_gc = (x for x in random.choices(single_lang_line_store, k=num_lines_to_keep)) for out in sampled_lines_gc: sys.stdout.write(f"{out}\n") logger.info(f"finished writing {langcode} to stdout, now collecting lines for {nextlang}") single_lang_line_store = [line] langcode = nextlang logger.info("sampling complete!") if __name__ == "__main__": main()