Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / config.py: 0%

288 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Training configuration schema — strict three-tier (model/data/train). 

16 

17Top-level keys are exactly ``model``, ``data``, ``train`` and nothing else. 

18arguments schema. 

19 

20YAML shape:: 

21 

22 model: 

23 name: qwen3_5 

24 weights_path: /path/to/weights 

25 data: 

26 type: hf_datasets 

27 train_path: /path/to/data 

28 train: 

29 max_steps: 100 

30 micro_batch_size: 1 

31 global_batch_size: 8 

32 seed: 42 

33 init_device: meta 

34 optimizer: 

35 type: adamw 

36 lr: 1.0e-4 

37 accelerator: 

38 dp_shard: 8 

39 tp: 1 

40 mixed_precision: 

41 enabled: true 

42 param_dtype: bfloat16 

43 checkpoint: 

44 output_dir: outputs/run1 

45 ... 

46""" 

47import argparse 

48import difflib 

49import logging 

50import os 

51from dataclasses import dataclass, field, fields, is_dataclass 

52from typing import Any, Dict, Optional, Type, TypeVar, Union, get_args, get_origin 

53 

54import yaml 

55 

56logger = logging.getLogger(__name__) 

57 

58T = TypeVar("T") 

59 

60_BOOL_TRUE_STRINGS = frozenset(("true", "yes", "y", "on", "1", "t")) 

61_BOOL_FALSE_STRINGS = frozenset(("false", "no", "n", "off", "0", "f")) 

62 

63# ============================================================================ 

64# model: 

65# ============================================================================ 

66 

67@dataclass 

68class ModelConfig: 

69 """``model.*`` — model identity, weights, and architecture overrides. 

70 

71 Only universal transformer fields are typed here. Anything model- 

72 specific (mRoPE section split, MoE expert geometry, linear-attention 

73 head counts, ``layer_types`` ...) goes through ``config_overrides`` — 

74 a free-form ``dict`` that is merged into the underlying model 

75 constructor by the model's ``build_model_fn``. 

76 """ 

77 name: str = "qwen3_5" 

78 weights_path: Optional[str] = None 

79 tokenizer_path: Optional[str] = None 

80 freeze_modules: Optional[list] = None 

81 tp_plan: Optional[dict] = None 

82 cp_modules: Optional[list] = None 

83 ep_modules: Optional[list] = None 

84 # Universal transformer architecture overrides. 

85 num_hidden_layers: Optional[int] = None 

86 hidden_size: Optional[int] = None 

87 intermediate_size: Optional[int] = None 

88 num_attention_heads: Optional[int] = None 

89 num_key_value_heads: Optional[int] = None 

90 vocab_size: Optional[int] = None 

91 max_position_embeddings: Optional[int] = None 

92 # Free-form per-model overrides handed to ``build_model_fn``. 

93 config_overrides: Optional[dict] = None 

94 

95# ============================================================================ 

96# data: 

97# ============================================================================ 

98 

99@dataclass 

100class DataConfig: 

101 """``data.*`` — dataset, tokenizer/processor, sampler, batch shape. 

102 

103 - ``streaming``: only ``False`` supported today. 

104 - ``num_workers``: keep ≥ 2 for real datasets. 

105 - ``shuffle``: when ``False``, sampler reads samples in dataset order. 

106 """ 

107 type: str = "dummy" 

108 train_path: Optional[str] = None 

109 subset: Optional[str] = None 

110 max_seq_len: int = 2048 

111 text_key: str = "text" 

112 train_size: Optional[int] = None 

113 # multimodal / VL (synthetic vl_dummy path) 

114 template: str = "empty" 

115 image_key: str = "image" 

116 messages_key: str = "messages" 

117 image_token_id: int = 151655 

118 vl_grid_t: int = 2 

119 vl_grid_h: int = 2 

120 vl_grid_w: int = 2 

121 # loader perf 

122 streaming: bool = False 

123 num_workers: int = 0 

124 prefetch_factor: Optional[int] = None 

125 pin_memory: bool = True 

126 shuffle: bool = True 

127 

128# ============================================================================ 

129# train.* — sub-configs 

130# ============================================================================ 

131 

132@dataclass 

133class AcceleratorConfig: 

134 """``train.accelerator.*`` — parallelism topology. 

