Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / base.py: 0%

642 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""BaseTrainer — composable training skeleton with 13 overridable ``_build_*`` steps. 

16 

17Design references: 

18- BaseTrainer (``): composition pattern, 13-step init 

19- hyper ``fsdp_demo.py``: FSDP wrapping via ``model.layers`` iteration 

20- hyper st tests: TP → CP → AC → FSDP composition order 

21 

22Subclasses (LLMTrainer, DiTTrainer, VLMTrainer) use the composition pattern: 

23instantiate ``BaseTrainer`` and call ``_build_*`` methods selectively, overriding 

24or skipping steps as needed. 

25""" 

26import json 

27import logging 

28import math 

29import os 

30from contextlib import nullcontext 

31from typing import TYPE_CHECKING, Any, Dict, Optional 

32 

33import torch 

34from torch.utils.data import Dataset, DistributedSampler 

35 

36from hyper_parallel import ( 

37 get_platform, 

38 init_empty_weights, 

39 init_process_group, 

40 destroy_process_group, 

41 hsdp_sync_stream, 

42 SkipDTensorDispatch, 

43 HSDPModule, 

44) 

45from hyper_parallel.core.distributed_checkpoint import load as dcp_load 

46from hyper_parallel.core.dtensor.dtensor import DTensor 

47from hyper_parallel.core.fully_shard.hsdp_utils import GroupInfo 

48from hyper_parallel.core.utils import clip_grad_norm_ 

49from hyper_parallel.models.spec.registry import get_spec 

50from hyper_parallel.trainer.parallel_dims import ParallelDims 

51from hyper_parallel.trainer.utils.loss import count_loss_token 

52from hyper_parallel.trainer.callbacks.base import ( 

53 LoggingCallback, 

54 CheckpointCallback, 

55 SafetensorsExportCallback, 

56 EvalCallback, 

57 ProfilerCallback, 

58 WandbCallback, 

59 ProgressCallback, 

60 MoEMonitorCallback, 

61 GradientHealthCallback, 

62 GCCallback, 

63 TensorBoardCallback, 

64 MemoryMonitorCallback, 

65) 

66 

67if TYPE_CHECKING: 

68 # Type-only imports — never executed at runtime, so the platform-agnostic 

69 # rule ("no torch/mindspore in trainer code") is preserved. Same pattern 

70 # as 

71 from torch import nn 

72 from torch.optim import Optimizer 

73 from torch.optim.lr_scheduler import LRScheduler 

74 from torch.utils.data import DataLoader 

75 from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

76 

77platform = get_platform() 

78logger = logging.getLogger(__name__) 

79 

80class TrainerState: 

81 """Mutable training state shared across callbacks. 

82 

83 Attributes: 

84 global_step: Current training step (update count). 

85 epoch: Current epoch index. 

86 max_steps: Total number of training steps. 

87 """ 

88 

89 def __init__(self, max_steps: int = 0): 

90 self.global_step: int = 0 

91 self.epoch: int = 0 

92 self.max_steps: int = max_steps 

93 self.log_history: list = [] 

94 

95class BaseTrainer: 

96 """Composable training skeleton. 

97 

98 Provides 13 ``_build_*`` methods that subclasses can call, override, or skip. 

99 The default ``_build_parallelized_model`` applies TP → CP → AC → FSDP by 

100 iterating ``model.layers`` — matching hyper's own ``fsdp_demo.py`` style. 

101 

102 Args: 

103 args: Training configuration (typically parsed from YAML). 

104 """ 

105 

106 # PEP 526 annotations — populated by ``_build_*``; ``None`` until built. 

107 model: Optional["nn.Module"] = None 

108 optimizer: Optional["Optimizer"] = None 

109 lr_scheduler: Optional["LRScheduler"] = None 

110 train_dataloader: Optional["DataLoader"] = None 

111 mesh: Optional["DeviceMesh"] = None 

112 

113 def __init__(self, args): 

114 # Only early-bound fields live here; the rest is built via 

115 # ``_build_*`` methods invoked by the subclass. 

116 self.args = args 

117 self.spec = get_spec(args.model.name) 

118 self.state = TrainerState(max_steps=getattr(args.train, 'max_steps', 0)) 

119 

120 # ------------------------------------------------------------------ 

121 # 13 overridable _build_* methods 

122 # ------------------------------------------------------------------ 

123 

124 def _setup(self): 

125 """Step 1: Initialize distributed environment, device mesh, and seed. 

126 

127 Calls hyper's own ``init_process_group`` and ``init_device_mesh``. 

128 Mesh shape is derived from ``args.parallel`` (dp, tp, cp, pp, ep). 

129 """ 

130 backend = getattr(self.args.train, 'comm_backend', None) 

131 init_process_group(backend=backend) 

132 

133 local_rank = getattr(self.args.train, 'local_rank', 0) 

134 device_type = platform.device_type() # "npu" or "cuda" 

135 # Use platform.device(idx) — backend-agnostic. 

136 self.device = platform.device(local_rank) 

137 device_handle = platform.get_device_handle(device_type) 

138 device_handle.set_device(local_rank) 

139 

140 # Build & validate parallel dims in one place (fail-fast). 

141 

142 self.parallel_dims = ParallelDims.from_config( 

143 self.args.train.accelerator, world_size=platform.get_world_size(), 

144 ) 

145 logger.info_rank0("ParallelDims: %s", self.parallel_dims.summary()) 

146 self.mesh = self.parallel_dims.build_mesh(platform.device_type()) 

147 

148 # Build DP group_info for trainer-level all_reduce (loss/token sync). 

149 # Uses hyper's GroupInfo + mesh.get_group (platform-agnostic). 

150 

151 dp_group = self._get_combined_dp_group() 

152 dp_size = self.parallel_dims.dp_size 

153 self._dp_group_info = GroupInfo( 

154 group_name="trainer_dp", group=dp_group, rank_size=dp_size, 

155 ) 

156 

157 debug_cfg = getattr(self.args.train, 'debug', None) 

158 seed = getattr(self.args.train, 'seed', 42) 

159 platform.manual_seed(seed) 

160 # Seed device-side RNG explicitly: ``platform.manual_seed`` only 

161 # covers CPU. 

162 try: 

163 handle = platform.get_device_handle(device_type) 

164 if hasattr(handle, "manual_seed_all"): 

165 handle.manual_seed_all(seed) 

166 elif hasattr(handle, "manual_seed"): 

167 handle.manual_seed(seed) 

168 except Exception as exc: # pylint: disable=W0718 

169 logger.warning("Device-side seed init skipped: %s", exc) 

170 

171 if debug_cfg is not None and getattr(debug_cfg, 'deterministic', False): 

172 warn_only = getattr(debug_cfg, 'deterministic_warn_only', False) 

173 torch.use_deterministic_algorithms(True, warn_only=warn_only) 

174 torch.backends.cudnn.deterministic = True 

175 torch.backends.cudnn.benchmark = False 

176 logger.info_rank0("Deterministic algorithms enabled (warn_only=%s)", warn_only) 

177 

178 logger.info_rank0( 

179 "Setup complete: rank=%d, world_size=%d, mesh=%s", 

180 platform.get_rank(), platform.get_world_size(), 

181 self.mesh.mesh_dim_names, 

182 ) 

183 logger.info_rank0( 

184 "Config: data.type=%s, model.name=%s, model.num_hidden_layers=%s, " 

185 "init_device=%s, max_steps=%d, global_bs=%d", 

186 getattr(self.args.data, 'type', '?'), 

187 getattr(self.args.model, 'name', '?'), 

188 getattr(self.args.model, 'num_hidden_layers', '?'), 

189 getattr(self.args.train, 'init_device', '?'), 

190 self.state.max_steps, 

191 getattr(self.args.train, 'global_batch_size', '?'), 

192 ) 

193 

194 def _build_model(self): 

195 """Step 2: Construct model via ``spec.build_model_fn``. 

196 

197 The model is a plain ``nn.Module`` at this point — not yet parallelized. 

198 When ``args.runtime.init_device == "meta"``, the model is constructed on 

199 the meta device (no memory allocated) and real weights are loaded after 

