Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / parallel_dims.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15 

16# Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py 

17 

18"""ParallelDims — fail-fast parallel configuration validator + mesh builder. 

19 

20Centralises parallel-degree validation in a single dataclass so 

21misconfigurations are caught before model construction. 

22 

23What this provides: 

24 

251. **Inference** — auto-fill ``dp`` (or ``dp_shard=-1``) from the product 

26 constraint ``dp_replicate * dp_shard * cp * tp * pp == world_size``. 

27 

282. **Validation against world_size** — raises ``ValueError`` with a clear 

29 message when the product mismatches. 

30 

313. **Validation against the model** — checks divisibility constraints that 

32 would otherwise crash deep inside ``parallelize_module``: 

33 

34 - ``num_attention_heads % tp == 0`` (TP shards heads) 

35 - ``num_key_value_heads % tp == 0`` (GQA constraint) 

36 - ``num_experts % ep == 0`` (when MoE) 

37 - ``ulysses_degree <= cp`` and ``cp % ulysses_degree == 0`` 

38 - ``seq_len % (cp * tp) == 0`` (sequence parallel + CP) 

39 - ``etp == tp or etp == 1`` (expert TP rule) 

40 

414. **Mesh building** — returns a ``DeviceMesh`` with the canonical dim order 

42 ``dp_replicate → dp_shard → ep → cp → tp → pp``. Backwards compatible with 

43 the legacy single ``dp`` field (auto-collapses to ``dp_shard``). 

44 

45User experience: 

46 

47- Default config (no parallel section) — works on 1 GPU, runs as DDP-1. 

48- Set only ``tp=4`` on world_size=8 → ``dp`` auto-inferred to 2. 

49- Set ``dp_shard=-1`` → fills remaining cards into FSDP shard dim. 

50- Misconfig (heads=12, tp=8) → fails at ``_setup`` with a single readable 

51 error before any model parallelization is attempted. 

52""" 

53from __future__ import annotations 

54 

55import logging 

56from dataclasses import dataclass, field 

57from typing import Optional 

58 

59from hyper_parallel import init_device_mesh 

60 

61logger = logging.getLogger(__name__) 

62 

63@dataclass 

64class ParallelDims: 

65 """Validated parallel degrees + lazy mesh builder. 

66 

67 Attributes: 

68 dp_replicate: DDP replication degree (HSDP outer dim). 

69 dp_shard: FSDP shard degree. ``-1`` means "fill the rest from 

70 ``world_size / (dp_replicate * cp * tp * pp)``". 

71 cp: Context parallel degree. 

72 tp: Tensor parallel degree (dense path). 

73 pp: Pipeline parallel degree. 

74 ep: Expert parallel degree (MoE only). 

75 etp: Expert tensor parallel degree. Must equal ``tp`` or ``1``. 

76 ulysses_degree: Ulysses sub-degree inside ``cp``. ``None`` means 

77 "pure Ulysses (degree == cp)". 

78 world_size: Total number of ranks. 

79 """ 

80 

81 dp_replicate: int = 1 

82 dp_shard: int = 1 

83 cp: int = 1 

84 tp: int = 1 

85 pp: int = 1 

86 ep: int = 1 

87 etp: int = 1 

88 ulysses_degree: Optional[int] = None 

89 world_size: int = 1 

90 # Cached after build_mesh. 

91 _device_mesh: object = field(default=None, repr=False) 

92 

93 # ------------------------------------------------------------------ 

94 # Construction & inference 

95 # ------------------------------------------------------------------ 

96 @classmethod 

97 def from_config(cls, parallel_cfg, world_size: int) -> "ParallelDims": 

98 """Build from a ``ParallelConfig`` (or any object with the same fields). 

99 

100 Accepts the legacy single-``dp`` field. If ``dp`` is set and 

101 ``dp_replicate``/``dp_shard`` are at default, ``dp`` is mapped to 

102 ``dp_shard`` (FSDP behavior). 

103 """ 

104 dp_replicate = getattr(parallel_cfg, 'dp_replicate', 1) 

105 dp_shard = getattr(parallel_cfg, 'dp_shard', None) 

106 legacy_dp = getattr(parallel_cfg, 'dp', None) 

107 

108 # Backward-compat: legacy ``dp`` maps to ``dp_shard`` when both 

109 # dp_replicate/dp_shard fields are at defaults. 

110 if dp_shard is None: 

111 dp_shard = legacy_dp if legacy_dp is not None else 1 

112 

113 return cls( 

114 dp_replicate=dp_replicate, 

115 dp_shard=dp_shard, 

116 cp=getattr(parallel_cfg, 'cp', 1), 

117 tp=getattr(parallel_cfg, 'tp', 1), 

118 pp=getattr(parallel_cfg, 'pp', 1), 

119 ep=getattr(parallel_cfg, 'ep', 1), 

120 etp=getattr(parallel_cfg, 'etp', getattr(parallel_cfg, 'tp', 1)), 

121 ulysses_degree=getattr(parallel_cfg, 'ulysses_degree', None), 

122 world_size=world_size, 

123 ) 