135 

136 Two ways to express data parallelism: 

137 

138 - **Legacy single field** (back-compat): ``dp`` only — maps to 

139 ``dp_shard`` for FSDP. 

140 - ** split**: ``dp_replicate`` (HSDP outer) and 

141 ``dp_shard`` (FSDP inner). Pass ``dp_shard=-1`` to auto-fill from 

142 ``world_size / (dp_replicate * cp * tp * pp)``. 

143 

144 For MoE: ``etp`` controls expert TP. Must equal ``tp`` or ``1`` 

145 ( rule). 

146 """ 

147 dp: Optional[int] = None 

148 dp_replicate: int = 1 

149 dp_shard: Optional[int] = None 

150 tp: int = 1 

151 cp: int = 1 

152 pp: int = 1 

153 ep: int = 1 

154 etp: int = 1 

155 zero_stage: int = 0 

156 reshard_after_forward: bool = True 

157 async_cp: bool = False 

158 ulysses_degree: Optional[int] = None 

159 # Bucketed reduce-scatter: single fused RS per FSDP unit, stable fp32 

160 # reduction order across runs. 

161 comm_fusion: bool = True 

162 

163@dataclass 

164class MixedPrecisionConfig: 

165 """``train.mixed_precision.*`` — FSDP2 mp_policy fields. 

166 

167 ``output_dtype`` controls forward-output dtype at FSDP wrap boundaries. 

168 Set this to ``"bfloat16"`` to keep the cross-FSDP-unit forward output 

169 in bf16 (matches the typical "uniform low-precision forward" mp_policy 

170 setup); leave ``None`` to inherit from ``param_dtype``. 

171 """ 

172 enabled: bool = False 

173 param_dtype: str = "bfloat16" 

174 reduce_dtype: str = "float32" 

175 output_dtype: Optional[str] = None 

176 

177@dataclass 

178class GradientCheckpointingConfig: 

179 """``train.gradient_checkpointing.*`` — activation recomputation. 

180 

181 . Modes: ``"off"``, ``"full"``, or ``"selective"``. 

182 """ 

183 activation_checkpoint: str = "off" 

184 

185@dataclass 

186class OptimizerConfig: 

187 """``train.optimizer.*`` — optimizer + LR schedule + grad clip. 

188 

189 ``loss_aggregation``: how the per-micro-batch loss is scaled before 

190 backward. ``"token_weighted"`` divides the summed loss by the global 

191 valid-token count; ``"rank_average"`` averages per-rank micro-batch 

192 means and is preferred when batches have variable valid-token counts 

193 across ranks. 

