File size: 2,120 Bytes
dd9600d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import logging
import json
import os
import numpy as np
import torch
import tqdm
import time
from transformers import T5EncoderModel, AutoTokenizer
import glob

def parse_args():
    parser = argparse.ArgumentParser(description="Encode the data captionings using t5 model")
    parser.add_argument('--save_dir', type=str, default=None, help="path to the manifest, phonemes, and encodec codes dirs")
    parser.add_argument('--start', type=int, default=0, help='start index for parallel processing')
    parser.add_argument('--end', type=int, default=10000000, help='end index for parallel processing')
    return parser.parse_args()

if __name__ == "__main__":

    formatter = (
        "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
    )
    logging.basicConfig(format=formatter, level=logging.INFO)
    args = parse_args()
    
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
    caption_encoder = T5EncoderModel.from_pretrained("google/flan-t5-large").cuda().eval()

    # get the path
    phn_save_root = os.path.join(args.save_dir, "t5")
    os.makedirs(phn_save_root, exist_ok=True)
    
    stime = time.time()

    logging.info(f"captioning...")
    json_paths = glob.glob(os.path.join(args.save_dir, 'jsons', '*.json'))
    for json_path in json_paths:
        with open(json_path, 'r', encoding="utf-8") as json_file:
            jsondata = json.load(json_file)
            
        jsondata = jsondata[args.start:args.end]

        for key in tqdm.tqdm(range(len(jsondata))):
            save_fn = os.path.join(phn_save_root, jsondata[key]['segment_id']+".npz")
            if not os.path.exists(save_fn):
                text = jsondata[key]['caption']
                
                with torch.no_grad():
                    batch_encoding = tokenizer(text, return_tensors="pt")
                    ori_tokens = batch_encoding["input_ids"].cuda()
                    outputs = caption_encoder(input_ids=ori_tokens).last_hidden_state
                
                phn = outputs.cpu().numpy()
                np.savez_compressed(save_fn, phn)