Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / config.py: 0%
288 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Training configuration schema — strict three-tier (model/data/train).
17Top-level keys are exactly ``model``, ``data``, ``train`` and nothing else.
18arguments schema.
20YAML shape::
22 model:
23 name: qwen3_5
24 weights_path: /path/to/weights
25 data:
26 type: hf_datasets
27 train_path: /path/to/data
28 train:
29 max_steps: 100
30 micro_batch_size: 1
31 global_batch_size: 8
32 seed: 42
33 init_device: meta
34 optimizer:
35 type: adamw
36 lr: 1.0e-4
37 accelerator:
38 dp_shard: 8
39 tp: 1
40 mixed_precision:
41 enabled: true
42 param_dtype: bfloat16
43 checkpoint:
44 output_dir: outputs/run1
45 ...
46"""
47import argparse
48import difflib
49import logging
50import os
51from dataclasses import dataclass, field, fields, is_dataclass
52from typing import Any, Dict, Optional, Type, TypeVar, Union, get_args, get_origin
54import yaml
56logger = logging.getLogger(__name__)
58T = TypeVar("T")
60_BOOL_TRUE_STRINGS = frozenset(("true", "yes", "y", "on", "1", "t"))
61_BOOL_FALSE_STRINGS = frozenset(("false", "no", "n", "off", "0", "f"))
63# ============================================================================
64# model:
65# ============================================================================
67@dataclass
68class ModelConfig:
69 """``model.*`` — model identity, weights, and architecture overrides.
71 Only universal transformer fields are typed here. Anything model-
72 specific (mRoPE section split, MoE expert geometry, linear-attention
73 head counts, ``layer_types`` ...) goes through ``config_overrides`` —
74 a free-form ``dict`` that is merged into the underlying model
75 constructor by the model's ``build_model_fn``.
76 """
77 name: str = "qwen3_5"
78 weights_path: Optional[str] = None
79 tokenizer_path: Optional[str] = None
80 freeze_modules: Optional[list] = None
81 tp_plan: Optional[dict] = None
82 cp_modules: Optional[list] = None
83 ep_modules: Optional[list] = None
84 # Universal transformer architecture overrides.
85 num_hidden_layers: Optional[int] = None
86 hidden_size: Optional[int] = None
87 intermediate_size: Optional[int] = None
88 num_attention_heads: Optional[int] = None
89 num_key_value_heads: Optional[int] = None
90 vocab_size: Optional[int] = None
91 max_position_embeddings: Optional[int] = None
92 # Free-form per-model overrides handed to ``build_model_fn``.
93 config_overrides: Optional[dict] = None
95# ============================================================================
96# data:
97# ============================================================================
99@dataclass
100class DataConfig:
101 """``data.*`` — dataset, tokenizer/processor, sampler, batch shape.
103 - ``streaming``: only ``False`` supported today.
104 - ``num_workers``: keep ≥ 2 for real datasets.
105 - ``shuffle``: when ``False``, sampler reads samples in dataset order.
106 """
107 type: str = "dummy"
108 train_path: Optional[str] = None
109 subset: Optional[str] = None
110 max_seq_len: int = 2048
111 text_key: str = "text"
112 train_size: Optional[int] = None
113 # multimodal / VL (synthetic vl_dummy path)
114 template: str = "empty"
115 image_key: str = "image"
116 messages_key: str = "messages"
117 image_token_id: int = 151655
118 vl_grid_t: int = 2
119 vl_grid_h: int = 2
120 vl_grid_w: int = 2
121 # loader perf
122 streaming: bool = False
123 num_workers: int = 0
124 prefetch_factor: Optional[int] = None
125 pin_memory: bool = True
126 shuffle: bool = True
128# ============================================================================
129# train.* — sub-configs
130# ============================================================================
132@dataclass
133class AcceleratorConfig:
134 """``train.accelerator.*`` — parallelism topology.
136 Two ways to express data parallelism:
138 - **Legacy single field** (back-compat): ``dp`` only — maps to
139 ``dp_shard`` for FSDP.
140 - ** split**: ``dp_replicate`` (HSDP outer) and
141 ``dp_shard`` (FSDP inner). Pass ``dp_shard=-1`` to auto-fill from
142 ``world_size / (dp_replicate * cp * tp * pp)``.
144 For MoE: ``etp`` controls expert TP. Must equal ``tp`` or ``1``
145 ( rule).
146 """
147 dp: Optional[int] = None
148 dp_replicate: int = 1
149 dp_shard: Optional[int] = None
150 tp: int = 1
151 cp: int = 1
152 pp: int = 1
153 ep: int = 1
154 etp: int = 1
155 zero_stage: int = 0
156 reshard_after_forward: bool = True
157 async_cp: bool = False
158 ulysses_degree: Optional[int] = None
159 # Bucketed reduce-scatter: single fused RS per FSDP unit, stable fp32
160 # reduction order across runs.
161 comm_fusion: bool = True
163@dataclass
164class MixedPrecisionConfig:
165 """``train.mixed_precision.*`` — FSDP2 mp_policy fields.
167 ``output_dtype`` controls forward-output dtype at FSDP wrap boundaries.
168 Set this to ``"bfloat16"`` to keep the cross-FSDP-unit forward output
169 in bf16 (matches the typical "uniform low-precision forward" mp_policy
170 setup); leave ``None`` to inherit from ``param_dtype``.
171 """
172 enabled: bool = False
173 param_dtype: str = "bfloat16"
174 reduce_dtype: str = "float32"
175 output_dtype: Optional[str] = None
177@dataclass
178class GradientCheckpointingConfig:
179 """``train.gradient_checkpointing.*`` — activation recomputation.
181 . Modes: ``"off"``, ``"full"``, or ``"selective"``.
182 """
183 activation_checkpoint: str = "off"
185@dataclass
186class OptimizerConfig:
187 """``train.optimizer.*`` — optimizer + LR schedule + grad clip.
189 ``loss_aggregation``: how the per-micro-batch loss is scaled before
190 backward. ``"token_weighted"`` divides the summed loss by the global
191 valid-token count; ``"rank_average"`` averages per-rank micro-batch
192 means and is preferred when batches have variable valid-token counts
193 across ranks.
194 """
195 type: str = "adamw"
196 lr: float = 1e-4
197 lr_min: float = 1e-5
198 lr_decay_style: str = "cosine"
199 lr_warmup_ratio: float = 0.1
200 loss_aggregation: str = "token_weighted"
201 weight_decay: float = 0.01
202 max_grad_norm: float = 1.0
203 bsz_warmup_ratio: float = 0.0
204 eps: float = 1e-8
205 betas: tuple = (0.9, 0.999)
207@dataclass
208class CheckpointConfig:
209 """``train.checkpoint.*`` — DCP save / load + HF export."""
210 output_dir: str = "outputs"
211 save_steps: int = 500
212 save_hf_weights: bool = True
213 load_path: Optional[str] = None
214 save_async: bool = False
216@dataclass
217class LoggingConfig:
218 """``train.logging.*`` — console / metric output (consumed by LoggingCallback)."""
219 report_to: str = "none"
220 report_global_loss: bool = False
221 log_steps: int = 10
222 report_throughput: bool = True
223 model_flops_per_token: Optional[int] = None
224 peak_tflops: Optional[float] = None # e.g. 312.0 for A100 bf16
226@dataclass
227class TensorBoardConfig:
228 """``train.tensorboard.*`` — TB SummaryWriter on rank 0."""
229 enabled: bool = False
230 output_dir: str = "tb_traces"
231 log_steps: int = 1
233@dataclass
234class WandbConfig:
235 """``train.wandb.*`` — W&B run logging on rank 0."""
236 enabled: bool = False
237 project: str = "hyper-parallel"
238 run_name: Optional[str] = None
239 log_steps: int = 1
241@dataclass
242class ProfileConfig:
243 """``train.profile.*`` — torch.profiler schedule ().
245 Schedule semantics: wait → warmup → active."""
246 enabled: bool = False
247 output_dir: str = "profiler_traces"
248 wait_steps: int = 1
249 warmup_steps: int = 1
250 active_steps: int = 3
252@dataclass
253class MemoryMonitorConfig:
254 """``train.memory_monitor.*`` — periodic device-memory snapshot."""
255 enabled: bool = False
256 log_steps: int = 50
257 reset_peak_each_step: bool = False
259@dataclass
260class MoEMonitorConfig:
261 """``train.moe_monitor.*`` — MoE routing / load-balance monitor (stub)."""
262 enabled: bool = False
264@dataclass
265class EvalConfig:
266 """``train.eval.*`` — eval cadence + dataset."""
267 eval_steps: int = 0
268 eval_dataset: Optional[str] = None
270@dataclass
271class DebugConfig:
272 """``train.debug.*`` — reproducibility and numerical-stability knobs.
274 All flags here are production-safe; they tune determinism (CI / paper
275 reproducibility), guard against numerical blow-ups, and bound memory
276 growth in long runs.
277 """
278 deterministic: bool = False
279 deterministic_warn_only: bool = False
280 check_nan_inf: bool = False
281 gc_steps: int = 0
283# ============================================================================
284# train: (top of the train section, holds the sub-configs)
285# ============================================================================
287@dataclass
288class TrainConfig:
289 """``train.*`` — full training-section config.
291 Flat fields cover the basic loop knobs (steps, batch shape, init device,
292 seed, comm backend); nested sub-configs cover everything else.
293 """
294 # Loop shape
295 max_steps: int = 100
296 num_train_epochs: int = 1
297 global_batch_size: int = 8
298 micro_batch_size: int = 1
299 seed: int = 42
301 # Runtime / device
302 backend: str = "torch"
303 init_device: str = "meta"
304 comm_backend: Optional[str] = None
305 local_rank: int = 0 # set from LOCAL_RANK env by parser
307 # Sub-configs
308 accelerator: AcceleratorConfig = field(default_factory=AcceleratorConfig)
309 mixed_precision: MixedPrecisionConfig = field(default_factory=MixedPrecisionConfig)
310 gradient_checkpointing: GradientCheckpointingConfig = field(
311 default_factory=GradientCheckpointingConfig
312 )
313 optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
314 checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
315 logging: LoggingConfig = field(default_factory=LoggingConfig)
316 tensorboard: TensorBoardConfig = field(default_factory=TensorBoardConfig)
317 wandb: WandbConfig = field(default_factory=WandbConfig)
318 profile: ProfileConfig = field(default_factory=ProfileConfig)
319 memory_monitor: MemoryMonitorConfig = field(default_factory=MemoryMonitorConfig)
320 moe_monitor: MoEMonitorConfig = field(default_factory=MoEMonitorConfig)
321 eval: EvalConfig = field(default_factory=EvalConfig)
322 debug: DebugConfig = field(default_factory=DebugConfig)
324# ============================================================================
325# Top-level: model / data / train (and only these three)
326# ============================================================================
328@dataclass
329class HyperTrainerConfig:
330 """Top-level config — strict three-tier ().
332 Allowed top-level keys: ``model``, ``data``, ``train``. Anything else in
333 the YAML is rejected by the parser with a typo-suggestion message.
334 """
335 model: ModelConfig = field(default_factory=ModelConfig)
336 data: DataConfig = field(default_factory=DataConfig)
337 train: TrainConfig = field(default_factory=TrainConfig)
339 # Computed (no user input)
340 train_steps: int = 0
342 def __post_init__(self):
343 self.train_steps = self.train.max_steps
345# ==============================================================================
346# CLI / YAML parser
347# ==============================================================================
348# Configuration parser: YAML file + CLI dot-path overrides.
349#
350# Supports:
351# - Unknown YAML/CLI keys emit a warning with difflib closest-match suggestions.
352# - Bool fields accept string aliases: ``true/yes/y/on/1/t`` -> ``True``,
353# ``false/no/n/off/0/f`` -> ``False``. Only applied when the dataclass
354# field type resolves to ``bool`` or ``Optional[bool]`` to avoid ambiguity.
356def _string_to_bool(value: Any) -> bool:
357 """Convert common string representations of booleans to ``bool``.
359 Accepts: ``true/yes/y/on/1/t`` → ``True``,
360 ``false/no/n/off/0/f`` → ``False``.
362 Args:
363 value: A string or bool value.
365 Returns:
366 The corresponding ``bool``.
368 Raises:
369 ValueError: When the string cannot be mapped to a bool.
370 """
371 if isinstance(value, bool):
372 return value
373 if isinstance(value, str):
374 lower = value.lower()
375 if lower in _BOOL_TRUE_STRINGS:
376 return True
377 if lower in _BOOL_FALSE_STRINGS:
378 return False
379 raise ValueError(
380 f"Cannot convert {value!r} to bool. "
381 "Expected one of: true/false/yes/no/y/n/on/off/1/0/t/f"
382 )
384def _resolve_field_type(cls: Type, dot_path: str) -> Optional[Type]:
385 """Walk a dataclass hierarchy to find the resolved type of a dot-path field.
387 Args:
388 cls: Root dataclass class.
389 dot_path: Dot-separated field path, e.g. ``"debug.deterministic"``.
391 Returns:
392 The resolved Python type, or ``None`` if the path cannot be resolved.
393 """
394 parts = dot_path.split(".")
395 current_cls = cls
396 for part in parts:
397 if not is_dataclass(current_cls):
398 return None
399 found = None
400 for f in fields(current_cls):
401 if f.name == part:
402 found = f
403 break
404 if found is None:
405 return None
406 field_type = found.type
407 # Unwrap Optional[X] → X
408 origin = get_origin(field_type)
409 if origin is Union:
410 unwrapped = [a for a in get_args(field_type) if a is not type(None)]
411 field_type = unwrapped[0] if len(unwrapped) == 1 else field_type
412 current_cls = field_type
413 return current_cls
415def _coerce_cli_value(raw: str, dot_path: str, root_class: Type) -> Any:
416 """Parse a CLI string value, coercing to the correct type for the field.
418 Bool fields accept an extended string set. For all other
419 fields the existing int → float → str heuristic is used.
421 Args:
422 raw: Raw string from the CLI.
423 dot_path: Dot-separated field path used for type lookup.
424 root_class: Root dataclass class for type resolution.
426 Returns:
427 Coerced value.
428 """
429 field_type = _resolve_field_type(root_class, dot_path)
430 if field_type is bool:
431 try:
432 return _string_to_bool(raw)
433 except ValueError:
434 pass # fall through to heuristic below
435 # Existing heuristic: int → float → bool-string → str
436 try:
437 return int(raw)
438 except ValueError:
439 pass
440 try:
441 return float(raw)
442 except ValueError:
443 pass
444 if raw.lower() in ("true", "false"):
445 return raw.lower() == "true"
446 return raw
448def _deep_update(source: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
449 """Recursively update source dict with overrides dict."""
450 for key, value in overrides.items():
451 if isinstance(value, dict) and isinstance(source.get(key), dict):
452 _deep_update(source[key], value)
453 else:
454 source[key] = value
455 return source
457_ALLOWED_TOP_LEVEL_KEYS = frozenset({"model", "data", "train"})
459def _validate_top_level(config: Dict[str, Any]) -> None:
460 """Reject any top-level key other than ``model`` / ``data`` / ``train``.
462 Strict three-tier YAML shape. Any flat-style legacy key
463 (``parallel``, ``optim``, ``mixed_precision``, ``runtime``, ``debug`` ...)
464 must be moved under ``train.*`` — see schema.py for the canonical layout.
466 Raises:
467 ValueError: With migration hints when forbidden top-level keys are
468 present in the YAML.
469 """
470 forbidden = sorted(set(config) - _ALLOWED_TOP_LEVEL_KEYS)
471 if not forbidden:
472 return
474 legacy_to_train_path = {
475 "parallel": "train.accelerator",
476 "optim": "train.optimizer",
477 "mixed_precision": "train.mixed_precision",
478 "memory": "train.gradient_checkpointing",
479 "checkpoint": "train.checkpoint",
480 "logging": "train.logging",
481 "tensorboard": "train.tensorboard",
482 "wandb": "train.wandb",
483 "profiler": "train.profile",
484 "memory_monitor": "train.memory_monitor",
485 "moe_monitor": "train.moe_monitor",
486 "eval": "train.eval",
487 "runtime": "train (flatten init_device / backend / comm_backend)",
488 "debug": "train.debug",
489 "seed": "train.seed",
490 }
491 hints = []
492 for key in forbidden:
493 new_path = legacy_to_train_path.get(key)
494 if new_path:
495 hints.append(f" - top-level '{key}:' → move under {new_path}")
496 else:
497 hints.append(f" - top-level '{key}:' is not allowed")
498 raise ValueError(
499 "Forbidden top-level YAML keys: %s. The schema is strict three-tier "
500 "(model / data / train) — see config/schema.py. Migrate as follows:\n%s"
501 % (forbidden, "\n".join(hints))
502 )
504def _instantiate_recursive(cls: Type[T], config_dict: Dict[str, Any]) -> T:
505 """Recursively convert a dict into nested dataclass instances.
507 Unknown keys in ``config_dict`` that have no corresponding field on
508 ``cls`` emit a ``logger.warning`` with a closest-match suggestion from
509 ``difflib``, helping users catch typos in YAML configs.
510 """
511 if not is_dataclass(cls):
512 return config_dict
514 known = {f.name for f in fields(cls)}
515 unknown = set(config_dict) - known
516 for name in sorted(unknown):
517 matches = difflib.get_close_matches(name, known, n=1)
518 suggestion = f" Did you mean '{matches[0]}'?" if matches else ""
519 logger.warning(
520 "Unknown config key '%s' for %s ignored.%s",
521 name, cls.__name__, suggestion,
522 )
524 field_values = {}
525 for field_info in fields(cls):
526 if field_info.name not in config_dict:
527 continue
528 raw_value = config_dict[field_info.name]
529 field_type = field_info.type
531 # Unwrap Optional[X] → X
532 origin = get_origin(field_type)
533 if origin is Union:
534 unwrapped = [a for a in get_args(field_type) if a is not type(None)]
535 if len(unwrapped) == 1:
536 field_type = unwrapped[0]
538 if is_dataclass(field_type) and isinstance(raw_value, dict):
539 field_values[field_info.name] = _instantiate_recursive(field_type, raw_value)
540 elif field_type is bool and isinstance(raw_value, str):
541 field_values[field_info.name] = _string_to_bool(raw_value)
542 else:
543 field_values[field_info.name] = raw_value
545 return cls(**field_values)
547def parse_args(root_class: Type[T]) -> T:
548 """Parse training config from YAML file + CLI overrides.
550 Usage::
552 args = parse_args(HyperTrainerConfig)
554 The first positional argument is the YAML config file path.
555 CLI arguments use dot-path notation under the strict three-tier schema:
556 ``--train.accelerator.dp_shard=8 --train.optimizer.lr=3e-4``
558 Bool fields accept extended string aliases (``yes/no/on/off/y/n/t/f/1/0``).
559 Unknown YAML keys emit a warning with a closest-match suggestion.
561 Args:
562 root_class: The root config dataclass type.
564 Returns:
565 An instance of root_class populated from YAML + CLI.
566 """
567 parser = argparse.ArgumentParser(description="HyperParallel Trainer")
568 parser.add_argument("config_file", nargs="?", help="Path to YAML config file")
569 args, remaining = parser.parse_known_args()
571 # Load YAML
572 final_config: dict = {}
573 if args.config_file:
574 if not os.path.isfile(args.config_file):
575 logger.warning(
576 "Config file not found: %s (cwd=%s). Using all defaults.",
577 args.config_file, os.getcwd(),
578 )
579 else:
580 with open(args.config_file, encoding="utf-8") as f:
581 yaml_config = yaml.safe_load(f)
582 if yaml_config:
583 final_config = yaml_config
585 # Parse CLI dot-path overrides: --train.accelerator.dp=8 → nested dict
586 cli_config: dict = {}
587 for item in remaining:
588 if item.startswith("--") and "=" in item:
589 dot_key, raw_value = item[2:].split("=", 1)
590 coerced = _coerce_cli_value(raw_value, dot_key, root_class)
591 keys = dot_key.split(".")
592 current = cli_config
593 for k in keys[:-1]:
594 current = current.setdefault(k, {})
595 current[keys[-1]] = coerced
597 # CLI overrides YAML
598 final_config = _deep_update(final_config, cli_config)
600 # Strict three-tier validation — only model / data / train allowed.
601 _validate_top_level(final_config)
603 # local_rank from environment (torchrun sets it).
604 local_rank = int(os.environ.get("LOCAL_RANK", "0"))
605 final_config.setdefault("train", {})["local_rank"] = local_rank
607 return _instantiate_recursive(root_class, final_config)