124 

125 def __post_init__(self) -> None: 

126 self._infer_and_validate() 

127 

128 def _infer_and_validate(self) -> None: 

129 """Auto-fill ``dp_shard=-1`` then validate ``product == world_size``.""" 

130 for name, value in ( 

131 ("dp_replicate", self.dp_replicate), 

132 ("cp", self.cp), 

133 ("tp", self.tp), 

134 ("pp", self.pp), 

135 ("ep", self.ep), 

136 ("etp", self.etp), 

137 ): 

138 if value < 1: 

139 raise ValueError(f"Parallel degree {name}={value} must be >= 1") 

140 

141 if self.dp_shard < -1 or self.dp_shard == 0: 

142 raise ValueError( 

143 f"dp_shard={self.dp_shard} must be -1 (auto) or a positive int" 

144 ) 

145 

146 # Auto-infer dp_shard when -1. ep is an independent peer mesh dim 

147 # (see build_mesh) so it does NOT reduce the dp pool. 

148 if self.dp_shard == -1: 

149 non_dp = self.dp_replicate * self.cp * self.tp * self.pp * self.ep 

150 if self.world_size % non_dp != 0: 

151 raise ValueError( 

152 f"Cannot auto-infer dp_shard: world_size={self.world_size} " 

153 f"is not divisible by dp_replicate*cp*tp*pp*ep={non_dp}" 

154 ) 

