Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / state.py: 59%

352 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP cell state""" 

16# pylint: disable=protected-access 

17 

18from typing import Optional, List 

19from collections import defaultdict 

20import torch 

21 

22from hyper_parallel.core.fully_shard.hsdp_state import HSDPState 

23from hyper_parallel.core.fully_shard.hsdp_utils import ( 

24 FullyShardParamMode, 

25 _get_param_module_infos, 

26 infer_fully_shard_param_mode, 

27) 

28from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy 

29from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2 

30from hyper_parallel.platform.torch.fully_shard.pack_utils import build_rs_plan 

31from hyper_parallel.platform.torch.fully_shard.param_group import get_comm_ctx, HSDPParamGroup, AllReduceParamGroup 

32 

33 

34def _to_dtype_if_needed( 

35 tensor: torch.Tensor, dtype: Optional[torch.dtype] 

36) -> torch.Tensor: 

37 """Cast tensor to the given dtype if it differs from current dtype. 

38 

39 Args: 

40 tensor: The input tensor to potentially cast. 

41 dtype: Target dtype. If None or same as tensor dtype, no-op. 

42 """ 

43 if dtype is not None and tensor.dtype != dtype: 

44 return tensor.to(dtype) 

45 return tensor 

46 

47 

48class TorchHSDPStateV2(HSDPState): 

49 """Torch HSDP cell state""" 

50 # DTensor compat parameters in pure-TP mode can accumulate gradients 

51 # directly on ``sharded_param.grad`` without ever materializing an 

52 # ``_unsharded_param``. Track their async all-reduce work separately from 

53 # the standard unsharded-grad queues. 

54 pre_direct_all_reduce_grads = [] 

55 # Record AllReduceParamGroup that has reduce_scatter issued, waiting for next post_backward to process 

56 pre_all_reduce_groups: List[AllReduceParamGroup] = [] 

57 # Record AllReduceParamGroup that has all_reduce issued, waiting for root_backward_hook to apply 

58 pending_all_reduce_groups: List[AllReduceParamGroup] = [] 

59 @staticmethod 

60 def _get_pending_unsharded_grad(hsdp_param): 

61 """Return the pending unsharded gradient tensor for all-reduce-based paths.""" 

62 if hsdp_param.unsharded_accumulated_grad is not None: 

63 return hsdp_param.unsharded_accumulated_grad_data 

64 return hsdp_param.unsharded_grad_data 

65 

66 @staticmethod 

67 def _has_pending_unsharded_grad(hsdp_param): 

68 """Whether the parameter currently has a gradient waiting for reduction.""" 

69 if hsdp_param.unsharded_accumulated_grad is not None: 

70 return True 

71 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

72 return False 

73 return hsdp_param.unsharded_param.grad is not None 

74 

75 @staticmethod 

76 def _get_local_sharded_grad(hsdp_param): 

77 """Return the local gradient tensor currently stored on ``sharded_param``.""" 

78 grad = hsdp_param.sharded_param.grad 

79 if grad is None: 

80 return None 

81 to_local = getattr(grad, "to_local", None) 

82 if callable(to_local): 

83 return to_local() 

84 return grad 

85 

86 def __init__(self, cell, mesh_info, config, platform, device): 

87 """ 

88 Initialize TorchHSDPStateV2. 

89 

90 Args: 

91 cell (nn.Module): The module whose parameters are managed by this state. 

92 mesh_info: Mesh topology for shard/replicate dimensions. 

93 config (HSDPConfigV2): HSDP configuration. 

94 platform (TorchPlatform): Torch platform abstraction. 

95 device (torch.device): Target device. 

96 """ 

97 super().__init__(cell, mesh_info, config, platform, device) 

98 self.comm_fusion = config.comm_fusion 

99 # Do ReduceScatter/AllReduce for grad 

100 self.device = device 

101 self.mp_policy = config.mp_policy 

102 self.offload_policy = config.offload_policy 

103 self.reduce_grads = True 

104 # Reshard parameter after backward 

105 self.reshard_after_backward = True 

106 # Requires AllReduce for grad When HSDP 

107 self.requires_all_reduce = True 

108 # Default reduce op is decided at the fully_shard-state level: 

109 # if any managed parameter is DTensor-backed, use SUM; otherwise AVG. 

110 self._user_reduce_op_type = None 

111 self.reduce_op_type = self._resolve_default_reduce_op() 

112 self._reset_sharded_params = False 

113 self._init_param_group() 

114 

115 @staticmethod 

116 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]: 

117 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion.""" 

