Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numpy as np | |
import math | |
def sample_to_batch(mod_dict, device, domains): | |
mod_dict = { | |
modality: {k: v.unsqueeze(0).to(device, non_blocking=True) for k, v in d.items()} | |
for modality, d in mod_dict.items() if modality in domains | |
} | |
return mod_dict | |
def unbatch(tensor): | |
return tensor.detach().squeeze(0).cpu() | |
def batch_to_sample(mod_dict, domains): | |
mod_dict = { | |
modality: {k: unbatch(v) for k, v in d.items()} | |
for modality, d in mod_dict.items() if modality in domains | |
} | |
return mod_dict | |
def batch_to_device(mod_dict, device, domains): | |
mod_dict = { | |
modality: {k: v.to(device, non_blocking=True) for k, v in d.items()} | |
for modality, d in mod_dict.items() if modality in domains | |
} | |
return mod_dict | |
def cosine_schedule(num_steps, total_tokens): | |
iters = np.arange(num_steps) | |
base_value = 1 | |
final_value = 0 | |
schedule = np.array( | |
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) | |
schedule_tokens = [round(total_tokens * i) for i in (schedule[:-1] - schedule[1:])] | |
schedule_tokens.append(total_tokens - sum(schedule_tokens)) | |
return np.array(schedule_tokens) | |
def linear_schedule(num_steps, total_tokens): | |
schedule = np.linspace(0, total_tokens, num_steps + 1, dtype=int) | |
schedule_tokens = np.diff(schedule)[::-1] | |
schedule_tokens.sort() # Sorts the array in ascending order. | |
schedule_tokens = schedule_tokens[::-1] # Reverses the array to descending order. | |
return np.trim_zeros(schedule_tokens, 'b') # Trims trailing zeros. | |
def continue_schedule(schedule, num_current_tokens): | |
schedule_cumsum = np.cumsum(schedule) | |
keep_mask = schedule_cumsum > num_current_tokens | |
diff = schedule_cumsum[keep_mask][0] - num_current_tokens | |
new_schedule = schedule[keep_mask] | |
new_schedule[0] = diff | |
return new_schedule | |
def decreasing_temp_schedule(max, min, token_schedule): | |
schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule) | |
temp_schedule = np.array([min + (max - min) * (1 - s) for s in schedule_cumsum]) | |
return temp_schedule | |
def onex_temp_schedule(max_t, min_t, token_schedule, power=0.5, min_linspace=1, max_linspace=100): | |
"""Abitrary temperature schedule for one over x""" | |
x = np.linspace(min_linspace, max_linspace, num=sum(token_schedule)) | |
y = 1/(x**power) | |
y = y - min(y) | |
y = y / max(y) | |
unscaled_schedule = y | |
schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule) | |
unscaled_schedule = [(1 - cs) * us for us, cs in zip(unscaled_schedule, schedule_cumsum)] | |
temp_schedule = np.array([min_t + (max_t - min_t) * s for s in unscaled_schedule]).clip(min=1e-9) | |
return temp_schedule | |
def linear_temp_schedule(temp, token_schedule): | |
""" Temperature that decays the temperature inversely proportional to the token schedule. """ | |
return np.concatenate([np.array([temp * 1.0]), (temp * (token_schedule.sum() - token_schedule.cumsum()) / token_schedule.sum())[:-1]]).clip(min=1e-9) | |