Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / base.py: 0%
642 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""BaseTrainer — composable training skeleton with 13 overridable ``_build_*`` steps.
17Design references:
18- BaseTrainer (``): composition pattern, 13-step init
19- hyper ``fsdp_demo.py``: FSDP wrapping via ``model.layers`` iteration
20- hyper st tests: TP → CP → AC → FSDP composition order
22Subclasses (LLMTrainer, DiTTrainer, VLMTrainer) use the composition pattern:
23instantiate ``BaseTrainer`` and call ``_build_*`` methods selectively, overriding
24or skipping steps as needed.
25"""
26import json
27import logging
28import math
29import os
30from contextlib import nullcontext
31from typing import TYPE_CHECKING, Any, Dict, Optional
33import torch
34from torch.utils.data import Dataset, DistributedSampler
36from hyper_parallel import (
37 get_platform,
38 init_empty_weights,
39 init_process_group,
40 destroy_process_group,
41 hsdp_sync_stream,
42 SkipDTensorDispatch,
43 HSDPModule,
44)
45from hyper_parallel.core.distributed_checkpoint import load as dcp_load
46from hyper_parallel.core.dtensor.dtensor import DTensor
47from hyper_parallel.core.fully_shard.hsdp_utils import GroupInfo
48from hyper_parallel.core.utils import clip_grad_norm_
49from hyper_parallel.models.spec.registry import get_spec
50from hyper_parallel.trainer.parallel_dims import ParallelDims
51from hyper_parallel.trainer.utils.loss import count_loss_token
52from hyper_parallel.trainer.callbacks.base import (
53 LoggingCallback,
54 CheckpointCallback,
55 SafetensorsExportCallback,
56 EvalCallback,
57 ProfilerCallback,
58 WandbCallback,
59 ProgressCallback,
60 MoEMonitorCallback,
61 GradientHealthCallback,
62 GCCallback,
63 TensorBoardCallback,
64 MemoryMonitorCallback,
65)
67if TYPE_CHECKING:
68 # Type-only imports — never executed at runtime, so the platform-agnostic
69 # rule ("no torch/mindspore in trainer code") is preserved. Same pattern
70 # as
71 from torch import nn
72 from torch.optim import Optimizer
73 from torch.optim.lr_scheduler import LRScheduler
74 from torch.utils.data import DataLoader
75 from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
77platform = get_platform()
78logger = logging.getLogger(__name__)
80class TrainerState:
81 """Mutable training state shared across callbacks.
83 Attributes:
84 global_step: Current training step (update count).
85 epoch: Current epoch index.
86 max_steps: Total number of training steps.
87 """
89 def __init__(self, max_steps: int = 0):
90 self.global_step: int = 0
91 self.epoch: int = 0
92 self.max_steps: int = max_steps
93 self.log_history: list = []
95class BaseTrainer:
96 """Composable training skeleton.
98 Provides 13 ``_build_*`` methods that subclasses can call, override, or skip.
99 The default ``_build_parallelized_model`` applies TP → CP → AC → FSDP by
100 iterating ``model.layers`` — matching hyper's own ``fsdp_demo.py`` style.
102 Args:
103 args: Training configuration (typically parsed from YAML).
104 """
106 # PEP 526 annotations — populated by ``_build_*``; ``None`` until built.
107 model: Optional["nn.Module"] = None
108 optimizer: Optional["Optimizer"] = None
109 lr_scheduler: Optional["LRScheduler"] = None
110 train_dataloader: Optional["DataLoader"] = None
111 mesh: Optional["DeviceMesh"] = None
113 def __init__(self, args):
114 # Only early-bound fields live here; the rest is built via
115 # ``_build_*`` methods invoked by the subclass.
116 self.args = args
117 self.spec = get_spec(args.model.name)
118 self.state = TrainerState(max_steps=getattr(args.train, 'max_steps', 0))
120 # ------------------------------------------------------------------
121 # 13 overridable _build_* methods
122 # ------------------------------------------------------------------
124 def _setup(self):
125 """Step 1: Initialize distributed environment, device mesh, and seed.
127 Calls hyper's own ``init_process_group`` and ``init_device_mesh``.
128 Mesh shape is derived from ``args.parallel`` (dp, tp, cp, pp, ep).
129 """
130 backend = getattr(self.args.train, 'comm_backend', None)
131 init_process_group(backend=backend)
133 local_rank = getattr(self.args.train, 'local_rank', 0)
134 device_type = platform.device_type() # "npu" or "cuda"
135 # Use platform.device(idx) — backend-agnostic.
136 self.device = platform.device(local_rank)
137 device_handle = platform.get_device_handle(device_type)
138 device_handle.set_device(local_rank)
140 # Build & validate parallel dims in one place (fail-fast).
142 self.parallel_dims = ParallelDims.from_config(
143 self.args.train.accelerator, world_size=platform.get_world_size(),
144 )
145 logger.info_rank0("ParallelDims: %s", self.parallel_dims.summary())
146 self.mesh = self.parallel_dims.build_mesh(platform.device_type())
148 # Build DP group_info for trainer-level all_reduce (loss/token sync).
149 # Uses hyper's GroupInfo + mesh.get_group (platform-agnostic).
151 dp_group = self._get_combined_dp_group()
152 dp_size = self.parallel_dims.dp_size
153 self._dp_group_info = GroupInfo(
154 group_name="trainer_dp", group=dp_group, rank_size=dp_size,
155 )
157 debug_cfg = getattr(self.args.train, 'debug', None)
158 seed = getattr(self.args.train, 'seed', 42)
159 platform.manual_seed(seed)
160 # Seed device-side RNG explicitly: ``platform.manual_seed`` only
161 # covers CPU.
162 try:
163 handle = platform.get_device_handle(device_type)
164 if hasattr(handle, "manual_seed_all"):
165 handle.manual_seed_all(seed)
166 elif hasattr(handle, "manual_seed"):
167 handle.manual_seed(seed)
168 except Exception as exc: # pylint: disable=W0718
169 logger.warning("Device-side seed init skipped: %s", exc)
171 if debug_cfg is not None and getattr(debug_cfg, 'deterministic', False):
172 warn_only = getattr(debug_cfg, 'deterministic_warn_only', False)
173 torch.use_deterministic_algorithms(True, warn_only=warn_only)
174 torch.backends.cudnn.deterministic = True
175 torch.backends.cudnn.benchmark = False
176 logger.info_rank0("Deterministic algorithms enabled (warn_only=%s)", warn_only)
178 logger.info_rank0(
179 "Setup complete: rank=%d, world_size=%d, mesh=%s",
180 platform.get_rank(), platform.get_world_size(),
181 self.mesh.mesh_dim_names,
182 )
183 logger.info_rank0(
184 "Config: data.type=%s, model.name=%s, model.num_hidden_layers=%s, "
185 "init_device=%s, max_steps=%d, global_bs=%d",
186 getattr(self.args.data, 'type', '?'),
187 getattr(self.args.model, 'name', '?'),
188 getattr(self.args.model, 'num_hidden_layers', '?'),
189 getattr(self.args.train, 'init_device', '?'),
190 self.state.max_steps,
191 getattr(self.args.train, 'global_batch_size', '?'),
192 )
194 def _build_model(self):
195 """Step 2: Construct model via ``spec.build_model_fn``.
197 The model is a plain ``nn.Module`` at this point — not yet parallelized.
198 When ``args.runtime.init_device == "meta"``, the model is constructed on
199 the meta device (no memory allocated) and real weights are loaded after
200 FSDP sharding via ``_load_weights_after_parallel``.
201 """
202 init_device = getattr(self.args.train, 'init_device', 'meta')
203 # Meta-device init: each rank materialises only its own shard
204 # post-FSDP — pre-trained weights via DCP, otherwise random init.
205 if init_device == "meta":
207 with init_empty_weights():
208 self.model = self.spec.build_model_fn(self.args)
209 logger.info_rank0(
210 "Model built on meta device (no memory allocated): %s",
211 type(self.model).__name__,
212 )
213 else:
214 self.model = self.spec.build_model_fn(self.args)
215 logger.info_rank0("Model built on %s: %s", init_device, type(self.model).__name__)
217 # Cross-check parallel degrees against the actual model hyperparams
218 # (heads%tp, kv_heads%tp, num_experts%ep, seq_len%(cp*tp)).
219 # Fails fast here instead of crashing inside parallelize_module.
220 seq_len = getattr(self.args.data, 'max_seq_len', None)
221 self.parallel_dims.validate_against_model(self.model, seq_len=seq_len)
223 def _freeze_model(self):
224 """Step 3: Freeze specified modules (optional)."""
225 freeze_modules = getattr(self.args.model, 'freeze_modules', None)
226 if not freeze_modules:
227 return
228 for name, param in self.model.named_parameters():
229 if any(pattern in name for pattern in freeze_modules):
230 param.requires_grad_(False)
232 def _build_model_assets(self):
233 """Step 4: Build tokenizer, processor, chat_template.
235 Default: no-op. LLMTrainer overrides to build tokenizer + chat_template.
236 VLMTrainer overrides to build processor.
237 """
238 self.tokenizer = None
239 self.processor = None
241 def _build_data_transform(self):
242 """Step 5: Build data preprocessing transform.
244 Default: identity transform. LLMTrainer overrides for tokenization.
245 """
246 self.data_transform = None
248 def _build_dataset(self):
249 """Step 6: Build training dataset.
251 Supports:
252 - ``data.type = "dummy"``: random tokens for validation
253 - ``data.type = "hf_datasets"``: HuggingFace datasets
254 - ``data.type = "megatron_indexed"``: Megatron .bin/.idx format
256 Subclass can override for custom dataset logic.
257 """
259 data_type = getattr(self.args.data, 'type', 'dummy')
260 seq_len = getattr(self.args.data, 'max_seq_len', 2048)
262 if data_type == "dummy":
263 vocab_size = getattr(self.model, 'config', None)
264 vocab_size = vocab_size.vocab_size if vocab_size else 32000
265 total_samples = self.state.max_steps * getattr(self.args.train, 'global_batch_size', 8)
267 class DummyDataset(Dataset):
268 """Deterministic random token dataset for FSDP validation.
270 Each sample's content is fixed by its index (seeded by
271 ``base_seed + idx``), so the same index always produces the
272 same tokens regardless of access order or DP configuration.
273 """
274 def __init__(self, num_samples, seq_length, vocab, base_seed=42):
275 self.num_samples = num_samples
276 self.seq_length = seq_length
277 self.vocab = vocab
278 self.base_seed = base_seed
279 def __len__(self):
280 return self.num_samples
281 def __getitem__(self, idx):
282 g = torch.Generator().manual_seed(self.base_seed + idx)
283 input_ids = torch.randint(
284 0, self.vocab, (self.seq_length,), generator=g,
285 )
286 return {"input_ids": input_ids, "labels": input_ids.clone()}
288 self.train_dataset = DummyDataset(total_samples, seq_len, vocab_size)
289 logger.info_rank0("Dummy dataset created: %d samples, seq_len=%d", total_samples, seq_len)
291 elif data_type == "hf_datasets":
292 from datasets import load_dataset # pylint: disable=C0415 # optional dep
293 train_path = self.args.data.train_path
294 streaming = getattr(self.args.data, 'streaming', False)
295 if streaming:
296 # ``DistributedSampler`` requires ``__len__``, which
297 # ``IterableDataset`` lacks; reject loudly until a
298 # sampler-less streaming path is wired.
299 raise NotImplementedError(
300 "data.streaming=True not yet wired. The current "
301 "_build_dataloader uses DistributedSampler which requires "
302 "len(dataset); IterableDataset has no __len__. "
303 "Use data.streaming=False (map-style) for now, or "
304 "subclass _build_dataset + _build_dataloader to emit an "
305 "IterableDataset that self-shards via dp_rank/dp_size."
306 )
307 ds = load_dataset(train_path, split="train", streaming=False)
308 if self.data_transform:
309 ds = ds.map(
310 self.data_transform, remove_columns=ds.column_names,
311 )
312 self.train_dataset = ds
313 logger.info_rank0("HF dataset loaded: %s (map-style)", train_path)
315 else:
316 raise ValueError(f"Unknown data type: {data_type}. Supported: dummy, hf_datasets")
318 def _build_collate_fn(self):
319 """Step 7: Build data collator.
321 Default: pads input_ids and labels to max length in the batch.
322 """
324 def _default_collate(batch):
325 """Simple padding collator."""
326 max_len = max(item["input_ids"].size(0) for item in batch)
327 input_ids_list = []
328 labels_list = []
329 for item in batch:
330 pad_len = max_len - item["input_ids"].size(0)
331 input_ids_list.append(
332 torch.nn.functional.pad(item["input_ids"], (0, pad_len), value=0)
333 )
334 labels_list.append(
335 torch.nn.functional.pad(item["labels"], (0, pad_len), value=-100)
336 )
337 return {
338 "input_ids": torch.stack(input_ids_list),
339 "labels": torch.stack(labels_list),
340 }
342 self.collate_fn = _default_collate
344 def _build_dataloader(self):
345 """Step 8: Build distributed stateful dataloader.
347 Uses ``torchdata.stateful_dataloader.StatefulDataLoader`` so that
348 iterator position is checkpointable — enabling exact resume after
349 restart (matching ).
351 Each ``next()`` call yields a list of micro-batches (for gradient
352 accumulation).
353 """
354 from torchdata.stateful_dataloader import StatefulDataLoader # pylint: disable=C0415 # optional dep
356 micro_bs = getattr(self.args.train, 'micro_batch_size', 1)
358 # Sampler uses DP rank/size — TP/CP/PP/EP peers share data.
359 dp_size = self.parallel_dims.dp_size
360 non_dp = self.parallel_dims.non_dp_size
361 global_rank = platform.get_rank()
362 dp_rank = global_rank // non_dp if non_dp > 1 else global_rank
364 shuffle = getattr(self.args.data, "shuffle", True)
365 sampler_seed = getattr(self.args.train, 'seed', 0)
367 self.sampler = DistributedSampler(
368 self.train_dataset,
369 num_replicas=dp_size,
370 rank=dp_rank,
371 shuffle=shuffle,
372 seed=sampler_seed,
373 drop_last=True,
374 )
376 # StatefulDataLoader supports state_dict() / load_state_dict()
377 # for checkpoint resume (torchdata API, used by + ).
378 num_workers = getattr(self.args.data, 'num_workers', 0)
379 prefetch_factor = getattr(self.args.data, 'prefetch_factor', None)
380 pin_memory = getattr(self.args.data, 'pin_memory', True)
381 loader_kwargs = {
382 "batch_size": micro_bs,
383 "sampler": self.sampler,
384 "collate_fn": self.collate_fn,
385 "num_workers": num_workers,
386 "pin_memory": pin_memory,
387 "drop_last": True,
388 }
389 # prefetch_factor is only accepted when num_workers > 0
390 if num_workers > 0 and prefetch_factor is not None:
391 loader_kwargs["prefetch_factor"] = prefetch_factor
392 self.train_dataloader = StatefulDataLoader(
393 self.train_dataset, **loader_kwargs,
394 )
396 # Use dp_size (not world_size) — TP/CP/PP ranks share data, not split it.
397 self._grad_accum = max(
398 getattr(self.args.train, 'global_batch_size', micro_bs) // (micro_bs * dp_size),
399 1,
400 )
402 logger.info_rank0(
403 "Dataloader built: micro_bs=%d, grad_accum=%d, dataset_size=%d",
404 micro_bs, self._grad_accum, len(self.train_dataset),
405 )
407 def _build_parallelized_model(self):
408 """Step 9: Apply parallel strategies to the model.
410 Each model owns its full parallelize pipeline in
411 ``models/<name>/parallelize.py`` (convention) and
412 registers it via ``ModelSpec.parallelize_fn``. There is no shared
413 "default" template — model-specific TP/EP/CP/AC/FSDP/Prefetch
414 composition lives next to the model that needs it.
415 """
416 if self.spec.parallelize_fn is None:
417 raise ValueError(
418 f"Model '{self.spec.name}' has no ``parallelize_fn`` registered "
419 f"on its ModelSpec. Each model must own its parallelize "
420 f"pipeline in models/<name>/parallelize.py."
421 )
422 self.model = self.spec.parallelize_fn(self.model, self.mesh, self.args)
423 self._post_parallelize()
425 def _post_parallelize(self):
426 """Common steps after parallelization (materialize weights + train mode).
428 Order when ``init_device == "meta"`` and ``weights_path`` is set:
430 1. Run ``_materialize_and_init_shards`` first — this calls
431 ``model.to_empty(device=...)`` + kaiming / zero init for every
432 parameter. That is the **baseline** state so no param stays on
433 meta (which would trip ``HSDPState._validate_no_meta_params``).
434 2. Then ``_load_weights`` copies the upstream checkpoint on top.
435 Every key that matches overwrites the random init; anything
436 missing in the checkpoint stays with its kaiming / zero init.
438 This pattern handles partial checkpoints cleanly — e.g. hyper's
439 current Qwen3-VL-MoE model lacks ``q_norm`` / ``k_norm`` /
440 ``pos_embed`` / ``deepstack_merger_list``; those stay random
441 while all other weights come from the pretrained checkpoint.
442 """
443 init_device = getattr(self.args.train, 'init_device', 'meta')
444 weights_path = getattr(self.args.model, 'weights_path', None)
445 if init_device == "meta":
446 # Always materialize first (random init baseline) so no param
447 # stays on meta — then overlay the checkpoint.
448 self._materialize_and_init_shards()
449 # Re-tie weights — ``to_empty`` gives every nn.Parameter fresh
450 # storage so ``__init__``-time ties are broken.
451 if hasattr(self.model, "tie_weights"):
452 self.model.tie_weights()
453 if weights_path:
454 self._load_weights(weights_path)
455 elif weights_path:
456 self._load_weights(weights_path)
457 # Mixed-precision storage policy: trainable params keep an fp32
458 # master; frozen params keep their loaded low-precision dtype so the
459 # forward stays uniform-bf16 within their FSDP unit.
460 self._maybe_downcast_frozen_params()
461 self._maybe_upcast_trainable_params()
462 self.model.train()
464 def _maybe_downcast_frozen_params(self) -> None:
465 """Maybe downcast frozen params (internal)."""
466 freeze_modules = getattr(self.args.model, 'freeze_modules', None)
467 if not freeze_modules:
468 return
469 mp_cfg = getattr(self.args.train, 'mixed_precision', None)
470 if mp_cfg is None or not getattr(mp_cfg, 'enabled', False):
471 return
473 target_dtype = {
474 'bfloat16': torch.bfloat16,
475 'bf16': torch.bfloat16,
476 'float16': torch.float16,
477 'fp16': torch.float16,
478 }.get(getattr(mp_cfg, 'param_dtype', 'bfloat16'))
479 if target_dtype is None:
480 return
481 n_cast = 0
482 for name, param in self.model.named_parameters():
483 if not any(pat in name for pat in freeze_modules):
484 continue
485 if param.requires_grad:
486 continue
487 local = param.data
488 if hasattr(local, 'to_local'):
489 local = local.to_local()
490 if local.dtype == target_dtype:
491 continue
492 new_local = local.to(target_dtype)
493 # DTensor: rebuild the global view via from_local with same placements.
494 if hasattr(param.data, 'to_local'):
496 if isinstance(param.data, DTensor):
497 param.data = DTensor.from_local(
498 new_local,
499 device_mesh=param.data.device_mesh,
500 placements=param.data.placements,
501 )
502 else:
503 param.data = new_local
504 else:
505 param.data = new_local
506 n_cast += 1
507 logger.info_rank0(
508 "Post-load: cast %d frozen params to %s",
509 n_cast, target_dtype,
510 )
512 def _maybe_upcast_trainable_params(self) -> None:
513 """Upcast trainable params to ``float32``.
515 Implements the standard mixed-precision pattern (fp32 master weight
516 + low-precision forward): trainable params loaded from a bf16
517 checkpoint would otherwise stay in bf16, and Adam moment estimates
518 accumulating at bf16's 7-bit mantissa diverge noticeably after only
519 a handful of optimizer steps.
521 Runs AFTER ``_maybe_downcast_frozen_params`` so frozen params keep
522 their bf16 storage and only trainable params get the fp32 master.
523 """
524 mp_cfg = getattr(self.args.train, 'mixed_precision', None)
525 if mp_cfg is None or not getattr(mp_cfg, 'enabled', False):
526 return
528 n_cast = 0
529 for _, param in self.model.named_parameters():
530 if not param.requires_grad:
531 continue
532 local = param.data
533 if hasattr(local, 'to_local'):
534 local = local.to_local()
535 if local.dtype == torch.float32:
536 continue
537 new_local = local.to(torch.float32)
538 if hasattr(param.data, 'to_local'):
540 if isinstance(param.data, DTensor):
541 param.data = DTensor.from_local(
542 new_local,
543 device_mesh=param.data.device_mesh,
544 placements=param.data.placements,
545 )
546 else:
547 param.data = new_local
548 else:
549 param.data = new_local
550 n_cast += 1
551 logger.info_rank0(
552 "Post-load: upcast %d trainable params to float32", n_cast,
553 )
555 def _build_optimizer(self):
556 """Step 10: Build optimizer. Must be called AFTER ``_build_parallelized_model``.
558 After FSDP, parameters are DTensor shards — optimizer operates on local shards.
559 Optimizer must be created after ``fully_shard``.
560 """
561 lr = getattr(self.args.train.optimizer, 'lr', 1e-4)
562 weight_decay = getattr(self.args.train.optimizer, 'weight_decay', 0.01)
564 # bias / LayerNorm / RMSNorm go to no-decay; grouping matters even
565 # at wd=0 — foreach Adam reduction order differs per group on NPU.
566 decay_keywords = ("bias", "layernorm", "norm", "rmsnorm")
568 def _is_no_decay(name: str) -> bool:
569 lname = name.lower()
570 return any(kw in lname for kw in decay_keywords)
572 decay_params = []
573 no_decay_params = []
574 seen_ids = set()
575 for n, p in self.model.named_parameters():
576 if not p.requires_grad:
577 continue
578 # Dedup tied params (same nn.Parameter shared across modules).
579 if id(p) in seen_ids:
580 continue
581 seen_ids.add(id(p))
582 if _is_no_decay(n):
583 no_decay_params.append(p)
584 else:
585 decay_params.append(p)
587 param_groups = [
588 {"params": decay_params, "weight_decay": weight_decay},
589 {"params": no_decay_params, "weight_decay": 0.0},
590 ]
591 adam_eps = getattr(self.args.train.optimizer, 'eps', 1e-8)
592 adam_betas = getattr(self.args.train.optimizer, 'betas', (0.9, 0.999))
593 adam_foreach = getattr(self.args.train.optimizer, 'foreach', None)
594 self.optimizer = torch.optim.AdamW(
595 param_groups,
596 lr=lr,
597 betas=adam_betas,
598 eps=adam_eps,
599 foreach=adam_foreach,
600 )
601 logger.info_rank0(
602 "Optimizer: AdamW lr=%.2e wd=%.3g decay_params=%d no_decay_params=%d",
603 lr, weight_decay, len(decay_params), len(no_decay_params),
604 )
606 def _build_lr_scheduler(self):
607 """Step 11: Build learning rate scheduler.
609 Supports cosine decay with warmup. Falls back to constant LR if
610 warmup_ratio is 0 and decay_style is 'constant'.
611 """
613 total_steps = self.state.max_steps
614 warmup_ratio = getattr(self.args.train.optimizer, 'lr_warmup_ratio', 0.0)
615 # ``ceil`` matches the standard warmup convention so a fractional
616 # ``warmup_ratio * max_steps`` rounds up to the next full step.
617 warmup_steps = math.ceil(total_steps * warmup_ratio)
618 decay_style = getattr(self.args.train.optimizer, 'lr_decay_style', 'cosine')
619 lr_min = getattr(self.args.train.optimizer, 'lr_min', 0.0)
620 lr_max = getattr(self.args.train.optimizer, 'lr', 1e-4)
622 def _lr_lambda(current_step):
623 if current_step < warmup_steps:
624 return float(current_step) / float(max(1, warmup_steps))
625 if decay_style == 'constant':
626 return 1.0
627 # Cosine decay
628 progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
629 cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
630 min_ratio = lr_min / lr_max if lr_max > 0 else 0.0
631 return min_ratio + (1.0 - min_ratio) * cosine_decay
633 self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, _lr_lambda)
634 logger.info_rank0(
635 "LR scheduler: %s, warmup_steps=%d/%d, lr=%.2e→%.2e",
636 decay_style, warmup_steps, total_steps, lr_max, lr_min,
637 )
639 def _build_training_context(self):
640 """Step 12: Build forward/backward context managers.
642 Delegates AMP context to ``_make_amp_context`` — the single place that
643 knows about backend-specific autocast. MindSpore subclass should
644 override ``_make_amp_context`` only.
645 """
648 mp_cfg = getattr(self.args.train, 'mixed_precision', None)
649 # FSDP2 mp_policy already runs forward/backward in low precision;
650 # stacking ``torch.autocast`` on top would re-promote LayerNorm /
651 # softmax / log_softmax to fp32 and break the uniform-bf16 forward.
652 self.model_fwd_context = nullcontext()
653 self.grad_scaler = None
654 if mp_cfg and getattr(mp_cfg, 'enabled', False):
655 param_dtype_str = getattr(mp_cfg, 'param_dtype', 'bfloat16')
656 logger.info_rank0(
657 "Mixed precision via FSDP2 mp_policy: dtype=%s on %s "
658 "(autocast disabled — pure low-precision forward under mp_policy)",
659 param_dtype_str, platform.device_type(),
660 )
662 self.model_bwd_context = nullcontext()
664 def _make_amp_context(self, param_dtype_str: str):
665 """Build the AMP forward context. Backend-specific.
667 This is the SOLE method in the trainer that touches the backend AMP
668 API directly. Override this single method to add MindSpore support
669 (``ms.amp.auto_mixed_precision``). Default uses ``torch.autocast``.
670 """
671 dtype = getattr(torch, param_dtype_str, torch.bfloat16)
672 return torch.autocast(platform.device_type(), dtype=dtype)
674 def _init_callbacks(self):
675 """Step 13: Initialize callbacks (explicit mode).
677 Each callback is a named field — engineer sees all callbacks and their
678 order in ``on_step_end`` at a glance. Add/remove/reorder = change one line.
679 """
680 self.logging_callback = LoggingCallback(self)
681 self.checkpoint_callback = CheckpointCallback(self)
682 self.hf_export_callback = SafetensorsExportCallback(self)
683 self.eval_callback = EvalCallback(self)
684 self.profiler_callback = ProfilerCallback(self)
685 self.wandb_callback = WandbCallback(self)
686 self.tensorboard_callback = TensorBoardCallback(self)
687 self.progress_callback = ProgressCallback(self)
688 self.moe_monitor_callback = MoEMonitorCallback(self)
689 # Health + operability (no-ops unless enabled in cfg.train.debug / .memory_monitor).
690 self.gradient_health_callback = GradientHealthCallback(self)
691 self.memory_monitor_callback = MemoryMonitorCallback(self)
692 self.gc_callback = GCCallback(self)
693 # ``user_callbacks`` lets external code append extra Callback instances
694 # (e.g. domain-specific monitors) without editing this method. They get
695 # the same lifecycle dispatch as built-ins.
696 self.user_callbacks: list = []
697 logger.info_rank0(
698 "Callbacks initialized: logging, checkpoint, hf_export, eval, "
699 "profiler, wandb, tensorboard, progress, moe_monitor, "
700 "gradient_health, memory_monitor, gc"
701 )
703 # ------------------------------------------------------------------
704 # Public API: external callback registration
705 # ------------------------------------------------------------------
707 def add_callback(self, callback) -> None:
708 """Register an extra ``Callback`` to receive every lifecycle event.
710 Use this to plug domain-specific monitors (custom metric sinks,
711 in-house experiment trackers, RL reward loggers) without editing
712 the trainer. Built-in callbacks always run first; user callbacks
713 run in registration order so a later user callback can read state
714 the earlier ones updated.
715 """
716 self.user_callbacks.append(callback)
717 logger.info_rank0(
718 "User callback registered: %s", type(callback).__name__,
719 )
721 # ------------------------------------------------------------------
722 # Callback dispatch (explicit mode)
723 # ------------------------------------------------------------------
725 def _builtin_callbacks(self) -> list:
726 """Return built-in callbacks in fixed dispatch order.
728 Centralised so every dispatcher iterates the same list — adding a
729 callback only needs an entry here plus a named field in
730 ``_init_callbacks`` (no per-event copy/paste).
731 """
732 return [
733 self.logging_callback,
734 self.eval_callback,
735 self.profiler_callback,
736 self.wandb_callback,
737 self.tensorboard_callback,
738 self.progress_callback,
739 self.checkpoint_callback,
740 self.hf_export_callback,
741 self.moe_monitor_callback,
742 self.gradient_health_callback,
743 self.memory_monitor_callback,
744 self.gc_callback,
745 ]
747 def _all_callbacks(self) -> list:
748 """Built-in callbacks followed by user-registered ones."""
749 return self._builtin_callbacks() + list(self.user_callbacks)
751 def on_init_end(self):
752 """Dispatch one-shot ``on_init_end`` after every ``_build_*`` ran.
754 Fired by the subclass at the end of its own ``__init__`` (see
755 ``LLMTrainer.__init__``); ``BaseTrainer.train()`` does NOT call it
756 because BaseTrainer instances are sometimes wrapped (composition
757 pattern) and the wrapper owns the init lifecycle.
758 """
759 for cb in self._all_callbacks():
760 cb.on_init_end(self.state)
762 def on_train_begin(self):
763 """Dispatch on_train_begin to all callbacks."""
764 # Memory monitor first so it captures the truly-initial peak.
765 self.memory_monitor_callback.on_train_begin(self.state)
766 self.moe_monitor_callback.on_train_begin(self.state)
767 self.profiler_callback.on_train_begin(self.state)
768 self.wandb_callback.on_train_begin(self.state)
769 self.tensorboard_callback.on_train_begin(self.state)
770 self.progress_callback.on_train_begin(self.state)
771 # Checkpoint runs LAST so resume sees an already-armed TB writer
772 # (it'll record the load event via dispatch_load_event).
773 self.checkpoint_callback.on_train_begin(self.state)
774 for cb in self.user_callbacks:
775 cb.on_train_begin(self.state)
777 def on_train_end(self):
778 """Dispatch on_train_end to all callbacks."""
779 self.checkpoint_callback.on_train_end(self.state)
780 self.hf_export_callback.on_train_end(self.state)
781 self.progress_callback.on_train_end(self.state)
782 self.tensorboard_callback.on_train_end(self.state)
783 self.wandb_callback.on_train_end(self.state)
784 self.profiler_callback.on_train_end(self.state)
785 for cb in self.user_callbacks:
786 cb.on_train_end(self.state)
788 def on_step_begin(self):
789 """Dispatch on_step_begin to all callbacks."""
790 self.logging_callback.on_step_begin(self.state)
791 for cb in self.user_callbacks:
792 cb.on_step_begin(self.state)
794 def on_step_end(self, loss=None, grad_norm=None):
795 """Dispatch on_step_end to all callbacks (built-ins + user)."""
796 for cb in self._all_callbacks():
797 cb.on_step_end(self.state, loss=loss, grad_norm=grad_norm)
799 def on_substep_end(self):
800 """Dispatch on_substep_end (after each micro-batch forward/backward)."""
801 self.moe_monitor_callback.on_substep_end(self.state)
802 for cb in self.user_callbacks:
803 cb.on_substep_end(self.state)
805 def on_pre_optimizer_step(self, grad_norm=None):
806 """Dispatch on_pre_optimizer_step (after grad clip, before optimizer.step)."""
807 # Health check runs FIRST so a NaN aborts before the logger misleads.
808 self.gradient_health_callback.on_pre_optimizer_step(
809 self.state, grad_norm=grad_norm,
810 )
811 self.logging_callback.on_pre_optimizer_step(self.state, grad_norm=grad_norm)
812 self.wandb_callback.on_pre_optimizer_step(self.state, grad_norm=grad_norm)
813 self.tensorboard_callback.on_pre_optimizer_step(self.state, grad_norm=grad_norm)
814 for cb in self.user_callbacks:
815 cb.on_pre_optimizer_step(self.state, grad_norm=grad_norm)
817 def on_epoch_begin(self):
818 """Dispatch on_epoch_begin."""
819 for cb in self._all_callbacks():
820 cb.on_epoch_begin(self.state)
822 def on_epoch_end(self):
823 """Dispatch on_epoch_end."""
824 for cb in self._all_callbacks():
825 cb.on_epoch_end(self.state)
827 # ------------------------------------------------------------------
828 # Event fan-out (LoggingCallback / CheckpointCallback emit these)
829 # ------------------------------------------------------------------
831 def dispatch_log_event(self, metrics: dict) -> None:
832 """Forward a metrics record to every callback's ``on_log``.
834 ``LoggingCallback`` calls this so TensorBoard / W&B / external sinks
835 log the SAME numbers — single source of truth, no duplicate work.
836 """
837 for cb in self._all_callbacks():
838 cb.on_log(self.state, metrics=metrics)
840 def dispatch_save_event(self, checkpoint_dir: str) -> None:
841 """Forward a ckpt-save event to every callback's ``on_save``."""
842 for cb in self._all_callbacks():
843 cb.on_save(self.state, checkpoint_dir=checkpoint_dir)
845 def dispatch_load_event(self, checkpoint_dir: str) -> None:
846 """Forward a ckpt-load event to every callback's ``on_load``."""
847 for cb in self._all_callbacks():
848 cb.on_load(self.state, checkpoint_dir=checkpoint_dir)
850 def dispatch_evaluate_event(self, metrics: dict = None) -> None:
851 """Forward an eval-pass-complete event to every callback's ``on_evaluate``."""
852 for cb in self._all_callbacks():
853 cb.on_evaluate(self.state, metrics=metrics)
855 # ------------------------------------------------------------------
856 # Training core
857 # ------------------------------------------------------------------
859 def forward_backward_step(
860 self,
861 micro_batch: Dict[str, Any],
862 micro_batch_tokens: int,
863 global_tokens: int,
864 num_micro: int = 1,
865 ):
866 """Run forward + backward for one micro-batch.
868 Uses global token normalisation: each micro-batch's
869 loss is scaled by ``micro_tokens / global_tokens`` so that every token
870 across all ranks and all micro-batches contributes equally to the
871 gradient, regardless of DP size or grad_accum.
873 Args:
874 micro_batch: Dict of input tensors.
875 micro_batch_tokens: Non-padding token count for this micro-batch.
876 global_tokens: Total non-padding tokens across **all** ranks and
877 **all** micro-batches (computed via all-reduce).
879 Returns:
880 Tuple of (raw_loss_scalar, micro_batch_tokens) for logging.
881 """
882 # Move tensors to device. Multimodal batches may contain nested
883 # lists/tuples/dicts, so keep this recursive instead of assuming a
884 # flat LM-only batch.
885 def _to_device(value):
886 if hasattr(value, "to"):
887 return value.to(self.device, non_blocking=True)
888 if isinstance(value, dict):
889 return {k: _to_device(v) for k, v in value.items()}
890 if isinstance(value, list):
891 return [_to_device(v) for v in value]
892 if isinstance(value, tuple):
893 return tuple(_to_device(v) for v in value)
894 return value
896 micro_batch = {k: _to_device(v) for k, v in micro_batch.items()}
898 # Forward (with training context for activation offload)
899 with self.model_fwd_context:
900 outputs = self.model(**micro_batch, use_cache=False)
901 loss = outputs["loss"] if isinstance(outputs, dict) else outputs.loss
903 # TP scenario: loss may be Partial DTensor — reduce before backward
904 if hasattr(loss, 'is_partial') and loss.is_partial():
905 loss = loss.reduce_partial()
907 # Keep raw loss value for logging before scaling
908 raw_loss = loss.detach()
910 # token_weighted: ``CE_mean * (micro_tokens / global_tokens) * dp_size``
911 # — DP-size and grad_accum invariant after FSDP's ``AVG`` reduction.
912 dp_size = self.parallel_dims.dp_size
913 agg = getattr(self.args.train.optimizer, 'loss_aggregation', 'token_weighted')
914 if agg == 'rank_average':
915 # Per-rank CE mean; FSDP2 ``ReduceOp.AVG`` handles cross-rank
916 # averaging. Dividing by ``num_micro`` averages grad-accum so
917 # 1-card grad_accum=N tracks the FSDP path's effective grad.
918 scaled_loss = loss / num_micro if num_micro > 1 else loss
919 else:
920 scaled_loss = loss * (micro_batch_tokens / global_tokens) * dp_size
922 # Backward (with training context)
923 with self.model_bwd_context:
924 scaled_loss.backward()
926 return raw_loss, micro_batch_tokens
928 def train_step(self, data_iterator):
929 """Execute one training step with gradient accumulation.
931 Precision-aligned across different DP configurations by:
932 1. All-reducing global token count before loss scaling ()
933 2. Syncing gradients only on the last micro-batch ()
934 3. All-reducing loss weighted by token count for reporting
936 Args:
937 data_iterator: Iterator yielding lists of micro-batch dicts.
938 """
939 # Pull data first; ``StopIteration`` propagates without bumping the
940 # step counter so checkpoint dirs / log indices match the steps that
941 # actually trained.
942 micro_batches = next(data_iterator)
943 self.state.global_step += 1
944 num_micro = len(micro_batches)
946 # ---- Phase 1: count global tokens ( style) ----
947 # All-reduce BEFORE forward so every rank uses the same denominator.
949 token_counts = [count_loss_token(mb) for mb in micro_batches]
950 local_tokens = sum(token_counts)
951 if local_tokens == 0:
952 local_tokens = 1
953 global_tokens = local_tokens
954 if platform.get_world_size() > 1:
955 gt = platform.full((1,), local_tokens).to(self.device)
956 platform.all_reduce(gt, self._dp_group_info)
957 global_tokens = max(int(gt.item()), 1)
958 # Expose for callbacks (e.g. LoggingCallback throughput).
959 self._last_global_tokens = global_tokens
961 # ---- Phase 2: forward / backward with per-micro-batch sync control ----
962 total_loss_sum = 0.0 # weighted: loss * micro_tokens
963 total_loss_arith_sum = 0.0 # arithmetic sum across micro-batches
964 total_tokens_local = 0
965 for i, mb in enumerate(micro_batches):
966 is_last = i == num_micro - 1
968 # Gradient sync: only on the last micro-batch ( pattern).
969 # Before last: accumulate gradients locally, no communication.
970 if isinstance(self.model, HSDPModule):
971 self.model.set_requires_gradient_sync(is_last)
972 self.model.set_is_last_backward(is_last)
974 # FSDP reshard optimization for gradient accumulation
975 self._maybe_toggle_reshard(i, num_micro)
977 raw_loss, mb_tokens = self.forward_backward_step(
978 mb, token_counts[i], global_tokens, num_micro=num_micro,
979 )
980 total_loss_sum += raw_loss.item() * mb_tokens
981 total_loss_arith_sum += raw_loss.item()
982 total_tokens_local += mb_tokens
983 self.on_substep_end()
985 # Wait for async gradient reduce
986 #
987 hsdp_sync_stream()
989 # Gradient clipping — DTensor-aware, returns plain Tensor
990 max_grad_norm = getattr(self.args.train.optimizer, 'max_grad_norm', 1.0)
991 clip_fn = self.spec.clip_grad_fn or clip_grad_norm_
992 grad_norm = clip_fn(self.model.parameters(), max_grad_norm)
993 self.on_pre_optimizer_step(grad_norm=grad_norm.item())
995 # Optimizer step — must be inside SkipDTensorDispatch
996 with SkipDTensorDispatch():
997 self.optimizer.step()
999 if self.lr_scheduler is not None:
1000 self.lr_scheduler.step()
1002 self.optimizer.zero_grad()
1004 # ---- Phase 3: all-reduce loss for reporting ----
1005 agg = getattr(self.args.train.optimizer, 'loss_aggregation', 'token_weighted')
1006 if agg == 'token_weighted':
1007 # Global token-weighted: sum local ce_sum across DP, divide by
1008 # global token count. Equivalent to mean-loss-per-token under
1009 # post-DP gradient mean.
1010 if platform.get_world_size() > 1:
1011 ls = platform.full((1,), total_loss_sum).to(self.device)
1012 platform.all_reduce(ls, self._dp_group_info)
1013 avg_loss = ls.item() / max(global_tokens, 1)
1014 else:
1015 avg_loss = total_loss_sum / max(total_tokens_local, 1)
1016 else:
1017 # rank_average: each rank averages its micro-batches, then
1018 # ``all_reduce`` averages across DP ranks. Divisor is the DP
1019 # group size (NOT global world_size); under TP/EP/PP/CP the
1020 # all-reduce only sums DP ranks, so dividing by world_size
1021 # would systematically under-report the loss.
1022 local_mean = total_loss_arith_sum / max(num_micro, 1)
1023 dp_size = self._dp_group_info.rank_size
1024 if dp_size > 1:
1025 ls = platform.full((1,), local_mean).to(self.device)
1026 platform.all_reduce(ls, self._dp_group_info)
1027 avg_loss = ls.item() / dp_size
1028 else:
1029 avg_loss = local_mean
1031 return {"loss": avg_loss, "grad_norm": grad_norm.item()}
1033 def train(self):
1034 """Main training loop: epoch → step → micro-batch.
1036 Dispatches callbacks at each lifecycle point (explicit mode).
1037 on_train_begin is called first — CheckpointCallback uses it to restore
1038 state.global_step from a saved checkpoint, so the loop below will
1039 correctly skip already-completed steps.
1040 """
1041 logger.info_rank0(
1042 "Training starts: max_steps=%d, epochs=%d",
1043 self.state.max_steps,
1044 getattr(self.args.train, 'num_train_epochs', 1),
1045 )
1046 # on_train_begin runs checkpoint resume — state.global_step may be
1047 # updated to the resumed step before the loop starts.
1048 self.on_train_begin()
1049 num_epochs = getattr(self.args.train, 'num_train_epochs', 1)
1051 if self.state.global_step > 0:
1052 logger.info_rank0(
1053 "Resuming training from step %d", self.state.global_step,
1054 )
1056 for epoch in range(num_epochs):
1057 if self.state.global_step >= self.state.max_steps:
1058 break
1059 self.state.epoch = epoch
1060 if hasattr(self, 'sampler'):
1061 self.sampler.set_epoch(epoch)
1062 self.on_epoch_begin()
1064 # Build micro-batch iterator from the stateful dataloader.
1065 # StatefulDataLoader tracks iterator position internally,
1066 # so after resume it skips already-consumed batches.
1067 data_iterator = self._make_micro_batch_iterator()
1069 # Drive the loop on the live ``global_step`` so total training
1070 # never exceeds ``max_steps`` regardless of ``num_train_epochs``
1071 # or resume offset.
1072 while self.state.global_step < self.state.max_steps:
1073 self.on_step_begin()
1074 try:
1075 metrics = self.train_step(data_iterator)
1076 except StopIteration:
1077 logger.info_rank0("Epoch %d: dataloader exhausted", epoch)
1078 break
1080 self.on_step_end(
1081 loss=metrics["loss"],
1082 grad_norm=metrics["grad_norm"],
1083 )
1085 self.on_epoch_end()
1087 self.on_train_end()
1088 destroy_process_group()
1089 logger.info_rank0("Training completed")
1091 # ------------------------------------------------------------------
1092 # Helpers
1093 # ------------------------------------------------------------------
1095 def _make_micro_batch_iterator(self):
1096 """Yield lists of micro-batches from the stateful dataloader.
1098 Groups ``self._grad_accum`` consecutive batches into a list for
1099 gradient accumulation. The underlying ``StatefulDataLoader`` tracks
1100 iteration position, so checkpoint/resume skips consumed batches.
1101 """
1102 batch_buffer = []
1103 for batch in self.train_dataloader:
1104 batch_buffer.append(batch)
1105 if len(batch_buffer) >= self._grad_accum:
1106 yield batch_buffer
1107 batch_buffer = []
1108 if batch_buffer:
1109 yield batch_buffer
1111 def _get_layers(self) -> list:
1112 """Return the repeating layers for FSDP/AC wrapping.
1114 Default: ``model.layers`` (standard transformer convention).
1115 Override in subclass for models with different structure.
1116 """
1117 if hasattr(self.model, 'layers'):
1118 return list(self.model.layers)
1119 raise ValueError(
1120 f"Model {type(self.model).__name__} has no .layers attribute. "
1121 f"Either add self.layers to the model, or override _get_layers() "
1122 f"in the Trainer subclass."
1123 )
1125 def _get_combined_dp_group(self):
1126 """Return the combined data-parallel ProcessGroup for trainer all-reduce.
1128 Prefers the ``"loss"`` flatten alias registered by
1129 ``ParallelDims.build_mesh`` (folds CP into the DP group when CP is
1130 active so token-count denominators include CP-sharded contributions).
1131 Falls back to ``"dp"``, then to the legacy ``dp_shard`` /
1132 ``dp_replicate`` axes for callers that built a custom mesh.
1133 """
1134 for name in ("loss", "dp", "dp_shard", "dp_replicate"):
1135 try:
1136 return self.mesh.get_group(name)
1137 except (KeyError, ValueError):
1138 continue
1139 return self.mesh.get_group()
1141 def _build_fsdp_kwargs(self) -> dict:
1142 """Build kwargs for ``fully_shard`` calls (dense parameters).
1144 For expert parameters when EP > 1, use ``_build_expert_fsdp_kwargs``.
1145 """
1146 for name in ("dp_shard", "dp", "dp_replicate"):
1147 try:
1148 dp_mesh = self.mesh[name]
1149 break
1150 except (KeyError, TypeError):
1151 continue
1152 else:
1153 dp_mesh = self.mesh
1154 kwargs = {"mesh": dp_mesh}
1156 reshard = getattr(self.args.train.accelerator, 'reshard_after_forward', True)
1157 kwargs["reshard_after_forward"] = reshard
1159 return kwargs
1161 def _build_expert_fsdp_kwargs(self) -> dict:
1162 """Build kwargs for ``fully_shard`` calls on expert parameters.
1164 When EP > 1, expert parameters are sharded across the EP group
1165 with a separate mesh dimension. Falls back to dense FSDP kwargs
1166 if EP is not enabled.
1167 """
1168 if not self.parallel_dims.ep_enabled:
1169 return self._build_fsdp_kwargs()
1171 try:
1172 ep_mesh = self.mesh["ep"]
1173 except (KeyError, TypeError):
1174 logger.warning("EP=%d but no 'ep' dimension in mesh, falling back to dp mesh",
1175 self.parallel_dims.ep)
1176 return self._build_fsdp_kwargs()
1178 kwargs = {"mesh": ep_mesh}
1179 reshard = getattr(self.args.train.accelerator, 'reshard_after_forward', True)
1180 kwargs["reshard_after_forward"] = reshard
1181 return kwargs
1183 def _materialize_and_init_shards(self) -> None:
1184 """Materialize meta-device parameters/buffers to real device in-place.
1186 After ``fully_shard`` on a meta-device model, each rank's parameters
1187 are meta DTensor shards **and FSDP2 holds internal views into those
1188 meta storages** (flat_param / unsharded buffer). Replacing the
1189 ``DTensor._local_tensor`` attribute leaves FSDP's internal views
1190 pointing at the old meta storage, so the first forward's all-gather
1191 still hits meta → ``c10d::_allgather_base_`` raises.
1193 PyTorch's ``nn.Module.to_empty(device=...)`` is the FSDP2-safe path:
1194 it walks every parameter/buffer (including DTensor shards) and
1195 **allocates real device storage in-place via ``torch.empty_like``**,
1196 preserving every existing view. After ``to_empty``, storage is
1197 uninitialised — we init on the local shard with kaiming_uniform for
1198 weights, zero for biases / 1-D / buffers.
1200 Reference: PyTorch FSDP2 meta-init docs;
1201 ``trainer.py`` post-``fully_shard`` init pattern.
1202 """
1203 device_type = platform.device_type()
1204 # Step 1: meta → real storage, in-place (FSDP-views preserved).
1205 self.model.to_empty(device=device_type)
1206 self._materialize_replicate_params(device_type)
1207 # Step 2: init the local shard of every param (and zero every buffer).
1208 param_count = self._init_local_shards()
1209 # Re-derive buffers wiped by ``to_empty`` (e.g. ``inv_freq``);
1210 # without this RoPE silently returns identity rotation.
1211 for module in self.model.modules():
1212 if hasattr(module, "reset_inv_freq"):
1213 module.reset_inv_freq()
1214 # ``to_empty`` strips DTensor; ``lazy_init`` re-wraps shards before
1215 # ``_load_weights`` / optimizer step see the params (the forward
1216 # pre-hook does the same later, but the loader needs DTensor first).
1217 reset_count = self._lazy_init_hsdp_modules()
1218 logger.info_rank0(
1219 "Meta → real on %s: to_empty + kaiming/zero init on %d params; "
1220 "FSDP lazy_init re-wrapped %d modules back to DTensor",
1221 device_type, param_count, reset_count,
1222 )
1224 def _iter_hsdp_states(self):
1225 """Yield the HSDP state attached to every HSDP-wrapped submodule."""
1226 for module in self.model.modules():
1227 if not isinstance(module, HSDPModule):
1228 continue
1229 scheduler = getattr(module, 'hsdp_scheduler', None)
1230 state = getattr(scheduler, 'hsdp_state', None) if scheduler else None
1231 if state is None:
1232 continue
1233 yield state
1235 def _materialize_replicate_params(self, device_type: str) -> None:
1236 """``to_empty`` skips ``replicate_params`` whose ``_local_tensor`` lives
1237 outside ``module._parameters`` — materialize them manually.
1238 """
1239 for state in self._iter_hsdp_states():
1240 for hsdp_param in getattr(state, 'replicate_params', []) or []:
1241 local = getattr(hsdp_param.sharded_param, "_local_tensor", None)
1242 if local is not None and local.is_meta:
1243 new_local = torch.empty_like(local, device=device_type)
1244 hsdp_param.sharded_param._local_tensor = new_local # pylint: disable=W0212
1246 def _init_local_shards(self) -> int:
1247 """Init local shard of every param (kaiming for >=2D, zero else); zero buffers."""
1248 param_count = 0
1249 with torch.no_grad():
1250 for _, param in self.model.named_parameters():
1251 local = param._local_tensor if hasattr(param, '_local_tensor') else param # pylint: disable=W0212
1252 if local.is_meta:
1253 continue
1254 if local.dim() >= 2:
1255 torch.nn.init.kaiming_uniform_(local)
1256 else:
1257 torch.nn.init.zeros_(local)
1258 param_count += 1
1259 for _, buf in self.model.named_buffers():
1260 if buf is not None:
1261 buf.zero_()
1262 return param_count
1264 def _lazy_init_hsdp_modules(self) -> int:
1265 """Re-wrap HSDP shards into DTensor so loader / optimizer see them."""
1266 reset_count = 0
1267 for state in self._iter_hsdp_states():
1268 if hasattr(state, 'lazy_init'):
1269 state.lazy_init()
1270 reset_count += 1
1271 return reset_count
1273 def _load_weights(self, weights_path: str) -> None:
1274 """Load pre-trained weights from ``weights_path`` into the (possibly sharded) model.
1276 Uses hyper's distributed checkpoint ``load`` API so that each rank only
1277 reads the shard it owns. Falls back to a plain ``torch.load`` + partial
1278 ``load_state_dict`` for single-file checkpoints (e.g. safetensors).
1280 Args:
1281 weights_path: Path to a directory containing a distributed checkpoint,
1282 or a single ``.pt`` / ``.bin`` file.
1283 """
1284 logger.info_rank0("Loading weights from %s", weights_path)
1285 try:
1286 if os.path.isdir(weights_path):
1287 hf_index = os.path.join(weights_path, "model.safetensors.index.json")
1288 # Delegate model-specific renaming / expert-splitting to
1289 # the per-spec ``state_dict_adapter``.
1290 adapter_cls = getattr(self.spec, "state_dict_adapter", None)
1291 if os.path.isfile(hf_index) and adapter_cls is not None:
1292 self._load_hf_safetensors(weights_path, adapter_cls)
1293 else:
1294 self._load_hyper_dcp(weights_path)
1295 else:
1296 self._load_single_file(weights_path)
1297 logger.info_rank0("Weights loaded from %s", weights_path)
1298 except Exception as exc:
1299 raise RuntimeError(
1300 f"Failed to load weights from {weights_path}: {exc}. "
1301 "weights_path was provided so silent random-init fallback is unsafe — "
1302 "uniform-logits loss would corrupt downstream training metrics."
1303 ) from exc
1305 def _load_hf_safetensors(self, weights_path: str, adapter_cls) -> None:
1306 """Load HF safetensors via spec's ``state_dict_adapter``; drop shape mismatches."""
1307 # Cast loaded params down to the checkpoint's advertised dtype so the
1308 # fp32 master matches what forward consumes.
1309 load_dtype = self._resolve_hf_load_dtype(weights_path)
1310 adapter = adapter_cls()
1311 hf_sd = adapter.load_hf_state_dict(
1312 weights_path, self.model.config, dtype=load_dtype,
1313 )
1314 valid_sd, dropped, missing, unexpected = self._validate_hf_state_dict(hf_sd)
1315 if dropped:
1316 logger.warning(
1317 "Dropped %d keys due to shape mismatch (first 5: %s)",
1318 len(dropped), dropped[:5],
1319 )
1320 # Derive missing/unexpected ourselves — ``HSDPModule.load_state_dict``
1321 # returns ``None``.
1322 self.model.load_state_dict(valid_sd, strict=False)
1323 model_name = getattr(self.args.model, "name", "")
1324 logger.info_rank0(
1325 "HF (%s) load: %d tensors into hyper model",
1326 model_name, len(valid_sd),
1327 )
1328 if missing:
1329 logger.warning(
1330 "Missing (randomly initialised): %d keys, e.g. %s ...",
1331 len(missing), missing[:5],
1332 )
1333 if unexpected:
1334 logger.warning(
1335 "Unexpected (ignored): %d keys, e.g. %s ...",
1336 len(unexpected), unexpected[:5],
1337 )
1339 def _resolve_hf_load_dtype(self, weights_path: str):
1340 """Resolve the dtype to cast loaded HF tensors to (matches checkpoint config)."""
1341 dtype_map = {
1342 'bfloat16': torch.bfloat16, 'bf16': torch.bfloat16,
1343 'float16': torch.float16, 'fp16': torch.float16,
1344 'float32': torch.float32, 'fp32': torch.float32,
1345 }
1346 cfg_dtype = (
1347 getattr(self.model.config, 'dtype', None)
1348 or getattr(self.model.config, 'torch_dtype', None)
1349 )
1350 if cfg_dtype is None:
1351 cfg_json = os.path.join(weights_path, 'config.json')
1352 if os.path.isfile(cfg_json):
1353 try:
1354 with open(cfg_json, 'r', encoding='utf-8') as f:
1355 cfg = json.load(f)
1356 cfg_dtype = cfg.get('dtype') or cfg.get('torch_dtype')
1357 except (OSError, json.JSONDecodeError):
1358 cfg_dtype = None
1359 if isinstance(cfg_dtype, str):
1360 return dtype_map.get(cfg_dtype)
1361 if isinstance(cfg_dtype, torch.dtype):
1362 return cfg_dtype
1363 return None
1365 def _validate_hf_state_dict(self, hf_sd: dict):
1366 """Strip wrapper segments and drop tensors whose shape differs from the model.
1368 Pre-validate shapes: ``load_state_dict`` aborts on the first mismatch
1369 and leaves later keys un-loaded.
1371 Returns:
1372 ``(valid_sd, dropped, missing, unexpected)``.
1373 """
1374 # Strip wrapper segments (e.g. ``_checkpoint_wrapped_module``) so
1375 # loader keys match ``named_parameters`` paths.
1376 wrapper_segments = ("._checkpoint_wrapped_module",)
1377 def _strip(k: str) -> str:
1378 for s in wrapper_segments:
1379 k = k.replace(s, "")
1380 return k
1381 logical_to_real = {}
1382 real_to_param = {}
1383 for name, param in self.model.named_parameters():
1384 logical_to_real[_strip(name)] = name
1385 real_to_param[name] = param
1386 valid_sd: dict = {}
1387 dropped: list = []
1388 for hf_name, hf_tensor in hf_sd.items():
1389 real_name = logical_to_real.get(hf_name)
1390 if real_name is None:
1391 continue
1392 tgt = tuple(real_to_param[real_name].shape)
1393 src = tuple(hf_tensor.shape)
1394 if src == tgt:
1395 valid_sd[real_name] = hf_tensor
1396 else:
1397 dropped.append((real_name, src, tgt))
1398 param_names = set(real_to_param.keys())
1399 loaded_names = set(valid_sd.keys())
1400 missing = sorted(param_names - loaded_names)
1401 unexpected = sorted(loaded_names - param_names)
1402 return valid_sd, dropped, missing, unexpected
1404 def _load_hyper_dcp(self, weights_path: str) -> None:
1405 """Load weights from hyper's own DCP checkpoint format."""
1406 model_sd = self.model.state_dict()
1407 dcp_load(model_sd, checkpoint_id=weights_path, use_collectives=False)
1408 self.model.load_state_dict(model_sd)
1410 def _load_single_file(self, weights_path: str) -> None:
1411 """Load weights from a single ``.pt`` / ``.safetensors`` / ``.bin`` file."""
1412 sd = torch.load(weights_path, map_location="cpu", weights_only=True)
1413 missing, unexpected = self.model.load_state_dict(sd, strict=False)
1414 if missing:
1415 logger.warning("Missing keys when loading weights: %s", missing)
1416 if unexpected:
1417 logger.warning("Unexpected keys when loading weights: %s", unexpected)
1419 def _maybe_toggle_reshard(self, micro_step: int, num_micro_steps: int):
1420 """Toggle FSDP reshard_after_backward for gradient accumulation optimization.
1422 During gradient accumulation, skip resharding between micro-steps to avoid
1423 redundant all-gather. Only reshard after the last micro-step.
1424 """
1425 if not isinstance(self.model, HSDPModule) or num_micro_steps <= 1:
1426 return
1427 if micro_step == 0:
1428 self.model.set_reshard_after_backward(False)
1429 elif micro_step == num_micro_steps - 1:
1430 self.model.set_reshard_after_backward(True)