Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / callbacks / base.py: 0%

319 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Callback base class and built-in callbacks. 

16 

17dispatched explicitly in ``on_step_end`` etc. Engineer sees all callbacks and 

18order at a glance. 

19 

20``checkpoint_callback.py`` (242 lines) + ``trace_callback.py`` (231 lines). 

21""" 

22import copy 

23import gc 

24import json 

25import logging 

26import math 

27import os 

28import threading 

29import time 

30from typing import TYPE_CHECKING, Optional 

31 

32import torch 

33 

34from hyper_parallel import get_platform 

35from hyper_parallel.core.distributed_checkpoint import load as dcp_load, save as dcp_save 

36from hyper_parallel.core.distributed_checkpoint.offline_transform import ( 

37 save_state_dict_as_huggingface_format, 

38) 

39from hyper_parallel.core.fully_shard.api import get_model_state_dict 

40 

41platform = get_platform() 

42 

43if TYPE_CHECKING: 

44 from hyper_parallel.trainer.base import BaseTrainer, TrainerState 

45 

46logger = logging.getLogger(__name__) 

47 

48class Callback: 

49 """Base class for all trainer callbacks. 

50 

51 Each callback holds a reference to the trainer for accessing model, 

52 optimizer, state, and config. Subclass and override the hooks you need. 

53 

54 Args: 

55 trainer: The BaseTrainer instance. 

56 """ 

57 

58 def __init__(self, trainer: "BaseTrainer") -> None: 

59 self.trainer = trainer 

60 

61 # ------------------------------------------------------------------ 

62 # Lifecycle hooks 

63 # ------------------------------------------------------------------ 

64 

65 def on_init_end(self, state: "TrainerState", **kwargs) -> None: 

66 """Called once at the end of ``BaseTrainer.__init__`` / subclass init. 

67 

68 At this point every ``_build_*`` has run — model is parallelised, 

69 optimizer/scheduler/dataloader are built, callbacks are constructed. 

70 Use this for one-shot setup that must see the FINAL trainer state 

71 (e.g. logging the parameter count, opening a TensorBoard writer 

72 keyed by run_id, validating user config against the built model). 

73 """ 

74 

75 def on_train_begin(self, state: "TrainerState", **kwargs) -> None: 

76 """Called at the start of ``train()`` (before any optimizer.step). 

77 

78 ``CheckpointCallback`` runs resume here, so when this hook fires 

79 ``state.global_step`` may already be > 0 if a checkpoint was loaded. 

80 """ 

81 

82 def on_train_end(self, state: "TrainerState", **kwargs) -> None: 

83 """Called at the end of training (before ``destroy_process_group``). 

84 

85 Final checkpoints, profiler stops, W&B finish, etc. happen here. 

86 """ 

87 

88 def on_epoch_begin(self, state: "TrainerState", **kwargs) -> None: 

89 """Called at the start of each epoch.""" 

90 

91 def on_epoch_end(self, state: "TrainerState", **kwargs) -> None: 

92 """Called at the end of each epoch.""" 

93 

94 def on_step_begin(self, state: "TrainerState", **kwargs) -> None: 

95 """Called at the start of each training step (before fwd of mb 0).""" 

96 

97 def on_step_end(self, state: "TrainerState", *, loss: float = None, 

98 grad_norm: float = None, **kwargs) -> None: 

99 """Called at the end of each training step (after optimizer.step).""" 

100 

101 def on_substep_end(self, state: "TrainerState", **kwargs) -> None: 

102 """Called after each micro-batch fwd+bwd (gradient accumulation).""" 

103 

104 def on_pre_optimizer_step(self, state: "TrainerState", *, 

105 grad_norm: float = None, **kwargs) -> None: 

106 """Called after grad clip, before ``optimizer.step``. 

107 

108 ``grad_norm`` here is the post-clip scalar produced by hyper's 

109 DTensor-aware clipper — use it to detect NaN/Inf or to log the 

110 effective clip ratio. 

111 """ 

112 

113 def on_log(self, state: "TrainerState", *, metrics: dict, **kwargs) -> None: 

114 """Called when ``LoggingCallback`` emits a structured metrics record. 

115 

116 Reuse this hook in TensorBoard / W&B / external metric sinks so 

117 every logging backend sees the SAME record. Avoids three callbacks 

118 each computing throughput / lr independently. 

119 

120 Args: 