118 if not hsdp_param.enable_fsdp_shard: 

119 return "non-sharded parameters such as replicate_params are not supported" 

120 if hsdp_param.param_mode not in ( 

121 FullyShardParamMode.LOCAL_PARAM, 

122 FullyShardParamMode.DTENSOR_UNIFIED, 

123 ): 

124 return ( 

125 "param_mode " 

126 f"{hsdp_param.param_mode} is not supported" 

127 ) 

128 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None) 

129 if local_shard is None: 

130 return "missing local shard tensor for comm_fusion plan validation" 

131 plan_world_size = getattr(hsdp_param, "shard_world_size", None) 

132 if plan_world_size is None: 

133 plan_world_size = getattr(hsdp_param, "shard_size", 1) 

134 try: 

135 build_rs_plan(hsdp_param, local_shard, plan_world_size) 

136 except NotImplementedError as exc: 

137 return str(exc) 

138 except (AssertionError, ValueError) as exc: 

139 return f"cannot build comm_fusion pack plan: {exc}" 

140 return None 

141 

142 def _init_param_group(self): 

143 """Initialize fused parameter group for communication fusion. 

144 

145 When ``comm_fusion`` is enabled, creates an ``HSDPParamGroup`` that packs all 

146 parameters into a single buffer for fused all-gather and reduce-scatter, 

147 replacing the per-parameter communication pattern. 

148 """ 

149 if self.config.comm_fusion: 

150 unsupported_param = next( 

151 ( 

152 hsdp_param 

153 for hsdp_param in self.hsdp_params 

154 if self._comm_fusion_unsupported_reason(hsdp_param) is not None 

155 ), 

156 None, 

157 ) 

158 if unsupported_param is not None: 

159 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>") 

160 reason = self._comm_fusion_unsupported_reason(unsupported_param) 

161 raise NotImplementedError( 

162 f"comm_fusion does not support parameter {param_fqn}: {reason}." 

163 ) 

164 self.param_group = None 

165 if self.hsdp_params: 

166 # pylint: disable=E1128 

167 self.param_group = HSDPParamGroup( 

168 self.hsdp_params, 

169 self.mesh_info, 

170 self.device, 

171 self.mp_policy, 

172 self.config.comm_fusion_zero_copy, 

173 ) 

174 

175 def _move_states_to_device(self): 

176 """move states to device""" 

177 for mod in self.modules: 

178 for param in mod.parameters(): 

179 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

180 continue 

181 if param.device == self.device or param.device.type == "meta": 

182 continue 

183 param.data = param.to(self.device) 

184 for buffer in mod.buffers(): 

185 if buffer.device == self.device or buffer.device.type == "meta": 

186 continue 

187 buffer.data = buffer.to(self.device) 

188 

189 def _init_hsdp_params(self): 

190 """init hsdp parameters and replicate parameters for cell.""" 

191 replicate_params = set(self.config.replicate_params or ()) 

192 # all parameters in the module tree(s), deduplicated 

193 ignored_params = set(self.config.ignored_params or ()) 

194 visited_params = set() 

195 filtered_params = [] 

196 for mod in self.modules: 

197 for _, param in mod.named_parameters(): 

198 if param in ignored_params: 

199 continue 

200 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

201 continue 

202 if param in visited_params: 

203 continue 

204 visited_params.add(param) 

205 filtered_params.append(param) 

206 

207 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules)) 

208 for param, module_info in zip(filtered_params, module_infos): 

209 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param]) 

210 enable_fsdp_shard = param not in replicate_params 

211 hsdp_param = TorchHSDPParamV2(param, 

212 module_info, 

213 self.mesh_info, 

214 shard_placement_fn=self.config.shard_placement_fn, 

215 mp_policy=self.mp_policy, 

216 offload_policy=self.offload_policy, 

217 device=self.device, 

218 param_mode=param_mode, 

219 enable_fsdp_shard=enable_fsdp_shard, 

220 ) 

221 if param in replicate_params: 

222 self.replicate_params.append(hsdp_param) 

