Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / clip_grad.py: 0%

289 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed-aware gradient clipping for parallel training. 

16 

17Communication is driven by each parameter's DTensorSpec (device_mesh + 

18placements) rather than any specific parallelism strategy, so a single 

19implementation covers FSDP, HSDP, TP+FSDP, and other DTensor-expressed 

20parallelisms. 

21 

22Collective safety aligned with FSDP1; numerical precision aligned with 

23FSDP2's ``_NormPartial`` norm computation path: 

24 

25* Gradient norms from sharded parameters are all-reduced across the 

26 corresponding shard process group. 

27* Non-sharded / replicated norms contribute locally without communication. 

28* **All ranks participate in the same collectives** regardless of local 

29 gradient availability, preventing collective-misalignment deadlocks. 

30 

31Note: PP does not use DTensor layout for gradients today. Cross-stage 

32norm aggregation will require an additional manual all-reduce and is 

33left for future work. 

34""" 

35import functools 

36import math 

37import warnings 

38from collections import defaultdict 

39from typing import Dict, Iterable, List, Optional, Tuple, Union 

40 

41import torch 

42import torch.distributed as dist 

43 

44from hyper_parallel.core.dtensor.dtensor import DTensor 

45from hyper_parallel.core.dtensor.placement_types import Partial 

46 

47try: 

48 from torch.utils._foreach_utils import ( 

49 _device_has_foreach_support, 

50 _group_tensors_by_device_and_dtype, 

51 _has_foreach_support, 

52 ) 

53except ImportError: 

54 _device_has_foreach_support = None # type: ignore[assignment] 

55 _group_tensors_by_device_and_dtype = None # type: ignore[assignment] 

56 _has_foreach_support = None # type: ignore[assignment] 

57 

58__all__: list[str] = ["clip_grad_norm_"] 

59 

60 

61# (id(mesh) or None, shard_dims) -> list of local grads for norm computation 

62_GradGroupKey = Tuple[Optional[int], Tuple[int, ...]] 

63 

64# (mesh_dim_index, dist.ReduceOp, needs_manual_avg) 

65_PartialReduceInfo = Tuple[int, "dist.ReduceOp", bool] 

66 

67 

68# --------------------------------------------------------------------------- 

69# Reduce-op mapping 

70# --------------------------------------------------------------------------- 

71 

72_REDUCE_OP_AVG_SUPPORTED = hasattr(dist.ReduceOp, "AVG") 

73 

74_STR_TO_REDUCE_OP: Dict[str, "dist.ReduceOp"] = { 

75 "sum": dist.ReduceOp.SUM, 

76 "max": dist.ReduceOp.MAX, 

77 "min": dist.ReduceOp.MIN, 

78} 

79if _REDUCE_OP_AVG_SUPPORTED: 

80 _STR_TO_REDUCE_OP["avg"] = dist.ReduceOp.AVG 

81 

82 

83def _str_to_reduce_op(op_str: str) -> Tuple["dist.ReduceOp", bool]: 

84 """Map a ``Partial`` placement's *reduce_op* string to ``dist.ReduceOp``. 

85 

86 Returns ``(reduce_op, needs_manual_avg)`` where *needs_manual_avg* 

87 is ``True`` when ``"avg"`` is requested but the backend does not 

88 support ``dist.ReduceOp.AVG`` — the caller should use SUM and 

89 manually divide by the group size. 

90 """ 

91 lower = op_str.lower() 

92 if lower == "avg" and not _REDUCE_OP_AVG_SUPPORTED: 

93 return dist.ReduceOp.SUM, True 

94 op = _STR_TO_REDUCE_OP.get(lower) 

95 if op is None: 

96 raise ValueError( 

97 f"Unsupported Partial reduce_op: {op_str!r}. " 

98 f"Supported: {sorted(set(list(_STR_TO_REDUCE_OP) + ['avg']))}" 

99 ) 

100 return op, False 

101 

102 

103# --------------------------------------------------------------------------- 

104# Helpers 

105# --------------------------------------------------------------------------- 

106 

107def _normalize_parameters( 

108 parameters: Union["torch.nn.Module", torch.Tensor, Iterable[torch.Tensor]], 

109) -> List[torch.Tensor]: 

110 """Normalize *parameters* to a flat list of tensors. 

