# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import bisect from typing import List, Sequence, Tuple def search_for_fit(numbers: Sequence[int], capacity: int) -> int: r""" Finds the index of largest number that fits into the knapsack with the given capacity. """ index = bisect.bisect(numbers, capacity) return -1 if index == 0 else (index - 1) def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: r""" An efficient greedy algorithm with binary search for the knapsack problem. """ numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] while numbers: current_knapsack = [] remaining_capacity = capacity while True: index = search_for_fit(numbers, remaining_capacity) if index == -1: break # no more numbers fit in this knapsack remaining_capacity -= numbers[index] # update the remaining capacity current_knapsack.append(numbers.pop(index)) # add the number to knapsack knapsacks.append(current_knapsack) return knapsacks def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: r""" Computes the real sequence length after truncation by the cutoff_len. """ if target_len * 2 < cutoff_len: # truncate source max_target_len = cutoff_len elif source_len * 2 < cutoff_len: # truncate target max_target_len = cutoff_len - source_len else: # truncate both max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) new_target_len = min(max_target_len, target_len) max_source_len = max(cutoff_len - new_target_len, 0) new_source_len = min(max_source_len, source_len) return new_source_len, new_target_len