ssa-perin / utility /predict.py
larkkin's picture
Add application code and models, update README
8044721
raw
history blame
1.78 kB
import os
import json
import torch
import sys
from subprocess import run
from data.batch import Batch
sys.path.append("../evaluation")
from evaluate_single_dataset import evaluate
def predict(model, data, input_path, raw_input_path, args, logger, output_directory, device, mode="validation", epoch=None):
model.eval()
framework, language = args.framework, args.language
sentences = {}
with open(input_path, encoding="utf8") as f:
for line in f.readlines():
line = json.loads(line)
line["nodes"], line["edges"], line["tops"] = [], [], []
line["framework"], line["language"] = framework, language
sentences[line["id"]] = line
for i, batch in enumerate(data):
with torch.no_grad():
predictions = model(Batch.to(batch, device), inference=True)
for prediction in predictions:
for key, value in prediction.items():
sentences[prediction["id"]][key] = value
if epoch is not None:
output_path = f"{output_directory}/prediction_{mode}_{epoch}_{framework}_{language}.json"
else:
output_path = f"{output_directory}/prediction.json"
with open(output_path, "w", encoding="utf8") as f:
for sentence in sentences.values():
json.dump(sentence, f, ensure_ascii=False)
f.write("\n")
f.flush()
run(["./convert.sh", output_path] + (["--node_centric "] if args.graph_mode != "labeled-edge" else []))
if raw_input_path:
results = evaluate(raw_input_path, f"{output_path}_converted")
print(mode, results, flush=True)
if logger is not None:
logger.log_evaluation(results, mode, epoch)
return results["sentiment_tuple/f1"]