200 FSDP sharding via ``_load_weights_after_parallel``. 

201 """ 

202 init_device = getattr(self.args.train, 'init_device', 'meta') 

203 # Meta-device init: each rank materialises only its own shard 

204 # post-FSDP — pre-trained weights via DCP, otherwise random init. 

205 if init_device == "meta": 

206 

207 with init_empty_weights(): 

208 self.model = self.spec.build_model_fn(self.args) 

209 logger.info_rank0( 

210 "Model built on meta device (no memory allocated): %s", 

211 type(self.model).__name__, 

212 ) 

213 else: 

214 self.model = self.spec.build_model_fn(self.args) 

215 logger.info_rank0("Model built on %s: %s", init_device, type(self.model).__name__) 

216 

217 # Cross-check parallel degrees against the actual model hyperparams 

218 # (heads%tp, kv_heads%tp, num_experts%ep, seq_len%(cp*tp)). 

219 # Fails fast here instead of crashing inside parallelize_module. 

220 seq_len = getattr(self.args.data, 'max_seq_len', None) 

221 self.parallel_dims.validate_against_model(self.model, seq_len=seq_len) 

222 

223 def _freeze_model(self): 

224 """Step 3: Freeze specified modules (optional).""" 

225 freeze_modules = getattr(self.args.model, 'freeze_modules', None) 

226 if not freeze_modules: 

227 return 

228 for name, param in self.model.named_parameters(): 

229 if any(pattern in name for pattern in freeze_modules): 

230 param.requires_grad_(False) 

231 

232 def _build_model_assets(self): 

233 """Step 4: Build tokenizer, processor, chat_template. 

234 

235 Default: no-op. LLMTrainer overrides to build tokenizer + chat_template. 

236 VLMTrainer overrides to build processor. 

237 """ 

238 self.tokenizer = None 

239 self.processor = None 

240 

241 def _build_data_transform(self): 

242 """Step 5: Build data preprocessing transform. 

243 

244 Default: identity transform. LLMTrainer overrides for tokenization. 

245 """ 

246 self.data_transform = None 

247 

248 def _build_dataset(self): 

249 """Step 6: Build training dataset. 

250 

251 Supports: 

252 - ``data.type = "dummy"``: random tokens for validation 

253 - ``data.type = "hf_datasets"``: HuggingFace datasets 

254 - ``data.type = "megatron_indexed"``: Megatron .bin/.idx format 

255 

256 Subclass can override for custom dataset logic. 

257 """ 

258 

259 data_type = getattr(self.args.data, 'type', 'dummy') 

260 seq_len = getattr(self.args.data, 'max_seq_len', 2048) 

261 

262 if data_type == "dummy": 

263 vocab_size = getattr(self.model, 'config', None) 

264 vocab_size = vocab_size.vocab_size if vocab_size else 32000 

265 total_samples = self.state.max_steps * getattr(self.args.train, 'global_batch_size', 8) 

266 

267 class DummyDataset(Dataset): 

268 """Deterministic random token dataset for FSDP validation. 

269 

270 Each sample's content is fixed by its index (seeded by 

271 ``base_seed + idx``), so the same index always produces the 

272 same tokens regardless of access order or DP configuration. 

273 """ 

274 def __init__(self, num_samples, seq_length, vocab, base_seed=42): 

275 self.num_samples = num_samples 

276 self.seq_length = seq_length 

277 self.vocab = vocab 

278 self.base_seed = base_seed 

279 def __len__(self): 

280 return self.num_samples 

281 def __getitem__(self, idx): 

282 g = torch.Generator().manual_seed(self.base_seed + idx) 

283 input_ids = torch.randint( 

284 0, self.vocab, (self.seq_length,), generator=g, 

285 ) 

286 return {"input_ids": input_ids, "labels": input_ids.clone()} 

287 

288 self.train_dataset = DummyDataset(total_samples, seq_len, vocab_size) 

289 logger.info_rank0("Dummy dataset created: %d samples, seq_len=%d", total_samples, seq_len) 

290 

291 elif data_type == "hf_datasets": 

292 from datasets import load_dataset # pylint: disable=C0415 # optional dep 

293 train_path = self.args.data.train_path 

294 streaming = getattr(self.args.data, 'streaming', False) 

295 if streaming: 

296 # ``DistributedSampler`` requires ``__len__``, which 

297 # ``IterableDataset`` lacks; reject loudly until a 

298 # sampler-less streaming path is wired. 

299 raise NotImplementedError( 

300 "data.streaming=True not yet wired. The current " 

301 "_build_dataloader uses DistributedSampler which requires " 

302 "len(dataset); IterableDataset has no __len__. " 

303 "Use data.streaming=False (map-style) for now, or " 

304 "subclass _build_dataset + _build_dataloader to emit an " 

305 "IterableDataset that self-shards via dp_rank/dp_size." 

306 ) 

307 ds = load_dataset(train_path, split="train", streaming=False) 

308 if self.data_transform: 

309 ds = ds.map( 

310 self.data_transform, remove_columns=ds.column_names, 

311 ) 

312 self.train_dataset = ds 

313 logger.info_rank0("HF dataset loaded: %s (map-style)", train_path) 

314 

315 else: 

316 raise ValueError(f"Unknown data type: {data_type}. Supported: dummy, hf_datasets") 

317 

318 def _build_collate_fn(self): 

319 """Step 7: Build data collator. 

320 

321 Default: pads input_ids and labels to max length in the batch. 

322 """ 

323 

324 def _default_collate(batch): 

325 """Simple padding collator.""" 

326 max_len = max(item["input_ids"].size(0) for item in batch) 

327 input_ids_list = [] 

328 labels_list = [] 

329 for item in batch: 

330 pad_len = max_len - item["input_ids"].size(0) 

331 input_ids_list.append( 

332 torch.nn.functional.pad(item["input_ids"], (0, pad_len), value=0) 

333 ) 

334 labels_list.append( 

335 torch.nn.functional.pad(item["labels"], (0, pad_len), value=-100) 

336 ) 

337 return { 

338 "input_ids": torch.stack(input_ids_list), 

339 "labels": torch.stack(labels_list), 

340 } 

341 

342 self.collate_fn = _default_collate 

343 

344 def _build_dataloader(self): 

345 """Step 8: Build distributed stateful dataloader. 

346 

347 Uses ``torchdata.stateful_dataloader.StatefulDataLoader`` so that 

348 iterator position is checkpointable — enabling exact resume after 

349 restart (matching ). 

350 

351 Each ``next()`` call yields a list of micro-batches (for gradient 

352 accumulation). 

353 """ 

354 from torchdata.stateful_dataloader import StatefulDataLoader # pylint: disable=C0415 # optional dep 

355 

356 micro_bs = getattr(self.args.train, 'micro_batch_size', 1) 

357 

358 # Sampler uses DP rank/size — TP/CP/PP/EP peers share data. 

359 dp_size = self.parallel_dims.dp_size 

360 non_dp = self.parallel_dims.non_dp_size 

361 global_rank = platform.get_rank() 

362 dp_rank = global_rank // non_dp if non_dp > 1 else global_rank 

363 

364 shuffle = getattr(self.args.data, "shuffle", True) 

365 sampler_seed = getattr(self.args.train, 'seed', 0) 

366 

367 self.sampler = DistributedSampler( 

368 self.train_dataset, 

369 num_replicas=dp_size, 

370 rank=dp_rank, 

371 shuffle=shuffle, 

372 seed=sampler_seed, 

373 drop_last=True, 

374 ) 

375 

376 # StatefulDataLoader supports state_dict() / load_state_dict() 

377 # for checkpoint resume (torchdata API, used by + ). 

378 num_workers = getattr(self.args.data, 'num_workers', 0) 

379 prefetch_factor = getattr(self.args.data, 'prefetch_factor', None) 

380 pin_memory = getattr(self.args.data, 'pin_memory', True) 

381 loader_kwargs = { 

382 "batch_size": micro_bs, 

383 "sampler": self.sampler, 

384 "collate_fn": self.collate_fn, 

385 "num_workers": num_workers, 

386 "pin_memory": pin_memory, 

387 "drop_last": True, 

388 } 

389 # prefetch_factor is only accepted when num_workers > 0 

390 if num_workers > 0 and prefetch_factor is not None: 

391 loader_kwargs["prefetch_factor"] = prefetch_factor 

392 self.train_dataloader = StatefulDataLoader( 

393 self.train_dataset, **loader_kwargs, 

394 ) 

395 

396 # Use dp_size (not world_size) — TP/CP/PP ranks share data, not split it. 

397 self._grad_accum = max( 

398 getattr(self.args.train, 'global_batch_size', micro_bs) // (micro_bs * dp_size), 

399 1, 

400 ) 

401 

402 logger.info_rank0( 

403 "Dataloader built: micro_bs=%d, grad_accum=%d, dataset_size=%d", 

404 micro_bs, self._grad_accum, len(self.train_dataset), 

405 ) 

406 

407 def _build_parallelized_model(self): 

408 """Step 9: Apply parallel strategies to the model. 

