Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / parallel_dims.py: 0%
139 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
16# Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
18"""ParallelDims — fail-fast parallel configuration validator + mesh builder.
20Centralises parallel-degree validation in a single dataclass so
21misconfigurations are caught before model construction.
23What this provides:
251. **Inference** — auto-fill ``dp`` (or ``dp_shard=-1``) from the product
26 constraint ``dp_replicate * dp_shard * cp * tp * pp == world_size``.
282. **Validation against world_size** — raises ``ValueError`` with a clear
29 message when the product mismatches.
313. **Validation against the model** — checks divisibility constraints that
32 would otherwise crash deep inside ``parallelize_module``:
34 - ``num_attention_heads % tp == 0`` (TP shards heads)
35 - ``num_key_value_heads % tp == 0`` (GQA constraint)
36 - ``num_experts % ep == 0`` (when MoE)
37 - ``ulysses_degree <= cp`` and ``cp % ulysses_degree == 0``
38 - ``seq_len % (cp * tp) == 0`` (sequence parallel + CP)
39 - ``etp == tp or etp == 1`` (expert TP rule)
414. **Mesh building** — returns a ``DeviceMesh`` with the canonical dim order
42 ``dp_replicate → dp_shard → ep → cp → tp → pp``. Backwards compatible with
43 the legacy single ``dp`` field (auto-collapses to ``dp_shard``).
45User experience:
47- Default config (no parallel section) — works on 1 GPU, runs as DDP-1.
48- Set only ``tp=4`` on world_size=8 → ``dp`` auto-inferred to 2.
49- Set ``dp_shard=-1`` → fills remaining cards into FSDP shard dim.
50- Misconfig (heads=12, tp=8) → fails at ``_setup`` with a single readable
51 error before any model parallelization is attempted.
52"""
53from __future__ import annotations
55import logging
56from dataclasses import dataclass, field
57from typing import Optional
59from hyper_parallel import init_device_mesh
61logger = logging.getLogger(__name__)
63@dataclass
64class ParallelDims:
65 """Validated parallel degrees + lazy mesh builder.
67 Attributes:
68 dp_replicate: DDP replication degree (HSDP outer dim).
69 dp_shard: FSDP shard degree. ``-1`` means "fill the rest from
70 ``world_size / (dp_replicate * cp * tp * pp)``".
71 cp: Context parallel degree.
72 tp: Tensor parallel degree (dense path).
73 pp: Pipeline parallel degree.
74 ep: Expert parallel degree (MoE only).
75 etp: Expert tensor parallel degree. Must equal ``tp`` or ``1``.
76 ulysses_degree: Ulysses sub-degree inside ``cp``. ``None`` means
77 "pure Ulysses (degree == cp)".
78 world_size: Total number of ranks.
79 """
81 dp_replicate: int = 1
82 dp_shard: int = 1
83 cp: int = 1
84 tp: int = 1
85 pp: int = 1
86 ep: int = 1
87 etp: int = 1
88 ulysses_degree: Optional[int] = None
89 world_size: int = 1
90 # Cached after build_mesh.
91 _device_mesh: object = field(default=None, repr=False)
93 # ------------------------------------------------------------------
94 # Construction & inference
95 # ------------------------------------------------------------------
96 @classmethod
97 def from_config(cls, parallel_cfg, world_size: int) -> "ParallelDims":
98 """Build from a ``ParallelConfig`` (or any object with the same fields).
100 Accepts the legacy single-``dp`` field. If ``dp`` is set and
101 ``dp_replicate``/``dp_shard`` are at default, ``dp`` is mapped to
102 ``dp_shard`` (FSDP behavior).
103 """
104 dp_replicate = getattr(parallel_cfg, 'dp_replicate', 1)
105 dp_shard = getattr(parallel_cfg, 'dp_shard', None)
106 legacy_dp = getattr(parallel_cfg, 'dp', None)
108 # Backward-compat: legacy ``dp`` maps to ``dp_shard`` when both
109 # dp_replicate/dp_shard fields are at defaults.
110 if dp_shard is None:
111 dp_shard = legacy_dp if legacy_dp is not None else 1
113 return cls(
114 dp_replicate=dp_replicate,
115 dp_shard=dp_shard,
116 cp=getattr(parallel_cfg, 'cp', 1),
117 tp=getattr(parallel_cfg, 'tp', 1),
118 pp=getattr(parallel_cfg, 'pp', 1),
119 ep=getattr(parallel_cfg, 'ep', 1),
120 etp=getattr(parallel_cfg, 'etp', getattr(parallel_cfg, 'tp', 1)),
121 ulysses_degree=getattr(parallel_cfg, 'ulysses_degree', None),
122 world_size=world_size,
123 )
125 def __post_init__(self) -> None:
126 self._infer_and_validate()
128 def _infer_and_validate(self) -> None:
129 """Auto-fill ``dp_shard=-1`` then validate ``product == world_size``."""
130 for name, value in (
131 ("dp_replicate", self.dp_replicate),
132 ("cp", self.cp),
133 ("tp", self.tp),
134 ("pp", self.pp),
135 ("ep", self.ep),
136 ("etp", self.etp),
137 ):
138 if value < 1:
139 raise ValueError(f"Parallel degree {name}={value} must be >= 1")
141 if self.dp_shard < -1 or self.dp_shard == 0:
142 raise ValueError(
143 f"dp_shard={self.dp_shard} must be -1 (auto) or a positive int"
144 )
146 # Auto-infer dp_shard when -1. ep is an independent peer mesh dim
147 # (see build_mesh) so it does NOT reduce the dp pool.
148 if self.dp_shard == -1:
149 non_dp = self.dp_replicate * self.cp * self.tp * self.pp * self.ep
150 if self.world_size % non_dp != 0:
151 raise ValueError(
152 f"Cannot auto-infer dp_shard: world_size={self.world_size} "
153 f"is not divisible by dp_replicate*cp*tp*pp*ep={non_dp}"
154 )
155 self.dp_shard = max(self.world_size // non_dp, 1)
156 logger.info_rank0(
157 "Auto-inferred dp_shard=%d (world_size=%d / dp_replicate=%d * "
158 "cp=%d * tp=%d * pp=%d * ep=%d)",
159 self.dp_shard, self.world_size,
160 self.dp_replicate, self.cp, self.tp, self.pp, self.ep,
161 )
163 product = (
164 self.dp_replicate * self.dp_shard
165 * self.cp * self.tp * self.pp * self.ep
166 )
167 if product != self.world_size:
168 raise ValueError(
169 f"Invalid parallel dims: dp_replicate({self.dp_replicate}) * "
170 f"dp_shard({self.dp_shard}) * cp({self.cp}) * tp({self.tp}) * "
171 f"pp({self.pp}) * ep({self.ep}) = {product} != "
172 f"world_size({self.world_size}). Set dp_shard=-1 to auto-infer."
173 )
175 # ep is an independent peer mesh dim alongside dp/cp/tp/pp (see build_mesh).
176 # It does NOT need to divide the dp_shard*cp pool; it occupies its own
177 # mesh axis. We only enforce the expert-TP compatibility rule below.
178 if self.ep > 1:
179 if self.etp not in (self.tp, 1):
180 raise ValueError(
181 f"etp={self.etp} must equal tp={self.tp} or 1 "
182 f"(expert tensor-parallel must align with TP or be inactive)"
183 )
185 # No model has implemented PP wiring yet — fail fast.
186 if self.pp > 1:
187 raise NotImplementedError(
188 f"Pipeline parallel (pp={self.pp}>1) is not yet implemented "
189 "for any model. Set pp=1 or add a per-model PP path in "
190 "models/<name>/parallelize.py before enabling."
191 )
193 # Ulysses must divide cp.
194 if self.ulysses_degree is not None and self.cp > 1:
195 if self.ulysses_degree > self.cp:
196 raise ValueError(
197 f"ulysses_degree={self.ulysses_degree} must be <= "
198 f"cp={self.cp}"
199 )
200 if self.cp % self.ulysses_degree != 0:
201 raise ValueError(
202 f"cp={self.cp} must be divisible by "
203 f"ulysses_degree={self.ulysses_degree}"
204 )
206 # ------------------------------------------------------------------
207 # Validate against an actual model
208 # ------------------------------------------------------------------
209 def validate_against_model(
210 self,
211 model,
212 seq_len: Optional[int] = None,
213 ) -> None:
214 """Cross-check parallel degrees against a built model's hyperparams.
216 Reads ``model.config`` for standard transformer fields. Skips silently
217 if a field is absent. Model-specific validation (e.g. "TP unsupported
218 for linear-attn layers") is inlined at the top of each
219 ``parallelize_<model>()`` function — convention.
221 Args:
222 model: The built ``nn.Module`` (must expose ``.config`` to trigger
223 most checks).
224 seq_len: Optional maximum sequence length used for cp/tp
225 divisibility checks.
227 Raises:
228 ValueError: With a single readable message when a constraint is
229 violated. Stops here so the user sees the real cause instead
230 of a stack trace from inside ``parallelize_module``.
231 """
232 cfg = getattr(model, 'config', None)
233 if cfg is None:
234 return
236 heads = getattr(cfg, 'num_attention_heads', None)
237 if heads is not None and self.tp > 1 and heads % self.tp != 0:
238 raise ValueError(
239 f"num_attention_heads={heads} not divisible by tp={self.tp}. "
240 f"Pick tp from the divisors of {heads}."
241 )
243 kv_heads = getattr(cfg, 'num_key_value_heads', None)
244 if kv_heads is not None and self.tp > 1 and kv_heads % self.tp != 0:
245 raise ValueError(
246 f"num_key_value_heads={kv_heads} not divisible by tp={self.tp} "
247 f"(GQA constraint). Pick tp from the divisors of {kv_heads}."
248 )
250 num_experts = getattr(cfg, 'num_experts', None)
251 if num_experts is not None and self.ep > 1 and num_experts % self.ep != 0:
252 raise ValueError(
253 f"num_experts={num_experts} not divisible by ep={self.ep}. "
254 f"Pick ep from the divisors of {num_experts}."
255 )
257 if seq_len is not None and self.cp * self.tp > 1:
258 divisor = self.cp * self.tp
259 if seq_len % divisor != 0:
260 raise ValueError(
261 f"max_seq_len={seq_len} not divisible by cp*tp={divisor}. "
262 f"Increase seq_len to a multiple of {divisor} or reduce "
263 f"cp/tp."
264 )
266 # ------------------------------------------------------------------
267 # Mesh building
268 # ------------------------------------------------------------------
269 def build_mesh(self, device_type: str):
270 """Build the DeviceMesh with canonical dim order and named flatten aliases.
272 Order of base dims: ``dp_replicate → dp_shard → ep → cp → tp → pp``.
273 Only base dims with degree > 1 are materialized; if all are 1, a 1D
274 ``dp_shard`` mesh of the world is created so the FSDP code path runs
275 unchanged on single-card.
277 After construction, the following flatten aliases are registered on
278 the root mesh so callers can reach them with ``mesh["fsdp"]`` /
279 ``mesh["dp"]`` regardless of the underlying parallel composition:
281 ``"fsdp"`` – mesh used for ``fully_shard`` / reduce-scatter.
282 Always equals the ``dp_shard`` axis.
283 ``"dp"`` – combined data-parallel mesh used for loss / token
284 all-reduce. ``dp_replicate × dp_shard`` when both
285 are >1 (HSDP); otherwise the single non-trivial
286 DP axis (or ``dp_shard`` for the 1-card case).
288 Args:
289 device_type: Backend device string (``"npu"`` / ``"cuda"``).
291 Returns:
292 ``DeviceMesh`` instance.
293 """
294 dims = []
295 names = []
296 for name, size in (
297 ("dp_replicate", self.dp_replicate),
298 ("dp_shard", self.dp_shard),
299 ("ep", self.ep),
300 ("cp", self.cp),
301 ("tp", self.tp),
302 ("pp", self.pp),
303 ):
304 if size > 1:
305 dims.append(size)
306 names.append(name)
308 if not dims:
309 dims = [self.world_size]
310 names = ["dp_shard"]
312 self._device_mesh = init_device_mesh(
313 device_type=device_type,
314 mesh_shape=tuple(dims),
315 mesh_dim_names=tuple(names),
316 )
317 self._register_flatten_aliases(names)
318 logger.info_rank0(
319 "DeviceMesh built: shape=%s, names=%s",
320 tuple(dims), tuple(names),
321 )
322 return self._device_mesh
324 def _register_flatten_aliases(self, base_names) -> None:
325 """Register named flatten aliases on the root mesh.
327 These aliases give the rest of the trainer a stable, intent-named
328 handle on combined parallel axes so callers never need to fall back
329 to the whole mesh:
331 ``"fsdp"`` – the axis FSDP shards along (= ``dp_shard``).
332 ``"dp"`` – combined data-parallel mesh (replicate × shard).
333 Used for grad/optimizer-state replication accounting.
334 ``"loss"`` – the mesh over which loss / token counts are
335 all-reduced. Equals ``dp × cp`` when CP is enabled
336 (CP-sharded ranks see different sub-sequences and
337 must contribute their token counts to the global
338 denominator); otherwise equals ``dp``.
340 Reserved names (intentionally not registered yet):
341 ``"efsdp"`` – FSDP mesh for expert layers when EP > 1. Will
342 fold ``dp_shard / ep`` once real EP lands.
343 ``"etp"`` – expert TP mesh (= ``ep × tp`` composition)
344 alongside dense ``tp``. Same gate.
345 ``"batch"`` – per-DP batch dispatch mesh; today identical to
346 ``"dp"``, will diverge if we ever support
347 microbatch-sharded scheduling.
349 Idempotent: every flatten call is gated on whether the alias is
350 already on the root mesh, so repeated ``build_mesh`` calls are
351 safe.
353 Args:
354 base_names: Sequence of base mesh-dim names that were materialized
355 (degree > 1, plus the degenerate ``dp_shard`` of size 1 when
356 no other dim was present).
357 """
358 # pylint: disable=protected-access
359 mesh = self._device_mesh
360 existing = set(mesh.mesh_dim_names or ())
361 flatten_keys = set(mesh._get_root_mesh().get_flatten_mapping().keys())
363 def _flatten_unique(source_dims, alias):
364 if alias in existing or alias in flatten_keys:
365 return
366 mesh[source_dims].flatten(alias)
367 flatten_keys.add(alias)
369 has_replicate = "dp_replicate" in base_names
370 has_shard = "dp_shard" in base_names
371 has_cp = "cp" in base_names
373 # ``fsdp`` — the axis ``fully_shard`` actually shards along.
374 if has_shard:
375 _flatten_unique("dp_shard", "fsdp")
377 # ``dp`` — combined replicate×shard data-parallel mesh.
378 if has_replicate and has_shard:
379 _flatten_unique(("dp_replicate", "dp_shard"), "dp")
380 elif has_replicate:
381 _flatten_unique("dp_replicate", "dp")
382 elif has_shard:
383 _flatten_unique("dp_shard", "dp")
385 # ``loss`` — dp folded with cp when context parallelism is active so
386 # loss / token counts include CP-sharded contributions.
387 if has_cp:
388 if has_replicate and has_shard:
389 _flatten_unique(("dp_replicate", "dp_shard", "cp"), "loss")
390 elif has_replicate:
391 _flatten_unique(("dp_replicate", "cp"), "loss")
392 elif has_shard:
393 _flatten_unique(("dp_shard", "cp"), "loss")
394 else:
395 _flatten_unique("cp", "loss")
396 else:
397 # No CP — ``loss`` and ``dp`` are the same group. Re-flatten
398 # the existing 1D dp mesh under the ``loss`` alias so both
399 # names resolve via ``__getitem__``.
400 if "loss" not in flatten_keys and "dp" in flatten_keys:
401 mesh["dp"].flatten("loss")
402 flatten_keys.add("loss")
404 # ------------------------------------------------------------------
405 # Convenience properties
406 # ------------------------------------------------------------------
407 @property
408 def dp_size(self) -> int:
409 """Combined data-parallel size = dp_replicate * dp_shard."""
410 return self.dp_replicate * self.dp_shard
412 @property
413 def non_dp_size(self) -> int:
414 """Product of model-side parallel dims (tp*cp*pp*ep)."""
415 return self.tp * self.cp * self.pp * self.ep
417 @property
418 def tp_enabled(self) -> bool:
419 return self.tp > 1
421 @property
422 def cp_enabled(self) -> bool:
423 return self.cp > 1
425 @property
426 def ep_enabled(self) -> bool:
427 return self.ep > 1
429 @property
430 def pp_enabled(self) -> bool:
431 return self.pp > 1
433 @property
434 def fsdp_enabled(self) -> bool:
435 """FSDP is on whenever there's a shard dim or HSDP outer dim."""
436 return self.dp_shard > 1 or self.dp_replicate > 1
438 def summary(self) -> str:
439 """Compact one-line summary for logging."""
440 return (
441 f"dp_replicate={self.dp_replicate} dp_shard={self.dp_shard} "
442 f"cp={self.cp} tp={self.tp} pp={self.pp} ep={self.ep} "
443 f"etp={self.etp} | dp={self.dp_size} world={self.world_size}"
444 )
446__all__ = ["ParallelDims"]