121 metrics: Dict containing at minimum ``step``, ``loss``, 

122 ``grad_norm``, ``lr``, ``step_time``; throughput fields 

123 (``tokens_per_sec``, ``tflops``, ``mfu``) are present iff 

124 ``logging.report_throughput`` is on. 

125 """ 

126 

127 def on_save(self, state: "TrainerState", *, checkpoint_dir: str, 

128 **kwargs) -> None: 

129 """Called immediately after ``CheckpointCallback`` finishes a save. 

130 

131 Use to upload to remote storage, register the ckpt with an 

132 experiment tracker, or trigger downstream eval jobs. ``checkpoint_dir`` 

133 is the on-disk path containing model shards + optimizer/scheduler/RNG/ 

134 dataloader/extra_state. 

135 """ 

136 

137 def on_load(self, state: "TrainerState", *, checkpoint_dir: str, 

138 **kwargs) -> None: 

139 """Called immediately after ``CheckpointCallback`` finishes a resume. 

140 

141 Use to verify the resumed step matches expectations, log the 

142 restore event, or seed downstream callbacks with the resumed state. 

143 """ 

144 

145 def on_evaluate(self, state: "TrainerState", *, metrics: dict = None, 

146 **kwargs) -> None: 

147 """Called when an evaluation pass completes. 

148 

149 Currently triggered as a stub from ``EvalCallback``; once a real 

150 eval loop lands the callback will pass back the eval ``metrics`` 

151 dict for sinks (TensorBoard / W&B) to log. 

152 """ 

153 

154class LoggingCallback(Callback): 

155 """Log training metrics: loss, grad_norm, lr, throughput. 

156 

157 """ 

158 

159 def __init__(self, trainer: "BaseTrainer") -> None: 

160 super().__init__(trainer) 

161 log_cfg = getattr(trainer.args, 'logging', None) 

162 self.log_steps = getattr(log_cfg, 'log_steps', 10) if log_cfg else 10 

163 self.report_global_loss = ( 

164 getattr(log_cfg, 'report_global_loss', False) if log_cfg else False 

165 ) 

166 self.report_throughput = ( 

167 getattr(log_cfg, 'report_throughput', True) if log_cfg else True 

168 ) 

169 self.model_flops_per_token = ( 

170 getattr(log_cfg, 'model_flops_per_token', None) if log_cfg else None 

171 ) 

172 self.peak_tflops = ( 

173 getattr(log_cfg, 'peak_tflops', None) if log_cfg else None 

174 ) 

175 # Estimate per-step tokens as upper bound (batch × seq_len). Real 

176 # token count is available per step via ``last_global_tokens`` that 

177 # ``BaseTrainer.train_step`` stashes onto the trainer. 

178 gbs = getattr(trainer.args.train, 'global_batch_size', 1) 

179 seq_len = getattr(trainer.args.data, 'max_seq_len', 1) 

180 self._tokens_per_step_est = int(gbs) * int(seq_len) 

181 self._step_start_time = 0.0 

182 

183 def on_step_begin(self, state: "TrainerState", **kwargs) -> None: 

184 self._step_start_time = time.time() 

185 

186 def on_step_end(self, state: "TrainerState", *, loss: float = None, 

187 grad_norm: float = None, **kwargs) -> None: 

188 if state.global_step % self.log_steps != 0: 

189 return 

190 

191 elapsed = max(time.time() - self._step_start_time, 1e-9) 

192 lr = 0.0 

193 if self.trainer.lr_scheduler is not None: 

194 lr = self.trainer.lr_scheduler.get_last_lr()[0] 

195 

196 metrics = { 

197 "step": state.global_step, 

198 # 8-decimal precision keeps fp32 sub-bf16 differences visible 

199 # in the log for sanity comparisons across runs. 

200 "loss": f"{loss:.8f}" if loss is not None else "N/A", 

201 "grad_norm": ( 

202 f"{grad_norm:.8f}" if grad_norm is not None else "N/A" 

203 ), 

204 "lr": f"{lr:.2e}", 

205 "step_time": f"{elapsed:.2f}s", 

206 } 

207 

208 tokens_per_sec = None 

209 if self.report_throughput: 

210 # Prefer real per-step token count stashed by train_step. 

211 tokens = getattr(self.trainer, '_last_global_tokens', 

212 self._tokens_per_step_est) 

213 tokens_per_sec = tokens / elapsed 

214 metrics["tokens_per_sec"] = f"{tokens_per_sec:,.0f}" 

215 

216 if self.model_flops_per_token and self.peak_tflops: 

217 # Observed TFLOPS = tokens/sec × flops/token / 1e12. 

218 # MFU = observed / (peak × world_size). 

219 world = max(platform.get_world_size(), 1) 

220 observed_tflops = ( 

221 tokens_per_sec * self.model_flops_per_token / 1e12 

222 ) 

223 mfu = observed_tflops / (self.peak_tflops * world) 

224 metrics["tflops"] = f"{observed_tflops:.1f}" 

225 metrics["mfu"] = f"{mfu * 100:.1f}%" 

226 

227 logger.info_rank0(" | ".join(f"{k}={v}" for k, v in metrics.items())) 

228 

229 record = { 

230 "step": state.global_step, 

231 "loss": loss, 

232 "grad_norm": grad_norm, 

233 "lr": lr, 

234 "step_time": elapsed, 

235 "tokens_per_sec": tokens_per_sec, 

236 } 

237 state.log_history.append(record) 

238 

239 # Fan-out to other log-event listeners (TB / W&B / sinks). 

240 dispatch = getattr(self.trainer, "dispatch_log_event", None) 

241 if dispatch is not None: 

242 dispatch(record) 

243 

244class CheckpointCallback(Callback): 

245 """Save distributed checkpoints and handle resume. 

