Spaces:
Configuration error
Configuration error
| import logging | |
| logger_initialized = {} | |
| def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): | |
| logger = logging.getLogger(name) | |
| if name in logger_initialized: | |
| return logger | |
| for logger_name in logger_initialized: | |
| if name.startswith(logger_name): | |
| return logger | |
| stream_handler = logging.StreamHandler() | |
| handlers = [stream_handler] | |
| try: | |
| import torch.distributed as dist | |
| if dist.is_available() and dist.is_initialized(): | |
| rank = dist.get_rank() | |
| else: | |
| rank = 0 | |
| except ImportError: | |
| rank = 0 | |
| if rank == 0 and log_file is not None: | |
| file_handler = logging.FileHandler(log_file, file_mode) | |
| handlers.append(file_handler) | |
| formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| for handler in handlers: | |
| handler.setFormatter(formatter) | |
| handler.setLevel(log_level) | |
| logger.addHandler(handler) | |
| if rank == 0: | |
| logger.setLevel(log_level) | |
| else: | |
| logger.setLevel(logging.ERROR) | |
| logger_initialized[name] = True | |
| return logger | |