Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from torch import nn | |
def populate_queues(queues, batch): | |
for key in batch: | |
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the | |
# queues have the keys they want). | |
if key not in queues: | |
continue | |
if len(queues[key]) != queues[key].maxlen: | |
# initialize by copying the first observation several times until the queue is full | |
while len(queues[key]) != queues[key].maxlen: | |
queues[key].append(batch[key]) | |
else: | |
# add latest observation to the queue | |
queues[key].append(batch[key]) | |
return queues | |
def get_device_from_parameters(module: nn.Module) -> torch.device: | |
"""Get a module's device by checking one of its parameters. | |
Note: assumes that all parameters have the same device | |
""" | |
return next(iter(module.parameters())).device | |
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: | |
"""Get a module's parameter dtype by checking one of its parameters. | |
Note: assumes that all parameters have the same dtype. | |
""" | |
return next(iter(module.parameters())).dtype | |
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: | |
""" | |
Calculates the output shape of a PyTorch module given an input shape. | |
Args: | |
module (nn.Module): a PyTorch module | |
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width) | |
Returns: | |
tuple: The output shape of the module. | |
""" | |
dummy_input = torch.zeros(size=input_shape) | |
with torch.inference_mode(): | |
output = module(dummy_input) | |
return tuple(output.shape) | |