194 """ 

195 type: str = "adamw" 

196 lr: float = 1e-4 

197 lr_min: float = 1e-5 

198 lr_decay_style: str = "cosine" 

199 lr_warmup_ratio: float = 0.1 

200 loss_aggregation: str = "token_weighted" 

201 weight_decay: float = 0.01 

202 max_grad_norm: float = 1.0 

203 bsz_warmup_ratio: float = 0.0 

204 eps: float = 1e-8 

205 betas: tuple = (0.9, 0.999) 

206 

207@dataclass 

208class CheckpointConfig: 

209 """``train.checkpoint.*`` — DCP save / load + HF export.""" 

210 output_dir: str = "outputs" 

211 save_steps: int = 500 

212 save_hf_weights: bool = True 

213 load_path: Optional[str] = None 

214 save_async: bool = False 

215 

216@dataclass 

217class LoggingConfig: 

218 """``train.logging.*`` — console / metric output (consumed by LoggingCallback).""" 

219 report_to: str = "none" 

220 report_global_loss: bool = False 

221 log_steps: int = 10 

222 report_throughput: bool = True 

223 model_flops_per_token: Optional[int] = None 

224 peak_tflops: Optional[float] = None # e.g. 312.0 for A100 bf16 

225 

226@dataclass 

227class TensorBoardConfig: 

228 """``train.tensorboard.*`` — TB SummaryWriter on rank 0.""" 

229 enabled: bool = False 

230 output_dir: str = "tb_traces" 

231 log_steps: int = 1 

232 

233@dataclass 

234class WandbConfig: 

235 """``train.wandb.*`` — W&B run logging on rank 0.""" 

236 enabled: bool = False 

237 project: str = "hyper-parallel" 

238 run_name: Optional[str] = None 

239 log_steps: int = 1 

240 

241@dataclass 

242class ProfileConfig: 

243 """``train.profile.*`` — torch.profiler schedule (). 

244 

245 Schedule semantics: wait → warmup → active.""" 

246 enabled: bool = False 

247 output_dir: str = "profiler_traces" 

248 wait_steps: int = 1 

249 warmup_steps: int = 1 

250 active_steps: int = 3 

251 

252@dataclass 

253class MemoryMonitorConfig: 

254 """``train.memory_monitor.*`` — periodic device-memory snapshot.""" 

255 enabled: bool = False 

256 log_steps: int = 50 

257 reset_peak_each_step: bool = False 

258 

259@dataclass 

260class MoEMonitorConfig: 

261 """``train.moe_monitor.*`` — MoE routing / load-balance monitor (stub).""" 

262 enabled: bool = False 

263 

264@dataclass 

265class EvalConfig: 

266 """``train.eval.*`` — eval cadence + dataset.""" 

267 eval_steps: int = 0 

268 eval_dataset: Optional[str] = None 

269 

270@dataclass 

271class DebugConfig: 

272 """``train.debug.*`` — reproducibility and numerical-stability knobs. 

273 

274 All flags here are production-safe; they tune determinism (CI / paper 

275 reproducibility), guard against numerical blow-ups, and bound memory 

276 growth in long runs. 

277 """ 

278 deterministic: bool = False 

279 deterministic_warn_only: bool = False 

280 check_nan_inf: bool = False 

281 gc_steps: int = 0 

282 

283# ============================================================================ 

284# train: (top of the train section, holds the sub-configs) 

285# ============================================================================ 

286 

287@dataclass 

288class TrainConfig: 

289 """``train.*`` — full training-section config. 

290 

291 Flat fields cover the basic loop knobs (steps, batch shape, init device, 

292 seed, comm backend); nested sub-configs cover everything else. 

293 """ 

294 # Loop shape 

295 max_steps: int = 100 

296 num_train_epochs: int = 1 

297 global_batch_size: int = 8 

298 micro_batch_size: int = 1 

299 seed: int = 42 

300 

301 # Runtime / device 

302 backend: str = "torch" 

303 init_device: str = "meta" 

304 comm_backend: Optional[str] = None 

305 local_rank: int = 0 # set from LOCAL_RANK env by parser 

306 

307 # Sub-configs 

308 accelerator: AcceleratorConfig = field(default_factory=AcceleratorConfig) 

309 mixed_precision: MixedPrecisionConfig = field(default_factory=MixedPrecisionConfig) 

310 gradient_checkpointing: GradientCheckpointingConfig = field( 

311 default_factory=GradientCheckpointingConfig 

312 ) 

313 optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) 

314 checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) 

315 logging: LoggingConfig = field(default_factory=LoggingConfig) 

316 tensorboard: TensorBoardConfig = field(default_factory=TensorBoardConfig) 

317 wandb: WandbConfig = field(default_factory=WandbConfig) 

318 profile: ProfileConfig = field(default_factory=ProfileConfig) 

319 memory_monitor: MemoryMonitorConfig = field(default_factory=MemoryMonitorConfig) 

320 moe_monitor: MoEMonitorConfig = field(default_factory=MoEMonitorConfig) 

321 eval: EvalConfig = field(default_factory=EvalConfig) 

322 debug: DebugConfig = field(default_factory=DebugConfig) 

323 

324# ============================================================================ 

325# Top-level: model / data / train (and only these three) 

326# ============================================================================ 

327 

328@dataclass 

329class HyperTrainerConfig: 

330 """Top-level config — strict three-tier (). 