246 

247 Uses hyper's own DCP ``save`` / ``load`` APIs. 

248 """ 

249 

250 def __init__(self, trainer: "BaseTrainer") -> None: 

251 super().__init__(trainer) 

252 ckpt_cfg = getattr(trainer.args, 'checkpoint', None) 

253 self.save_steps = getattr(ckpt_cfg, 'save_steps', 0) if ckpt_cfg else 0 

254 self.output_dir = ( 

255 getattr(ckpt_cfg, 'output_dir', 'outputs') if ckpt_cfg else 'outputs' 

256 ) 

257 self.load_path = ( 

258 getattr(ckpt_cfg, 'load_path', None) if ckpt_cfg else None 

259 ) 

260 self.save_async = ( 

261 getattr(ckpt_cfg, 'save_async', False) if ckpt_cfg else False 

262 ) 

263 self._last_saved_step = -1 

264 self._save_thread = None # async save worker 

265 

266 def on_train_begin(self, state: "TrainerState", **kwargs) -> None: 

267 """Resume from checkpoint: model + optimizer + lr_scheduler + step + RNG. 

268 

269 RFC DoD: "Save → resume → 续训 loss 一致(含 dataloader + RNG 恢复)" 

270 """ 

271 if not self.load_path: 

272 return 

273 try: 

274 # pylint: disable=C0415 

275 # Non-model artifacts (optimizer/scheduler/RNG) are plain dicts — 

276 # use torch.save/load, matching the save side. 

277 

278 if not os.path.isdir(self.load_path): 

279 logger.warning("Checkpoint path not found: %s", self.load_path) 

280 return 

281 

282 # 1. Restore model via hyper DCP 

283 model_sd = self.trainer.model.state_dict() 

284 dcp_load(model_sd, checkpoint_id=self.load_path, use_collectives=False) 

285 self.trainer.model.load_state_dict(model_sd) 

286 logger.info("Model restored from %s", self.load_path) 

287 

288 # 2. Restore extra state (step, epoch) 

289 extra_path = os.path.join(self.load_path, "extra_state.json") 

290 if os.path.isfile(extra_path): 

291 with open(extra_path, encoding="utf-8") as f: 

292 extra = json.load(f) 

293 state.global_step = extra.get("global_step", 0) 

294 state.epoch = extra.get("epoch", 0) 

295 logger.info("Resumed at step=%d, epoch=%d", 

296 state.global_step, state.epoch) 

297 

298 # 3. Restore optimizer 

299 optim_path = os.path.join(self.load_path, f"optimizer_rank{platform.get_rank()}.pt") 

300 if os.path.isfile(optim_path) and self.trainer.optimizer: 

301 optim_sd = torch.load(optim_path, map_location="cpu", weights_only=True) 

302 self.trainer.optimizer.load_state_dict(optim_sd) 

303 logger.info("Optimizer restored") 

304 

305 # 4. Restore LR scheduler 

306 sched_path = os.path.join(self.load_path, "scheduler.pt") 

307 if os.path.isfile(sched_path) and self.trainer.lr_scheduler: 

308 sched_sd = torch.load(sched_path, map_location="cpu", weights_only=True) 

309 self.trainer.lr_scheduler.load_state_dict(sched_sd) 

310 logger.info("LR scheduler restored") 

311 

312 # 5. Restore RNG state 

313 rng_path = os.path.join(self.load_path, f"rng_rank{platform.get_rank()}.pt") 

314 if os.path.isfile(rng_path): 

315 rng_state = torch.load(rng_path, map_location="cpu", weights_only=True) 

316 platform.set_rng_state(rng_state) 

317 logger.info("RNG state restored") 

318 

319 # 6. Restore dataloader position (StatefulDataLoader) 

320 dl_path = os.path.join(self.load_path, f"dataloader_rank{platform.get_rank()}.pt") 

321 if os.path.isfile(dl_path) and hasattr(self.trainer, 'train_dataloader'): 

322 dl_state = torch.load(dl_path, map_location="cpu", weights_only=False) 

323 self.trainer.train_dataloader.load_state_dict(dl_state) 

324 logger.info("Dataloader state restored") 

325 

326 # Fan-out the load event so other callbacks (TensorBoard / 

327 # W&B / external trackers) can record the resume. 

328 dispatch = getattr(self.trainer, "dispatch_load_event", None) 

329 if dispatch is not None: 

330 dispatch(self.load_path) 

331 

332 except (OSError, RuntimeError, ValueError) as exc: 

333 logger.warning("Failed to load checkpoint from %s: %s", self.load_path, exc) 

334 

335 def on_step_end(self, state: "TrainerState", *, loss: float = None, 

336 grad_norm: float = None, **kwargs) -> None: 

337 if self.save_steps <= 0: 

338 return 

339 if state.global_step % self.save_steps != 0: 

340 return 

341 if state.global_step == self._last_saved_step: 

342 return 

343 self._dispatch_save(state) 

344 

345 def on_train_end(self, state: "TrainerState", **kwargs) -> None: 

346 """Save final checkpoint (synchronously, to guarantee completion).""" 

347 # Wait for any outstanding async save first so the two don't race on 

348 # the same directory / state-dict iterator. 

349 self._join_pending() 

350 if self.save_steps > 0 and state.global_step != self._last_saved_step: 

351 # Final save always sync — the process is about to exit. 

352 self._save(state) 

353 

354 # --- async plumbing ------------------------------------------------- 

355 def _dispatch_save(self, state: "TrainerState") -> None: 

356 """Route to sync or async save based on ``save_async`` flag.""" 

357 if not self.save_async: 

358 self._save(state) 

359 return 

360 # Wait for previous save to finish before starting a new one; saving 

361 # twice concurrently would double RAM and race the filesystem. 

362 self._join_pending() 

363 # pylint: disable=C0415 

364 # Snapshot state fields so the worker doesn't see later mutations. 

365 snap_step = state.global_step 

366 snap_epoch = state.epoch 

367 state_snapshot = copy.copy(state) 

368 state_snapshot.global_step = snap_step 

369 state_snapshot.epoch = snap_epoch 

370 self._save_thread = threading.Thread( 

371 target=self._save, 

372 args=(state_snapshot,), 

373 name=f"ckpt-save-step{snap_step}", 

374 daemon=True, 

375 ) 

376 self._save_thread.start() 

377 logger.info_rank0( 

378 "Checkpoint save for step %d dispatched async (thread=%s)", 

379 snap_step, self._save_thread.name, 

380 ) 

381 

382 def _join_pending(self) -> None: 

383 """Block until any running async save finishes.""" 

384 t = self._save_thread 

385 if t is not None and t.is_alive(): 

386 logger.info_rank0( 

387 "Waiting for prior async ckpt save (%s)...", t.name, 

388 ) 

389 t.join() 

390 self._save_thread = None 

391 

392 def _save(self, state: "TrainerState") -> None: 

393 """Save complete training state: model + optimizer + scheduler + step + RNG. 