155 self.dp_shard = max(self.world_size // non_dp, 1) 

156 logger.info_rank0( 

157 "Auto-inferred dp_shard=%d (world_size=%d / dp_replicate=%d * " 

158 "cp=%d * tp=%d * pp=%d * ep=%d)", 

159 self.dp_shard, self.world_size, 

160 self.dp_replicate, self.cp, self.tp, self.pp, self.ep, 

161 ) 

162 

163 product = ( 

164 self.dp_replicate * self.dp_shard 

165 * self.cp * self.tp * self.pp * self.ep 

166 ) 

167 if product != self.world_size: 

168 raise ValueError( 

169 f"Invalid parallel dims: dp_replicate({self.dp_replicate}) * " 

170 f"dp_shard({self.dp_shard}) * cp({self.cp}) * tp({self.tp}) * " 

171 f"pp({self.pp}) * ep({self.ep}) = {product} != " 

172 f"world_size({self.world_size}). Set dp_shard=-1 to auto-infer." 

173 ) 

174 

175 # ep is an independent peer mesh dim alongside dp/cp/tp/pp (see build_mesh). 

176 # It does NOT need to divide the dp_shard*cp pool; it occupies its own 

177 # mesh axis. We only enforce the expert-TP compatibility rule below. 

178 if self.ep > 1: 

179 if self.etp not in (self.tp, 1): 

180 raise ValueError( 

181 f"etp={self.etp} must equal tp={self.tp} or 1 " 

182 f"(expert tensor-parallel must align with TP or be inactive)" 

183 ) 

184 

185 # No model has implemented PP wiring yet — fail fast. 

186 if self.pp > 1: 

187 raise NotImplementedError( 

188 f"Pipeline parallel (pp={self.pp}>1) is not yet implemented " 

189 "for any model. Set pp=1 or add a per-model PP path in " 

190 "models/<name>/parallelize.py before enabling." 

191 ) 

192 

193 # Ulysses must divide cp. 

194 if self.ulysses_degree is not None and self.cp > 1: 

195 if self.ulysses_degree > self.cp: 

196 raise ValueError( 

197 f"ulysses_degree={self.ulysses_degree} must be <= " 

198 f"cp={self.cp}" 

199 ) 

200 if self.cp % self.ulysses_degree != 0: 

201 raise ValueError( 

202 f"cp={self.cp} must be divisible by " 

203 f"ulysses_degree={self.ulysses_degree}" 

204 ) 

205 

206 # ------------------------------------------------------------------ 

207 # Validate against an actual model 

208 # ------------------------------------------------------------------ 

209 def validate_against_model( 

210 self, 

211 model, 

212 seq_len: Optional[int] = None, 

213 ) -> None: 

214 """Cross-check parallel degrees against a built model's hyperparams. 

215 

216 Reads ``model.config`` for standard transformer fields. Skips silently 

217 if a field is absent. Model-specific validation (e.g. "TP unsupported 

218 for linear-attn layers") is inlined at the top of each 

219 ``parallelize_<model>()`` function — convention. 

220 

221 Args: 

222 model: The built ``nn.Module`` (must expose ``.config`` to trigger 

223 most checks). 

224 seq_len: Optional maximum sequence length used for cp/tp 

225 divisibility checks. 

226 

227 Raises: 

228 ValueError: With a single readable message when a constraint is 

229 violated. Stops here so the user sees the real cause instead 

230 of a stack trace from inside ``parallelize_module``. 

231 """ 

232 cfg = getattr(model, 'config', None) 

233 if cfg is None: 

234 return 

235 

236 heads = getattr(cfg, 'num_attention_heads', None) 

237 if heads is not None and self.tp > 1 and heads % self.tp != 0: 

238 raise ValueError( 

239 f"num_attention_heads={heads} not divisible by tp={self.tp}. " 

240 f"Pick tp from the divisors of {heads}." 

241 ) 

242 

243 kv_heads = getattr(cfg, 'num_key_value_heads', None) 

244 if kv_heads is not None and self.tp > 1 and kv_heads % self.tp != 0: 

245 raise ValueError( 

246 f"num_key_value_heads={kv_heads} not divisible by tp={self.tp} " 

247 f"(GQA constraint). Pick tp from the divisors of {kv_heads}." 

248 ) 

249 

250 num_experts = getattr(cfg, 'num_experts', None) 

251 if num_experts is not None and self.ep > 1 and num_experts % self.ep != 0: 

252 raise ValueError( 

253 f"num_experts={num_experts} not divisible by ep={self.ep}. " 

254 f"Pick ep from the divisors of {num_experts}." 

255 ) 

256 

257 if seq_len is not None and self.cp * self.tp > 1: 

258 divisor = self.cp * self.tp 

259 if seq_len % divisor != 0: 

260 raise ValueError( 

261 f"max_seq_len={seq_len} not divisible by cp*tp={divisor}. " 

262 f"Increase seq_len to a multiple of {divisor} or reduce " 

263 f"cp/tp." 

264 ) 

265 

266 # ------------------------------------------------------------------ 

267 # Mesh building 

268 # ------------------------------------------------------------------ 

269 def build_mesh(self, device_type: str): 

270 """Build the DeviceMesh with canonical dim order and named flatten aliases. 

271 

272 Order of base dims: ``dp_replicate → dp_shard → ep → cp → tp → pp``. 

273 Only base dims with degree > 1 are materialized; if all are 1, a 1D 

274 ``dp_shard`` mesh of the world is created so the FSDP code path runs 

275 unchanged on single-card. 

276 

277 After construction, the following flatten aliases are registered on 

278 the root mesh so callers can reach them with ``mesh["fsdp"]`` / 

279 ``mesh["dp"]`` regardless of the underlying parallel composition: 

280 

281 ``"fsdp"`` – mesh used for ``fully_shard`` / reduce-scatter. 

282 Always equals the ``dp_shard`` axis. 

283 ``"dp"`` – combined data-parallel mesh used for loss / token 

284 all-reduce. ``dp_replicate × dp_shard`` when both 

285 are >1 (HSDP); otherwise the single non-trivial 

286 DP axis (or ``dp_shard`` for the 1-card case). 

287 

288 Args: 

289 device_type: Backend device string (``"npu"`` / ``"cuda"``). 

290 

291 Returns: 

292 ``DeviceMesh`` instance. 

293 """ 

294 dims = [] 

295 names = [] 

296 for name, size in ( 

297 ("dp_replicate", self.dp_replicate), 

298 ("dp_shard", self.dp_shard), 

299 ("ep", self.ep), 

300 ("cp", self.cp), 

301 ("tp", self.tp), 

302 ("pp", self.pp), 

303 ): 

304 if size > 1: 

305 dims.append(size) 

306 names.append(name) 

307 

308 if not dims: 

309 dims = [self.world_size] 

310 names = ["dp_shard"] 

311 

312 self._device_mesh = init_device_mesh( 

313 device_type=device_type, 

314 mesh_shape=tuple(dims), 

315 mesh_dim_names=tuple(names), 

316 ) 

317 self._register_flatten_aliases(names) 

318 logger.info_rank0( 

319 "DeviceMesh built: shape=%s, names=%s", 

320 tuple(dims), tuple(names), 

321 ) 

322 return self._device_mesh 

323 

324 def _register_flatten_aliases(self, base_names) -> None: 

325 """Register named flatten aliases on the root mesh. 

326 

327 These aliases give the rest of the trainer a stable, intent-named 

328 handle on combined parallel axes so callers never need to fall back 

329 to the whole mesh: 

330 

331 ``"fsdp"`` – the axis FSDP shards along (= ``dp_shard``). 

332 ``"dp"`` – combined data-parallel mesh (replicate × shard). 

333 Used for grad/optimizer-state replication accounting. 

334 ``"loss"`` – the mesh over which loss / token counts are 

335 all-reduced. Equals ``dp × cp`` when CP is enabled 

336 (CP-sharded ranks see different sub-sequences and 

337 must contribute their token counts to the global 

338 denominator); otherwise equals ``dp``. 

339 

340 Reserved names (intentionally not registered yet): 

341 ``"efsdp"`` – FSDP mesh for expert layers when EP > 1. Will 

342 fold ``dp_shard / ep`` once real EP lands. 

343 ``"etp"`` – expert TP mesh (= ``ep × tp`` composition) 

344 alongside dense ``tp``. Same gate. 

345 ``"batch"`` – per-DP batch dispatch mesh; today identical to 

346 ``"dp"``, will diverge if we ever support 

347 microbatch-sharded scheduling. 

348 

349 Idempotent: every flatten call is gated on whether the alias is 

350 already on the root mesh, so repeated ``build_mesh`` calls are 

351 safe. 

352 

353 Args: 

354 base_names: Sequence of base mesh-dim names that were materialized 

355 (degree > 1, plus the degenerate ``dp_shard`` of size 1 when 

356 no other dim was present). 

357 """ 

358 # pylint: disable=protected-access 

359 mesh = self._device_mesh 

360 existing = set(mesh.mesh_dim_names or ()) 

361 flatten_keys = set(mesh._get_root_mesh().get_flatten_mapping().keys()) 

362 

363 def _flatten_unique(source_dims, alias): 

364 if alias in existing or alias in flatten_keys: 

365 return 

366 mesh[source_dims].flatten(alias) 

367 flatten_keys.add(alias) 

368 

369 has_replicate = "dp_replicate" in base_names 

370 has_shard = "dp_shard" in base_names 

371 has_cp = "cp" in base_names 

372 

373 # ``fsdp`` — the axis ``fully_shard`` actually shards along. 

374 if has_shard: 

375 _flatten_unique("dp_shard", "fsdp") 

376 

377 # ``dp`` — combined replicate×shard data-parallel mesh. 

378 if has_replicate and has_shard: 

379 _flatten_unique(("dp_replicate", "dp_shard"), "dp") 

380 elif has_replicate: 

381 _flatten_unique("dp_replicate", "dp") 

382 elif has_shard: 

383 _flatten_unique("dp_shard", "dp") 

384 

385 # ``loss`` — dp folded with cp when context parallelism is active so 

386 # loss / token counts include CP-sharded contributions. 

387 if has_cp: 

388 if has_replicate and has_shard: 

389 _flatten_unique(("dp_replicate", "dp_shard", "cp"), "loss") 

390 elif has_replicate: 

391 _flatten_unique(("dp_replicate", "cp"), "loss") 

392 elif has_shard: 

393 _flatten_unique(("dp_shard", "cp"), "loss") 

394 else: 

395 _flatten_unique("cp", "loss") 

396 else: 

397 # No CP — ``loss`` and ``dp`` are the same group. Re-flatten 

398 # the existing 1D dp mesh under the ``loss`` alias so both 

399 # names resolve via ``__getitem__``. 

400 if "loss" not in flatten_keys and "dp" in flatten_keys: 

401 mesh["dp"].flatten("loss") 

402 flatten_keys.add("loss") 

403 

404 # ------------------------------------------------------------------ 

405 # Convenience properties 

406 # ------------------------------------------------------------------ 

407 @property 

408 def dp_size(self) -> int: 

409 """Combined data-parallel size = dp_replicate * dp_shard.""" 

410 return self.dp_replicate * self.dp_shard 

411 

412 @property 

413 def non_dp_size(self) -> int: 

414 """Product of model-side parallel dims (tp*cp*pp*ep).""" 

415 return self.tp * self.cp * self.pp * self.ep 

416 

417 @property 

418 def tp_enabled(self) -> bool: 

419 return self.tp > 1 

420 

421 @property 

422 def cp_enabled(self) -> bool: 

423 return self.cp > 1 

424 

425 @property 

426 def ep_enabled(self) -> bool: 

427 return self.ep > 1 

428 

429 @property 

430 def pp_enabled(self) -> bool: 

431 return self.pp > 1 

432 

433 @property 

434 def fsdp_enabled(self) -> bool: 

435 """FSDP is on whenever there's a shard dim or HSDP outer dim.""" 

436 return self.dp_shard > 1 or self.dp_replicate > 1 

437 

438 def summary(self) -> str: 

439 """Compact one-line summary for logging.""" 

440 return ( 

441 f"dp_replicate={self.dp_replicate} dp_shard={self.dp_shard} " 

442 f"cp={self.cp} tp={self.tp} pp={self.pp} ep={self.ep} " 

443 f"etp={self.etp} | dp={self.dp_size} world={self.world_size}" 

444 ) 

445 

446__all__ = ["ParallelDims"]