111 

112 * ``torch.nn.Module`` -> ``list(module.parameters())`` 

113 * single ``torch.Tensor`` -> ``[tensor]`` 

114 * iterable of tensors -> ``list(iterable)`` 

115 """ 

116 if isinstance(parameters, torch.nn.Module): 

117 return list(parameters.parameters()) 

118 if isinstance(parameters, torch.Tensor): 

119 return [parameters] 

120 return list(parameters) 

121 

122 

123def _param_device(param: torch.Tensor) -> torch.device: 

124 """Return the local device of *param* (unwrap DTensor if needed).""" 

125 if isinstance(param, DTensor): 

126 return param._local_tensor.device # pylint: disable=protected-access 

127 return param.device 

128 

129 

130def _get_grad_obj(param: torch.nn.Parameter) -> Optional[torch.Tensor]: 

131 """Return the gradient object for *param*. 

132 

133 Checks ``param.main_grad`` first (used when 

134 ``MixedPrecisionPolicy.apply_grad_on_fp32_main_grad=True``), 

135 falling back to ``param.grad``. 

136 """ 

137 grad = getattr(param, "main_grad", None) 

138 if grad is not None: 

139 return grad 

140 return param.grad 

141 

142 

143def _get_local_grad(param: torch.nn.Parameter) -> Optional[torch.Tensor]: 

144 """Return the local gradient tensor, or ``None`` if absent. 

145 

146 Supports ``main_grad`` for fp32 mixed-precision training. 

147 """ 

148 if not param.requires_grad: 

149 return None 

150 grad = _get_grad_obj(param) 

151 if grad is None: 

152 return None 

153 if isinstance(grad, DTensor): 

154 return grad._local_tensor # pylint: disable=protected-access 

155 return grad 

156 

157 

158def _get_param_mesh_info( 

159 param: torch.nn.Parameter, 

160) -> Tuple[ 

161 Optional[object], 

162 Tuple[int, ...], 

163 Tuple[_PartialReduceInfo, ...], 

164]: 

165 """Derive DeviceMesh, Shard dims and Partial info from DTensorSpec. 

166 

167 Checks the *gradient's* spec first; falls back to the *parameter's* 

168 spec when the gradient is a plain tensor on a DTensor parameter 

169 (common after FSDP/HSDP backward where ``param.grad`` is stored as 

170 the local shard tensor). 

171 

172 Returns ``(mesh, shard_dims, partial_info)`` where *partial_info* 

173 is a tuple of ``(mesh_dim, dist.ReduceOp, needs_manual_avg)`` 

174 triples that respect the ``Partial`` placement's ``reduce_op`` 

175 attribute. *needs_manual_avg* is ``True`` when ``"avg"`` was 

176 requested but the backend lacks ``dist.ReduceOp.AVG`` support. 

177 """ 

178 grad = _get_grad_obj(param) 

179 # Prefer grad's spec (most accurate); fall back to param's. 

180 spec_source = grad if isinstance(grad, DTensor) else param 

181 if not isinstance(spec_source, DTensor): 

182 return None, (), () 

183 

184 shard_dims = tuple( 

185 i for i, p in enumerate(spec_source.placements) 

186 if p.is_shard() 

187 ) 

188 partial_info = tuple( 

189 (i, *_str_to_reduce_op(p.reduce_op)) 

190 for i, p in enumerate(spec_source.placements) 

191 if isinstance(p, Partial) 

192 ) 

193 return spec_source.device_mesh, shard_dims, partial_info 

194 

195 

196def _sum_p_norms( 

197 dev_grads: List[torch.Tensor], 

198 norm_type: float, 

199 device: torch.device, 

200 total: torch.Tensor, 

201) -> None: 

202 """Accumulate sum-of-p-th-powers for *dev_grads* into *total*.""" 

203 for g in dev_grads: 

204 n = torch.linalg.vector_norm(g, norm_type) 

205 total.add_(n.to(device=device) ** norm_type) 

206 

207 

208def _foreach_p_norms( 

209 grads: List[torch.Tensor], 

210 norm_type: float, 

211 device: torch.device, 

212) -> torch.Tensor: 

213 """Fast path: fuse per-tensor norms via ``_foreach_norm``. 