223 else: 

224 self.hsdp_params.append(hsdp_param) 

225 if hsdp_param.is_sharded: 

226 self.sharded_hsdp_params.append(hsdp_param) 

227 

228 def _init_mp_dtypes(self): 

229 """init mp dtypes for hsdp parameters and replicate parameters""" 

230 for hsdp_param in self.hsdp_params: 

231 hsdp_param.init_dtype_attrs(self.mp_policy) 

232 for replicate_param in self.replicate_params: 

233 replicate_param.init_dtype_attrs(self.mp_policy) 

234 trainable_params: list[TorchHSDPParamV2] = [ 

235 p for p in self._iter_managed_params() if p.sharded_param.requires_grad 

236 ] 

237 orig_dtypes = {p.orig_dtype for p in trainable_params} 

238 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

239 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

240 raise AssertionError( 

241 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

242 ) 

243 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

244 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

245 raise AssertionError( 

246 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

247 ) 

248 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

249 

250 def _validate_cpu_offload_params(self): 

251 """Validate that all parameters are on CPU when CPU offload policy is enabled.""" 

252 if not isinstance(self.offload_policy, CPUOffloadPolicy): 

253 return 

254 hsdp_params_not_on_cpu = [ 

255 hsdp_param 

256 for hsdp_param in self._iter_managed_params() 

257 if hsdp_param.sharded_param.device.type != "cpu" 

258 ] 

259 if hsdp_params_not_on_cpu: 

260 raise RuntimeError( 

261 "HSDP parameters should be materialized on CPU when enabling CPU offloading. " 

262 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' 

263 "Found following parameters on non-CPU device: " 

264 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n" 

265 ) 

266 

267 def lazy_init(self): 

268 if self.is_shard and not self._reset_sharded_params: 

269 for hsdp_param in self.hsdp_params: 

270 if hsdp_param.is_sharded: 

271 hsdp_param.reset_sharded_param() 

272 self._reset_sharded_params = True 

273 self._validate_no_meta_params() 

274 self._validate_cpu_offload_params() 

275 self._init_mp_dtypes() 

276 

277 def _validate_no_meta_params(self): 

278 param_names_on_meta = [ 

279 hsdp_param._param_fqn 

280 for hsdp_param in self._iter_managed_params() 

281 if hsdp_param.sharded_param.device.type == "meta" 

282 ] 

283 if param_names_on_meta: 

284 raise RuntimeError( 

285 "HSDP parameters should be materialized from meta device before training, " 

286 f"but the following were still on meta device: {param_names_on_meta}\n" 

287 "For example, call module.to_empty(device) to materialize to device and " 

288 "call module.reset_parameters() on each module to initialize values." 

289 ) 

290 

291 def post_backward_for_comm_fusion(self): 

292 """post_backward_for_comm_fusion.""" 

293 # Replicate-only params still use the non-fused compat all-reduce path. 

294 # Drain any pending side-path reductions before advancing the fused 

295 # param-group pipeline for sharded params. 

296 self.reduce_params() 

297 # Fused gradient reduction path: first apply any pending async reduction 

298 # from the previous module's backward (pipelined overlap), then issue 

299 # this module's fused reduce-scatter (+ all-reduce for HSDP). 

300 comm_ctx = get_comm_ctx() 

301 # Phase 2: apply grads for the param group whose all_reduce is done 

302 if comm_ctx.all_reduce_param_group is not None: 

303 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad() 

304 comm_ctx.all_reduce_param_group = None 

305 # Phase 1: wait reduce_scatter, issue async all_reduce for previous layer 

306 if comm_ctx.pre_param_group is not None: 

307 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce() 

308 comm_ctx.pre_param_group = None 

309 if self.param_group is not None: 

310 self.param_group.foreach_reduce( 

311 reduce_scatter_reduce_op=self.reduce_op_type 

312 ) 

313 for hsdp_param in self.replicate_params: 

314 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

315 continue 

316 if not hsdp_param.sharded_param.requires_grad: 

317 continue 

318 if not self._has_pending_unsharded_grad(hsdp_param): 

319 continue 

320 reduce_op = self._resolve_reduce_op(hsdp_param) 

321 self._queue_compat_all_reduce(hsdp_param, reduce_op) 

322 

323 def _resolve_default_reduce_op(self): 