409 

410 Each model owns its full parallelize pipeline in 

411 ``models/<name>/parallelize.py`` (convention) and 

412 registers it via ``ModelSpec.parallelize_fn``. There is no shared 

413 "default" template — model-specific TP/EP/CP/AC/FSDP/Prefetch 

414 composition lives next to the model that needs it. 

415 """ 

416 if self.spec.parallelize_fn is None: 

417 raise ValueError( 

418 f"Model '{self.spec.name}' has no ``parallelize_fn`` registered " 

419 f"on its ModelSpec. Each model must own its parallelize " 

420 f"pipeline in models/<name>/parallelize.py." 

421 ) 

422 self.model = self.spec.parallelize_fn(self.model, self.mesh, self.args) 

423 self._post_parallelize() 

424 

425 def _post_parallelize(self): 

426 """Common steps after parallelization (materialize weights + train mode). 

427 

428 Order when ``init_device == "meta"`` and ``weights_path`` is set: 

429 

430 1. Run ``_materialize_and_init_shards`` first — this calls 

431 ``model.to_empty(device=...)`` + kaiming / zero init for every 

432 parameter. That is the **baseline** state so no param stays on 

433 meta (which would trip ``HSDPState._validate_no_meta_params``). 

434 2. Then ``_load_weights`` copies the upstream checkpoint on top. 

435 Every key that matches overwrites the random init; anything 

436 missing in the checkpoint stays with its kaiming / zero init. 

437 

438 This pattern handles partial checkpoints cleanly — e.g. hyper's 

439 current Qwen3-VL-MoE model lacks ``q_norm`` / ``k_norm`` / 

440 ``pos_embed`` / ``deepstack_merger_list``; those stay random 

441 while all other weights come from the pretrained checkpoint. 

442 """ 

443 init_device = getattr(self.args.train, 'init_device', 'meta') 

444 weights_path = getattr(self.args.model, 'weights_path', None) 

445 if init_device == "meta": 

446 # Always materialize first (random init baseline) so no param 

447 # stays on meta — then overlay the checkpoint. 

448 self._materialize_and_init_shards() 

449 # Re-tie weights — ``to_empty`` gives every nn.Parameter fresh 

450 # storage so ``__init__``-time ties are broken. 

451 if hasattr(self.model, "tie_weights"): 

452 self.model.tie_weights() 

453 if weights_path: 

454 self._load_weights(weights_path) 

455 elif weights_path: 

456 self._load_weights(weights_path) 

457 # Mixed-precision storage policy: trainable params keep an fp32 

458 # master; frozen params keep their loaded low-precision dtype so the 

459 # forward stays uniform-bf16 within their FSDP unit. 

460 self._maybe_downcast_frozen_params() 

461 self._maybe_upcast_trainable_params() 

462 self.model.train() 

463 

464 def _maybe_downcast_frozen_params(self) -> None: 

465 """Maybe downcast frozen params (internal).""" 

466 freeze_modules = getattr(self.args.model, 'freeze_modules', None) 

467 if not freeze_modules: 

468 return 

469 mp_cfg = getattr(self.args.train, 'mixed_precision', None) 

470 if mp_cfg is None or not getattr(mp_cfg, 'enabled', False): 

471 return 

472 

473 target_dtype = { 

474 'bfloat16': torch.bfloat16, 

475 'bf16': torch.bfloat16, 

476 'float16': torch.float16, 

477 'fp16': torch.float16, 

478 }.get(getattr(mp_cfg, 'param_dtype', 'bfloat16')) 

479 if target_dtype is None: 

480 return 

481 n_cast = 0 

482 for name, param in self.model.named_parameters(): 

483 if not any(pat in name for pat in freeze_modules): 

484 continue 

485 if param.requires_grad: 

486 continue 

487 local = param.data 

488 if hasattr(local, 'to_local'): 

489 local = local.to_local() 

490 if local.dtype == target_dtype: 

491 continue 

492 new_local = local.to(target_dtype) 

493 # DTensor: rebuild the global view via from_local with same placements. 

494 if hasattr(param.data, 'to_local'): 

495 

496 if isinstance(param.data, DTensor): 

497 param.data = DTensor.from_local( 

498 new_local, 

499 device_mesh=param.data.device_mesh, 

500 placements=param.data.placements, 

501 ) 

502 else: 

503 param.data = new_local 

504 else: 

505 param.data = new_local 

506 n_cast += 1 

507 logger.info_rank0( 

508 "Post-load: cast %d frozen params to %s", 

509 n_cast, target_dtype, 

510 ) 

511 

512 def _maybe_upcast_trainable_params(self) -> None: 

513 """Upcast trainable params to ``float32``. 

514 

515 Implements the standard mixed-precision pattern (fp32 master weight 

516 + low-precision forward): trainable params loaded from a bf16 

517 checkpoint would otherwise stay in bf16, and Adam moment estimates 

518 accumulating at bf16's 7-bit mantissa diverge noticeably after only 

519 a handful of optimizer steps. 

520 

521 Runs AFTER ``_maybe_downcast_frozen_params`` so frozen params keep 

522 their bf16 storage and only trainable params get the fp32 master. 

523 """ 

524 mp_cfg = getattr(self.args.train, 'mixed_precision', None) 

525 if mp_cfg is None or not getattr(mp_cfg, 'enabled', False): 

526 return 

527 

528 n_cast = 0 

529 for _, param in self.model.named_parameters(): 

530 if not param.requires_grad: 

531 continue 

532 local = param.data 

533 if hasattr(local, 'to_local'): 

534 local = local.to_local() 

535 if local.dtype == torch.float32: 

536 continue 

537 new_local = local.to(torch.float32) 

538 if hasattr(param.data, 'to_local'): 

539 

540 if isinstance(param.data, DTensor): 

541 param.data = DTensor.from_local( 

542 new_local, 

543 device_mesh=param.data.device_mesh, 

544 placements=param.data.placements, 

545 ) 

546 else: 

547 param.data = new_local 

548 else: 

549 param.data = new_local 

550 n_cast += 1 

551 logger.info_rank0( 

552 "Post-load: upcast %d trainable params to float32", n_cast, 

553 ) 

554 

555 def _build_optimizer(self): 

556 """Step 10: Build optimizer. Must be called AFTER ``_build_parallelized_model``. 

557 

558 After FSDP, parameters are DTensor shards — optimizer operates on local shards. 

559 Optimizer must be created after ``fully_shard``. 

