|
''' |
|
Infers log-likelihoods from UniRep models. |
|
''' |
|
|
|
import argparse |
|
import os |
|
import pathlib |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import tensorflow as tf |
|
|
|
from unirep import babbler1900 |
|
from utils import load_and_filter_seqs, save, format_batch_seqs, nonpad_len |
|
|
|
|
|
def run_inference(seqs, model_weight_path, output_dir, |
|
batch_size=64, save_hidden=False): |
|
if len(seqs) < batch_size: |
|
batch_size = len(seqs) |
|
babbler_class = babbler1900 |
|
|
|
b = babbler_class(batch_size=batch_size, model_path=model_weight_path) |
|
|
|
final_hidden_op, avg_hidden_op, x_ph, batch_size_ph, seq_len_ph, init_state_ph = b.get_rep_ops() |
|
logits_op, loss_op, x_ph, y_ph, batch_size_ph, init_state_ph = b.get_babbler_ops() |
|
batch_loss_op = b.batch_losses |
|
|
|
final_hidden_vals = [] |
|
avg_hidden_vals = [] |
|
loss_vals = [] |
|
with tf.compat.v1.Session() as sess: |
|
sess.run(tf.compat.v1.global_variables_initializer()) |
|
n_batches = int(len(seqs) / batch_size) |
|
leftover = len(seqs) % batch_size |
|
n_batches += int(bool(leftover)) |
|
for i in range(n_batches): |
|
print('----Running inference for batch # %d------' % i) |
|
if i == n_batches - 1: |
|
batch_seqs = seqs[-batch_size:] |
|
else: |
|
batch_seqs = seqs[i*batch_size:(i+1)*batch_size] |
|
batch_seqs = [seq.replace('-', 'X') for seq in batch_seqs] |
|
batch = format_batch_seqs(batch_seqs) |
|
length = nonpad_len(batch) |
|
|
|
avg_hidden_, loss_ = sess.run( |
|
[avg_hidden_op, batch_loss_op], |
|
feed_dict={ |
|
|
|
x_ph: batch[:, :-1], |
|
y_ph: batch[:, 1:], |
|
batch_size_ph: batch.shape[0], |
|
seq_len_ph: length, |
|
init_state_ph:b._zero_state |
|
}) |
|
if i == n_batches - 1: |
|
loss_vals.append(loss_[-leftover:]) |
|
if save_hidden: |
|
avg_hidden_vals.append(avg_hidden_[-leftover:]) |
|
else: |
|
loss_vals.append(loss_) |
|
if save_hidden: |
|
avg_hidden_vals.append(avg_hidden_) |
|
|
|
loss_vals = np.concatenate(loss_vals, axis=0) |
|
loss_filename = os.path.join( |
|
args.output_dir, f'loss.npy') |
|
save(loss_filename, loss_vals) |
|
|
|
if save_hidden: |
|
avg_hidden_vals = np.concatenate(avg_hidden_vals, axis=0) |
|
avg_hidden_filename = os.path.join( |
|
args.output_dir, f'avg_hidden.npy') |
|
save(avg_hidden_filename, avg_hidden_vals) |
|
|
|
print('Ran inference on %d sequences. Saved results to %s.' % |
|
(len(seqs), args.output_dir)) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('model_path', type=str) |
|
parser.add_argument('data_path', type=str) |
|
parser.add_argument('output_dir', type=str) |
|
parser.add_argument('--batch_size', type=int, default=64) |
|
parser.add_argument('--save_hidden', dest='save_hidden', action='store_true') |
|
args = parser.parse_args() |
|
|
|
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
seqs = load_and_filter_seqs(args.data_path) |
|
np.savetxt(os.path.join(args.output_dir, 'seqs.npy'), seqs, '%s') |
|
|
|
run_inference(seqs, args.model_path, |
|
args.output_dir, batch_size=args.batch_size, |
|
save_hidden=args.save_hidden) |
|
|