324 """Resolve the default reduce op for the whole fully_shard state.""" 

325 for hsdp_param in self._iter_managed_params(): 

326 if hsdp_param.param_mode in ( 

327 FullyShardParamMode.DTENSOR_COMPAT, 

328 FullyShardParamMode.DTENSOR_UNIFIED, 

329 ): 

330 return torch.distributed.ReduceOp.SUM 

331 return torch.distributed.ReduceOp.AVG 

332 

333 def _resolve_reduce_op(self, hsdp_param=None): 

334 """Resolve the gradient reduction op for the current fully_shard state.""" 

335 if self._user_reduce_op_type is not None: 

336 return self._user_reduce_op_type 

337 return self.reduce_op_type 

338 

339 def _should_run_all_reduce(self, hsdp_param) -> bool: 

340 """Whether the current parameter should issue an all-reduce in this backward pass.""" 

341 return self.requires_all_reduce and hsdp_param.dp_size > 1 

342 

343 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param, reduce_op): 

344 """Queue the standard FSDP/HSDP reduction path.""" 

345 hsdp_param.reduce_scatter_grad( 

346 dtype=self._reduce_dtype, 

347 reduce_op=reduce_op, 

348 ) 

349 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype)) 

350 if not self._should_run_all_reduce(hsdp_param): 

351 return 

352 reduced_grad = hsdp_param.reduce_scatter_output() 

353 if ( 

354 HSDPState.pre_reduce_scatter_params 

355 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param 

356 ): 

357 HSDPState.pre_reduce_scatter_params.pop() 

358 hsdp_param.all_reduce_grad( 

359 grad=reduced_grad, 

360 dtype=self._reduce_dtype, 

361 reduce_op=reduce_op, 

362 ) 

363 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype)) 

364 

365 def _queue_compat_all_reduce(self, hsdp_param, reduce_op): 

366 """Queue the compatibility all-reduce path without FSDP sharding.""" 

367 if not self._should_run_all_reduce(hsdp_param): 

368 return 

369 hsdp_param.all_reduce_grad( 

370 grad=self._get_pending_unsharded_grad(hsdp_param), 

371 dtype=self._reduce_dtype, 

372 reduce_op=reduce_op, 

373 ) 

374 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype)) 

375 

376 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool: 

377 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly.""" 

378 return ( 

379 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

380 and hsdp_param.enable_fsdp_shard 

381 and not hsdp_param.is_sharded 

382 and hsdp_param.shard_size == 1 

383 and hsdp_param.sharded_param.requires_grad 

384 and self._should_run_all_reduce(hsdp_param) 

385 and self._get_local_sharded_grad(hsdp_param) is not None 

386 ) 

387 

388 def _queue_direct_compat_all_reduce(self, hsdp_param, reduce_op): 

389 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``.""" 

390 grad = self._get_local_sharded_grad(hsdp_param) 

391 if grad is None: 

392 return 

393 reduced_grad = grad 

394 if self._reduce_dtype is not None and reduced_grad.dtype != self._reduce_dtype: 

395 reduced_grad = reduced_grad.to(self._reduce_dtype) 

396 handle = None 

397 if hsdp_param.unsharded_group_info.group is not None and hsdp_param.dp_size > 1: 

398 handle = torch.distributed.all_reduce( 

399 reduced_grad, 

400 op=reduce_op, 

401 group=hsdp_param.unsharded_group_info.group, 

402 async_op=True, 

403 ) 

404 TorchHSDPStateV2.pre_direct_all_reduce_grads.append((handle, reduced_grad, grad)) 

405 

406 def post_backward(self, *unused): # pylint: disable=unused-argument 

407 """Reduce gradients and reshard parameters after backward.""" 

408 for hsdp_param in self._iter_managed_params(): 

409 hsdp_param.accumulate_unsharded_grad_if_needed() 

410 if not self.reduce_grads: 

411 if self.reshard_after_backward: 

412 self.shard() 

413 for hsdp_param in self._iter_managed_params(): 

414 hsdp_param.to_accumulated_grad_if_needed() 

415 return 

416 if not self.comm_fusion: 

417 # Handle user config replicate params and mirror params. 

418 self.reduce_params() 

419 for hsdp_param in self._iter_managed_params(): 