560 """ 

561 lr = getattr(self.args.train.optimizer, 'lr', 1e-4) 

562 weight_decay = getattr(self.args.train.optimizer, 'weight_decay', 0.01) 

563 

564 # bias / LayerNorm / RMSNorm go to no-decay; grouping matters even 

565 # at wd=0 — foreach Adam reduction order differs per group on NPU. 

566 decay_keywords = ("bias", "layernorm", "norm", "rmsnorm") 

567 

568 def _is_no_decay(name: str) -> bool: 

569 lname = name.lower() 

570 return any(kw in lname for kw in decay_keywords) 

571 

572 decay_params = [] 

573 no_decay_params = [] 

574 seen_ids = set() 

575 for n, p in self.model.named_parameters(): 

576 if not p.requires_grad: 

577 continue 

578 # Dedup tied params (same nn.Parameter shared across modules). 

579 if id(p) in seen_ids: 

580 continue 

581 seen_ids.add(id(p)) 

582 if _is_no_decay(n): 

583 no_decay_params.append(p) 

584 else: 

585 decay_params.append(p) 

586 

587 param_groups = [ 

588 {"params": decay_params, "weight_decay": weight_decay}, 

589 {"params": no_decay_params, "weight_decay": 0.0}, 

590 ] 

591 adam_eps = getattr(self.args.train.optimizer, 'eps', 1e-8) 

592 adam_betas = getattr(self.args.train.optimizer, 'betas', (0.9, 0.999)) 

593 adam_foreach = getattr(self.args.train.optimizer, 'foreach', None) 

594 self.optimizer = torch.optim.AdamW( 

595 param_groups, 

596 lr=lr, 

597 betas=adam_betas, 

598 eps=adam_eps, 

599 foreach=adam_foreach, 

600 ) 

601 logger.info_rank0( 

602 "Optimizer: AdamW lr=%.2e wd=%.3g decay_params=%d no_decay_params=%d", 

603 lr, weight_decay, len(decay_params), len(no_decay_params), 

604 ) 

605 

606 def _build_lr_scheduler(self): 

607 """Step 11: Build learning rate scheduler. 

608 

609 Supports cosine decay with warmup. Falls back to constant LR if 

610 warmup_ratio is 0 and decay_style is 'constant'. 

611 """ 

612 

613 total_steps = self.state.max_steps 

614 warmup_ratio = getattr(self.args.train.optimizer, 'lr_warmup_ratio', 0.0) 

615 # ``ceil`` matches the standard warmup convention so a fractional 

616 # ``warmup_ratio * max_steps`` rounds up to the next full step. 

617 warmup_steps = math.ceil(total_steps * warmup_ratio) 

618 decay_style = getattr(self.args.train.optimizer, 'lr_decay_style', 'cosine') 

619 lr_min = getattr(self.args.train.optimizer, 'lr_min', 0.0) 

620 lr_max = getattr(self.args.train.optimizer, 'lr', 1e-4) 

621 

622 def _lr_lambda(current_step): 

623 if current_step < warmup_steps: 

624 return float(current_step) / float(max(1, warmup_steps)) 

625 if decay_style == 'constant': 

626 return 1.0 

627 # Cosine decay 

628 progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) 

629 cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 

630 min_ratio = lr_min / lr_max if lr_max > 0 else 0.0 

631 return min_ratio + (1.0 - min_ratio) * cosine_decay 

632 

633 self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, _lr_lambda) 

634 logger.info_rank0( 

635 "LR scheduler: %s, warmup_steps=%d/%d, lr=%.2e→%.2e", 

636 decay_style, warmup_steps, total_steps, lr_max, lr_min, 

637 ) 

638 

639 def _build_training_context(self): 

640 """Step 12: Build forward/backward context managers. 

641 

642 Delegates AMP context to ``_make_amp_context`` — the single place that 

643 knows about backend-specific autocast. MindSpore subclass should 

644 override ``_make_amp_context`` only. 

645 """ 

646 

647 

648 mp_cfg = getattr(self.args.train, 'mixed_precision', None) 

649 # FSDP2 mp_policy already runs forward/backward in low precision; 

650 # stacking ``torch.autocast`` on top would re-promote LayerNorm / 

651 # softmax / log_softmax to fp32 and break the uniform-bf16 forward. 

652 self.model_fwd_context = nullcontext() 

653 self.grad_scaler = None 

654 if mp_cfg and getattr(mp_cfg, 'enabled', False): 

655 param_dtype_str = getattr(mp_cfg, 'param_dtype', 'bfloat16') 

656 logger.info_rank0( 

657 "Mixed precision via FSDP2 mp_policy: dtype=%s on %s " 

658 "(autocast disabled — pure low-precision forward under mp_policy)", 

659 param_dtype_str, platform.device_type(), 

660 ) 

661 

662 self.model_bwd_context = nullcontext() 

663 

664 def _make_amp_context(self, param_dtype_str: str): 

665 """Build the AMP forward context. Backend-specific. 

666 

667 This is the SOLE method in the trainer that touches the backend AMP 

668 API directly. Override this single method to add MindSpore support 

669 (``ms.amp.auto_mixed_precision``). Default uses ``torch.autocast``. 

670 """ 

671 dtype = getattr(torch, param_dtype_str, torch.bfloat16) 

672 return torch.autocast(platform.device_type(), dtype=dtype) 

673 

674 def _init_callbacks(self): 

675 """Step 13: Initialize callbacks (explicit mode). 

676 

677 Each callback is a named field — engineer sees all callbacks and their 

678 order in ``on_step_end`` at a glance. Add/remove/reorder = change one line. 

679 """ 

680 self.logging_callback = LoggingCallback(self) 

681 self.checkpoint_callback = CheckpointCallback(self) 

682 self.hf_export_callback = SafetensorsExportCallback(self) 

683 self.eval_callback = EvalCallback(self) 

684 self.profiler_callback = ProfilerCallback(self) 

685 self.wandb_callback = WandbCallback(self) 

686 self.tensorboard_callback = TensorBoardCallback(self) 

687 self.progress_callback = ProgressCallback(self) 

688 self.moe_monitor_callback = MoEMonitorCallback(self) 

689 # Health + operability (no-ops unless enabled in cfg.train.debug / .memory_monitor). 

690 self.gradient_health_callback = GradientHealthCallback(self) 

691 self.memory_monitor_callback = MemoryMonitorCallback(self) 

692 self.gc_callback = GCCallback(self) 

693 # ``user_callbacks`` lets external code append extra Callback instances 

694 # (e.g. domain-specific monitors) without editing this method. They get 

695 # the same lifecycle dispatch as built-ins. 

696 self.user_callbacks: list = [] 

697 logger.info_rank0( 

698 "Callbacks initialized: logging, checkpoint, hf_export, eval, " 

699 "profiler, wandb, tensorboard, progress, moe_monitor, " 

700 "gradient_health, memory_monitor, gc" 

701 ) 

702 

703 # ------------------------------------------------------------------ 

704 # Public API: external callback registration 

705 # ------------------------------------------------------------------ 

706 

707 def add_callback(self, callback) -> None: 

708 """Register an extra ``Callback`` to receive every lifecycle event. 

709 

710 Use this to plug domain-specific monitors (custom metric sinks, 

711 in-house experiment trackers, RL reward loggers) without editing 

712 the trainer. Built-in callbacks always run first; user callbacks 

713 run in registration order so a later user callback can read state 

714 the earlier ones updated. 

715 """ 

716 self.user_callbacks.append(callback) 

717 logger.info_rank0( 

718 "User callback registered: %s", type(callback).__name__, 

719 ) 

720 

721 # ------------------------------------------------------------------ 

722 # Callback dispatch (explicit mode) 

723 # ------------------------------------------------------------------ 

724 

725 def _builtin_callbacks(self) -> list: 

726 """Return built-in callbacks in fixed dispatch order. 

727 

728 Centralised so every dispatcher iterates the same list — adding a 

729 callback only needs an entry here plus a named field in 

730 ``_init_callbacks`` (no per-event copy/paste). 

731 """ 

732 return [ 

733 self.logging_callback, 

734 self.eval_callback, 

735 self.profiler_callback, 

736 self.wandb_callback, 

737 self.tensorboard_callback, 

738 self.progress_callback, 

739 self.checkpoint_callback, 

740 self.hf_export_callback, 

741 self.moe_monitor_callback, 

742 self.gradient_health_callback, 

743 self.memory_monitor_callback, 

744 self.gc_callback, 

745 ] 

746 

747 def _all_callbacks(self) -> list: 

748 """Built-in callbacks followed by user-registered ones.""" 

749 return self._builtin_callbacks() + list(self.user_callbacks) 

750 

751 def on_init_end(self): 

752 """Dispatch one-shot ``on_init_end`` after every ``_build_*`` ran. 

753 

