# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Few common utils used in both/all flexi-trainers.""" import functools import itertools import numpy as np def mkrng(xid, wid, step): # Need to cap at 0, for example localruns use -1. rng_key = (max(xid, 0), max(wid, 0), max(step, 0)) return np.random.default_rng(rng_key) def mkprob(x): if x is None: return x return np.array(x) / np.sum(x) def choice(values, ratios, rng=None): rng = rng or np.random.default_rng() return rng.choice(values, p=mkprob(ratios)) def mkpredictfns(predict_fn, config, template="predict_{x}"): # If we have two flexi args a=[1,2], b=[10,20], then we create a # predict_fn for all possible combinations, named "predict_a=1_b=10" etc. all_combinations = [dict(comb) for comb in itertools.product( *[[(arg, val) for val in config[arg].v] for arg in config] )] return { template.format(x="_".join(f"{k}={v}" for k, v in kw.items())): functools.partial(predict_fn, **kw) for kw in all_combinations}