420 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

421 if self._can_direct_all_reduce_compat_grad(hsdp_param): 

422 reduce_op = self._resolve_reduce_op(hsdp_param) 

423 self._queue_direct_compat_all_reduce(hsdp_param, reduce_op) 

424 

425 # Step 1: wait prev reduce_scatter (for params needing allreduce) 

426 prev_group = self._wait_prev_reduce_scatter() 

427 

428 # Step 2: wait and apply prev reduce_scatter (for params NOT needing allreduce) 

429 self._wait_and_apply_prev_no_allreduce_params() 

430 

431 # Step 3: issue current reduce_scatter 

432 self._issue_reduce_scatter_for_current_module() 

433 

434 # Step 4: issue prev fused allreduce (async) - using saved prev_group 

435 self._issue_prev_fused_allreduce(prev_group) 

436 else: 

437 self.post_backward_for_comm_fusion() 

438 if self.reshard_after_backward: 

439 self.shard() 

440 

441 def _issue_reduce_scatter_for_current_module(self): 

442 """Issue reduce_scatter for current module's parameters with fused all-reduce support. 

443 

444 This method groups parameters by their replicate_process_group and: 

445 1. For params without all_reduce needs: issue reduce_scatter directly 

446 2. For params with all_reduce needs: allocate fused buffer and issue reduce_scatter 

447 into aligned views, enabling zero-copy fused all_reduce later. 

448 """ 

449 # Collect parameters that need gradient reduction 

450 params_to_reduce = [] 

451 for hsdp_param in self._iter_managed_params(): 

452 skip_param = (not hasattr(hsdp_param, "_unsharded_param") 

453 or hsdp_param.unsharded_param is None 

454 or not hsdp_param.sharded_param.requires_grad 

455 or self._can_direct_all_reduce_compat_grad(hsdp_param) 

456 or (hsdp_param.unsharded_param.grad is None 

457 and hsdp_param.unsharded_accumulated_grad_data is None)) 

458 if skip_param: 

459 continue 

460 params_to_reduce.append(hsdp_param) 

461 

462 if not params_to_reduce: 

463 return 

464 

465 # Group by replicate_process_group for fused all-reduce 

466 # Key: id of process group, or None for params that don't need all_reduce 

467 groups_by_comm = defaultdict(list) 

468 for hsdp_param in params_to_reduce: 

469 if self._should_run_all_reduce(hsdp_param): 

470 key = id(hsdp_param.unsharded_group_info.group) 

471 groups_by_comm[key].append(hsdp_param) 

472 else: 

473 groups_by_comm[None].append(hsdp_param) 

474 

475 # Handle params that don't need all_reduce (FSDP or single replica) 

476 if None in groups_by_comm: 

477 for hsdp_param in groups_by_comm[None]: 

478 hsdp_param.reduce_scatter_grad( 

479 dtype=self._reduce_dtype, 

480 reduce_op=self._resolve_reduce_op() 

481 ) 

482 HSDPState.pre_reduce_scatter_params.append( 

483 (hsdp_param, self._orig_dtype)) 

484 

485 # Handle params that need all_reduce (HSDP with multiple replicas) 

486 for key, hsdp_params in groups_by_comm.items(): 

487 if key is None: 

488 continue 

489 

490 # Create AllReduceParamGroup for fused all-reduce 

491 group = AllReduceParamGroup( 

492 replicate_group=hsdp_params[0].unsharded_group_info.group, 

493 hsdp_params=hsdp_params, 

494 orig_dtypes=[self._orig_dtype] * len(hsdp_params), 

495 reduce_dtype=self._reduce_dtype, 

496 reduce_op=self._resolve_reduce_op(), 

497 mp_policy=self.mp_policy, 

498 ) 

499 

500 # Allocate fused buffer with 512-byte alignment 

501 group.allocate_fused_buffer(self.device) 

502 

503 # Issue reduce_scatter with output directly into fused buffer views 

504 for idx, hsdp_param in enumerate(hsdp_params): 

505 buffer_view = group.get_param_buffer_view(idx) 

506 hsdp_param.reduce_scatter_grad( 

507 dtype=self._reduce_dtype, 

508 reduce_op=self._resolve_reduce_op(), 

509 output_buffer=buffer_view, 

510 ) 

511 