331 

332 Allowed top-level keys: ``model``, ``data``, ``train``. Anything else in 

333 the YAML is rejected by the parser with a typo-suggestion message. 

334 """ 

335 model: ModelConfig = field(default_factory=ModelConfig) 

336 data: DataConfig = field(default_factory=DataConfig) 

337 train: TrainConfig = field(default_factory=TrainConfig) 

338 

339 # Computed (no user input) 

340 train_steps: int = 0 

341 

342 def __post_init__(self): 

343 self.train_steps = self.train.max_steps 

344 

345# ============================================================================== 

346# CLI / YAML parser 

347# ============================================================================== 

348# Configuration parser: YAML file + CLI dot-path overrides. 

349# 

350# Supports: 

351# - Unknown YAML/CLI keys emit a warning with difflib closest-match suggestions. 

352# - Bool fields accept string aliases: ``true/yes/y/on/1/t`` -> ``True``, 

353# ``false/no/n/off/0/f`` -> ``False``. Only applied when the dataclass 

354# field type resolves to ``bool`` or ``Optional[bool]`` to avoid ambiguity. 

355 

356def _string_to_bool(value: Any) -> bool: 

357 """Convert common string representations of booleans to ``bool``. 

358 

359 Accepts: ``true/yes/y/on/1/t`` → ``True``, 

360 ``false/no/n/off/0/f`` → ``False``. 

361 

362 Args: 

363 value: A string or bool value. 

364 

365 Returns: 

366 The corresponding ``bool``. 

367 

368 Raises: 

369 ValueError: When the string cannot be mapped to a bool. 

370 """ 

371 if isinstance(value, bool): 

372 return value 

373 if isinstance(value, str): 

374 lower = value.lower() 

375 if lower in _BOOL_TRUE_STRINGS: 

376 return True 

377 if lower in _BOOL_FALSE_STRINGS: 

378 return False 

379 raise ValueError( 

380 f"Cannot convert {value!r} to bool. " 

381 "Expected one of: true/false/yes/no/y/n/on/off/1/0/t/f" 

382 ) 

383 

384def _resolve_field_type(cls: Type, dot_path: str) -> Optional[Type]: 

385 """Walk a dataclass hierarchy to find the resolved type of a dot-path field. 

386 

387 Args: 

388 cls: Root dataclass class. 

389 dot_path: Dot-separated field path, e.g. ``"debug.deterministic"``. 

390 

391 Returns: 

392 The resolved Python type, or ``None`` if the path cannot be resolved. 

393 """ 

394 parts = dot_path.split(".") 

395 current_cls = cls 

396 for part in parts: 

397 if not is_dataclass(current_cls): 

398 return None 

399 found = None 

400 for f in fields(current_cls): 

401 if f.name == part: 

402 found = f 

403 break 

404 if found is None: 

405 return None 

406 field_type = found.type 

407 # Unwrap Optional[X] → X 

408 origin = get_origin(field_type) 

409 if origin is Union: 

410 unwrapped = [a for a in get_args(field_type) if a is not type(None)] 

411 field_type = unwrapped[0] if len(unwrapped) == 1 else field_type 

412 current_cls = field_type 

413 return current_cls 

414 

415def _coerce_cli_value(raw: str, dot_path: str, root_class: Type) -> Any: 

416 """Parse a CLI string value, coercing to the correct type for the field. 