754 Fired by the subclass at the end of its own ``__init__`` (see 

755 ``LLMTrainer.__init__``); ``BaseTrainer.train()`` does NOT call it 

756 because BaseTrainer instances are sometimes wrapped (composition 

757 pattern) and the wrapper owns the init lifecycle. 

758 """ 

759 for cb in self._all_callbacks(): 

760 cb.on_init_end(self.state) 

761 

762 def on_train_begin(self): 

763 """Dispatch on_train_begin to all callbacks.""" 

764 # Memory monitor first so it captures the truly-initial peak. 

765 self.memory_monitor_callback.on_train_begin(self.state) 

766 self.moe_monitor_callback.on_train_begin(self.state) 

767 self.profiler_callback.on_train_begin(self.state) 

768 self.wandb_callback.on_train_begin(self.state) 

769 self.tensorboard_callback.on_train_begin(self.state) 

770 self.progress_callback.on_train_begin(self.state) 

771 # Checkpoint runs LAST so resume sees an already-armed TB writer 

772 # (it'll record the load event via dispatch_load_event). 

773 self.checkpoint_callback.on_train_begin(self.state) 

774 for cb in self.user_callbacks: 

775 cb.on_train_begin(self.state) 

776 

777 def on_train_end(self): 

778 """Dispatch on_train_end to all callbacks.""" 

779 self.checkpoint_callback.on_train_end(self.state) 

780 self.hf_export_callback.on_train_end(self.state) 

781 self.progress_callback.on_train_end(self.state) 

782 self.tensorboard_callback.on_train_end(self.state) 

783 self.wandb_callback.on_train_end(self.state) 

784 self.profiler_callback.on_train_end(self.state) 

785 for cb in self.user_callbacks: 

786 cb.on_train_end(self.state) 

787 

788 def on_step_begin(self): 

789 """Dispatch on_step_begin to all callbacks.""" 

790 self.logging_callback.on_step_begin(self.state) 

791 for cb in self.user_callbacks: 

792 cb.on_step_begin(self.state) 

793 

794 def on_step_end(self, loss=None, grad_norm=None): 

795 """Dispatch on_step_end to all callbacks (built-ins + user).""" 

796 for cb in self._all_callbacks(): 

797 cb.on_step_end(self.state, loss=loss, grad_norm=grad_norm) 

798 

799 def on_substep_end(self): 

800 """Dispatch on_substep_end (after each micro-batch forward/backward).""" 

801 self.moe_monitor_callback.on_substep_end(self.state) 

802 for cb in self.user_callbacks: 

803 cb.on_substep_end(self.state) 

804 

805 def on_pre_optimizer_step(self, grad_norm=None): 

806 """Dispatch on_pre_optimizer_step (after grad clip, before optimizer.step).""" 

807 # Health check runs FIRST so a NaN aborts before the logger misleads. 

808 self.gradient_health_callback.on_pre_optimizer_step( 

809 self.state, grad_norm=grad_norm, 

810 ) 

811 self.logging_callback.on_pre_optimizer_step(self.state, grad_norm=grad_norm) 

812 self.wandb_callback.on_pre_optimizer_step(self.state, grad_norm=grad_norm) 

813 self.tensorboard_callback.on_pre_optimizer_step(self.state, grad_norm=grad_norm) 

814 for cb in self.user_callbacks: 

815 cb.on_pre_optimizer_step(self.state, grad_norm=grad_norm) 

816 

817 def on_epoch_begin(self): 

818 """Dispatch on_epoch_begin.""" 

819 for cb in self._all_callbacks(): 

820 cb.on_epoch_begin(self.state) 

821 

822 def on_epoch_end(self): 

823 """Dispatch on_epoch_end.""" 

824 for cb in self._all_callbacks(): 

825 cb.on_epoch_end(self.state) 

826 

827 # ------------------------------------------------------------------ 

828 # Event fan-out (LoggingCallback / CheckpointCallback emit these) 

829 # ------------------------------------------------------------------ 

830 

831 def dispatch_log_event(self, metrics: dict) -> None: 

832 """Forward a metrics record to every callback's ``on_log``. 

833 

834 ``LoggingCallback`` calls this so TensorBoard / W&B / external sinks 

835 log the SAME numbers — single source of truth, no duplicate work. 

836 """ 

837 for cb in self._all_callbacks(): 

838 cb.on_log(self.state, metrics=metrics) 

839 

840 def dispatch_save_event(self, checkpoint_dir: str) -> None: 

841 """Forward a ckpt-save event to every callback's ``on_save``.""" 

842 for cb in self._all_callbacks(): 

843 cb.on_save(self.state, checkpoint_dir=checkpoint_dir) 

844 

845 def dispatch_load_event(self, checkpoint_dir: str) -> None: 

846 """Forward a ckpt-load event to every callback's ``on_load``.""" 

847 for cb in self._all_callbacks(): 

848 cb.on_load(self.state, checkpoint_dir=checkpoint_dir) 

849 

850 def dispatch_evaluate_event(self, metrics: dict = None) -> None: 

851 """Forward an eval-pass-complete event to every callback's ``on_evaluate``.""" 

852 for cb in self._all_callbacks(): 

853 cb.on_evaluate(self.state, metrics=metrics) 

854 

855 # ------------------------------------------------------------------ 

856 # Training core 

857 # ------------------------------------------------------------------ 

858 

859 def forward_backward_step( 

860 self, 

861 micro_batch: Dict[str, Any], 

862 micro_batch_tokens: int, 

863 global_tokens: int, 

864 num_micro: int = 1, 

865 ): 

866 """Run forward + backward for one micro-batch. 

867 

868 Uses global token normalisation: each micro-batch's 

869 loss is scaled by ``micro_tokens / global_tokens`` so that every token 

870 across all ranks and all micro-batches contributes equally to the 

871 gradient, regardless of DP size or grad_accum. 

872 

873 Args: 

874 micro_batch: Dict of input tensors. 

875 micro_batch_tokens: Non-padding token count for this micro-batch. 

876 global_tokens: Total non-padding tokens across **all** ranks and 

877 **all** micro-batches (computed via all-reduce). 

878 

879 Returns: 

880 Tuple of (raw_loss_scalar, micro_batch_tokens) for logging. 

881 """ 

882 # Move tensors to device. Multimodal batches may contain nested 

883 # lists/tuples/dicts, so keep this recursive instead of assuming a 

884 # flat LM-only batch. 

885 def _to_device(value): 

886 if hasattr(value, "to"): 

887 return value.to(self.device, non_blocking=True) 

888 if isinstance(value, dict): 

889 return {k: _to_device(v) for k, v in value.items()} 

890 if isinstance(value, list): 

891 return [_to_device(v) for v in value] 

892 if isinstance(value, tuple): 

893 return tuple(_to_device(v) for v in value) 

894 return value 

895 

896 micro_batch = {k: _to_device(v) for k, v in micro_batch.items()} 

897 

898 # Forward (with training context for activation offload) 

899 with self.model_fwd_context: 

900 outputs = self.model(**micro_batch, use_cache=False) 

901 loss = outputs["loss"] if isinstance(outputs, dict) else outputs.loss 

902 

903 # TP scenario: loss may be Partial DTensor — reduce before backward 

904 if hasattr(loss, 'is_partial') and loss.is_partial(): 

905 loss = loss.reduce_partial() 

906 

907 # Keep raw loss value for logging before scaling 

908 raw_loss = loss.detach() 

909 

910 # token_weighted: ``CE_mean * (micro_tokens / global_tokens) * dp_size`` 

911 # — DP-size and grad_accum invariant after FSDP's ``AVG`` reduction. 

912 dp_size = self.parallel_dims.dp_size 

913 agg = getattr(self.args.train.optimizer, 'loss_aggregation', 'token_weighted') 

914 if agg == 'rank_average': 

915 # Per-rank CE mean; FSDP2 ``ReduceOp.AVG`` handles cross-rank 

916 # averaging. Dividing by ``num_micro`` averages grad-accum so 

917 # 1-card grad_accum=N tracks the FSDP path's effective grad. 

918 scaled_loss = loss / num_micro if num_micro > 1 else loss 