394 

395 RFC DoD: "Save → resume → 续训 loss 一致(含 dataloader + RNG 恢复)" 

396 """ 

397 # Optimizer/scheduler/RNG state dicts are plain Python dicts, not 

398 # nn.Module — platform.save_checkpoint expects Module (safetensors). 

399 # Use torch.save/load for these non-model artifacts. 

400 save_dir = os.path.join(self.output_dir, f"step_{state.global_step}") 

401 os.makedirs(save_dir, exist_ok=True) 

402 rank = platform.get_rank() 

403 

404 try: 

405 # 1. Model — via hyper DCP (each rank saves its own shards) 

406 model_sd = self.trainer.model.state_dict() 

407 dcp_save(model_sd, checkpoint_id=save_dir, use_collectives=False) 

408 

409 # 2. Optimizer — per-rank 

410 if self.trainer.optimizer: 

411 optim_path = os.path.join(save_dir, f"optimizer_rank{rank}.pt") 

412 torch.save(self.trainer.optimizer.state_dict(), optim_path) 

413 

414 # 3. LR scheduler 

415 if self.trainer.lr_scheduler and rank == 0: 

416 sched_path = os.path.join(save_dir, "scheduler.pt") 

417 torch.save(self.trainer.lr_scheduler.state_dict(), sched_path) 

418 

419 # 4. Extra state: global_step, epoch 

420 if rank == 0: 

421 extra = { 

422 "global_step": state.global_step, 

423 "epoch": state.epoch, 

424 } 

425 extra_path = os.path.join(save_dir, "extra_state.json") 

426 with open(extra_path, "w", encoding="utf-8") as f: 

427 json.dump(extra, f) 

428 

429 # 5. RNG state — per-rank via platform API 

430 rng_state = platform.get_rng_state() 

431 rng_path = os.path.join(save_dir, f"rng_rank{rank}.pt") 

432 torch.save(rng_state, rng_path) 

433 

434 # 6. Dataloader position — per-rank (StatefulDataLoader) 

435 if hasattr(self.trainer, 'train_dataloader') and hasattr( 

436 self.trainer.train_dataloader, 'state_dict' 

437 ): 

438 dl_path = os.path.join(save_dir, f"dataloader_rank{rank}.pt") 

439 torch.save(self.trainer.train_dataloader.state_dict(), dl_path) 

440 

441 self._last_saved_step = state.global_step 

442 logger.info_rank0("Checkpoint saved to %s", save_dir) 

443 

444 # Fan-out the save event so other callbacks (W&B artifact 

445 # upload, remote-storage sync, downstream eval triggers) can 

446 # observe the new checkpoint without coupling to ckpt internals. 

447 dispatch = getattr(self.trainer, "dispatch_save_event", None) 

448 if dispatch is not None: 

449 dispatch(save_dir) 

450 

451 except (OSError, RuntimeError, ValueError) as exc: 

452 logger.warning("Failed to save checkpoint: %s", exc) 

453 

454 # HF format export is handled by SafetensorsExportCallback (separate concern). 

455 

456class SafetensorsExportCallback(Callback): 

457 """Export model weights in HuggingFace safetensor format. 