417 

418 Bool fields accept an extended string set. For all other 

419 fields the existing int → float → str heuristic is used. 

420 

421 Args: 

422 raw: Raw string from the CLI. 

423 dot_path: Dot-separated field path used for type lookup. 

424 root_class: Root dataclass class for type resolution. 

425 

426 Returns: 

427 Coerced value. 

428 """ 

429 field_type = _resolve_field_type(root_class, dot_path) 

430 if field_type is bool: 

431 try: 

432 return _string_to_bool(raw) 

433 except ValueError: 

434 pass # fall through to heuristic below 

435 # Existing heuristic: int → float → bool-string → str 

436 try: 

437 return int(raw) 

438 except ValueError: 

439 pass 

440 try: 

441 return float(raw) 

442 except ValueError: 

443 pass 

444 if raw.lower() in ("true", "false"): 

445 return raw.lower() == "true" 

446 return raw 

447 

448def _deep_update(source: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: 

449 """Recursively update source dict with overrides dict.""" 

450 for key, value in overrides.items(): 

451 if isinstance(value, dict) and isinstance(source.get(key), dict): 

452 _deep_update(source[key], value) 

453 else: 

454 source[key] = value 

455 return source 

456 

457_ALLOWED_TOP_LEVEL_KEYS = frozenset({"model", "data", "train"}) 

458 

459def _validate_top_level(config: Dict[str, Any]) -> None: 

460 """Reject any top-level key other than ``model`` / ``data`` / ``train``. 

461 

462 Strict three-tier YAML shape. Any flat-style legacy key 

463 (``parallel``, ``optim``, ``mixed_precision``, ``runtime``, ``debug`` ...) 

464 must be moved under ``train.*`` — see schema.py for the canonical layout. 

465 

466 Raises: 

467 ValueError: With migration hints when forbidden top-level keys are 

468 present in the YAML. 

469 """ 

470 forbidden = sorted(set(config) - _ALLOWED_TOP_LEVEL_KEYS) 

471 if not forbidden: 

472 return 

473 

474 legacy_to_train_path = { 

475 "parallel": "train.accelerator", 

476 "optim": "train.optimizer", 

477 "mixed_precision": "train.mixed_precision", 

478 "memory": "train.gradient_checkpointing", 

479 "checkpoint": "train.checkpoint", 

480 "logging": "train.logging", 

481 "tensorboard": "train.tensorboard", 

482 "wandb": "train.wandb", 

483 "profiler": "train.profile", 

484 "memory_monitor": "train.memory_monitor", 

485 "moe_monitor": "train.moe_monitor", 

486 "eval": "train.eval", 

487 "runtime": "train (flatten init_device / backend / comm_backend)", 

488 "debug": "train.debug", 

489 "seed": "train.seed", 

490 } 

491 hints = [] 

492 for key in forbidden: 

493 new_path = legacy_to_train_path.get(key) 

494 if new_path: 

495 hints.append(f" - top-level '{key}:' → move under {new_path}") 

496 else: 

497 hints.append(f" - top-level '{key}:' is not allowed") 

498 raise ValueError( 

499 "Forbidden top-level YAML keys: %s. The schema is strict three-tier " 

500 "(model / data / train) — see config/schema.py. Migrate as follows:\n%s" 

501 % (forbidden, "\n".join(hints)) 

502 ) 

503 

504def _instantiate_recursive(cls: Type[T], config_dict: Dict[str, Any]) -> T: 

505 """Recursively convert a dict into nested dataclass instances. 

506 

507 Unknown keys in ``config_dict`` that have no corresponding field on 

508 ``cls`` emit a ``logger.warning`` with a closest-match suggestion from 

509 ``difflib``, helping users catch typos in YAML configs. 

