Spaces:
Running
Running
File size: 295 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 |
import numpy as np
def categorical_sample(prob_n, np_random: np.random.Generator):
"""Sample from categorical distribution where each row specifies class probabilities."""
prob_n = np.asarray(prob_n)
csprob_n = np.cumsum(prob_n)
return np.argmax(csprob_n > np_random.random())
|