ZeqiangLai's picture
Upload 16 files
6b88e7a verified
raw
history blame
2.02 kB
import logging
import os
from functools import wraps
import torch
def get_logger(name):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
logger = get_logger('hy3dgen.shapgen')
class synchronize_timer:
""" Synchronized timer to count the inference time of `nn.Module.forward`.
Supports both context manager and decorator usage.
Example as context manager:
```python
with synchronize_timer('name') as t:
run()
```
Example as decorator:
```python
@synchronize_timer('Export to trimesh')
def export_to_trimesh(mesh_output):
pass
```
"""
def __init__(self, name=None):
self.name = name
def __enter__(self):
"""Context manager entry: start timing."""
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
self.start.record()
return lambda: self.time
def __exit__(self, exc_type, exc_value, exc_tb):
"""Context manager exit: stop timing and log results."""
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
self.end.record()
torch.cuda.synchronize()
self.time = self.start.elapsed_time(self.end)
if self.name is not None:
logger.info(f'{self.name} takes {self.time} ms')
def __call__(self, func):
"""Decorator: wrap the function to time its execution."""
@wraps(func)
def wrapper(*args, **kwargs):
with self:
result = func(*args, **kwargs)
return result
return wrapper