458 

459 Separated from CheckpointCallback per RFC Section 5.2. 

460 Uses ``get_model_state_dict`` with ``full_state_dict=True`` to gather 

461 all FSDP shards into a full state dict before saving. 

462 

463 """ 

464 

465 def __init__(self, trainer: "BaseTrainer") -> None: 

466 super().__init__(trainer) 

467 ckpt_cfg = getattr(trainer.args, 'checkpoint', None) 

468 self.enabled = getattr(ckpt_cfg, 'save_hf_weights', False) if ckpt_cfg else False 

469 self.save_steps = getattr(ckpt_cfg, 'save_steps', 0) if ckpt_cfg else 0 

470 self.output_dir = getattr(ckpt_cfg, 'output_dir', 'outputs') if ckpt_cfg else 'outputs' 

471 self._last_saved_step = -1 

472 

473 def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None, 

474 grad_norm: Optional[float] = None, **kwargs) -> None: 

475 if not self.enabled or self.save_steps <= 0: 

476 return 

477 if state.global_step % self.save_steps != 0: 

478 return 

479 if state.global_step == self._last_saved_step: 

480 return 

481 self._export(state) 

482 

483 def on_train_end(self, state: "TrainerState", **kwargs) -> None: 

484 if self.enabled and self.save_steps > 0 and state.global_step != self._last_saved_step: 

485 self._export(state) 

486 

487 def _export(self, state: "TrainerState") -> None: 

488 """Gather full state dict from FSDP shards and save in HF format. 

489 