919 else: 

920 scaled_loss = loss * (micro_batch_tokens / global_tokens) * dp_size 

921 

922 # Backward (with training context) 

923 with self.model_bwd_context: 

924 scaled_loss.backward() 

925 

926 return raw_loss, micro_batch_tokens 

927 

928 def train_step(self, data_iterator): 

929 """Execute one training step with gradient accumulation. 

930 

931 Precision-aligned across different DP configurations by: 

932 1. All-reducing global token count before loss scaling () 

933 2. Syncing gradients only on the last micro-batch () 

934 3. All-reducing loss weighted by token count for reporting 

935 

936 Args: 

937 data_iterator: Iterator yielding lists of micro-batch dicts. 

938 """ 

939 # Pull data first; ``StopIteration`` propagates without bumping the 

940 # step counter so checkpoint dirs / log indices match the steps that 

941 # actually trained. 

942 micro_batches = next(data_iterator) 

943 self.state.global_step += 1 

944 num_micro = len(micro_batches) 

945 

946 # ---- Phase 1: count global tokens ( style) ---- 

947 # All-reduce BEFORE forward so every rank uses the same denominator. 

948 

949 token_counts = [count_loss_token(mb) for mb in micro_batches] 

950 local_tokens = sum(token_counts) 

951 if local_tokens == 0: 

952 local_tokens = 1 

953 global_tokens = local_tokens 

954 if platform.get_world_size() > 1: 

955 gt = platform.full((1,), local_tokens).to(self.device) 

956 platform.all_reduce(gt, self._dp_group_info) 

957 global_tokens = max(int(gt.item()), 1) 

958 # Expose for callbacks (e.g. LoggingCallback throughput). 

959 self._last_global_tokens = global_tokens 

960 

961 # ---- Phase 2: forward / backward with per-micro-batch sync control ---- 

962 total_loss_sum = 0.0 # weighted: loss * micro_tokens 

963 total_loss_arith_sum = 0.0 # arithmetic sum across micro-batches 

964 total_tokens_local = 0 

965 for i, mb in enumerate(micro_batches): 

966 is_last = i == num_micro - 1 

967 

968 # Gradient sync: only on the last micro-batch ( pattern). 

969 # Before last: accumulate gradients locally, no communication. 

970 if isinstance(self.model, HSDPModule): 

971 self.model.set_requires_gradient_sync(is_last) 

972 self.model.set_is_last_backward(is_last) 

973 

974 # FSDP reshard optimization for gradient accumulation 

975 self._maybe_toggle_reshard(i, num_micro) 

976 

977 raw_loss, mb_tokens = self.forward_backward_step( 

978 mb, token_counts[i], global_tokens, num_micro=num_micro, 

979 ) 

980 total_loss_sum += raw_loss.item() * mb_tokens 

981 total_loss_arith_sum += raw_loss.item() 

982 total_tokens_local += mb_tokens 

983 self.on_substep_end() 

984 

985 # Wait for async gradient reduce 

986 # 

987 hsdp_sync_stream() 

988 

989 # Gradient clipping — DTensor-aware, returns plain Tensor 

990 max_grad_norm = getattr(self.args.train.optimizer, 'max_grad_norm', 1.0) 

991 clip_fn = self.spec.clip_grad_fn or clip_grad_norm_ 

992 grad_norm = clip_fn(self.model.parameters(), max_grad_norm) 

993 self.on_pre_optimizer_step(grad_norm=grad_norm.item()) 

994 

995 # Optimizer step — must be inside SkipDTensorDispatch 

996 with SkipDTensorDispatch(): 

997 self.optimizer.step() 

998 

999 if self.lr_scheduler is not None: 

1000 self.lr_scheduler.step() 

1001 

1002 self.optimizer.zero_grad() 

1003 

1004 # ---- Phase 3: all-reduce loss for reporting ---- 

1005 agg = getattr(self.args.train.optimizer, 'loss_aggregation', 'token_weighted') 

1006 if agg == 'token_weighted': 

1007 # Global token-weighted: sum local ce_sum across DP, divide by 

1008 # global token count. Equivalent to mean-loss-per-token under 

1009 # post-DP gradient mean. 

1010 if platform.get_world_size() > 1: 

1011 ls = platform.full((1,), total_loss_sum).to(self.device) 

1012 platform.all_reduce(ls, self._dp_group_info) 

1013 avg_loss = ls.item() / max(global_tokens, 1) 

1014 else: 

1015 avg_loss = total_loss_sum / max(total_tokens_local, 1) 

1016 else: 

1017 # rank_average: each rank averages its micro-batches, then 

1018 # ``all_reduce`` averages across DP ranks. Divisor is the DP 

1019 # group size (NOT global world_size); under TP/EP/PP/CP the 

1020 # all-reduce only sums DP ranks, so dividing by world_size 

1021 # would systematically under-report the loss. 

1022 local_mean = total_loss_arith_sum / max(num_micro, 1) 

1023 dp_size = self._dp_group_info.rank_size 

1024 if dp_size > 1: 

1025 ls = platform.full((1,), local_mean).to(self.device) 

1026 platform.all_reduce(ls, self._dp_group_info) 

1027 avg_loss = ls.item() / dp_size 

1028 else: 

1029 avg_loss = local_mean 

1030 

1031 return {"loss": avg_loss, "grad_norm": grad_norm.item()} 

1032 

1033 def train(self): 

1034 """Main training loop: epoch → step → micro-batch. 

1035 

1036 Dispatches callbacks at each lifecycle point (explicit mode). 

1037 on_train_begin is called first — CheckpointCallback uses it to restore 

1038 state.global_step from a saved checkpoint, so the loop below will 

1039 correctly skip already-completed steps. 

1040 """ 

1041 logger.info_rank0( 

1042 "Training starts: max_steps=%d, epochs=%d", 

1043 self.state.max_steps, 

1044 getattr(self.args.train, 'num_train_epochs', 1), 

1045 ) 

1046 # on_train_begin runs checkpoint resume — state.global_step may be 

1047 # updated to the resumed step before the loop starts. 

1048 self.on_train_begin() 

1049 num_epochs = getattr(self.args.train, 'num_train_epochs', 1) 

1050 

1051 if self.state.global_step > 0: 

1052 logger.info_rank0( 

1053 "Resuming training from step %d", self.state.global_step, 

1054 ) 

1055 

1056 for epoch in range(num_epochs): 

1057 if self.state.global_step >= self.state.max_steps: 

1058 break 

1059 self.state.epoch = epoch 

1060 if hasattr(self, 'sampler'): 

1061 self.sampler.set_epoch(epoch) 

1062 self.on_epoch_begin() 

1063 

1064 # Build micro-batch iterator from the stateful dataloader. 

1065 # StatefulDataLoader tracks iterator position internally, 

1066 # so after resume it skips already-consumed batches. 

1067 data_iterator = self._make_micro_batch_iterator() 

1068 

1069 # Drive the loop on the live ``global_step`` so total training 

1070 # never exceeds ``max_steps`` regardless of ``num_train_epochs`` 

1071 # or resume offset. 

1072 while self.state.global_step < self.state.max_steps: 

1073 self.on_step_begin() 

1074 try: 

1075 metrics = self.train_step(data_iterator) 

1076 except StopIteration: 

1077 logger.info_rank0("Epoch %d: dataloader exhausted", epoch) 

1078 break 

1079 

1080 self.on_step_end( 

1081 loss=metrics["loss"], 

1082 grad_norm=metrics["grad_norm"], 

1083 ) 

1084 

1085 self.on_epoch_end() 

1086 

1087 self.on_train_end() 

1088 destroy_process_group() 

1089 logger.info_rank0("Training completed") 

1090 

1091 # ------------------------------------------------------------------ 

1092 # Helpers 

1093 # ------------------------------------------------------------------ 

1094 

1095 def _make_micro_batch_iterator(self): 

1096 """Yield lists of micro-batches from the stateful dataloader. 

1097 

1098 Groups ``self._grad_accum`` consecutive batches into a list for 

1099 gradient accumulation. The underlying ``StatefulDataLoader`` tracks 

1100 iteration position, so checkpoint/resume skips consumed batches. 

1101 """ 

