Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / utils / logging.py: 0%
67 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed-aware logging helpers for the trainer.
17Design (matches industry consensus across , Megatron-LM, DeepSpeed,
18HF Transformers, ):
20- ``logger.info / warning / error`` always fire on **every rank** so that
21 rank-local failures (OOM on rank 7, NCCL timeout on rank 3) are never
22 silently dropped.
23- Rank-0-only progress / status messages call **explicit helpers**:
24 ``logger.info_rank0(...)`` / ``logger.warning_rank0(...)``.
25- One-shot dedup helpers ``info_once`` / ``warning_once`` use ``lru_cache``
26 so the same message never spams every step.
28Design rejected: a global ``logging.Filter`` returning ``False`` on
29non-rank-0 records — that's what hyper had before, and it silently dropped
30rank-local errors. Don't reintroduce it.
32The functions here are also installed as bound methods on
33``logging.Logger`` so any module that already does
34``logger = logging.getLogger(__name__)`` gets the helpers for free.
35"""
36import functools
37import logging
38import os
39import sys
40from typing import Optional
42# ---------------------------------------------------------------------------
43# Rank lookup — uses platform when available, falls back to env var.
44# Going through platform avoids importing torch / mindspore here directly.
45# ---------------------------------------------------------------------------
47def _get_rank() -> int:
48 """Best-effort rank lookup.
50 Order:
51 1. ``hyper_parallel.platform.get_platform().get_rank()`` if the process
52 group is initialised.
53 2. ``RANK`` env var (set by ``torchrun`` / ``msrun``).
54 3. ``LOCAL_RANK`` env var.
55 4. ``0`` as the last resort.
56 """
57 try:
58 # pylint: disable=C0415
59 from hyper_parallel import get_platform
60 platform = get_platform()
61 return int(platform.get_rank())
62 except (ImportError, RuntimeError, ValueError, AttributeError):
63 pass
64 return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0")))
66# ---------------------------------------------------------------------------
67# Configuration entry points.
68# ---------------------------------------------------------------------------
70_DEFAULT_FORMAT = (
71 "[%(asctime)s][rank%(rank)s][%(levelname)s] %(name)s: %(message)s"
72)
73_DATE_FORMAT = "%H:%M:%S"
75class _RankInjector(logging.Filter):
76 """Inject the current rank into every record as ``record.rank``.
78 Unlike a "drop non-rank-0 record" filter, this **never returns False** —
79 every rank's record is preserved. The format string can use
80 ``%(rank)s`` to display it.
81 """
83 def filter(self, record: logging.LogRecord) -> bool:
84 if not hasattr(record, "rank"):
85 record.rank = _get_rank()
86 return True
88def init_logger(
89 level: int = logging.INFO,
90 fmt: str = _DEFAULT_FORMAT,
91 datefmt: str = _DATE_FORMAT,
92 stream=None,
93) -> None:
94 """Configure the root logger with rank-injecting formatter.
96 Idempotent: calling twice replaces handlers cleanly so re-importing or
97 testing doesn't double-print.
99 Args:
100 level: Root logger level (default ``INFO``).
101 fmt: Format string. Must include ``%(rank)s`` if you want the rank
102 displayed.
103 datefmt: ``%(asctime)s`` format.
104 stream: Output stream (default ``sys.stdout``).
105 """
106 root = logging.getLogger()
107 root.setLevel(level)
108 # Replace handlers so re-init in tests / notebooks doesn't double-log.
109 for handler in list(root.handlers):
110 root.removeHandler(handler)
111 handler = logging.StreamHandler(stream or sys.stdout)
112 handler.setLevel(level)
113 handler.setFormatter(logging.Formatter(fmt, datefmt=datefmt))
114 handler.addFilter(_RankInjector())
115 root.addHandler(handler)
117def get_logger(name: Optional[str] = None) -> logging.Logger:
118 """Return a logger with the rank-aware helpers attached.
120 Equivalent to ``logging.getLogger(name)`` plus binding
121 ``info_rank0`` / ``warning_rank0`` / ``info_once`` / ``warning_once``
122 on the ``logging.Logger`` class (idempotent).
123 """
124 _install_logger_methods()
125 return logging.getLogger(name)
127# ---------------------------------------------------------------------------
128# Standalone module-level functions.
129# ---------------------------------------------------------------------------
131def info_rank0(self, msg, *args, **kwargs) -> None:
132 """``logger.info`` that fires only on rank 0."""
133 if _get_rank() == 0:
134 kwargs.setdefault("stacklevel", 2)
135 self.info(msg, *args, **kwargs)
137def warning_rank0(self, msg, *args, **kwargs) -> None:
138 """``logger.warning`` that fires only on rank 0."""
139 if _get_rank() == 0:
140 kwargs.setdefault("stacklevel", 2)
141 self.warning(msg, *args, **kwargs)
143@functools.lru_cache(maxsize=None)
144def _info_once_cached(name: str, msg: str) -> None:
145 """LRU-cached one-shot info; key = (logger name, message)."""
146 if _get_rank() == 0:
147 logging.getLogger(name).info(msg)
149def info_once(self, msg, *args, **kwargs) -> None: # pylint: disable=W0613
150 """``logger.info`` that fires at most once across the whole run."""
151 if args:
152 msg = msg % args
153 _info_once_cached(self.name, str(msg))
155@functools.lru_cache(maxsize=None)
156def _warning_once_cached(name: str, msg: str) -> None:
157 if _get_rank() == 0:
158 logging.getLogger(name).warning(msg)
160def warning_once(self, msg, *args, **kwargs) -> None: # pylint: disable=W0613
161 """``logger.warning`` that fires at most once across the whole run."""
162 if args:
163 msg = msg % args
164 _warning_once_cached(self.name, str(msg))
166# ---------------------------------------------------------------------------
167# Method installation — bind helpers onto Logger so existing
168# ``logging.getLogger(__name__)`` consumers get them for free.
169# ---------------------------------------------------------------------------
171_INSTALLED = False
173def _install_logger_methods() -> None:
174 """Idempotently install rank-aware helpers on ``logging.Logger``."""
175 global _INSTALLED
176 if _INSTALLED:
177 return
178 logging.Logger.info_rank0 = info_rank0
179 logging.Logger.warning_rank0 = warning_rank0
180 logging.Logger.info_once = info_once
181 logging.Logger.warning_once = warning_once
182 _INSTALLED = True
184# Install at import time so any module that imports this package gets
185# the methods automatically.
186_install_logger_methods()