Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import threading | |
import time | |
from contextlib import ContextDecorator | |
class TimeBenchmark(ContextDecorator): | |
""" | |
Measures execution time using a context manager or decorator. | |
This class supports both context manager and decorator usage, and is thread-safe for multithreaded | |
environments. | |
Args: | |
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults | |
to False. | |
Examples: | |
Using as a context manager: | |
>>> benchmark = TimeBenchmark() | |
>>> with benchmark: | |
... time.sleep(1) | |
>>> print(f"Block took {benchmark.result:.4f} seconds") | |
Block took approximately 1.0000 seconds | |
Using with multithreading: | |
```python | |
import threading | |
benchmark = TimeBenchmark() | |
def context_manager_example(): | |
with benchmark: | |
time.sleep(0.01) | |
print(f"Block took {benchmark.result_ms:.2f} milliseconds") | |
threads = [] | |
for _ in range(3): | |
t1 = threading.Thread(target=context_manager_example) | |
threads.append(t1) | |
for t in threads: | |
t.start() | |
for t in threads: | |
t.join() | |
``` | |
Expected output: | |
Block took approximately 10.00 milliseconds | |
Block took approximately 10.00 milliseconds | |
Block took approximately 10.00 milliseconds | |
""" | |
def __init__(self, print=False): | |
self.local = threading.local() | |
self.print_time = print | |
def __enter__(self): | |
self.local.start_time = time.perf_counter() | |
return self | |
def __exit__(self, *exc): | |
self.local.end_time = time.perf_counter() | |
self.local.elapsed_time = self.local.end_time - self.local.start_time | |
if self.print_time: | |
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds") | |
return False | |
def result(self): | |
return getattr(self.local, "elapsed_time", None) | |
def result_ms(self): | |
return self.result * 1e3 | |