1102 batch_buffer = [] 

1103 for batch in self.train_dataloader: 

1104 batch_buffer.append(batch) 

1105 if len(batch_buffer) >= self._grad_accum: 

1106 yield batch_buffer 

1107 batch_buffer = [] 

1108 if batch_buffer: 

1109 yield batch_buffer 

1110 

1111 def _get_layers(self) -> list: 

1112 """Return the repeating layers for FSDP/AC wrapping. 

1113 

1114 Default: ``model.layers`` (standard transformer convention). 

1115 Override in subclass for models with different structure. 

1116 """ 

1117 if hasattr(self.model, 'layers'): 

1118 return list(self.model.layers) 

1119 raise ValueError( 

1120 f"Model {type(self.model).__name__} has no .layers attribute. " 

1121 f"Either add self.layers to the model, or override _get_layers() " 

1122 f"in the Trainer subclass." 

1123 ) 

1124 

1125 def _get_combined_dp_group(self): 

1126 """Return the combined data-parallel ProcessGroup for trainer all-reduce. 

1127 

1128 Prefers the ``"loss"`` flatten alias registered by 

1129 ``ParallelDims.build_mesh`` (folds CP into the DP group when CP is 

1130 active so token-count denominators include CP-sharded contributions). 

1131 Falls back to ``"dp"``, then to the legacy ``dp_shard`` / 

1132 ``dp_replicate`` axes for callers that built a custom mesh. 

1133 """ 

1134 for name in ("loss", "dp", "dp_shard", "dp_replicate"): 

1135 try: 

1136 return self.mesh.get_group(name) 

1137 except (KeyError, ValueError): 

1138 continue 

1139 return self.mesh.get_group() 

1140 

1141 def _build_fsdp_kwargs(self) -> dict: 

1142 """Build kwargs for ``fully_shard`` calls (dense parameters). 

1143 

1144 For expert parameters when EP > 1, use ``_build_expert_fsdp_kwargs``. 

1145 """ 

1146 for name in ("dp_shard", "dp", "dp_replicate"): 

1147 try: 

1148 dp_mesh = self.mesh[name] 

1149 break 

1150 except (KeyError, TypeError): 

1151 continue 

1152 else: 

1153 dp_mesh = self.mesh 

1154 kwargs = {"mesh": dp_mesh} 

1155 

1156 reshard = getattr(self.args.train.accelerator, 'reshard_after_forward', True) 

1157 kwargs["reshard_after_forward"] = reshard 

1158 

1159 return kwargs 

1160 

1161 def _build_expert_fsdp_kwargs(self) -> dict: 

1162 """Build kwargs for ``fully_shard`` calls on expert parameters. 

1163 

1164 When EP > 1, expert parameters are sharded across the EP group 

1165 with a separate mesh dimension. Falls back to dense FSDP kwargs 

1166 if EP is not enabled. 

1167 """ 

1168 if not self.parallel_dims.ep_enabled: 

1169 return self._build_fsdp_kwargs() 

1170 

1171 try: 

1172 ep_mesh = self.mesh["ep"] 

1173 except (KeyError, TypeError): 

1174 logger.warning("EP=%d but no 'ep' dimension in mesh, falling back to dp mesh", 

1175 self.parallel_dims.ep) 

1176 return self._build_fsdp_kwargs() 

1177 

1178 kwargs = {"mesh": ep_mesh} 

1179 reshard = getattr(self.args.train.accelerator, 'reshard_after_forward', True) 

1180 kwargs["reshard_after_forward"] = reshard 

1181 return kwargs 

1182 

1183 def _materialize_and_init_shards(self) -> None: 

1184 """Materialize meta-device parameters/buffers to real device in-place. 

1185 

1186 After ``fully_shard`` on a meta-device model, each rank's parameters 

1187 are meta DTensor shards **and FSDP2 holds internal views into those 

1188 meta storages** (flat_param / unsharded buffer). Replacing the 

1189 ``DTensor._local_tensor`` attribute leaves FSDP's internal views 

1190 pointing at the old meta storage, so the first forward's all-gather 

1191 still hits meta → ``c10d::_allgather_base_`` raises. 

1192 

1193 PyTorch's ``nn.Module.to_empty(device=...)`` is the FSDP2-safe path: 

1194 it walks every parameter/buffer (including DTensor shards) and 

1195 **allocates real device storage in-place via ``torch.empty_like``**, 

1196 preserving every existing view. After ``to_empty``, storage is 

1197 uninitialised — we init on the local shard with kaiming_uniform for 

1198 weights, zero for biases / 1-D / buffers. 

1199 

1200 Reference: PyTorch FSDP2 meta-init docs; 

1201 ``trainer.py`` post-``fully_shard`` init pattern. 

1202 """ 

1203 device_type = platform.device_type() 

1204 # Step 1: meta → real storage, in-place (FSDP-views preserved). 

1205 self.model.to_empty(device=device_type) 

1206 self._materialize_replicate_params(device_type) 

1207 # Step 2: init the local shard of every param (and zero every buffer). 

1208 param_count = self._init_local_shards() 

1209 # Re-derive buffers wiped by ``to_empty`` (e.g. ``inv_freq``); 

1210 # without this RoPE silently returns identity rotation. 

1211 for module in self.model.modules(): 

1212 if hasattr(module, "reset_inv_freq"): 

1213 module.reset_inv_freq() 

1214 # ``to_empty`` strips DTensor; ``lazy_init`` re-wraps shards before 

1215 # ``_load_weights`` / optimizer step see the params (the forward 

1216 # pre-hook does the same later, but the loader needs DTensor first). 

1217 reset_count = self._lazy_init_hsdp_modules() 

1218 logger.info_rank0( 

1219 "Meta → real on %s: to_empty + kaiming/zero init on %d params; " 

1220 "FSDP lazy_init re-wrapped %d modules back to DTensor", 

1221 device_type, param_count, reset_count, 

1222 ) 

1223 

1224 def _iter_hsdp_states(self): 

1225 """Yield the HSDP state attached to every HSDP-wrapped submodule.""" 

1226 for module in self.model.modules(): 

1227 if not isinstance(module, HSDPModule): 

1228 continue 

1229 scheduler = getattr(module, 'hsdp_scheduler', None) 

1230 state = getattr(scheduler, 'hsdp_state', None) if scheduler else None 

1231 if state is None: 

1232 continue 

1233 yield state 

1234 

1235 def _materialize_replicate_params(self, device_type: str) -> None: 

1236 """``to_empty`` skips ``replicate_params`` whose ``_local_tensor`` lives 

1237 outside ``module._parameters`` — materialize them manually. 

1238 """ 

1239 for state in self._iter_hsdp_states(): 

1240 for hsdp_param in getattr(state, 'replicate_params', []) or []: 

1241 local = getattr(hsdp_param.sharded_param, "_local_tensor", None) 

1242 if local is not None and local.is_meta: 

1243 new_local = torch.empty_like(local, device=device_type) 

1244 hsdp_param.sharded_param._local_tensor = new_local # pylint: disable=W0212 

1245 

1246 def _init_local_shards(self) -> int: 

1247 """Init local shard of every param (kaiming for >=2D, zero else); zero buffers.""" 

1248 param_count = 0 

1249 with torch.no_grad(): 

1250 for _, param in self.model.named_parameters(): 

1251 local = param._local_tensor if hasattr(param, '_local_tensor') else param # pylint: disable=W0212 

1252 if local.is_meta: 

1253 continue 

1254 if local.dim() >= 2: 

1255 torch.nn.init.kaiming_uniform_(local) 

1256 else: 

1257 torch.nn.init.zeros_(local) 

1258 param_count += 1 

1259 for _, buf in self.model.named_buffers(): 

1260 if buf is not None: 

1261 buf.zero_() 

1262 return param_count 

1263 

1264 def _lazy_init_hsdp_modules(self) -> int: 

1265 """Re-wrap HSDP shards into DTensor so loader / optimizer see them.""" 

1266 reset_count = 0 

1267 for state in self._iter_hsdp_states(): 