512 # Save group for later all_reduce in reduce_params() 

513 TorchHSDPStateV2.pre_all_reduce_groups.append(group) 

514 

515 def _wait_prev_reduce_scatter(self) -> List[AllReduceParamGroup]: 

516 """Step 1: wait prev reduce_scatter. 

517 

518 This enables overlapping: 

519 - Layer N-1's reduce_scatter wait with Layer N's backward compute 

520 

521 Returns: 

522 List of previous AllReduceParamGroups (one per communication group). 

523 """ 

524 if TorchHSDPStateV2.pre_all_reduce_groups: 

525 prev_groups = list(TorchHSDPStateV2.pre_all_reduce_groups) 

526 TorchHSDPStateV2.pre_all_reduce_groups.clear() 

527 for prev_group in prev_groups: 

528 for hsdp_param in prev_group.hsdp_params: 

529 hsdp_param.reduce_scatter_output() 

530 hsdp_param.clear_reduce_scatter_output() 

531 if hsdp_param.unsharded_accumulated_grad_data is not None: 

532 hsdp_param.unsharded_accumulated_grad = None 

533 elif hsdp_param.unsharded_param.grad is not None: 

534 hsdp_param.unsharded_param.grad = None 

535 return prev_groups 

536 return [] 

537 

538 def _issue_prev_fused_allreduce(self, prev_groups: List[AllReduceParamGroup]): 

539 """Step 4: issue previous module's fused allreduce (async). 

540 

541 The allreduce handle is collected in pending_all_reduce_groups, 

542 and will be processed in root_backward_hook's delay_apply_reduce_grads(). 

543 

544 Args: 

545 prev_groups: List of previous AllReduceParamGroups to issue allreduce for. 

546 """ 

547 for prev_group in prev_groups: 

548 prev_group.accumulate_existing_grads_to_buffer() 

549 prev_group.issue_async_allreduce() 

550 # Move to pending queue for root_backward_hook to process 

551 TorchHSDPStateV2.pending_all_reduce_groups.append(prev_group) 

552 

553 def _wait_and_apply_prev_no_allreduce_params(self): 

554 """Step 2: wait and apply previous reduce_scatter for params NOT needing allreduce. 

555 

556 These are FSDP params or single-replica HSDP params that don't need 

557 cross-replica allreduce. Their reduce_scatter was issued by the previous 

558 module's _issue_reduce_scatter_for_current_module(), and we wait and apply here. 

559 """ 

560 need_synchronize = False 

561 while HSDPState.pre_reduce_scatter_params: 

562 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_reduce_scatter_params.pop(0) 

563 reduced_grad = pre_hsdp_param.reduce_scatter_output() 

564 pre_hsdp_param.clear_reduce_scatter_output() 

565 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize 

566 pre_hsdp_param.accumulated_allreduced_grad = False 

567 

568 if need_synchronize: 

569 if self.device.type == "npu": 

570 torch.npu.current_stream().synchronize() 

571 elif self.device.type == "cuda": 

572 torch.cuda.current_stream().synchronize() 

573 else: 

574 raise NotImplementedError( 

575 f"Unsupported device type {self.device.type} for synchronization after CPU offload." 

576 ) 

577 

578 @classmethod 

579 def delay_apply_reduce_grads(cls, device: torch.device): 

580 """Apply all pending allreduce gradients in root_backward_hook. 

581 

582 This is called at the end of root_backward_hook to wait for all 

583 async allreduce operations and apply gradients to sharded parameters. 

584 

585 Args: 

586 device: Device for CPU offload synchronization. 

587 """ 

588 need_synchronize = False 

589 

590 for group in cls.pending_all_reduce_groups: 

591 need_synchronize = group.wait_and_apply_grads() or need_synchronize 

592 

593 cls.pending_all_reduce_groups.clear() 

594 

595 if need_synchronize: 

596 if device.type == "npu": 

597 torch.npu.current_stream().synchronize() 

598 elif device.type == "cuda": 

599 torch.cuda.current_stream().synchronize() 

600 else: 

601 raise NotImplementedError( 

602 f"Unsupported device type {device.type} for synchronization after CPU offload." 

603 ) 

604 

605 

606 def reduce_scattered_params(self): 

607 """ 

608 reduce_scattered_params 

609 """ 

