Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,120 Bytes
dd9600d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import argparse
import logging
import json
import os
import numpy as np
import torch
import tqdm
import time
from transformers import T5EncoderModel, AutoTokenizer
import glob
def parse_args():
parser = argparse.ArgumentParser(description="Encode the data captionings using t5 model")
parser.add_argument('--save_dir', type=str, default=None, help="path to the manifest, phonemes, and encodec codes dirs")
parser.add_argument('--start', type=int, default=0, help='start index for parallel processing')
parser.add_argument('--end', type=int, default=10000000, help='end index for parallel processing')
return parser.parse_args()
if __name__ == "__main__":
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
caption_encoder = T5EncoderModel.from_pretrained("google/flan-t5-large").cuda().eval()
# get the path
phn_save_root = os.path.join(args.save_dir, "t5")
os.makedirs(phn_save_root, exist_ok=True)
stime = time.time()
logging.info(f"captioning...")
json_paths = glob.glob(os.path.join(args.save_dir, 'jsons', '*.json'))
for json_path in json_paths:
with open(json_path, 'r', encoding="utf-8") as json_file:
jsondata = json.load(json_file)
jsondata = jsondata[args.start:args.end]
for key in tqdm.tqdm(range(len(jsondata))):
save_fn = os.path.join(phn_save_root, jsondata[key]['segment_id']+".npz")
if not os.path.exists(save_fn):
text = jsondata[key]['caption']
with torch.no_grad():
batch_encoding = tokenizer(text, return_tensors="pt")
ori_tokens = batch_encoding["input_ids"].cuda()
outputs = caption_encoder(input_ids=ori_tokens).last_hidden_state
phn = outputs.cpu().numpy()
np.savez_compressed(save_fn, phn) |