214 

215 Restricted to float32 tensors to preserve the same numerical 

216 precision as ``vector_norm(dtype=float32)``. Non-float32 tensors 

217 and backends that raise ``RuntimeError`` fall back to per-tensor 

218 ``vector_norm``. 

219 """ 

220 total = torch.tensor(0.0, device=device, dtype=torch.float32) 

221 grouped = _group_tensors_by_device_and_dtype( 

222 [[g.detach() for g in grads]], 

223 ) 

224 for (dev, _), ([dev_grads], _) in grouped.items(): 

225 if ( 

226 dev_grads[0].dtype == torch.float32 

227 and _has_foreach_support(dev_grads, dev) 

228 ): 

229 try: 

230 per_norms = torch._foreach_norm( # pylint: disable=W0212 

231 dev_grads, norm_type, 

232 ) 

233 except RuntimeError: 

234 per_norms = None 

235 if per_norms is not None: 

236 total.add_( 

237 torch.stack([ 

238 n.to(device=device) ** norm_type 

239 for n in per_norms 

240 ]).sum(), 

241 ) 

242 else: 

243 _sum_p_norms(dev_grads, norm_type, device, total) 

244 else: 

245 _sum_p_norms(dev_grads, norm_type, device, total) 

246 return total 

247 

248 

249def _per_tensor_norms( 

250 grads: List[torch.Tensor], 

251 norm_type: float, 

252 device: torch.device, 

253) -> List[torch.Tensor]: 

254 """Return per-tensor norms as a list of scalar tensors on *device*.""" 

255 if not grads: 

256 return [] 

257 

258 if _group_tensors_by_device_and_dtype is None or not hasattr(torch, "_foreach_norm"): 

259 return [ 

260 torch.linalg.vector_norm(g.detach(), norm_type).to(device=device) 

261 for g in grads 

262 ] 

263 

264 norms: List[torch.Tensor] = [] 

265 grouped = _group_tensors_by_device_and_dtype( 

266 [[g.detach() for g in grads]], 

267 ) 

268 for (dev, _), ([dev_grads], _) in grouped.items(): 

269 if dev_grads and _has_foreach_support(dev_grads, dev): 

270 try: 

271 per_norms = torch._foreach_norm( # pylint: disable=W0212 

272 dev_grads, norm_type, 

273 ) 

274 except RuntimeError: 

275 per_norms = None 

276 if per_norms is not None: 

277 norms.extend( 

278 [n.to(device=device) for n in per_norms], 

279 ) 

280 continue 

281 norms.extend([ 

282 torch.linalg.vector_norm(g, norm_type).to(device=device) 

283 for g in dev_grads 

284 ]) 

285 return norms 

286 

287 

288def _compute_local_norm( # pylint: disable=R0911 

289 grads: List[torch.Tensor], 

290 norm_type: float, 

291 device: torch.device, 

292) -> torch.Tensor: 

293 """Compute the combined norm of *grads* locally in FP32. 

294 

295 When *grads* is empty, returns the **identity element** for the 

296 subsequent all-reduce so that this rank contributes a neutral value 