490 Routes through ``spec.state_dict_adapter().save_hf_state_dict`` when 

491 the model's ``ModelSpec`` provides one, so per-model HF tensor 

492 renaming and per-expert packing live in the model package, not in 

493 this generic callback. Falls back to the legacy 

494 ``save_state_dict_as_huggingface_format`` path when the spec has no 

495 adapter (keeps ad-hoc / template models working). 

496 """ 

497 # pylint: disable=C0415 

498 

499 rank = platform.get_rank() 

500 save_dir = os.path.join(self.output_dir, f"step_{state.global_step}", "hf_ckpt") 

501 

502 try: 

503 # ``StateDictOptions`` is a torch-backend type; hyper does not yet 

504 # provide a wrapper, so the trainer reaches into torch directly. 

505 # pylint: disable=C0415 

506 from torch.distributed.checkpoint.state_dict import StateDictOptions 

507 # full_state_dict=True gathers all FSDP shards; cpu_offload avoids OOM 

508 options = StateDictOptions(full_state_dict=True, cpu_offload=True) 

509 full_sd = get_model_state_dict(self.trainer.model, options=options) 

510 

511 if rank == 0: 

512 os.makedirs(save_dir, exist_ok=True) 

513 

514 # Prefer the model-specific save adapter (closes the load/save 

515 # loop via the ModelSpec contract). When absent, fall back to 

516 # the generic offline-transform path. 

517 spec = getattr(self.trainer, "spec", None) 

518 adapter_cls = getattr(spec, "state_dict_adapter", None) if spec else None 

519 save_fn = ( 

520 getattr(adapter_cls(), "save_hf_state_dict", None) 

521 if adapter_cls is not None else None 

522 ) 

523 if save_fn is not None: 

524 hf_sd = save_fn(full_sd, self.trainer.model.config) 

525 from safetensors.torch import save_file # pylint: disable=C0415 

526 save_file(hf_sd, os.path.join(save_dir, "model.safetensors")) 

527 logger.info( 

528 "HF checkpoint saved via %s.save_hf_state_dict to %s", 

529 adapter_cls.__name__, save_dir, 

530 ) 

531 else: 

532 save_state_dict_as_huggingface_format(full_sd, save_dir) 

533 logger.info( 

534 "HF checkpoint saved (no adapter on spec) to %s", save_dir, 

535 ) 

536 

537 self._last_saved_step = state.global_step 

538 

539 except (OSError, RuntimeError, ValueError) as exc: 

540 logger.warning_rank0("Failed to save HF checkpoint: %s", exc) 

541 

542class EvalCallback(Callback): 

543 """Evaluation callback stub. 

544 

545 Full evaluation is not yet implemented. This stub logs a warning whenever 

546 an evaluation trigger is received so the absence of eval is visible in 

547 training logs rather than silently skipped. 

548 """ 

549 

550 def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None, 

551 grad_norm: Optional[float] = None, **kwargs) -> None: 

552 eval_cfg = getattr(self.trainer.args, 'eval', None) 

553 eval_steps = getattr(eval_cfg, 'eval_steps', 0) if eval_cfg else 0 

554 if eval_steps > 0 and state.global_step % eval_steps == 0: 

555 if platform.get_rank() == 0: 

556 logger.warning( 

557 "EvalCallback: evaluation not implemented (step=%d)", state.global_step 

558 ) 

559 

560class ProfilerCallback(Callback): 

561 """Training profiler callback — STUB (not verified). 

562 

563 Hook reserved for ``torch.profiler.profile`` integration. Not yet 

564 verified against the trainer; if you enable ``args.profiler.enabled`` 

565 we emit a one-time warning so the absence of profiling traces is 

566 visible. To implement: wire ``torch.profiler.profile`` start/step/stop 

567 in ``on_train_begin`` / ``on_step_end`` / ``on_train_end``. 

568 """ 

569 

570 def __init__(self, trainer: "BaseTrainer") -> None: 

571 super().__init__(trainer) 

572 prof_cfg = getattr(trainer.args, 'profiler', None) 

573 if getattr(prof_cfg, 'enabled', False) and platform.get_rank() == 0: 

574 logger.warning( 

575 "ProfilerCallback: enabled=True but the implementation is " 

576 "a stub — torch profiler is NOT started. Implement before " 

577 "relying on traces." 

578 ) 

579 

580class WandbCallback(Callback): 

581 """Weights & Biases logging callback — STUB (not verified). 