610 need_synchronize = False 

611 while HSDPState.pre_reduce_scatter_params: 

612 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_reduce_scatter_params.pop(0) 

613 reduced_grad = pre_hsdp_param.reduce_scatter_output() 

614 pre_hsdp_param.clear_reduce_scatter_output() 

615 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize 

616 pre_hsdp_param.accumulated_allreduced_grad = False 

617 if need_synchronize: 

618 if self.device.type == "npu": 

619 torch.npu.current_stream().synchronize() 

620 elif self.device.type == "cuda": 

621 torch.cuda.current_stream().synchronize() 

622 else: 

623 raise NotImplementedError( 

624 f"Unsupported device type {self.device.type} for synchronization after CPU offload." 

625 ) 

626 

627 def reduce_params(self): 

628 """Apply reduced gradients from pre-staged HSDP parameters to sharded parameters. 

629 

630 This function processes two lists of pre-queued HSDP parameters (`pre_reduce_scatter_params` 

631 and `pre_all_reduce_params`), retrieves the reduced gradients from asynchronous 

632 reduce-scatter/all-reduce operations, clears cached communication outputs, and applies 

633 the reduced gradients to the corresponding sharded parameters (including reshaping, 

634 dtype conversion, optional CPU offloading, and gradient accumulation/assignment). 

635 

636 Note: 

637 - Parameters are processed in **FIFO (First-In-First-Out)** order (via `pop(0)`), ensuring 

638 gradient application order matches the order of gradient reduction operations. 

639 - After retrieving the reduced gradient, the cached communication output (reduce_scatter_output 

640 or all_reduce_output) is cleared to free memory and avoid stale data. 

641 - Gradient application logic (in `apply_reduced_grad`) includes: 

642 1. Reshaping the flat reduced gradient to match the local shard shape 

643 2. Optional dtype conversion to `param_type` 

644 3. Optional CPU offloading (per the HSDP parameter's offload policy) 

645 4. Assigning or accumulating the gradient to `sharded_param.grad` 

646 """ 

647 need_synchronize = False 

648 while HSDPState.pre_all_reduce_params: 

649 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_all_reduce_params.pop(0) 

650 reduced_grad = pre_hsdp_param.all_reduce_output() 

651 pre_hsdp_param.clear_all_reduce_output() 

652 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize 

653 

654 while TorchHSDPStateV2.pre_direct_all_reduce_grads: 

655 handle, reduced_grad, target_grad = TorchHSDPStateV2.pre_direct_all_reduce_grads.pop(0) 

656 if handle is not None: 

657 handle.wait() 

658 if reduced_grad is not target_grad: 

659 if reduced_grad.dtype != target_grad.dtype: 

660 reduced_grad = reduced_grad.to(target_grad.dtype) 

661 target_grad.copy_(reduced_grad) 

662 if need_synchronize: 

663 if self.device.type == "npu": 

664 torch.npu.current_stream().synchronize() 

665 elif self.device.type == "cuda": 

666 torch.cuda.current_stream().synchronize() 

667 else: 

668 raise NotImplementedError( 

669 f"Unsupported device type {self.device.type} for synchronization after CPU offload." 

670 ) 

671 

672 def set_requires_grad_sync(self, requires_grad_sync): 

673 """set requires grad sync flag to control gradient sync.""" 

674 self.reduce_grads = requires_grad_sync 

675 

676 @property 

677 def _is_hsdp(self) -> bool: 

678 return isinstance(self.mesh_info, HSDPMeshInfo) 

679 

680 def set_reduce_op_type(self, reduce_op_type: str): 

681 """set reduce op type for gradient reduction.""" 

682 fsdp_support_reduce_op = { 

683 "sum": torch.distributed.ReduceOp.SUM, 

684 "avg": torch.distributed.ReduceOp.AVG, 

685 } 

686 if reduce_op_type not in fsdp_support_reduce_op: 

687 raise ValueError( 

688 f"Unsupported reduce op type {reduce_op_type}, " 

689 f"supported types are {list(fsdp_support_reduce_op.keys())}" 

690 ) 

691 reduce_op: str = reduce_op_type.lower().strip() 

692 self._user_reduce_op_type = fsdp_support_reduce_op[reduce_op] 

693 self.reduce_op_type = self._user_reduce_op_type