KawshikManikantan's picture
upload_trial
98e2ea5
raw
history blame
474 Bytes
import torch
from typing import Tuple
from torch import Tensor
def sort_mentions(
ment_starts: Tensor, ment_ends: Tensor, return_sorted_indices=False
) -> Tuple:
sort_scores = ment_starts.to(torch.float64) + 1e-5 * ment_ends.to(torch.float64)
_, sorted_indices = torch.sort(sort_scores, 0)
output: Tuple = (ment_starts[sorted_indices], ment_ends[sorted_indices])
if return_sorted_indices:
output = output + (sorted_indices,)
return output