1268 if hasattr(state, 'lazy_init'): 

1269 state.lazy_init() 

1270 reset_count += 1 

1271 return reset_count 

1272 

1273 def _load_weights(self, weights_path: str) -> None: 

1274 """Load pre-trained weights from ``weights_path`` into the (possibly sharded) model. 

1275 

1276 Uses hyper's distributed checkpoint ``load`` API so that each rank only 

1277 reads the shard it owns. Falls back to a plain ``torch.load`` + partial 

1278 ``load_state_dict`` for single-file checkpoints (e.g. safetensors). 

1279 

1280 Args: 

1281 weights_path: Path to a directory containing a distributed checkpoint, 

1282 or a single ``.pt`` / ``.bin`` file. 

1283 """ 

1284 logger.info_rank0("Loading weights from %s", weights_path) 

1285 try: 

1286 if os.path.isdir(weights_path): 

1287 hf_index = os.path.join(weights_path, "model.safetensors.index.json") 

1288 # Delegate model-specific renaming / expert-splitting to 

1289 # the per-spec ``state_dict_adapter``. 

1290 adapter_cls = getattr(self.spec, "state_dict_adapter", None) 

1291 if os.path.isfile(hf_index) and adapter_cls is not None: 

1292 self._load_hf_safetensors(weights_path, adapter_cls) 

1293 else: 

1294 self._load_hyper_dcp(weights_path) 

1295 else: 

1296 self._load_single_file(weights_path) 

1297 logger.info_rank0("Weights loaded from %s", weights_path) 

1298 except Exception as exc: 

1299 raise RuntimeError( 

1300 f"Failed to load weights from {weights_path}: {exc}. " 

1301 "weights_path was provided so silent random-init fallback is unsafe — " 

1302 "uniform-logits loss would corrupt downstream training metrics." 

1303 ) from exc 

1304 

1305 def _load_hf_safetensors(self, weights_path: str, adapter_cls) -> None: 

1306 """Load HF safetensors via spec's ``state_dict_adapter``; drop shape mismatches.""" 

1307 # Cast loaded params down to the checkpoint's advertised dtype so the 

1308 # fp32 master matches what forward consumes. 

1309 load_dtype = self._resolve_hf_load_dtype(weights_path) 

1310 adapter = adapter_cls() 

1311 hf_sd = adapter.load_hf_state_dict( 

1312 weights_path, self.model.config, dtype=load_dtype, 

1313 ) 

1314 valid_sd, dropped, missing, unexpected = self._validate_hf_state_dict(hf_sd) 

1315 if dropped: 

1316 logger.warning( 

1317 "Dropped %d keys due to shape mismatch (first 5: %s)", 

1318 len(dropped), dropped[:5], 

1319 ) 

1320 # Derive missing/unexpected ourselves — ``HSDPModule.load_state_dict`` 

1321 # returns ``None``. 

1322 self.model.load_state_dict(valid_sd, strict=False) 

1323 model_name = getattr(self.args.model, "name", "") 

1324 logger.info_rank0( 

1325 "HF (%s) load: %d tensors into hyper model", 

1326 model_name, len(valid_sd), 

1327 ) 

1328 if missing: 

1329 logger.warning( 

1330 "Missing (randomly initialised): %d keys, e.g. %s ...", 

1331 len(missing), missing[:5], 

1332 ) 

1333 if unexpected: 

1334 logger.warning( 

1335 "Unexpected (ignored): %d keys, e.g. %s ...", 

1336 len(unexpected), unexpected[:5], 

1337 ) 

1338 

1339 def _resolve_hf_load_dtype(self, weights_path: str): 

1340 """Resolve the dtype to cast loaded HF tensors to (matches checkpoint config).""" 

1341 dtype_map = { 

1342 'bfloat16': torch.bfloat16, 'bf16': torch.bfloat16, 

1343 'float16': torch.float16, 'fp16': torch.float16, 

1344 'float32': torch.float32, 'fp32': torch.float32, 

1345 } 

1346 cfg_dtype = ( 

1347 getattr(self.model.config, 'dtype', None) 

1348 or getattr(self.model.config, 'torch_dtype', None) 

1349 ) 

1350 if cfg_dtype is None: 

1351 cfg_json = os.path.join(weights_path, 'config.json') 

1352 if os.path.isfile(cfg_json): 

1353 try: 

1354 with open(cfg_json, 'r', encoding='utf-8') as f: 

1355 cfg = json.load(f) 

1356 cfg_dtype = cfg.get('dtype') or cfg.get('torch_dtype') 

1357 except (OSError, json.JSONDecodeError): 

1358 cfg_dtype = None 

1359 if isinstance(cfg_dtype, str): 

1360 return dtype_map.get(cfg_dtype) 

1361 if isinstance(cfg_dtype, torch.dtype): 

1362 return cfg_dtype 

1363 return None 

1364 

1365 def _validate_hf_state_dict(self, hf_sd: dict): 

1366 """Strip wrapper segments and drop tensors whose shape differs from the model. 

1367 

1368 Pre-validate shapes: ``load_state_dict`` aborts on the first mismatch 

1369 and leaves later keys un-loaded. 

1370 

1371 Returns: 

1372 ``(valid_sd, dropped, missing, unexpected)``. 

1373 """ 

1374 # Strip wrapper segments (e.g. ``_checkpoint_wrapped_module``) so 

1375 # loader keys match ``named_parameters`` paths. 

1376 wrapper_segments = ("._checkpoint_wrapped_module",) 

1377 def _strip(k: str) -> str: 

1378 for s in wrapper_segments: 

1379 k = k.replace(s, "") 

1380 return k 

1381 logical_to_real = {} 

1382 real_to_param = {} 

1383 for name, param in self.model.named_parameters(): 

1384 logical_to_real[_strip(name)] = name 

1385 real_to_param[name] = param 

1386 valid_sd: dict = {} 

1387 dropped: list = [] 

1388 for hf_name, hf_tensor in hf_sd.items(): 

1389 real_name = logical_to_real.get(hf_name) 

1390 if real_name is None: 

1391 continue 

1392 tgt = tuple(real_to_param[real_name].shape) 

1393 src = tuple(hf_tensor.shape) 

1394 if src == tgt: 

1395 valid_sd[real_name] = hf_tensor 

1396 else: 

1397 dropped.append((real_name, src, tgt)) 

1398 param_names = set(real_to_param.keys()) 

1399 loaded_names = set(valid_sd.keys()) 

1400 missing = sorted(param_names - loaded_names) 

1401 unexpected = sorted(loaded_names - param_names) 

1402 return valid_sd, dropped, missing, unexpected 

1403 

1404 def _load_hyper_dcp(self, weights_path: str) -> None: 

1405 """Load weights from hyper's own DCP checkpoint format.""" 

1406 model_sd = self.model.state_dict() 

1407 dcp_load(model_sd, checkpoint_id=weights_path, use_collectives=False) 

1408 self.model.load_state_dict(model_sd) 

1409 

1410 def _load_single_file(self, weights_path: str) -> None: 

1411 """Load weights from a single ``.pt`` / ``.safetensors`` / ``.bin`` file.""" 

1412 sd = torch.load(weights_path, map_location="cpu", weights_only=True) 

1413 missing, unexpected = self.model.load_state_dict(sd, strict=False) 

1414 if missing: 

1415 logger.warning("Missing keys when loading weights: %s", missing) 

1416 if unexpected: 

1417 logger.warning("Unexpected keys when loading weights: %s", unexpected) 

1418 

1419 def _maybe_toggle_reshard(self, micro_step: int, num_micro_steps: int): 

1420 """Toggle FSDP reshard_after_backward for gradient accumulation optimization. 

1421 

1422 During gradient accumulation, skip resharding between micro-steps to avoid 

1423 redundant all-gather. Only reshard after the last micro-step. 

1424 """ 

1425 if not isinstance(self.model, HSDPModule) or num_micro_steps <= 1: 

1426 return 

1427 if micro_step == 0: 

1428 self.model.set_reshard_after_backward(False) 

1429 elif micro_step == num_micro_steps - 1: 

1430 self.model.set_reshard_after_backward(True)