582 

583 Hook reserved for W&B integration. Not yet verified; if you enable 

584 ``args.wandb.enabled`` we emit a one-time warning so missing W&B logs 

585 are visible. To implement: wire ``wandb.init`` / ``wandb.log`` / 

586 ``wandb.finish`` in ``on_train_begin`` / ``on_step_end`` / 

587 ``on_train_end`` and verify against a real W&B run. 

588 """ 

589 

590 def __init__(self, trainer: "BaseTrainer") -> None: 

591 super().__init__(trainer) 

592 wandb_cfg = getattr(trainer.args, 'wandb', None) 

593 if getattr(wandb_cfg, 'enabled', False) and platform.get_rank() == 0: 

594 logger.warning( 

595 "WandbCallback: enabled=True but the implementation is a " 

596 "stub — nothing is sent to W&B. Implement before relying on " 

597 "W&B dashboards." 

598 ) 

599 

600class ProgressCallback(Callback): 

601 """tqdm progress bar callback (rank 0 only). 

602 

603 Displays a progress bar over training steps with live loss and grad_norm 

604 metrics. Requires ``tqdm``; degrades gracefully if not installed. 

605 """ 

606 

607 def __init__(self, trainer: "BaseTrainer") -> None: 

608 super().__init__(trainer) 

609 self._pbar = None 

610 

611 def on_train_begin(self, state: "TrainerState", **kwargs) -> None: 

612 if platform.get_rank() != 0: 

613 return 

614 try: 

615 # pylint: disable=C0415 

616 from tqdm import tqdm # pylint: disable=C0415 # optional dep 

617 self._pbar = tqdm( 

618 total=state.max_steps, 

619 initial=state.global_step, 

620 desc="Training", 

621 unit="step", 

622 dynamic_ncols=True, 

623 ) 

624 except ImportError: 

625 logger.warning("ProgressCallback: 'tqdm' not installed — progress bar disabled") 

626 

627 def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None, 

628 grad_norm: Optional[float] = None, **kwargs) -> None: 

629 if self._pbar is None: 

630 return 

631 postfix = {} 

632 if loss is not None: 

633 postfix["loss"] = f"{loss:.4f}" 

634 if grad_norm is not None: 

635 postfix["gnorm"] = f"{grad_norm:.4f}" 

636 self._pbar.set_postfix(postfix) 

637 self._pbar.update(1) 

638 

639 def on_train_end(self, state: "TrainerState", **kwargs) -> None: 

640 if self._pbar is not None: 

641 self._pbar.close() 

642 self._pbar = None 

643 

644class MoEMonitorCallback(Callback): 

645 """Mixture-of-Experts load-balancing monitor stub. 

646 

647 Tracks expert routing statistics (token counts per expert, load imbalance) 

648 during training. Full implementation requires MoE-specific hooks in the 

649 model's router; this stub logs a one-time info message so it is visible 

650 in logs when enabled. 

651 """ 

652 

653 def __init__(self, trainer: "BaseTrainer") -> None: 

654 super().__init__(trainer) 

655 moe_cfg = getattr(trainer.args, 'moe_monitor', None) 

656 self.enabled = getattr(moe_cfg, 'enabled', False) if moe_cfg else False 

657 

658 def on_train_begin(self, state: "TrainerState", **kwargs) -> None: 

659 if self.enabled and platform.get_rank() == 0: 

660 logger.info( 

661 "MoEMonitorCallback: MoE expert-load monitoring enabled " 

662 "(stub — full implementation pending model router hooks)" 

663 ) 

664 

665class GradientHealthCallback(Callback): 

666 """Detect NaN / Inf grad_norm and raise / warn. 

667 

668 Hooks ``on_pre_optimizer_step`` — which fires after ``clip_grad_norm_`` 

669 and before ``optimizer.step()``. ``grad_norm`` at that point is a plain 

670 scalar produced by hyper's DTensor-aware clipper. If it's not finite, the 

671 optimizer.step() would silently corrupt weights with NaN; we want to 

672 surface it immediately. 

673 

674 Config: ``cfg.train.debug.check_nan_inf``. 