510 """ 

511 if not is_dataclass(cls): 

512 return config_dict 

513 

514 known = {f.name for f in fields(cls)} 

515 unknown = set(config_dict) - known 

516 for name in sorted(unknown): 

517 matches = difflib.get_close_matches(name, known, n=1) 

518 suggestion = f" Did you mean '{matches[0]}'?" if matches else "" 

519 logger.warning( 

520 "Unknown config key '%s' for %s ignored.%s", 

521 name, cls.__name__, suggestion, 

522 ) 

523 

524 field_values = {} 

525 for field_info in fields(cls): 

526 if field_info.name not in config_dict: 

527 continue 

528 raw_value = config_dict[field_info.name] 

529 field_type = field_info.type 

530 

531 # Unwrap Optional[X] → X 

532 origin = get_origin(field_type) 

533 if origin is Union: 

534 unwrapped = [a for a in get_args(field_type) if a is not type(None)] 

535 if len(unwrapped) == 1: 

536 field_type = unwrapped[0] 

537 

538 if is_dataclass(field_type) and isinstance(raw_value, dict): 

539 field_values[field_info.name] = _instantiate_recursive(field_type, raw_value) 

540 elif field_type is bool and isinstance(raw_value, str): 

541 field_values[field_info.name] = _string_to_bool(raw_value) 

542 else: 

543 field_values[field_info.name] = raw_value 

544 

545 return cls(**field_values) 

546 

547def parse_args(root_class: Type[T]) -> T: 

548 """Parse training config from YAML file + CLI overrides. 

549 

550 Usage:: 

551 

552 args = parse_args(HyperTrainerConfig) 

553 

554 The first positional argument is the YAML config file path. 

555 CLI arguments use dot-path notation under the strict three-tier schema: 

556 ``--train.accelerator.dp_shard=8 --train.optimizer.lr=3e-4`` 

557 

558 Bool fields accept extended string aliases (``yes/no/on/off/y/n/t/f/1/0``). 

559 Unknown YAML keys emit a warning with a closest-match suggestion. 

560 

561 Args: 

562 root_class: The root config dataclass type. 

563 

564 Returns: 

565 An instance of root_class populated from YAML + CLI. 

566 """ 

567 parser = argparse.ArgumentParser(description="HyperParallel Trainer") 

568 parser.add_argument("config_file", nargs="?", help="Path to YAML config file") 

569 args, remaining = parser.parse_known_args() 

570 

571 # Load YAML 

572 final_config: dict = {} 

573 if args.config_file: 

574 if not os.path.isfile(args.config_file): 

575 logger.warning( 

576 "Config file not found: %s (cwd=%s). Using all defaults.", 

577 args.config_file, os.getcwd(), 

578 ) 

579 else: 

580 with open(args.config_file, encoding="utf-8") as f: 

581 yaml_config = yaml.safe_load(f) 

582 if yaml_config: 

583 final_config = yaml_config 

584 

585 # Parse CLI dot-path overrides: --train.accelerator.dp=8 → nested dict 

586 cli_config: dict = {} 

587 for item in remaining: 

588 if item.startswith("--") and "=" in item: 

589 dot_key, raw_value = item[2:].split("=", 1) 

590 coerced = _coerce_cli_value(raw_value, dot_key, root_class) 

591 keys = dot_key.split(".") 

592 current = cli_config 

593 for k in keys[:-1]: 

594 current = current.setdefault(k, {}) 

595 current[keys[-1]] = coerced 

596 

597 # CLI overrides YAML 

598 final_config = _deep_update(final_config, cli_config) 

599 

600 # Strict three-tier validation — only model / data / train allowed. 

601 _validate_top_level(final_config) 

602 

603 # local_rank from environment (torchrun sets it). 

604 local_rank = int(os.environ.get("LOCAL_RANK", "0")) 

605 final_config.setdefault("train", {})["local_rank"] = local_rank 

606 

607 return _instantiate_recursive(root_class, final_config)