Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / callbacks / base.py: 0%
319 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Callback base class and built-in callbacks.
17dispatched explicitly in ``on_step_end`` etc. Engineer sees all callbacks and
18order at a glance.
20``checkpoint_callback.py`` (242 lines) + ``trace_callback.py`` (231 lines).
21"""
22import copy
23import gc
24import json
25import logging
26import math
27import os
28import threading
29import time
30from typing import TYPE_CHECKING, Optional
32import torch
34from hyper_parallel import get_platform
35from hyper_parallel.core.distributed_checkpoint import load as dcp_load, save as dcp_save
36from hyper_parallel.core.distributed_checkpoint.offline_transform import (
37 save_state_dict_as_huggingface_format,
38)
39from hyper_parallel.core.fully_shard.api import get_model_state_dict
41platform = get_platform()
43if TYPE_CHECKING:
44 from hyper_parallel.trainer.base import BaseTrainer, TrainerState
46logger = logging.getLogger(__name__)
48class Callback:
49 """Base class for all trainer callbacks.
51 Each callback holds a reference to the trainer for accessing model,
52 optimizer, state, and config. Subclass and override the hooks you need.
54 Args:
55 trainer: The BaseTrainer instance.
56 """
58 def __init__(self, trainer: "BaseTrainer") -> None:
59 self.trainer = trainer
61 # ------------------------------------------------------------------
62 # Lifecycle hooks
63 # ------------------------------------------------------------------
65 def on_init_end(self, state: "TrainerState", **kwargs) -> None:
66 """Called once at the end of ``BaseTrainer.__init__`` / subclass init.
68 At this point every ``_build_*`` has run — model is parallelised,
69 optimizer/scheduler/dataloader are built, callbacks are constructed.
70 Use this for one-shot setup that must see the FINAL trainer state
71 (e.g. logging the parameter count, opening a TensorBoard writer
72 keyed by run_id, validating user config against the built model).
73 """
75 def on_train_begin(self, state: "TrainerState", **kwargs) -> None:
76 """Called at the start of ``train()`` (before any optimizer.step).
78 ``CheckpointCallback`` runs resume here, so when this hook fires
79 ``state.global_step`` may already be > 0 if a checkpoint was loaded.
80 """
82 def on_train_end(self, state: "TrainerState", **kwargs) -> None:
83 """Called at the end of training (before ``destroy_process_group``).
85 Final checkpoints, profiler stops, W&B finish, etc. happen here.
86 """
88 def on_epoch_begin(self, state: "TrainerState", **kwargs) -> None:
89 """Called at the start of each epoch."""
91 def on_epoch_end(self, state: "TrainerState", **kwargs) -> None:
92 """Called at the end of each epoch."""
94 def on_step_begin(self, state: "TrainerState", **kwargs) -> None:
95 """Called at the start of each training step (before fwd of mb 0)."""
97 def on_step_end(self, state: "TrainerState", *, loss: float = None,
98 grad_norm: float = None, **kwargs) -> None:
99 """Called at the end of each training step (after optimizer.step)."""
101 def on_substep_end(self, state: "TrainerState", **kwargs) -> None:
102 """Called after each micro-batch fwd+bwd (gradient accumulation)."""
104 def on_pre_optimizer_step(self, state: "TrainerState", *,
105 grad_norm: float = None, **kwargs) -> None:
106 """Called after grad clip, before ``optimizer.step``.
108 ``grad_norm`` here is the post-clip scalar produced by hyper's
109 DTensor-aware clipper — use it to detect NaN/Inf or to log the
110 effective clip ratio.
111 """
113 def on_log(self, state: "TrainerState", *, metrics: dict, **kwargs) -> None:
114 """Called when ``LoggingCallback`` emits a structured metrics record.
116 Reuse this hook in TensorBoard / W&B / external metric sinks so
117 every logging backend sees the SAME record. Avoids three callbacks
118 each computing throughput / lr independently.
120 Args:
121 metrics: Dict containing at minimum ``step``, ``loss``,
122 ``grad_norm``, ``lr``, ``step_time``; throughput fields
123 (``tokens_per_sec``, ``tflops``, ``mfu``) are present iff
124 ``logging.report_throughput`` is on.
125 """
127 def on_save(self, state: "TrainerState", *, checkpoint_dir: str,
128 **kwargs) -> None:
129 """Called immediately after ``CheckpointCallback`` finishes a save.
131 Use to upload to remote storage, register the ckpt with an
132 experiment tracker, or trigger downstream eval jobs. ``checkpoint_dir``
133 is the on-disk path containing model shards + optimizer/scheduler/RNG/
134 dataloader/extra_state.
135 """
137 def on_load(self, state: "TrainerState", *, checkpoint_dir: str,
138 **kwargs) -> None:
139 """Called immediately after ``CheckpointCallback`` finishes a resume.
141 Use to verify the resumed step matches expectations, log the
142 restore event, or seed downstream callbacks with the resumed state.
143 """
145 def on_evaluate(self, state: "TrainerState", *, metrics: dict = None,
146 **kwargs) -> None:
147 """Called when an evaluation pass completes.
149 Currently triggered as a stub from ``EvalCallback``; once a real
150 eval loop lands the callback will pass back the eval ``metrics``
151 dict for sinks (TensorBoard / W&B) to log.
152 """
154class LoggingCallback(Callback):
155 """Log training metrics: loss, grad_norm, lr, throughput.
157 """
159 def __init__(self, trainer: "BaseTrainer") -> None:
160 super().__init__(trainer)
161 log_cfg = getattr(trainer.args, 'logging', None)
162 self.log_steps = getattr(log_cfg, 'log_steps', 10) if log_cfg else 10
163 self.report_global_loss = (
164 getattr(log_cfg, 'report_global_loss', False) if log_cfg else False
165 )
166 self.report_throughput = (
167 getattr(log_cfg, 'report_throughput', True) if log_cfg else True
168 )
169 self.model_flops_per_token = (
170 getattr(log_cfg, 'model_flops_per_token', None) if log_cfg else None
171 )
172 self.peak_tflops = (
173 getattr(log_cfg, 'peak_tflops', None) if log_cfg else None
174 )
175 # Estimate per-step tokens as upper bound (batch × seq_len). Real
176 # token count is available per step via ``last_global_tokens`` that
177 # ``BaseTrainer.train_step`` stashes onto the trainer.
178 gbs = getattr(trainer.args.train, 'global_batch_size', 1)
179 seq_len = getattr(trainer.args.data, 'max_seq_len', 1)
180 self._tokens_per_step_est = int(gbs) * int(seq_len)
181 self._step_start_time = 0.0
183 def on_step_begin(self, state: "TrainerState", **kwargs) -> None:
184 self._step_start_time = time.time()
186 def on_step_end(self, state: "TrainerState", *, loss: float = None,
187 grad_norm: float = None, **kwargs) -> None:
188 if state.global_step % self.log_steps != 0:
189 return
191 elapsed = max(time.time() - self._step_start_time, 1e-9)
192 lr = 0.0
193 if self.trainer.lr_scheduler is not None:
194 lr = self.trainer.lr_scheduler.get_last_lr()[0]
196 metrics = {
197 "step": state.global_step,
198 # 8-decimal precision keeps fp32 sub-bf16 differences visible
199 # in the log for sanity comparisons across runs.
200 "loss": f"{loss:.8f}" if loss is not None else "N/A",
201 "grad_norm": (
202 f"{grad_norm:.8f}" if grad_norm is not None else "N/A"
203 ),
204 "lr": f"{lr:.2e}",
205 "step_time": f"{elapsed:.2f}s",
206 }
208 tokens_per_sec = None
209 if self.report_throughput:
210 # Prefer real per-step token count stashed by train_step.
211 tokens = getattr(self.trainer, '_last_global_tokens',
212 self._tokens_per_step_est)
213 tokens_per_sec = tokens / elapsed
214 metrics["tokens_per_sec"] = f"{tokens_per_sec:,.0f}"
216 if self.model_flops_per_token and self.peak_tflops:
217 # Observed TFLOPS = tokens/sec × flops/token / 1e12.
218 # MFU = observed / (peak × world_size).
219 world = max(platform.get_world_size(), 1)
220 observed_tflops = (
221 tokens_per_sec * self.model_flops_per_token / 1e12
222 )
223 mfu = observed_tflops / (self.peak_tflops * world)
224 metrics["tflops"] = f"{observed_tflops:.1f}"
225 metrics["mfu"] = f"{mfu * 100:.1f}%"
227 logger.info_rank0(" | ".join(f"{k}={v}" for k, v in metrics.items()))
229 record = {
230 "step": state.global_step,
231 "loss": loss,
232 "grad_norm": grad_norm,
233 "lr": lr,
234 "step_time": elapsed,
235 "tokens_per_sec": tokens_per_sec,
236 }
237 state.log_history.append(record)
239 # Fan-out to other log-event listeners (TB / W&B / sinks).
240 dispatch = getattr(self.trainer, "dispatch_log_event", None)
241 if dispatch is not None:
242 dispatch(record)
244class CheckpointCallback(Callback):
245 """Save distributed checkpoints and handle resume.
247 Uses hyper's own DCP ``save`` / ``load`` APIs.
248 """
250 def __init__(self, trainer: "BaseTrainer") -> None:
251 super().__init__(trainer)
252 ckpt_cfg = getattr(trainer.args, 'checkpoint', None)
253 self.save_steps = getattr(ckpt_cfg, 'save_steps', 0) if ckpt_cfg else 0
254 self.output_dir = (
255 getattr(ckpt_cfg, 'output_dir', 'outputs') if ckpt_cfg else 'outputs'
256 )
257 self.load_path = (
258 getattr(ckpt_cfg, 'load_path', None) if ckpt_cfg else None
259 )
260 self.save_async = (
261 getattr(ckpt_cfg, 'save_async', False) if ckpt_cfg else False
262 )
263 self._last_saved_step = -1
264 self._save_thread = None # async save worker
266 def on_train_begin(self, state: "TrainerState", **kwargs) -> None:
267 """Resume from checkpoint: model + optimizer + lr_scheduler + step + RNG.
269 RFC DoD: "Save → resume → 续训 loss 一致(含 dataloader + RNG 恢复)"
270 """
271 if not self.load_path:
272 return
273 try:
274 # pylint: disable=C0415
275 # Non-model artifacts (optimizer/scheduler/RNG) are plain dicts —
276 # use torch.save/load, matching the save side.
278 if not os.path.isdir(self.load_path):
279 logger.warning("Checkpoint path not found: %s", self.load_path)
280 return
282 # 1. Restore model via hyper DCP
283 model_sd = self.trainer.model.state_dict()
284 dcp_load(model_sd, checkpoint_id=self.load_path, use_collectives=False)
285 self.trainer.model.load_state_dict(model_sd)
286 logger.info("Model restored from %s", self.load_path)
288 # 2. Restore extra state (step, epoch)
289 extra_path = os.path.join(self.load_path, "extra_state.json")
290 if os.path.isfile(extra_path):
291 with open(extra_path, encoding="utf-8") as f:
292 extra = json.load(f)
293 state.global_step = extra.get("global_step", 0)
294 state.epoch = extra.get("epoch", 0)
295 logger.info("Resumed at step=%d, epoch=%d",
296 state.global_step, state.epoch)
298 # 3. Restore optimizer
299 optim_path = os.path.join(self.load_path, f"optimizer_rank{platform.get_rank()}.pt")
300 if os.path.isfile(optim_path) and self.trainer.optimizer:
301 optim_sd = torch.load(optim_path, map_location="cpu", weights_only=True)
302 self.trainer.optimizer.load_state_dict(optim_sd)
303 logger.info("Optimizer restored")
305 # 4. Restore LR scheduler
306 sched_path = os.path.join(self.load_path, "scheduler.pt")
307 if os.path.isfile(sched_path) and self.trainer.lr_scheduler:
308 sched_sd = torch.load(sched_path, map_location="cpu", weights_only=True)
309 self.trainer.lr_scheduler.load_state_dict(sched_sd)
310 logger.info("LR scheduler restored")
312 # 5. Restore RNG state
313 rng_path = os.path.join(self.load_path, f"rng_rank{platform.get_rank()}.pt")
314 if os.path.isfile(rng_path):
315 rng_state = torch.load(rng_path, map_location="cpu", weights_only=True)
316 platform.set_rng_state(rng_state)
317 logger.info("RNG state restored")
319 # 6. Restore dataloader position (StatefulDataLoader)
320 dl_path = os.path.join(self.load_path, f"dataloader_rank{platform.get_rank()}.pt")
321 if os.path.isfile(dl_path) and hasattr(self.trainer, 'train_dataloader'):
322 dl_state = torch.load(dl_path, map_location="cpu", weights_only=False)
323 self.trainer.train_dataloader.load_state_dict(dl_state)
324 logger.info("Dataloader state restored")
326 # Fan-out the load event so other callbacks (TensorBoard /
327 # W&B / external trackers) can record the resume.
328 dispatch = getattr(self.trainer, "dispatch_load_event", None)
329 if dispatch is not None:
330 dispatch(self.load_path)
332 except (OSError, RuntimeError, ValueError) as exc:
333 logger.warning("Failed to load checkpoint from %s: %s", self.load_path, exc)
335 def on_step_end(self, state: "TrainerState", *, loss: float = None,
336 grad_norm: float = None, **kwargs) -> None:
337 if self.save_steps <= 0:
338 return
339 if state.global_step % self.save_steps != 0:
340 return
341 if state.global_step == self._last_saved_step:
342 return
343 self._dispatch_save(state)
345 def on_train_end(self, state: "TrainerState", **kwargs) -> None:
346 """Save final checkpoint (synchronously, to guarantee completion)."""
347 # Wait for any outstanding async save first so the two don't race on
348 # the same directory / state-dict iterator.
349 self._join_pending()
350 if self.save_steps > 0 and state.global_step != self._last_saved_step:
351 # Final save always sync — the process is about to exit.
352 self._save(state)
354 # --- async plumbing -------------------------------------------------
355 def _dispatch_save(self, state: "TrainerState") -> None:
356 """Route to sync or async save based on ``save_async`` flag."""
357 if not self.save_async:
358 self._save(state)
359 return
360 # Wait for previous save to finish before starting a new one; saving
361 # twice concurrently would double RAM and race the filesystem.
362 self._join_pending()
363 # pylint: disable=C0415
364 # Snapshot state fields so the worker doesn't see later mutations.
365 snap_step = state.global_step
366 snap_epoch = state.epoch
367 state_snapshot = copy.copy(state)
368 state_snapshot.global_step = snap_step
369 state_snapshot.epoch = snap_epoch
370 self._save_thread = threading.Thread(
371 target=self._save,
372 args=(state_snapshot,),
373 name=f"ckpt-save-step{snap_step}",
374 daemon=True,
375 )
376 self._save_thread.start()
377 logger.info_rank0(
378 "Checkpoint save for step %d dispatched async (thread=%s)",
379 snap_step, self._save_thread.name,
380 )
382 def _join_pending(self) -> None:
383 """Block until any running async save finishes."""
384 t = self._save_thread
385 if t is not None and t.is_alive():
386 logger.info_rank0(
387 "Waiting for prior async ckpt save (%s)...", t.name,
388 )
389 t.join()
390 self._save_thread = None
392 def _save(self, state: "TrainerState") -> None:
393 """Save complete training state: model + optimizer + scheduler + step + RNG.
395 RFC DoD: "Save → resume → 续训 loss 一致(含 dataloader + RNG 恢复)"
396 """
397 # Optimizer/scheduler/RNG state dicts are plain Python dicts, not
398 # nn.Module — platform.save_checkpoint expects Module (safetensors).
399 # Use torch.save/load for these non-model artifacts.
400 save_dir = os.path.join(self.output_dir, f"step_{state.global_step}")
401 os.makedirs(save_dir, exist_ok=True)
402 rank = platform.get_rank()
404 try:
405 # 1. Model — via hyper DCP (each rank saves its own shards)
406 model_sd = self.trainer.model.state_dict()
407 dcp_save(model_sd, checkpoint_id=save_dir, use_collectives=False)
409 # 2. Optimizer — per-rank
410 if self.trainer.optimizer:
411 optim_path = os.path.join(save_dir, f"optimizer_rank{rank}.pt")
412 torch.save(self.trainer.optimizer.state_dict(), optim_path)
414 # 3. LR scheduler
415 if self.trainer.lr_scheduler and rank == 0:
416 sched_path = os.path.join(save_dir, "scheduler.pt")
417 torch.save(self.trainer.lr_scheduler.state_dict(), sched_path)
419 # 4. Extra state: global_step, epoch
420 if rank == 0:
421 extra = {
422 "global_step": state.global_step,
423 "epoch": state.epoch,
424 }
425 extra_path = os.path.join(save_dir, "extra_state.json")
426 with open(extra_path, "w", encoding="utf-8") as f:
427 json.dump(extra, f)
429 # 5. RNG state — per-rank via platform API
430 rng_state = platform.get_rng_state()
431 rng_path = os.path.join(save_dir, f"rng_rank{rank}.pt")
432 torch.save(rng_state, rng_path)
434 # 6. Dataloader position — per-rank (StatefulDataLoader)
435 if hasattr(self.trainer, 'train_dataloader') and hasattr(
436 self.trainer.train_dataloader, 'state_dict'
437 ):
438 dl_path = os.path.join(save_dir, f"dataloader_rank{rank}.pt")
439 torch.save(self.trainer.train_dataloader.state_dict(), dl_path)
441 self._last_saved_step = state.global_step
442 logger.info_rank0("Checkpoint saved to %s", save_dir)
444 # Fan-out the save event so other callbacks (W&B artifact
445 # upload, remote-storage sync, downstream eval triggers) can
446 # observe the new checkpoint without coupling to ckpt internals.
447 dispatch = getattr(self.trainer, "dispatch_save_event", None)
448 if dispatch is not None:
449 dispatch(save_dir)
451 except (OSError, RuntimeError, ValueError) as exc:
452 logger.warning("Failed to save checkpoint: %s", exc)
454 # HF format export is handled by SafetensorsExportCallback (separate concern).
456class SafetensorsExportCallback(Callback):
457 """Export model weights in HuggingFace safetensor format.
459 Separated from CheckpointCallback per RFC Section 5.2.
460 Uses ``get_model_state_dict`` with ``full_state_dict=True`` to gather
461 all FSDP shards into a full state dict before saving.
463 """
465 def __init__(self, trainer: "BaseTrainer") -> None:
466 super().__init__(trainer)
467 ckpt_cfg = getattr(trainer.args, 'checkpoint', None)
468 self.enabled = getattr(ckpt_cfg, 'save_hf_weights', False) if ckpt_cfg else False
469 self.save_steps = getattr(ckpt_cfg, 'save_steps', 0) if ckpt_cfg else 0
470 self.output_dir = getattr(ckpt_cfg, 'output_dir', 'outputs') if ckpt_cfg else 'outputs'
471 self._last_saved_step = -1
473 def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None,
474 grad_norm: Optional[float] = None, **kwargs) -> None:
475 if not self.enabled or self.save_steps <= 0:
476 return
477 if state.global_step % self.save_steps != 0:
478 return
479 if state.global_step == self._last_saved_step:
480 return
481 self._export(state)
483 def on_train_end(self, state: "TrainerState", **kwargs) -> None:
484 if self.enabled and self.save_steps > 0 and state.global_step != self._last_saved_step:
485 self._export(state)
487 def _export(self, state: "TrainerState") -> None:
488 """Gather full state dict from FSDP shards and save in HF format.
490 Routes through ``spec.state_dict_adapter().save_hf_state_dict`` when
491 the model's ``ModelSpec`` provides one, so per-model HF tensor
492 renaming and per-expert packing live in the model package, not in
493 this generic callback. Falls back to the legacy
494 ``save_state_dict_as_huggingface_format`` path when the spec has no
495 adapter (keeps ad-hoc / template models working).
496 """
497 # pylint: disable=C0415
499 rank = platform.get_rank()
500 save_dir = os.path.join(self.output_dir, f"step_{state.global_step}", "hf_ckpt")
502 try:
503 # ``StateDictOptions`` is a torch-backend type; hyper does not yet
504 # provide a wrapper, so the trainer reaches into torch directly.
505 # pylint: disable=C0415
506 from torch.distributed.checkpoint.state_dict import StateDictOptions
507 # full_state_dict=True gathers all FSDP shards; cpu_offload avoids OOM
508 options = StateDictOptions(full_state_dict=True, cpu_offload=True)
509 full_sd = get_model_state_dict(self.trainer.model, options=options)
511 if rank == 0:
512 os.makedirs(save_dir, exist_ok=True)
514 # Prefer the model-specific save adapter (closes the load/save
515 # loop via the ModelSpec contract). When absent, fall back to
516 # the generic offline-transform path.
517 spec = getattr(self.trainer, "spec", None)
518 adapter_cls = getattr(spec, "state_dict_adapter", None) if spec else None
519 save_fn = (
520 getattr(adapter_cls(), "save_hf_state_dict", None)
521 if adapter_cls is not None else None
522 )
523 if save_fn is not None:
524 hf_sd = save_fn(full_sd, self.trainer.model.config)
525 from safetensors.torch import save_file # pylint: disable=C0415
526 save_file(hf_sd, os.path.join(save_dir, "model.safetensors"))
527 logger.info(
528 "HF checkpoint saved via %s.save_hf_state_dict to %s",
529 adapter_cls.__name__, save_dir,
530 )
531 else:
532 save_state_dict_as_huggingface_format(full_sd, save_dir)
533 logger.info(
534 "HF checkpoint saved (no adapter on spec) to %s", save_dir,
535 )
537 self._last_saved_step = state.global_step
539 except (OSError, RuntimeError, ValueError) as exc:
540 logger.warning_rank0("Failed to save HF checkpoint: %s", exc)
542class EvalCallback(Callback):
543 """Evaluation callback stub.
545 Full evaluation is not yet implemented. This stub logs a warning whenever
546 an evaluation trigger is received so the absence of eval is visible in
547 training logs rather than silently skipped.
548 """
550 def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None,
551 grad_norm: Optional[float] = None, **kwargs) -> None:
552 eval_cfg = getattr(self.trainer.args, 'eval', None)
553 eval_steps = getattr(eval_cfg, 'eval_steps', 0) if eval_cfg else 0
554 if eval_steps > 0 and state.global_step % eval_steps == 0:
555 if platform.get_rank() == 0:
556 logger.warning(
557 "EvalCallback: evaluation not implemented (step=%d)", state.global_step
558 )
560class ProfilerCallback(Callback):
561 """Training profiler callback — STUB (not verified).
563 Hook reserved for ``torch.profiler.profile`` integration. Not yet
564 verified against the trainer; if you enable ``args.profiler.enabled``
565 we emit a one-time warning so the absence of profiling traces is
566 visible. To implement: wire ``torch.profiler.profile`` start/step/stop
567 in ``on_train_begin`` / ``on_step_end`` / ``on_train_end``.
568 """
570 def __init__(self, trainer: "BaseTrainer") -> None:
571 super().__init__(trainer)
572 prof_cfg = getattr(trainer.args, 'profiler', None)
573 if getattr(prof_cfg, 'enabled', False) and platform.get_rank() == 0:
574 logger.warning(
575 "ProfilerCallback: enabled=True but the implementation is "
576 "a stub — torch profiler is NOT started. Implement before "
577 "relying on traces."
578 )
580class WandbCallback(Callback):
581 """Weights & Biases logging callback — STUB (not verified).
583 Hook reserved for W&B integration. Not yet verified; if you enable
584 ``args.wandb.enabled`` we emit a one-time warning so missing W&B logs
585 are visible. To implement: wire ``wandb.init`` / ``wandb.log`` /
586 ``wandb.finish`` in ``on_train_begin`` / ``on_step_end`` /
587 ``on_train_end`` and verify against a real W&B run.
588 """
590 def __init__(self, trainer: "BaseTrainer") -> None:
591 super().__init__(trainer)
592 wandb_cfg = getattr(trainer.args, 'wandb', None)
593 if getattr(wandb_cfg, 'enabled', False) and platform.get_rank() == 0:
594 logger.warning(
595 "WandbCallback: enabled=True but the implementation is a "
596 "stub — nothing is sent to W&B. Implement before relying on "
597 "W&B dashboards."
598 )
600class ProgressCallback(Callback):
601 """tqdm progress bar callback (rank 0 only).
603 Displays a progress bar over training steps with live loss and grad_norm
604 metrics. Requires ``tqdm``; degrades gracefully if not installed.
605 """
607 def __init__(self, trainer: "BaseTrainer") -> None:
608 super().__init__(trainer)
609 self._pbar = None
611 def on_train_begin(self, state: "TrainerState", **kwargs) -> None:
612 if platform.get_rank() != 0:
613 return
614 try:
615 # pylint: disable=C0415
616 from tqdm import tqdm # pylint: disable=C0415 # optional dep
617 self._pbar = tqdm(
618 total=state.max_steps,
619 initial=state.global_step,
620 desc="Training",
621 unit="step",
622 dynamic_ncols=True,
623 )
624 except ImportError:
625 logger.warning("ProgressCallback: 'tqdm' not installed — progress bar disabled")
627 def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None,
628 grad_norm: Optional[float] = None, **kwargs) -> None:
629 if self._pbar is None:
630 return
631 postfix = {}
632 if loss is not None:
633 postfix["loss"] = f"{loss:.4f}"
634 if grad_norm is not None:
635 postfix["gnorm"] = f"{grad_norm:.4f}"
636 self._pbar.set_postfix(postfix)
637 self._pbar.update(1)
639 def on_train_end(self, state: "TrainerState", **kwargs) -> None:
640 if self._pbar is not None:
641 self._pbar.close()
642 self._pbar = None
644class MoEMonitorCallback(Callback):
645 """Mixture-of-Experts load-balancing monitor stub.
647 Tracks expert routing statistics (token counts per expert, load imbalance)
648 during training. Full implementation requires MoE-specific hooks in the
649 model's router; this stub logs a one-time info message so it is visible
650 in logs when enabled.
651 """
653 def __init__(self, trainer: "BaseTrainer") -> None:
654 super().__init__(trainer)
655 moe_cfg = getattr(trainer.args, 'moe_monitor', None)
656 self.enabled = getattr(moe_cfg, 'enabled', False) if moe_cfg else False
658 def on_train_begin(self, state: "TrainerState", **kwargs) -> None:
659 if self.enabled and platform.get_rank() == 0:
660 logger.info(
661 "MoEMonitorCallback: MoE expert-load monitoring enabled "
662 "(stub — full implementation pending model router hooks)"
663 )
665class GradientHealthCallback(Callback):
666 """Detect NaN / Inf grad_norm and raise / warn.
668 Hooks ``on_pre_optimizer_step`` — which fires after ``clip_grad_norm_``
669 and before ``optimizer.step()``. ``grad_norm`` at that point is a plain
670 scalar produced by hyper's DTensor-aware clipper. If it's not finite, the
671 optimizer.step() would silently corrupt weights with NaN; we want to
672 surface it immediately.
674 Config: ``cfg.train.debug.check_nan_inf``.
675 """
677 def __init__(self, trainer: "BaseTrainer") -> None:
678 super().__init__(trainer)
679 debug_cfg = getattr(trainer.args, 'debug', None)
680 self.enabled = (
681 getattr(debug_cfg, 'check_nan_inf', False) if debug_cfg else False
682 )
684 def on_pre_optimizer_step(self, state: "TrainerState", *,
685 grad_norm: Optional[float] = None,
686 **kwargs) -> None:
687 if not self.enabled or grad_norm is None:
688 return
689 if math.isnan(grad_norm) or math.isinf(grad_norm):
690 # Always log on every rank — divergence may be rank-local.
691 logger.error(
692 "GradientHealthCallback: grad_norm=%s at step %d "
693 "(NaN/Inf). Optimizer.step would corrupt weights.",
694 grad_norm, state.global_step,
695 )
696 # Raise on rank 0 only; other ranks will be torn down by NCCL.
697 if platform.get_rank() == 0:
698 raise RuntimeError(
699 f"Non-finite grad_norm={grad_norm} at "
700 f"step {state.global_step}. "
701 "Disable cfg.train.debug.check_nan_inf to skip this guard."
702 )
704class GCCallback(Callback):
705 """Explicit garbage-collection scheduler.
707 Python's cyclic GC can stall large training jobs when it decides to run;
708 forcing a collection every N steps — outside the compute hot path —
709 keeps pauses predictable.).
711 Config: ``cfg.train.debug.gc_steps`` (``0`` disables).
712 """
714 def __init__(self, trainer: "BaseTrainer") -> None:
715 super().__init__(trainer)
716 debug_cfg = getattr(trainer.args, 'debug', None)
717 self.gc_steps = (
718 getattr(debug_cfg, 'gc_steps', 0) if debug_cfg else 0
719 )
720 if self.gc_steps > 0:
721 # Disable the automatic generational collector; we'll drive it.
722 gc.disable()
723 logger.info("GCCallback: Python gc.collect every %d steps "
724 "(auto GC disabled)", self.gc_steps)
726 def on_step_end(self, state: "TrainerState", *,
727 loss: Optional[float] = None,
728 grad_norm: Optional[float] = None, **kwargs) -> None:
729 if self.gc_steps <= 0:
730 return
731 if state.global_step % self.gc_steps != 0:
732 return
733 gc.collect()
735class TensorBoardCallback(Callback):
736 """TensorBoard scalar writer — STUB (not verified).
738 Hook reserved for ``torch.utils.tensorboard.SummaryWriter`` integration.
739 Not yet verified; if you enable ``args.tensorboard.enabled`` we emit
740 a one-time warning so missing TB scalars are visible. To implement:
741 open SummaryWriter in ``on_train_begin``, write scalars in ``on_log``,
742 close in ``on_train_end``.
743 """
745 def __init__(self, trainer: "BaseTrainer") -> None:
746 super().__init__(trainer)
747 tb_cfg = getattr(trainer.args, 'tensorboard', None)
748 if getattr(tb_cfg, 'enabled', False) and platform.get_rank() == 0:
749 logger.warning(
750 "TensorBoardCallback: enabled=True but the implementation "
751 "is a stub — nothing is written to TensorBoard. Implement "
752 "before relying on TB scalars."
753 )
755class MemoryMonitorCallback(Callback):
756 """Peak / current device memory monitor — STUB (not verified).
758 Hook reserved for ``platform.get_device_handle().memory_allocated`` /
759 ``max_memory_allocated`` polling. Not yet verified; if you enable
760 ``args.memory_monitor.enabled`` we emit a one-time warning so missing
761 memory logs are visible. To implement: poll the device handle in
762 ``on_step_end`` gated by ``log_steps`` and log
763 ``cur=...GB peak=...GB``.
764 """
766 def __init__(self, trainer: "BaseTrainer") -> None:
767 super().__init__(trainer)
768 cfg = getattr(trainer.args, 'memory_monitor', None)
769 if getattr(cfg, 'enabled', False) and platform.get_rank() == 0:
770 logger.warning(
771 "MemoryMonitorCallback: enabled=True but the implementation "
772 "is a stub — no memory stats are emitted. Implement before "
773 "relying on these logs."
774 )