675 """ 

676 

677 def __init__(self, trainer: "BaseTrainer") -> None: 

678 super().__init__(trainer) 

679 debug_cfg = getattr(trainer.args, 'debug', None) 

680 self.enabled = ( 

681 getattr(debug_cfg, 'check_nan_inf', False) if debug_cfg else False 

682 ) 

683 

684 def on_pre_optimizer_step(self, state: "TrainerState", *, 

685 grad_norm: Optional[float] = None, 

686 **kwargs) -> None: 

687 if not self.enabled or grad_norm is None: 

688 return 

689 if math.isnan(grad_norm) or math.isinf(grad_norm): 

690 # Always log on every rank — divergence may be rank-local. 

691 logger.error( 

692 "GradientHealthCallback: grad_norm=%s at step %d " 

693 "(NaN/Inf). Optimizer.step would corrupt weights.", 

694 grad_norm, state.global_step, 

695 ) 

696 # Raise on rank 0 only; other ranks will be torn down by NCCL. 

697 if platform.get_rank() == 0: 

698 raise RuntimeError( 

699 f"Non-finite grad_norm={grad_norm} at " 

700 f"step {state.global_step}. " 

701 "Disable cfg.train.debug.check_nan_inf to skip this guard." 

702 ) 

703 

704class GCCallback(Callback): 

705 """Explicit garbage-collection scheduler. 

706 

707 Python's cyclic GC can stall large training jobs when it decides to run; 

708 forcing a collection every N steps — outside the compute hot path — 

709 keeps pauses predictable.). 

710 

711 Config: ``cfg.train.debug.gc_steps`` (``0`` disables). 

712 """ 

713 

714 def __init__(self, trainer: "BaseTrainer") -> None: 

715 super().__init__(trainer) 

716 debug_cfg = getattr(trainer.args, 'debug', None) 

717 self.gc_steps = ( 

718 getattr(debug_cfg, 'gc_steps', 0) if debug_cfg else 0 

719 ) 

720 if self.gc_steps > 0: 

721 # Disable the automatic generational collector; we'll drive it. 

722 gc.disable() 

723 logger.info("GCCallback: Python gc.collect every %d steps " 

724 "(auto GC disabled)", self.gc_steps) 

725 

726 def on_step_end(self, state: "TrainerState", *, 

727 loss: Optional[float] = None, 

728 grad_norm: Optional[float] = None, **kwargs) -> None: 

729 if self.gc_steps <= 0: 

730 return 

731 if state.global_step % self.gc_steps != 0: 

732 return 

733 gc.collect() 

734 

735class TensorBoardCallback(Callback): 

736 """TensorBoard scalar writer — STUB (not verified). 

737 

738 Hook reserved for ``torch.utils.tensorboard.SummaryWriter`` integration. 

739 Not yet verified; if you enable ``args.tensorboard.enabled`` we emit 

740 a one-time warning so missing TB scalars are visible. To implement: 

741 open SummaryWriter in ``on_train_begin``, write scalars in ``on_log``, 

742 close in ``on_train_end``. 

743 """ 

744 

745 def __init__(self, trainer: "BaseTrainer") -> None: 

746 super().__init__(trainer) 

747 tb_cfg = getattr(trainer.args, 'tensorboard', None) 

748 if getattr(tb_cfg, 'enabled', False) and platform.get_rank() == 0: 

749 logger.warning( 

750 "TensorBoardCallback: enabled=True but the implementation " 

751 "is a stub — nothing is written to TensorBoard. Implement " 

752 "before relying on TB scalars." 

753 ) 

754 

755class MemoryMonitorCallback(Callback): 

756 """Peak / current device memory monitor — STUB (not verified). 

757 

758 Hook reserved for ``platform.get_device_handle().memory_allocated`` / 

759 ``max_memory_allocated`` polling. Not yet verified; if you enable 

760 ``args.memory_monitor.enabled`` we emit a one-time warning so missing 

761 memory logs are visible. To implement: poll the device handle in 

762 ``on_step_end`` gated by ``log_steps`` and log 

763 ``cur=...GB peak=...GB``. 

764 """ 

765 

766 def __init__(self, trainer: "BaseTrainer") -> None: 

767 super().__init__(trainer) 

768 cfg = getattr(trainer.args, 'memory_monitor', None) 

769 if getattr(cfg, 'enabled', False) and platform.get_rank() == 0: 

770 logger.warning( 

771 "MemoryMonitorCallback: enabled=True but the implementation " 

772 "is a stub — no memory stats are emitted. Implement before " 

773 "relying on these logs." 

774 )