Spaces:
Runtime error
Runtime error
File size: 4,287 Bytes
be13417 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import json
import os
import logging
import inspect
import numpy as np
import torch
from lavis.common.dist_utils import main_process
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask
@registry.register_task("multimodal_classification")
class MultimodalClassificationTask(BaseTask):
def __init__(self,
max_len,
min_len,
length_penalty,
segments):
super().__init__()
self.max_len = max_len
self.min_len = min_len
self.length_penalty = length_penalty
self.segments = segments
@classmethod
def setup_task(cls, cfg):
run_cfg = cfg.run_cfg
max_len = run_cfg.get("max_len", 30)
min_len = run_cfg.get("min_len", 1)
length_penalty = run_cfg.get("length_penalty", -1.)
segments = run_cfg.get("segments", 1)
return cls(
max_len=max_len,
min_len=min_len,
length_penalty=length_penalty,
segments=segments
)
def valid_step(self, model, samples):
results = []
argspec = inspect.getargspec(model.predict)
# check if model allows for generation arguments in classification
if all([k in argspec.args for k in ['max_length', "min_length", "length_penalty"]]):
outputs = model.predict(samples,
max_length=self.max_len,
min_length=self.min_len,
length_penalty=self.length_penalty,
)
else:
outputs = model.predict(samples, n_segments=self.segments)
if outputs == None: # missing data
return {}
predictions = outputs["predictions"]
if isinstance(predictions[0], str):
targets = samples["label"]
indices = samples[self.inst_id_key]
for pred, tgt, index in zip(predictions, targets, indices):
results.append(
{
self.inst_id_key: index,
"prediction": pred,
"target": tgt,
}
)
else:
targets = outputs["targets"]
predictions = predictions.max(1)[1].cpu().numpy()
targets = targets.cpu().numpy()
indices = samples[self.inst_id_key]
for pred, tgt, index in zip(predictions, targets, indices):
if isinstance(index, torch.Tensor):
index = index.item()
results.append(
{
self.inst_id_key: index,
"prediction": pred.item(),
"target": tgt.item(),
}
)
return results
def after_evaluation(self, val_result, split_name, epoch, **kwargs):
eval_result_file = self.save_result(
result=val_result,
result_dir=registry.get_path("result_dir"),
filename="{}_epoch{}".format(split_name, epoch),
remove_duplicate=self.inst_id_key,
)
metrics = self._report_metrics(
eval_result_file=eval_result_file, split_name=split_name
)
return metrics
@main_process
def _report_metrics(self, eval_result_file, split_name):
results = json.load(open(eval_result_file))
predictions = np.array([res["prediction"] for res in results])
targets = np.array([res["target"] for res in results])
accuracy = (targets == predictions).sum() / targets.shape[0]
metrics = {"agg_metrics": accuracy, "acc": accuracy}
log_stats = {split_name: {k: v for k, v in metrics.items()}}
with open(
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(log_stats) + "\n")
logging.info(metrics)
return metrics
|