|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
__all__ = [ |
|
"torch_randint", |
|
"torch_random", |
|
"torch_shuffle", |
|
"torch_uniform", |
|
"torch_random_choices", |
|
] |
|
|
|
|
|
def torch_randint( |
|
low: int, high: int, generator: torch.Generator or None = None |
|
) -> int: |
|
"""uniform: [low, high)""" |
|
if low == high: |
|
return low |
|
else: |
|
assert low < high |
|
return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) |
|
|
|
|
|
def torch_random(generator: torch.Generator or None = None) -> float: |
|
"""uniform distribution on the interval [0, 1)""" |
|
return float(torch.rand(1, generator=generator)) |
|
|
|
|
|
def torch_shuffle( |
|
src_list: list[any], generator: torch.Generator or None = None |
|
) -> list[any]: |
|
rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() |
|
return [src_list[i] for i in rand_indexes] |
|
|
|
|
|
def torch_uniform( |
|
low: float, high: float, generator: torch.Generator or None = None |
|
) -> float: |
|
"""uniform distribution on the interval [low, high)""" |
|
rand_val = torch_random(generator) |
|
return (high - low) * rand_val + low |
|
|
|
|
|
def torch_random_choices( |
|
src_list: list[any], |
|
generator: torch.Generator or None = None, |
|
k=1, |
|
weight_list: list[float] or None = None, |
|
) -> any or list: |
|
if weight_list is None: |
|
rand_idx = torch.randint( |
|
low=0, high=len(src_list), generator=generator, size=(k,) |
|
) |
|
out_list = [src_list[i] for i in rand_idx] |
|
else: |
|
assert len(weight_list) == len(src_list) |
|
accumulate_weight_list = np.cumsum(weight_list) |
|
|
|
out_list = [] |
|
for _ in range(k): |
|
val = torch_uniform(0, accumulate_weight_list[-1], generator) |
|
active_id = 0 |
|
for i, weight_val in enumerate(accumulate_weight_list): |
|
active_id = i |
|
if weight_val > val: |
|
break |
|
out_list.append(src_list[active_id]) |
|
|
|
return out_list[0] if k == 1 else out_list |
|
|