297 (aligned with FSDP1's ``_zero_scalar`` approach): 

298 

299 * ``inf`` -> 0 (neutral for MAX; norms are non-negative) 

300 * ``-inf`` -> +inf (neutral for MIN) 

301 * ``0`` -> 0 (neutral for SUM) 

302 * finite -> 0 (neutral for SUM) 

303 """ 

304 if not grads: 

305 if norm_type == -math.inf: 

306 return torch.tensor( 

307 float("inf"), device=device, dtype=torch.float32, 

308 ) 

309 return torch.tensor(0.0, device=device, dtype=torch.float32) 

310 

311 if norm_type == math.inf: 

312 norms = [ 

313 torch.linalg.vector_norm(g.detach(), math.inf) 

314 for g in grads 

315 ] 

316 return torch.stack(norms).max().to(device) 

317 

318 if norm_type == -math.inf: 

319 norms = [ 

320 torch.linalg.vector_norm(g.detach(), -math.inf) 

321 for g in grads 

322 ] 

323 return torch.stack(norms).min().to(device) 

324 

325 if norm_type == 0: 

326 norms = [ 

327 torch.linalg.vector_norm(g.detach(), 0) 

328 for g in grads 

329 ] 

330 return torch.stack(norms).sum().to(device) 

331 

332 # Finite p-norm: return sum of p-th powers. 

333 if ( 

334 len(grads) > 1 

335 and _group_tensors_by_device_and_dtype is not None 

336 and hasattr(torch, "_foreach_norm") 

337 ): 

338 return _foreach_p_norms(grads, norm_type, device) 

339 

340 # Scalar fallback when foreach utilities are unavailable. 

341 norms = [ 

342 torch.linalg.vector_norm(g.detach(), norm_type) 

343 for g in grads 

344 ] 

345 norm_powers = [n.to(device=device) ** norm_type for n in norms] 

346 return torch.stack(norm_powers).sum() 

347 

348 

349# --------------------------------------------------------------------------- 

350# Total norm aggregation with collectives 

351# --------------------------------------------------------------------------- 

352 

353def _get_total_norm( 

354 grad_groups: Dict[_GradGroupKey, List[torch.Tensor]], 

355 norm_type: float, 

356 mesh_cache: Dict[int, object], 

357 device: torch.device, 

358 all_grads: Optional[List[torch.Tensor]] = None, 

359) -> torch.Tensor: 

360 """Compute total gradient norm with per-group all-reduce.""" 

361 if norm_type == math.inf: 

362 return _total_norm_inf( 

363 grad_groups, norm_type, mesh_cache, device, 

364 dist.ReduceOp.MAX, 

365 ) 

366 

367 if norm_type == -math.inf: 

368 return _total_norm_inf( 

369 grad_groups, norm_type, mesh_cache, device, 

370 dist.ReduceOp.MIN, 

371 ) 

372 

373 if norm_type == 0: 

374 return _total_norm_sum( 

375 grad_groups, norm_type, mesh_cache, device, 

376 ) 

377 

378 # Finite p-norm: FSDP2-aligned sequence. 

379 total_p = _total_norm_fsdp2_aligned( 

380 grad_groups, norm_type, mesh_cache, device, all_grads, 

381 ) 

382 return total_p ** (1.0 / norm_type) 

383 

384 

385def _total_norm_inf( # pylint: disable=R0913,R0917 

386 grad_groups, norm_type, mesh_cache, device, reduce_op, 

387): 

388 """Shared logic for inf / -inf norms.""" 

389 group_norms: List[torch.Tensor] = [] 

390 for (mesh_id, shard_dims), grads in grad_groups.items(): 

391 local_norm = _compute_local_norm(grads, norm_type, device) 

392 if mesh_id is not None: 

393 mesh = mesh_cache[mesh_id] 

394 for dim in shard_dims: 

395 dist.all_reduce( 

396 local_norm, op=reduce_op, 

397 group=mesh.get_group(dim), 

398 ) 

399 group_norms.append(local_norm) 

400 if not group_norms: 

401 if norm_type == -math.inf: 

402 return torch.tensor(float("inf"), device=device) 

403 return torch.tensor(0.0, device=device) 

404 stacked = torch.stack(group_norms) 

405 return stacked.max() if reduce_op == dist.ReduceOp.MAX else stacked.min() 

406 

407 

408def _total_norm_sum(grad_groups, norm_type, mesh_cache, device): 

409 """Shared logic for finite norms and L0 (all use SUM all-reduce).""" 

410 total = torch.tensor(0.0, device=device) 

411 for (mesh_id, shard_dims), grads in grad_groups.items(): 

412 local_val = _compute_local_norm(grads, norm_type, device) 

413 if mesh_id is not None: 

414 mesh = mesh_cache[mesh_id] 

415 for dim in shard_dims: 

416 dist.all_reduce( 

417 local_val, op=dist.ReduceOp.SUM, 

418 group=mesh.get_group(dim), 

419 ) 

420 total.add_(local_val) 

421 return total 

422 

423 

424def _total_norm_fsdp2_aligned(grad_groups, norm_type, mesh_cache, device, 

425 all_grads=None): 

426 """FSDP2-aligned norm for finite p-norms. 

427 

428 Computes per-param norms in **original parameter order** (matching 

429 FSDP2's ``[vector_norm(g) for g in grads]`` iteration), then does 

430 ONE stack → ONE vector_norm → ONE all_reduce. 

431 

432 The original parameter order matters because ``vector_norm(stack(...))`` 

433 accumulates in stack order. Different orders can produce 1 ULP 

434 different results on non-rank-0 ranks, causing the all_reduce sum 

435 to differ. 

436 

437 Returns the global sum of p-th powers (caller takes p-th root). 

438 """ 

439 # --- Phase 1: find the reduce group --- 

440 reduce_mesh = None 

441 reduce_shard_dims: Tuple[int, ...] = () 

442 

443 for (mesh_id, shard_dims), _ in grad_groups.items(): 

444 if mesh_id is not None and reduce_mesh is None: 

445 reduce_mesh = mesh_cache[mesh_id] 

446 reduce_shard_dims = shard_dims 

447 

448 # --- Phase 1b: per-param norms in ORIGINAL param order --- 

449 # all_grads is in the same order as `parameters` passed to clip_grad_norm_, 

450 # matching FSDP2's `[p.grad for p in parameters if p.grad is not None]`. 

451 if all_grads: 

452 all_norms = _per_tensor_norms(all_grads, norm_type, device) 

453 else: 

454 all_norms = [] 

455 

456 # --- Phase 2: ONE vector_norm on local norms --- 

457 if not all_norms: 

458 partial_sq = torch.tensor(0.0, device=device, dtype=torch.float32) 

459 else: 

460 combined = torch.linalg.vector_norm( 

461 torch.stack(all_norms), norm_type, 

462 ) 

463 partial_sq = combined ** norm_type 

464 

465 # --- Phase 3: ONE all_reduce --- 

466 if reduce_mesh is not None: 

467 for dim in reduce_shard_dims: 

468 dist.all_reduce(partial_sq, op=dist.ReduceOp.SUM, 

469 group=reduce_mesh.get_group(dim)) 

470 

471 return partial_sq 

472 

473 

474def _build_coalesce_buffer( 

475 param_infos: List[Tuple], 

476 indices: List[int], 

477) -> Tuple[List[torch.Tensor], List[int], List[bool], List[int]]: 

478 """Build flat fp32 chunks for one coalesce group. 

479 

480 Returns ``(chunks, chunk_sizes, has_grad, active_indices)``. 

481 Frozen params are skipped; trainable grad-free params contribute 

482 zeros so the collective matches ranks that have a grad. 

483 """ 

484 chunks: List[torch.Tensor] = [] 

485 chunk_sizes: List[int] = [] 

486 has_grad: List[bool] = [] 

487 active_indices: List[int] = [] 

488 

489 for idx in indices: 

490 param = param_infos[idx][0] 

491 local_grad = param_infos[idx][1] 

492 if local_grad is not None: 

493 chunks.append( 

494 local_grad.detach().reshape(-1).to(torch.float32), 

495 ) 

496 chunk_sizes.append(local_grad.numel()) 

497 has_grad.append(True) 

498 active_indices.append(idx) 

499 elif param.requires_grad: 

500 local_p = ( 

501 param._local_tensor # pylint: disable=W0212 

502 if isinstance(param, DTensor) else param.data 

503 ) 

504 numel = local_p.numel() 

505 chunks.append( 

506 torch.zeros( 

507 numel, device=local_p.device, 

508 dtype=torch.float32, 

509 ), 

510 ) 

511 chunk_sizes.append(numel) 

512 has_grad.append(False) 

513 active_indices.append(idx) 

514 

515 return chunks, chunk_sizes, has_grad, active_indices 

516 

517 

518def _coalesce_partial_reduce( # pylint: disable=R0914 

519 param_infos: List[Tuple], 

520 mesh_cache: Dict[int, object], 

521) -> Dict[int, torch.Tensor]: 

522 """Coalesce Partial all-reduces: O(N) collectives → O(G). 

523 

524 Groups parameters sharing the same ``(mesh, partial_info)`` and 

525 flattens their gradients (or zeros for trainable grad-free params) 

526 into a single fp32 buffer. **One** ``all_reduce`` per buffer 

527 replaces the previous per-parameter collective calls. 

528 

529 For TP+FSDP (all params share the same mesh / placements), this 

530 turns ~200 individual all-reduces into 1 — saving 10-20 ms per 

531 training step at typical HCCS/NCCL latencies. 

532 

533 Frozen params (``requires_grad=False``) are consistently grad-free 

534 across all ranks and are excluded from the buffer to avoid wasting 

535 bandwidth. 

536 

537 All buffers use float32 to guarantee dtype consistency across ranks 

538 in mixed-precision training (grad may be fp16/bf16 while param is 

539 fp32). 

540 

541 Returns a dict mapping *param_infos* index → reduced gradient view 

542 (1-D fp32 slice of the coalesced buffer). Only entries for params 

543 with actual gradients are included. 

544 """ 

545 # Group by Partial coalesce key: (mesh_id, partial_info) 

546 coalesce_groups: Dict[ 

547 Tuple, List[int], 

548 ] = defaultdict(list) 

549 for idx, info in enumerate(param_infos): 

550 mesh, partial_info = info[2], info[3] 

551 if partial_info: 

552 if mesh is None: 

553 raise RuntimeError( 

554 "clip_grad_norm_: parameter has Partial placements " 

555 "but no DeviceMesh. This is a DTensor invariant " 

556 "violation." 

557 ) 

558 pck = (id(mesh), partial_info) 

559 coalesce_groups[pck].append(idx) 

560 

561 reduced: Dict[int, torch.Tensor] = {} 

562 

563 for (mesh_id, partial_info), indices in coalesce_groups.items(): 

564 mesh = mesh_cache[mesh_id] 

565 chunks, chunk_sizes, has_grad, active_indices = ( 

566 _build_coalesce_buffer(param_infos, indices) 

567 ) 

568 

569 if not chunks: 

570 continue # all params frozen, no collective needed 

571 

572 # Sanity check: same mesh → same device. Fail fast on 

573 # misconfigured inputs rather than silent NCCL errors. 

574 buf_device = chunks[0].device 

575 for chunk in chunks[1:]: 

576 if chunk.device != buf_device: 

577 raise RuntimeError( 

578 f"clip_grad_norm_: parameters in the same Partial " 

579 f"coalesce group are on different devices " 

580 f"({buf_device} vs {chunk.device}). All parameters " 

581 f"sharing the same DeviceMesh must reside on the " 

582 f"same local device." 

583 ) 

584 

585 buf = torch.cat(chunks) 

586 

587 for pdim, reduce_op, needs_avg in partial_info: 

588 group = mesh.get_group(pdim) 

589 dist.all_reduce(buf, op=reduce_op, group=group) 

590 if needs_avg: 

591 buf /= dist.get_world_size(group=group) 

592 

593 # Extract views for params with actual gradients. 

594 offset = 0 

595 for i, idx in enumerate(active_indices): 

596 numel = chunk_sizes[i] 

597 if has_grad[i]: 

598 reduced[idx] = buf[offset:offset + numel] 

599 offset += numel 

600 

601 return reduced 

602 

603 

604def _build_grad_groups( # pylint: disable=R0914 

605 params: List[torch.Tensor], 

606) -> Tuple[ 

607 Dict[_GradGroupKey, List[torch.Tensor]], 

608 List[torch.Tensor], 

609 Dict[int, object], 

610 torch.device, 

611 bool, 

612]: 

613 """Classify parameters into grad groups and pre-reduce Partial grads. 

614 

615 Group structure is derived from *parameter* DTensorSpecs (always 

616 present on every rank) rather than gradients (which may be ``None`` 

617 on some ranks). This ensures every rank enters the same set of 

618 collectives, preventing deadlocks (aligned with FSDP1 where all 

619 ranks unconditionally execute the same all-reduce path). 

620 

621 Partial gradients are reduced via a **coalesced** all-reduce 

622 (see ``_coalesce_partial_reduce``), turning O(N) per-parameter 

623 collectives into O(G) where G is the number of distinct 

624 ``(mesh, partial_info)`` groups (typically 1 for TP+FSDP). 

625 

626 Returns ``(grad_groups, all_grads, mesh_cache, device, has_dtensor_grad)``. 

627 """ 

628 # --- Phase 1: classify all parameters --- 

629 param_infos: List[Tuple] = [] 

630 mesh_cache: Dict[int, object] = {} 

631 device: Optional[torch.device] = None 

632 

633 for param in params: 

634 mesh, shard_dims, partial_info = _get_param_mesh_info(param) 

635 key: _GradGroupKey = ( 

636 id(mesh) if mesh is not None else None, shard_dims, 

637 ) 

638 if mesh is not None: 

639 mesh_cache[id(mesh)] = mesh 

640 if device is None: 

641 device = _param_device(param) 

642 local_grad = _get_local_grad(param) 

643 param_infos.append( 

644 (param, local_grad, mesh, partial_info, key), 

645 ) 

646 

647 if device is None: 

648 device = torch.device("cpu") 

649 

650 # --- Phase 2: coalesced Partial reduction (O(N) → O(G)) --- 

651 reduced = _coalesce_partial_reduce(param_infos, mesh_cache) 

652 

653 # --- Phase 3: build grad_groups --- 

654 grad_groups: Dict[_GradGroupKey, List[torch.Tensor]] = defaultdict( 

655 list, 

656 ) 

657 all_grads: List[torch.Tensor] = [] 

658 has_dtensor_grad = False 

659 

660 for idx, info in enumerate(param_infos): 

661 param, local_grad, key = info[0], info[1], info[4] 

662 if local_grad is None: 

663 # Ensure the key exists so the Shard norm all-reduce is 

664 # entered even when this rank has no grads for the group. 

665 if key not in grad_groups: 

666 grad_groups[key] = [] 

667 continue 

668 

669 grad_obj = _get_grad_obj(param) 

670 if isinstance(grad_obj, DTensor): 

671 has_dtensor_grad = True 

672 all_grads.append(local_grad) 

673 if idx in reduced: 

674 # Use Partial-reduced view for norm computation. 

675 grad_groups[key].append(reduced[idx]) 

676 else: 

677 # Non-Partial: use original grad directly. 

678 grad_groups[key].append(local_grad) 

679 

680 return grad_groups, all_grads, mesh_cache, device, has_dtensor_grad 

681 

682 

683def _clip_grads_with_norm_( 

684 all_grads: List[torch.Tensor], 

685 max_norm: float, 

686 total_norm: torch.Tensor, 

687 foreach: Optional[bool] = None, 

688) -> None: 

689 """Scale gradients in-place so the total norm <= *max_norm*.""" 

690 clip_coef = max_norm / (total_norm + 1e-6) 

691 clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 

692 

693 if _group_tensors_by_device_and_dtype is not None: 

694 grouped_grads = _group_tensors_by_device_and_dtype( 

695 [all_grads], 

696 ) 

697 for (device, dtype), ([device_grads], _) in grouped_grads.items(): 

698 if ( 

699 foreach is None 

700 and _has_foreach_support(device_grads, device) 

701 ) or ( 

702 foreach 

703 and _device_has_foreach_support(device) 

704 ): 

705 torch._foreach_mul_( # pylint: disable=W0212 

706 device_grads, 

707 clip_coef_clamped.to(device=device, dtype=dtype), 

708 ) 

709 elif foreach: 

710 raise RuntimeError( 

711 f"foreach=True was passed, but can't use the " 

712 f"foreach API on {device.type} tensors" 

713 ) 

714 else: 

715 clip_coef_clamped_cast = clip_coef_clamped.to(device=device, dtype=dtype) 

716 for g in device_grads: 

717 g.mul_(clip_coef_clamped_cast) 

718 else: 

719 # Fallback when _foreach_utils is unavailable. 

720 if foreach: 

721 raise RuntimeError( 

722 "foreach=True was passed, but " 

723 "torch.utils._foreach_utils is not available" 

724 ) 

725 for grad in all_grads: 

726 grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype)) 

727 

728 

729# --------------------------------------------------------------------------- 

730# Public API 

731# --------------------------------------------------------------------------- 

732 

733@torch.no_grad() 

734def clip_grad_norm_( 

735 parameters: Union[ 

736 "torch.nn.Module", torch.Tensor, Iterable[torch.Tensor], 

737 ], 

738 max_norm: float, 

739 norm_type: float = 2.0, 

740 error_if_nonfinite: bool = False, 

741 foreach: Optional[bool] = None, 

742) -> torch.Tensor: 

743 """Compute and clip gradient norm for distributed models. 

744 

745 Drop-in replacement for the standard ``clip_grad_norm_`` that 

746 correctly handles DTensor-sharded parameters by deriving 

747 communication from each parameter's DTensorSpec. 

748 

749 .. warning:: This function uses collective communications. It 

750 **must be called on all ranks** to avoid deadlocks. Aligned 

751 with FSDP1: every rank participates in the same collectives 

752 regardless of local gradient availability. 

753 

754 Communication is derived from each parameter's DTensorSpec: 

755 

756 * ``Shard`` on mesh dim *d* -- all-reduce norm statistics 

757 across ``device_mesh.get_group(d)`` 

758 * ``Partial`` on mesh dim *d* -- all-reduce gradient values 

759 using the placement's ``reduce_op`` before norm computation 

760 * ``Replicate`` / plain tensor -- no communication 

761 

762 This covers FSDP, HSDP, TP+FSDP, and any combination expressible 

763 via DTensor placements. PP cross-stage norm aggregation is not 

764 yet handled (requires manual all-reduce across stages). 

765 

766 Args: 

767 parameters: An ``nn.Module``, a single ``Tensor``, or an 

768 iterable of ``Tensor`` s whose gradients to clip. 

769 max_norm: Maximum allowed gradient norm. 

770 norm_type: Type of the norm (default ``2.0``). 

771 error_if_nonfinite: If ``True``, raise a ``RuntimeError`` 

772 when the total norm is non-finite. Default ``False``. 

773 foreach: Use the faster foreach-based implementation for the 

774 gradient clipping step. If ``None``, use the foreach 

775 implementation for devices that support it and silently 

776 fall back to the per-tensor implementation for others. 

777 Default ``None``. 

778 

779 Returns: 

780 The total (unclipped) gradient norm as a scalar tensor, 

781 cast to the promoted dtype of all gradient tensors. 

782 """ 

783 max_norm = float(max_norm) 

784 norm_type = float(norm_type) 

785 

786 params = _normalize_parameters(parameters) 

787 

788 grad_groups, all_grads, mesh_cache, device, has_dtensor_grad = _build_grad_groups(params) 

789 

790 # -- Norm + clip (all ranks participate) -------------------------------- 

791 # _compute_local_norm returns identity elements for empty groups, 

792 # so the subsequent all-reduce is safe and semantically neutral. 

793 total_norm = _get_total_norm( 

794 grad_groups, norm_type, mesh_cache, device, all_grads, 

795 ) 

796 

797 if error_if_nonfinite and torch.logical_or( 

798 total_norm.isnan(), total_norm.isinf() 

799 ): 

800 raise RuntimeError( 

801 f"The total norm of order {norm_type} for gradients from " 

802 "`parameters` is non-finite, so it cannot be clipped. To " 

803 "disable this error and scale the gradients by the " 

804 "non-finite norm anyway, set " 

805 "`error_if_nonfinite=False`" 

806 ) 

807 

808 if all_grads: 

809 # Disable foreach for dtensor-backed grads to avoid dispatch issues. 

810 effective_foreach = False if has_dtensor_grad and foreach is None else foreach 

811 _clip_grads_with_norm_( 

812 all_grads, max_norm, total_norm, effective_foreach, 

813 ) 

814 

815 # Promote return dtype to match gradient dtypes (FSDP1 convention). 

816 # When this rank has no gradients, return in the default FP32 dtype 

817 # (same as FSDP1's behavior to avoid extra communication). 

818 if not all_grads: 

819 warnings.warn( 

820 "clip_grad_norm_ called on this rank with no gradients -- " 

821 "returning the local norm in the default dtype " 

822 f"{total_norm.dtype}", 

823 stacklevel=2, 

824 ) 

825 return total_norm 

826 

827 total_norm_dtype = functools.reduce( 

828 torch.promote_types, 

829 [g.dtype for g in all_grads], 

830 ) 

831 # Return global all-reduced norm, consistent with torchtitan's 

832 # full_tensor() approach — .item() returns the correct global value. 

833 return total_norm.to(total_norm_dtype)