Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / utils / logging.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed-aware logging helpers for the trainer. 

16 

17Design (matches industry consensus across , Megatron-LM, DeepSpeed, 

18HF Transformers, ): 

19 

20- ``logger.info / warning / error`` always fire on **every rank** so that 

21 rank-local failures (OOM on rank 7, NCCL timeout on rank 3) are never 

22 silently dropped. 

23- Rank-0-only progress / status messages call **explicit helpers**: 

24 ``logger.info_rank0(...)`` / ``logger.warning_rank0(...)``. 

25- One-shot dedup helpers ``info_once`` / ``warning_once`` use ``lru_cache`` 

26 so the same message never spams every step. 

27 

28Design rejected: a global ``logging.Filter`` returning ``False`` on 

29non-rank-0 records — that's what hyper had before, and it silently dropped 

30rank-local errors. Don't reintroduce it. 

31 

32The functions here are also installed as bound methods on 

33``logging.Logger`` so any module that already does 

34``logger = logging.getLogger(__name__)`` gets the helpers for free. 

35""" 

36import functools 

37import logging 

38import os 

39import sys 

40from typing import Optional 

41 

42# --------------------------------------------------------------------------- 

43# Rank lookup — uses platform when available, falls back to env var. 

44# Going through platform avoids importing torch / mindspore here directly. 

45# --------------------------------------------------------------------------- 

46 

47def _get_rank() -> int: 

48 """Best-effort rank lookup. 

49 

50 Order: 

51 1. ``hyper_parallel.platform.get_platform().get_rank()`` if the process 

52 group is initialised. 

53 2. ``RANK`` env var (set by ``torchrun`` / ``msrun``). 

54 3. ``LOCAL_RANK`` env var. 

55 4. ``0`` as the last resort. 

56 """ 

57 try: 

58 # pylint: disable=C0415 

59 from hyper_parallel import get_platform 

60 platform = get_platform() 

61 return int(platform.get_rank()) 

62 except (ImportError, RuntimeError, ValueError, AttributeError): 

63 pass 

64 return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))) 

65 

66# --------------------------------------------------------------------------- 

67# Configuration entry points. 

68# --------------------------------------------------------------------------- 

69 

70_DEFAULT_FORMAT = ( 

71 "[%(asctime)s][rank%(rank)s][%(levelname)s] %(name)s: %(message)s" 

72) 

73_DATE_FORMAT = "%H:%M:%S" 

74 

75class _RankInjector(logging.Filter): 

76 """Inject the current rank into every record as ``record.rank``. 

77 

78 Unlike a "drop non-rank-0 record" filter, this **never returns False** — 

79 every rank's record is preserved. The format string can use 

80 ``%(rank)s`` to display it. 

81 """ 

82 

83 def filter(self, record: logging.LogRecord) -> bool: 

84 if not hasattr(record, "rank"): 

85 record.rank = _get_rank() 

86 return True 

87 

88def init_logger( 

89 level: int = logging.INFO, 

90 fmt: str = _DEFAULT_FORMAT, 

91 datefmt: str = _DATE_FORMAT, 

92 stream=None, 

93) -> None: 

94 """Configure the root logger with rank-injecting formatter. 

95 

96 Idempotent: calling twice replaces handlers cleanly so re-importing or 

97 testing doesn't double-print. 

98 

99 Args: 

100 level: Root logger level (default ``INFO``). 

101 fmt: Format string. Must include ``%(rank)s`` if you want the rank 

102 displayed. 

103 datefmt: ``%(asctime)s`` format. 

104 stream: Output stream (default ``sys.stdout``). 

105 """ 

106 root = logging.getLogger() 

107 root.setLevel(level) 

108 # Replace handlers so re-init in tests / notebooks doesn't double-log. 

109 for handler in list(root.handlers): 

110 root.removeHandler(handler) 

111 handler = logging.StreamHandler(stream or sys.stdout) 

112 handler.setLevel(level) 

113 handler.setFormatter(logging.Formatter(fmt, datefmt=datefmt)) 

114 handler.addFilter(_RankInjector()) 

115 root.addHandler(handler) 

116 

117def get_logger(name: Optional[str] = None) -> logging.Logger: 

118 """Return a logger with the rank-aware helpers attached. 

119 

120 Equivalent to ``logging.getLogger(name)`` plus binding 

121 ``info_rank0`` / ``warning_rank0`` / ``info_once`` / ``warning_once`` 

122 on the ``logging.Logger`` class (idempotent). 

123 """ 

124 _install_logger_methods() 

125 return logging.getLogger(name) 

126 

127# --------------------------------------------------------------------------- 

128# Standalone module-level functions. 

129# --------------------------------------------------------------------------- 

130 

131def info_rank0(self, msg, *args, **kwargs) -> None: 

132 """``logger.info`` that fires only on rank 0.""" 

133 if _get_rank() == 0: 

134 kwargs.setdefault("stacklevel", 2) 

135 self.info(msg, *args, **kwargs) 

136 

137def warning_rank0(self, msg, *args, **kwargs) -> None: 

138 """``logger.warning`` that fires only on rank 0.""" 

139 if _get_rank() == 0: 

140 kwargs.setdefault("stacklevel", 2) 

141 self.warning(msg, *args, **kwargs) 

142 

143@functools.lru_cache(maxsize=None) 

144def _info_once_cached(name: str, msg: str) -> None: 

145 """LRU-cached one-shot info; key = (logger name, message).""" 

146 if _get_rank() == 0: 

147 logging.getLogger(name).info(msg) 

148 

149def info_once(self, msg, *args, **kwargs) -> None: # pylint: disable=W0613 

150 """``logger.info`` that fires at most once across the whole run.""" 

151 if args: 

152 msg = msg % args 

153 _info_once_cached(self.name, str(msg)) 

154 

155@functools.lru_cache(maxsize=None) 

156def _warning_once_cached(name: str, msg: str) -> None: 

157 if _get_rank() == 0: 

158 logging.getLogger(name).warning(msg) 

159 

160def warning_once(self, msg, *args, **kwargs) -> None: # pylint: disable=W0613 

161 """``logger.warning`` that fires at most once across the whole run.""" 

162 if args: 

163 msg = msg % args 

164 _warning_once_cached(self.name, str(msg)) 

165 

166# --------------------------------------------------------------------------- 

167# Method installation — bind helpers onto Logger so existing 

168# ``logging.getLogger(__name__)`` consumers get them for free. 

169# --------------------------------------------------------------------------- 

170 

171_INSTALLED = False 

172 

173def _install_logger_methods() -> None: 

174 """Idempotently install rank-aware helpers on ``logging.Logger``.""" 

175 global _INSTALLED 

176 if _INSTALLED: 

177 return 

178 logging.Logger.info_rank0 = info_rank0 

179 logging.Logger.warning_rank0 = warning_rank0 

180 logging.Logger.info_once = info_once 

181 logging.Logger.warning_once = warning_once 

182 _INSTALLED = True 

183 

184# Install at import time so any module that imports this package gets 

185# the methods automatically. 

186_install_logger_methods()