Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / param.py: 66%

502 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py 

16# enhanced with fully_shard parameter management 

17# ============================================================================ 

18"""HSDP parameter""" 

19# pylint: disable=W0212 

20import itertools 

21from typing import Callable, List, Optional, Tuple, Union, cast 

22 

23import torch 

24import torch.distributed as dist 

25from torch import nn 

26from torch._prims_common import make_contiguous_strides_for 

27 

28from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

29from hyper_parallel.core.dtensor.dtensor import DTensor, SkipDTensorDispatch 

30from hyper_parallel.core.dtensor.layout import Layout 

31from hyper_parallel.core.dtensor.placement_types import Replicate, Shard, StridedShard 

32from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2 

33from hyper_parallel.core.fully_shard.hsdp_utils import ( 

34 FullyShardParamMode, 

35 GroupInfo, 

36 ParamModuleInfo, 

37 ShardedState, 

38 get_rank_list_for_axes, 

39 get_split_rank_lists_for_axes, 

40) 

41from hyper_parallel.core.fully_shard.utils import ( 

42 CPUOffloadPolicy, 

43 DDPMeshInfo, 

44 FSDPMeshInfo, 

45 MixedPrecisionPolicy, 

46 OffloadPolicy, 

47) 

48from hyper_parallel.platform import get_platform 

49from hyper_parallel.platform.torch.fully_shard.pack_utils import ( 

50 build_rs_plan, 

51 pack_for_reduce_scatter, 

52 unpack_from_all_gather, 

53) 

54 

55_GROUP_INFO_CACHE = {} 

56platform = get_platform() 

57 

58 

59def _copy_without_bumping_version(dst: torch.Tensor, src: torch.Tensor) -> None: 

60 """Copy into ``dst`` while preserving its autograd version counter.""" 

61 # pylint: disable=W0212 

62 with torch.autograd._unsafe_preserve_version_counter(dst): 

63 dst.copy_(src) 

64 

65 

66def _build_group_info_from_rank_list( 

67 group_name: str, 

68 rank_list, 

69) -> GroupInfo: 

70 """Create group metadata from an explicit rank list.""" 

71 normalized_rank_list = tuple(sorted(int(rank) for rank in rank_list)) 

72 if len(normalized_rank_list) <= 1: 

73 return GroupInfo(f"{group_name}_invalid", None, 1) 

74 if normalized_rank_list in _GROUP_INFO_CACHE: 

75 cached_group = _GROUP_INFO_CACHE[normalized_rank_list] 

76 return GroupInfo(str(normalized_rank_list), cached_group, len(normalized_rank_list)) 

77 try: 

78 group = platform.create_group(list(normalized_rank_list)) 

79 except (RuntimeError, ValueError): # pragma: no cover - UT may run without dist init 

80 group = None 

81 _GROUP_INFO_CACHE[normalized_rank_list] = group 

82 return GroupInfo(str(normalized_rank_list), group, len(normalized_rank_list)) 

83 

84 

85def _build_group_info_from_process_group( 

86 group_name: str, 

87 process_group, 

88 rank_size: int, 

89) -> GroupInfo: 

90 """Create group metadata from an existing process group.""" 

91 if process_group is None or rank_size <= 1: 

92 return GroupInfo(f"{group_name}_invalid", None, 1) 

93 try: 

94 rank_list = dist.get_process_group_ranks(process_group) 

95 resolved_group_name = str(tuple(sorted(rank_list))) 

96 except (AssertionError, AttributeError, KeyError, RuntimeError, TypeError, ValueError): 

97 # pragma: no cover - best-effort naming / mocked process groups in UT 

98 resolved_group_name = group_name 

99 return GroupInfo(resolved_group_name, process_group, rank_size) 

100 

101 

102class TorchHSDPParamV2(HSDPParamV2): 

103 """ 

104 Torch HSDP parameter. 

105 """ 

106 

107 def __init__( 

108 self, 

109 param: nn.Parameter, 

110 module_info: ParamModuleInfo, 

111 mesh_info: FSDPMeshInfo, 

112 shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, 

113 mp_policy: Optional[MixedPrecisionPolicy] = None, 

114 offload_policy: Optional[OffloadPolicy] = None, 

115 device: Optional[torch.device] = None, 

116 param_mode: Optional[FullyShardParamMode] = None, 

117 enable_fsdp_shard: bool = True, 

118 ): 

119 """ 

120 Initialize TorchHSDPParamV2 and shard the parameter. 

121 

122 Args: 

123 param (nn.Parameter): The original full parameter to shard. 

124 module_info (ParamModuleInfo): Ownership and shared-weight metadata. 

125 mesh_info (FSDPMeshInfo): Mesh topology for shard/replicate dimensions. 

126 shard_placement_fn (Callable, optional): Returns a Shard placement for the parameter, 

127 or None to use default (Shard(0)). 

128 mp_policy (MixedPrecisionPolicy, optional): Mixed precision dtype policy. 

129 offload_policy (OffloadPolicy, optional): CPU offload policy. 

130 device (torch.device, optional): Target device for the sharded parameter. 

131 """ 

132 self._module_info: ParamModuleInfo = module_info 

133 self.mesh_info = mesh_info 

134 self.mp_policy = mp_policy 

135 self.device = device 

136 if param_mode is None: 

137 raise AssertionError("param_mode must be resolved before TorchHSDPParamV2 initialization.") 

138 self.param_mode = param_mode 

139 self.enable_fsdp_shard = enable_fsdp_shard 

140 self.orig_dtype = None 

141 self.param_dtype = None 

142 self.reduce_dtype = None 

143 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) 

144 self.pin_memory = ( 

145 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory 

146 ) 

147 self._orig_param_hooks: List[Callable] = [] 

148 self.grad_offload_event: Optional[torch.Event] = None 

149 self._orig_param_is_dtensor = isinstance(param, DTensor) 

150 self._orig_dtensor_mesh = param.device_mesh if self._orig_param_is_dtensor else None 

151 self._orig_dtensor_placements = tuple(param.placements) if self._orig_param_is_dtensor else None 

152 self._spmd_shard_mesh_dim = self.mesh_info.shard_mesh_dim 

153 self._spmd_replicate_mesh_dim = self.mesh_info.replicate_mesh_dim 

154 self._init_sharded_param(param, shard_placement_fn) 

155 self._init_group_infos() 

156 self.all_gather_outputs: List[torch.Tensor] = [] 

157 self.unsharded_accumulated_grad = None 

158 self._param_fqn: Optional[str] = None 

159 # Communication attributes for prefetch pattern 

160 self.prefetch_handle: Optional[dist.Work] = None 

161 self._post_load_hook_handle = ( 

162 module_info.module.register_load_state_dict_post_hook( 

163 lambda *args, **kwargs: self.reset_sharded_param() 

164 ) 

165 ) 

166 self._reduce_scatter_output = None 

167 self.reduce_scatter_handle = None 

168 self._all_reduce_output = None 

169 self.all_reduce_handle = None 

170 self._save_backward_hooks(param) 

171 self._grad = None 

172 self._accumulated_allreduced_grad = True 

173 

174 @property 

175 def uses_param_shard(self) -> bool: 

176 """Whether fully_shard should physically shard parameter storage for this param.""" 

177 return self.enable_fsdp_shard 

178 

179 @property 

180 def is_dtensor_compat_mode(self) -> bool: 

181 """Whether the parameter is managed through the DTensor compatibility path only.""" 

182 return self.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

183 

184 def _get_base_spmd_placements(self) -> tuple: 

185 if self.param_mode == FullyShardParamMode.DTENSOR_UNIFIED and self._orig_param_is_dtensor: 

186 # DTENSOR_UNIFIED keeps the original distributed layout and prefixes 

187 # explicit DP/FSDP mesh dimensions ahead of it on the unified mesh. 

188 self._spmd_mesh = DeviceMesh.concatenate([self.mesh_info.mesh, self._orig_dtensor_mesh]) 

189 dp_prefix_placements = tuple(Replicate() for _ in range(self.mesh_info.mesh.ndim)) 

190 return dp_prefix_placements + tuple(self._orig_dtensor_placements) 

191 

192 if self.is_dtensor_compat_mode and self._orig_param_is_dtensor: 

193 self._spmd_mesh = self._orig_dtensor_mesh 

194 return tuple(self._orig_dtensor_placements) 

195 

196 self._spmd_mesh = self.mesh_info.mesh 

197 return tuple(Replicate() for _ in range(self._spmd_mesh.ndim)) 

198 

199 def _apply_data_parallel_placements(self, placements: list, shard_placement: Shard) -> tuple: 

200 if len(placements) != self._spmd_mesh.ndim: 

201 raise AssertionError( 

202 f"Expected {self._spmd_mesh.ndim} unified placements, got {len(placements)}: {placements}" 

203 ) 

204 if ( 

205 isinstance(self.mesh_info, DDPMeshInfo) 

206 and self._spmd_replicate_mesh_dim is not None 

207 and not self._orig_param_is_dtensor 

208 ): 

209 placements[self._spmd_replicate_mesh_dim] = Replicate() 

210 if ( 

211 self.uses_param_shard 

212 and isinstance(self.mesh_info, FSDPMeshInfo) 

213 and self._spmd_shard_mesh_dim is not None 

214 ): 

215 # If TP/EP already shards the same tensor dimension, fully_shard must 

216 # use StridedShard so the unified placement preserves the intended 

217 # shard order on the concatenated mesh. 

218 split_factor = 1 

219 for mesh_idx, placement in enumerate(placements): 

220 if mesh_idx == self._spmd_shard_mesh_dim: 

221 continue 

222 if placement.is_shard(shard_placement.dim): 

223 split_factor *= self._spmd_mesh.mesh_shape[mesh_idx] 

224 placements[self._spmd_shard_mesh_dim] = ( 

225 StridedShard(shard_placement.dim, split_factor=split_factor) 

226 if split_factor > 1 

227 else shard_placement 

228 ) 

229 return tuple(placements) 

230 

231 def _init_group_infos(self) -> None: 

232 if self.uses_param_shard and self.is_sharded and isinstance(self.mesh_info, FSDPMeshInfo): 

233 self.sharded_group_info = _build_group_info_from_process_group( 

234 "fully_shard_sharded_group", 

235 self.mesh_info.shard_process_group, 

236 self.mesh_info.shard_mesh_size, 

237 ) 

238 else: 

239 self.sharded_group_info = GroupInfo("fully_shard_sharded_group_invalid", None, 1) 

240 

241 # The all-reduce group is always derived from the final materialized layout. 

242 # This keeps replicate_params, DTensor compat, and unified multi-dim layouts 

243 # on a single source of truth. 

244 self.unsharded_group_info = self._build_layout_driven_group_info() 

245 

246 self.shard_size = self.sharded_group_info.rank_size 

247 self.dp_size = self.unsharded_group_info.rank_size 

248 self.rank_size = max(1, self.shard_size * self.dp_size) 

249 

250 def _build_layout_driven_group_info(self): 

251 group_axes = [ 

252 axis 

253 for axis, placement in enumerate(self._spmd_placements) 

254 if placement.is_replicate() 

255 ] 

256 if self.uses_param_shard and self._spmd_shard_mesh_dim is not None: 

257 group_axes = [axis for axis in group_axes if axis != self._spmd_shard_mesh_dim] 

258 if not group_axes: 

259 return GroupInfo("fully_shard_unsharded_group_invalid", None, 1) 

260 group_dim_names = getattr(self._spmd_mesh, "mesh_dim_names", None) 

261 if group_dim_names: 

262 try: 

263 mesh_axis_names = tuple(group_dim_names[axis] for axis in group_axes) 

264 if len(mesh_axis_names) == 1: 

265 axis_name = mesh_axis_names[0] 

266 process_group = self._spmd_mesh.get_group(axis_name) 

267 if process_group is not None: 

268 rank_size = self._spmd_mesh.mesh_shape[group_dim_names.index(axis_name)] 

269 return _build_group_info_from_process_group( 

270 "fully_shard_unsharded_group", 

271 process_group, 

272 rank_size, 

273 ) 

274 

275 split_rank_lists = get_split_rank_lists_for_axes(self._spmd_mesh, group_axes) 

276 process_group = platform.split_group(split_ranks=split_rank_lists) 

277 if process_group is not None: 

278 rank_size = 1 

279 for axis in group_axes: 

280 rank_size *= self._spmd_mesh.mesh_shape[axis] 

281 return _build_group_info_from_process_group( 

282 "fully_shard_unsharded_group", 

283 process_group, 

284 rank_size, 

285 ) 

286 except ( 

287 AssertionError, 

288 AttributeError, 

289 KeyError, 

290 RuntimeError, 

291 TypeError, 

292 ValueError, 

293 ): 

294 # Fall back to the explicit rank-list path for mocked meshes in UT 

295 # or when a mesh implementation cannot materialize a reusable group. 

296 pass 

297 

298 rank_list = get_rank_list_for_axes(self._spmd_mesh, group_axes) 

299 return _build_group_info_from_rank_list("fully_shard_unsharded_group", rank_list) 

300 

301 def _to_local_unsharded_grad(self, grad): 

302 """Normalize a pending gradient to a local tensor expected by fully_shard collectives.""" 

303 if not isinstance(grad, DTensor): 

304 return grad 

305 

306 if any(placement.is_partial() for placement in grad.placements): 

307 grad = grad.reduce_partial() 

308 

309 if ( 

310 self._orig_dtensor_mesh is not None 

311 and grad.device_mesh.to_hash() != self._orig_dtensor_mesh.to_hash() 

312 ) or ( 

313 self._orig_dtensor_placements is not None 

314 and tuple(grad.placements) != tuple(self._orig_dtensor_placements) 

315 ): 

316 grad = grad.redistribute(self._orig_dtensor_mesh, self._orig_dtensor_placements) 

317 return grad.to_local() 

318 

319 @property 

320 def accumulated_allreduced_grad(self) -> bool: 

321 """Whether the parameter has accumulated all-reduced gradient.""" 

322 return self._accumulated_allreduced_grad 

323 

324 @accumulated_allreduced_grad.setter 

325 def accumulated_allreduced_grad(self, value: bool) -> None: 

326 self._accumulated_allreduced_grad = value 

327 

328 def _save_backward_hooks(self, param: nn.Parameter) -> None: 

329 """Save the backward hooks of the original parameter""" 

330 if not hasattr(param, '_backward_hooks') or param._backward_hooks is None: 

331 return 

332 

333 # Get the set of saved hook function IDs for deduplication 

334 if not hasattr(self, '_saved_hook_ids'): 

335 object.__setattr__(self, '_saved_hook_ids', set()) 

336 

337 for _, hook_func in param._backward_hooks.items(): 

338 # Use the id of hook_func to avoid adding the same function object repeatedly 

339 hook_func_id = id(hook_func) 

340 if hook_func_id not in self._saved_hook_ids: 

341 self._orig_param_hooks.append(hook_func) 

342 self._saved_hook_ids.add(hook_func_id) 

343 

344 def _migrate_backward_hooks(self, new_param: nn.Parameter) -> None: 

345 """Migrate backward hooks from the original parameter to the new parameter""" 

346 if not self._orig_param_hooks or hasattr(new_param, "migrate_backward_hooks_run_once"): 

347 return 

348 

349 # Properly register each hook using the register_hook method 

350 for hook_func in self._orig_param_hooks: 

351 try: 

352 if new_param.requires_grad: 

353 new_param.register_hook(hook_func) 

354 except RuntimeError: 

355 # Skip hook registration if the parameter does not require gradients 

356 pass 

357 new_param.migrate_backward_hooks_run_once = True 

358 

359 def reduce_scatter_output(self): 

360 """ 

361 Get the reduce-scatter output tensor and wait for asynchronous operation to complete. 

362 

363 Returns: 

364 torch.Tensor: The sharded gradient tensor after reduce-scatter operation. 

365 """ 

366 if self.reduce_scatter_handle is not None: 

367 self.reduce_scatter_handle.wait() 

368 self._grad.untyped_storage().resize_(0) 

369 self._grad = None 

370 self.reduce_scatter_handle = None 

371 return self._reduce_scatter_output 

372 

373 def clear_reduce_scatter_output(self): 

374 """Clear the reduce-scatter output tensor to free memory.""" 

375 self._reduce_scatter_output = None 

376 

377 def all_reduce_output(self): 

378 """ 

379 Get the all-reduce output tensor and wait for asynchronous operation to complete. 

380 

381 Returns: 

382 torch.Tensor: The reduced gradient tensor after all-reduce operation. 

383 """ 

384 if self.all_reduce_handle is not None: 

385 self.all_reduce_handle.wait() 

386 self.all_reduce_handle = None 

387 return self._all_reduce_output 

388 

389 def clear_all_reduce_output(self): 

390 """Clear the all-reduce output tensor to free memory.""" 

391 self._all_reduce_output = None 

392 

393 def apply_reduced_grad(self, reduced_grad, param_type): 

394 """ 

395 Apply reduced gradient to the sharded parameter. 

396 

397 Reshapes ``reduced_grad`` to match the local shard, optionally 

398 offloads to CPU, then accumulates or assigns onto 

399 ``hsdp_param.sharded_param.grad``. 

400 

401 Args: 

402 reduced_grad (torch.Tensor): Gradient after reduce-scatter 

403 and/or all-reduce. 

404 param_type (Optional[torch.dtype]): Target dtype for the gradient (if conversion is needed). 

405 """ 

406 sharded_grad = None 

407 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

408 sharded_grad = self.sharded_param.grad 

409 else: 

410 if not hasattr(self.sharded_param, "main_grad"): 

411 self.sharded_param.main_grad = None 

412 sharded_grad = self.sharded_param.main_grad 

413 sharded_param_local_shape = ( 

414 self.sharded_param.local_shape 

415 if isinstance(self.sharded_param, DTensor) 

416 else self.sharded_param.shape 

417 ) 

418 reduced_grad = reduced_grad.view(sharded_param_local_shape) 

419 if (not self.mp_policy.apply_grad_on_fp32_main_grad and param_type is not None 

420 and reduced_grad.dtype != param_type): 

421 reduced_grad = reduced_grad.to(param_type) 

422 to_accumulate_grad = sharded_grad is not None 

423 need_synchronize = False 

424 if self.offload_to_cpu: 

425 non_blocking = self.pin_memory and not to_accumulate_grad 

426 reduced_grad = reduced_grad.to( 

427 torch.device("cpu"), non_blocking=non_blocking 

428 ) 

429 need_synchronize = True 

430 if sharded_grad is None: 

431 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

432 self.sharded_param.grad = self.to_sharded_dtensor(reduced_grad) 

433 else: 

434 self.sharded_param.main_grad = self.to_sharded_dtensor(reduced_grad) 

435 self.sharded_param.grad = None 

436 else: 

437 with SkipDTensorDispatch(): 

438 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

439 self.sharded_param.grad._local_tensor += reduced_grad 

440 else: 

441 self.sharded_param.main_grad._local_tensor += reduced_grad 

442 self.sharded_param.grad = None 

443 if self.unsharded_accumulated_grad_data is not None: 

444 self.unsharded_accumulated_grad = None 

445 elif self.unsharded_param.grad is not None: 

446 self.unsharded_param.grad = None 

447 return need_synchronize 

448 

449 @torch.no_grad() 

450 def _init_sharded_param( 

451 self, 

452 param: nn.Parameter, 

453 shard_placement_fn: Optional[Callable], 

454 ) -> None: 

455 if param.device != self.device and param.device.type != "meta": 

456 raise AssertionError( 

457 f"Expects the parameter to already be moved to device {self.device} but got {param.device}" 

458 ) 

459 

460 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None 

461 if hsdp_placement is None: 

462 hsdp_placement = Shard(0) 

463 elif hsdp_placement.dim < 0: 

464 # if dim is negative, add the number of dimensions of the parameter 

465 hsdp_placement = Shard(hsdp_placement.dim + param.ndim) 

466 

467 if not isinstance(hsdp_placement, Shard): 

468 raise AssertionError( 

469 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}" 

470 ) 

471 

472 self.hsdp_placement = hsdp_placement 

473 base_placements = list(self._get_base_spmd_placements()) 

474 self._spmd_placements = self._apply_data_parallel_placements(base_placements, hsdp_placement) 

475 param_data = param.to_local() if self._orig_param_is_dtensor else param 

476 

477 shard_dim = hsdp_placement.dim 

478 self._orig_size = param_data.size() 

479 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) 

480 

481 if self.uses_param_shard and isinstance(self.mesh_info, FSDPMeshInfo): 

482 shard_rank = self.mesh_info.shard_mesh_rank 

483 shard_world_size = self.mesh_info.shard_mesh_size 

484 else: 

485 shard_rank = 0 

486 shard_world_size = 1 

487 

488 if isinstance(param_data, DTensor) and isinstance(self.mesh_info, DDPMeshInfo): 

489 param_data.data = param_data.full_tensor() 

490 

491 self.is_sharded = bool(self.uses_param_shard and shard_world_size > 1) 

492 

493 if param_data.size(shard_dim) % shard_world_size != 0: 

494 raise NotImplementedError( 

495 f"Uneven sharding on dim {shard_dim} not supported: " 

496 f"shape={param_data.shape}, world_size={shard_world_size}" 

497 ) 

498 chunks = torch.chunk(param_data, shard_world_size, dim=shard_dim) 

499 sharded_param = chunks[shard_rank].clone().contiguous() 

500 self.sharded_size = sharded_param.size() 

501 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) 

502 if self.offload_to_cpu and not sharded_param.is_meta: 

503 sharded_param = sharded_param.cpu() 

504 if self.pin_memory: 

505 sharded_param = sharded_param.pin_memory() 

506 self._sharded_param_data = sharded_param.view(-1) 

507 

508 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh) 

509 self._sharding_spec.set_placements(self._spmd_placements) 

510 self._sharding_spec.placement_to_tensor_map(param.ndim) 

511 

512 self.sharded_param = nn.Parameter(DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements)) 

513 self.sharded_param.requires_grad_(param.requires_grad) 

514 self._setattr_on_modules(self.sharded_param) 

515 # after init, self.sharded_param replaces original param, gradients must accumulate to this Parameter's grad 

516 self.sharded_param._hsdp_param_initialized = True 

517 self.sharded_state = ShardedState.SHARDED 

518 self.param_dtype = None 

519 

520 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): 

521 """Initialize param_dtype and reduce_dtype from the mixed precision policy.""" 

522 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) 

523 self.orig_dtype = self.sharded_param.dtype 

524 if reduce_dtype == param_dtype: 

525 reduce_dtype = None 

526 if param_dtype == self.orig_dtype: 

527 param_dtype = None 

528 self.param_dtype = param_dtype 

529 self.reduce_dtype = reduce_dtype 

530 

531 def init_all_gather_outputs( 

532 self, 

533 all_gather_input_numels: list[int], 

534 all_gather_input_dtypes: list[torch.dtype], 

535 world_size: int, 

536 device: torch.device, 

537 force_recreate: bool = False, 

538 ): 

539 """ 

540 Allocate output buffers for all-gather communication. 

541 

542 Args: 

543 all_gather_input_numels: Number of elements per input shard. 

544 all_gather_input_dtypes: Dtype of each input shard. 

545 world_size: Number of ranks in the shard process group. 

546 device: Device on which to allocate the output buffers. 

547 force_recreate: If True, always recreate buffers even if already initialized. 

548 """ 

549 if not force_recreate and len(self.all_gather_outputs) > 0: 

550 return # already initialized 

551 self.all_gather_outputs = [ 

552 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) 

553 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) 

554 ] 

555 

556 def init_unsharded_param(self): 

557 """ 

558 Initialize unsharded parameter from all-gather outputs. 

559 

560 This reconstructs the full parameter after all-gather by unpacking the 

561 gathered flat buffer back to the original tensor layout. 

562 """ 

563 unsharded_param = self._get_unsharded_param_from_all_gather_output() 

564 # Always refresh the unsharded Parameter from the latest all-gather output. 

565 # Non-dim0 unpack currently materializes a contiguous tensor copy, so 

566 # keeping stale .data would otherwise reuse old weights after optimizer.step() 

567 # mutates only the sharded local shard. Preserve the Parameter object identity 

568 # so autograd-facing module state stays stable across unshard cycles. 

569 if hasattr(self, "_unsharded_param"): 

570 # pylint: disable=access-member-before-definition 

571 self._unsharded_param.data = unsharded_param 

572 self._unsharded_param.requires_grad_(self.sharded_param.requires_grad) 

573 self._unsharded_param.grad = None 

574 return 

575 self._unsharded_param = nn.Parameter( 

576 unsharded_param, 

577 requires_grad=self.sharded_param.requires_grad, 

578 ) 

579 

580 def _get_unsharded_param_from_all_gather_output(self) -> torch.Tensor: 

581 """Reconstruct the full local parameter view from the packed all-gather output.""" 

582 if len(self.all_gather_outputs) != 1: 

583 raise AssertionError( 

584 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}" 

585 ) 

586 unsharded_tensor = self.all_gather_outputs[0] 

587 plan = build_rs_plan( 

588 self, 

589 self._sharded_local_tensor, 

590 self.shard_world_size if self.is_sharded else 1, 

591 ) 

592 unsharded_param = unpack_from_all_gather(unsharded_tensor, plan) 

593 if self._orig_param_is_dtensor: 

594 # Rebuild the original DTensor view after all-gather so gradient 

595 # consumers keep seeing the source DTensor layout. 

596 unsharded_param = DTensor.from_local( 

597 unsharded_param, 

598 self._orig_dtensor_mesh, 

599 self._orig_dtensor_placements, 

600 ) 

601 return unsharded_param 

602 

603 def to_sharded(self) -> None: 

604 if not self.uses_param_shard and self._unsharded_param is not None: 

605 # Replicate params keep the same local shape across shard/unshard, 

606 # so persist forward-time state updates before switching objects. 

607 src = self._unsharded_param.to_local() if isinstance(self._unsharded_param, DTensor) \ 

608 else self._unsharded_param 

609 dst = self.sharded_param.to_local() if isinstance(self.sharded_param, DTensor) else self.sharded_param 

610 _copy_without_bumping_version(dst, src) 

611 self._setattr_on_modules(self.sharded_param) 

612 self.free_unsharded_param() 

613 self.sharded_state = ShardedState.SHARDED 

614 

615 def to_unsharded(self) -> None: 

616 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) 

617 self._setattr_on_modules(self._unsharded_param) 

618 self.sharded_state = ShardedState.UNSHARDED 

619 

620 def _setattr_on_modules(self, param: nn.Parameter) -> None: 

621 """Set parameter on module and shared modules, preserving pointer consistency.""" 

622 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Module.__setattr__: 

623 # fast path 

624 self._module_info.module._parameters[self._module_info.param_name] = param 

625 else: 

626 # slow path 

627 setattr(self._module_info.module, self._module_info.param_name, param) 

628 self._save_backward_hooks(self.sharded_param) 

629 self._migrate_backward_hooks(param) 

630 # Iterate through all modules that share this parameter to prevent pointer desync. 

631 for shared_module, shared_param_name in zip( 

632 self._module_info.shared_modules, self._module_info.shared_param_names 

633 ): 

634 if getattr(shared_module.__setattr__, "__func__", None) is nn.Module.__setattr__: 

635 shared_module._parameters[shared_param_name] = param 

636 else: 

637 setattr(shared_module, shared_param_name, param) 

638 

639 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: 

640 """ 

641 Converts a local tensor representing either the sharded parameter or 

642 sharded gradient to DTensor. 

643 """ 

644 return DTensor.from_local( 

645 tensor, 

646 self._sharding_spec.mesh, 

647 self._sharding_spec.placements 

648 ) 

649 

650 def to_accumulated_grad_if_needed(self) -> None: 

651 if self._unsharded_param.grad is None: 

652 return 

653 # Keep local gradients alive across no-sync / delayed-sync steps even 

654 # after the parameter transitions back to the sharded view. 

655 unsharded_grad = self._unsharded_param.grad 

656 self._unsharded_param.grad = None 

657 if self.reduce_dtype is not None and unsharded_grad.dtype != self.reduce_dtype: 

658 unsharded_grad = unsharded_grad.to(self.reduce_dtype) 

659 if self.unsharded_accumulated_grad is None: 

660 self.unsharded_accumulated_grad = unsharded_grad 

661 else: 

662 self.unsharded_accumulated_grad += unsharded_grad 

663 

664 def accumulate_unsharded_grad_if_needed(self) -> None: 

665 if ( 

666 self.unsharded_accumulated_grad is not None 

667 and self.unsharded_param.grad is not None 

668 ): 

669 grad = self.unsharded_param.grad 

670 if self.reduce_dtype is not None and grad.dtype != self.reduce_dtype: 

671 grad = grad.to(self.reduce_dtype) 

672 self.unsharded_accumulated_grad += grad 

673 self.unsharded_param.grad = None 

674 

675 def alloc_all_gather_outputs(self) -> None: 

676 """Resize all-gather output buffers to their full capacity for communication.""" 

677 for tensor in self.all_gather_outputs: 

678 expected_size = tensor.numel() * tensor.itemsize 

679 storage = tensor.untyped_storage() 

680 if storage.size() != expected_size: 

681 storage.resize_(expected_size) 

682 

683 def free_unsharded_param(self) -> None: 

684 """Release storage of all-gather outputs to free device memory.""" 

685 for tensor in self.all_gather_outputs: 

686 storage = tensor.untyped_storage() 

687 if storage.size() != 0: 

688 storage.resize_(0) 

689 

690 @property 

691 def all_gather_inputs(self) -> list[torch.Tensor]: 

692 """Return the local sharded tensor to use as input for all-gather, applying dtype cast if needed.""" 

693 self._assert_in_states(ShardedState.SHARDED) 

694 sharded_param_data = self._sharded_param_data 

695 if self.offload_to_cpu: 

696 sharded_param_data = sharded_param_data.to( 

697 self.device, non_blocking=True 

698 ) 

699 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype: 

700 return [sharded_param_data.to(self.param_dtype)] 

701 return [sharded_param_data] 

702 

703 @property 

704 def unsharded_param(self) -> nn.Parameter: 

705 """Return the full unsharded parameter after all-gather.""" 

706 return self._unsharded_param 

707 

708 @property 

709 def unsharded_grad_data(self) -> torch.Tensor: 

710 """ 

711 Get the unsharded gradient data as a local tensor. 

712 """ 

713 grad = self.unsharded_param.grad 

714 if grad is None: 

715 raise AssertionError("Expects unsharded_param.grad to not be None") 

716 return self._to_local_unsharded_grad(grad) 

717 

718 @property 

719 def unsharded_accumulated_grad_data(self) -> torch.Tensor: 

720 """ 

721 Get the unsharded accumulated gradient data as a local tensor. 

722 """ 

723 grad = self.unsharded_accumulated_grad 

724 return self._to_local_unsharded_grad(grad) 

725 

726 @property 

727 def _sharded_local_tensor(self) -> torch.Tensor: 

728 """Return the underlying local tensor of the sharded DTensor parameter.""" 

729 return cast(DTensor, self.sharded_param)._local_tensor 

730 

731 @property 

732 def shard_world_size(self) -> int: 

733 """Get the world size for shard dimension.""" 

734 return self.shard_size 

735 

736 @property 

737 def replicate_world_size(self) -> int: 

738 """Get the world size for replicate dimension (HSDP only).""" 

739 return self.dp_size 

740 

741 def _assert_in_states(self, *states: ShardedState) -> None: 

742 """Assert current state is one of expected states.""" 

743 if self.sharded_state not in states: 

744 raise AssertionError( 

745 f"Expected sharded_state in {states}, got {self.sharded_state}" 

746 ) 

747 

748 def reset_sharded_param(self) -> None: 

749 """Reset sharded param after load_state_dict.""" 

750 module_info = self._module_info 

751 new_param = getattr(module_info.module, module_info.param_name) 

752 if new_param is not self.sharded_param: 

753 # Ensure object identity is preserved after parameter conversion. 

754 if torch.__future__.get_swap_module_params_on_conversion(): 

755 raise AssertionError( 

756 f"Expects swap_tensors to preserve object but got {new_param} " 

757 f"instead of {self.sharded_param}" 

758 ) 

759 if isinstance(new_param, DTensor): 

760 self.sharded_param = new_param 

761 if not getattr(self.sharded_param, "_hsdp_param_initialized", None): 

762 # reset _hsdp_param_initialized flag. 

763 self.sharded_param._hsdp_param_initialized = True 

764 elif isinstance(new_param, torch.Tensor): 

765 # if new_param is Tensor, don't change 'self.sharded_param' ref 

766 # just update self.sharded_param._local_tensor and self.sharded_param_data. 

767 pass 

768 

769 local_tensor = new_param._local_tensor if isinstance(new_param, DTensor) else new_param 

770 if local_tensor.is_meta: 

771 return 

772 updated_local_tensor = False 

773 # local_tensor can be padded twice 

774 # 1st time in fully_shard(model) 

775 # 2nd time in model(input) lazy_init 

776 # 2nd time should be no-op if parameters remain unchanged 

777 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init 

778 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop 

779 # and use `sd` without calling .state_dict() per iteration 

780 same_local_tensor = False 

781 if isinstance(self._sharded_param_data, torch.Tensor): 

782 same_local_tensor = ( 

783 # when sharding param with shape (1, ...) over 2 ranks 

784 # local_tensor on rank 1 can be size 0, data_ptr() can be 0 

785 self._sharded_param_data.untyped_storage().data_ptr() > 0 

786 and self._sharded_param_data.untyped_storage().data_ptr() 

787 == local_tensor.untyped_storage().data_ptr() 

788 ) 

789 sharded_size = self.sharded_size 

790 shard_dim = self.hsdp_placement.dim 

791 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 

792 if not same_local_tensor: 

793 if local_tensor.size() != sharded_size: 

794 raise AssertionError( 

795 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}" 

796 ) 

797 updated_local_tensor = True 

798 if self.pin_memory and not local_tensor.is_pinned(): 

799 local_tensor = local_tensor.cpu().pin_memory() 

800 updated_local_tensor = True 

801 if not same_local_tensor: 

802 self._sharded_param_data = local_tensor.view(-1) 

803 if not isinstance(self.sharded_param, DTensor): 

804 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") 

805 if updated_local_tensor: 

806 # Only change the local tensor object if needed 

807 self.sharded_param._local_tensor = local_tensor.narrow( 

808 dim=shard_dim, start=0, length=length 

809 ) 

810 if not self.sharded_param._local_tensor.is_contiguous(): 

811 raise AssertionError( 

812 "Expected sharded_param._local_tensor to be contiguous" 

813 ) 

814 self._sharding_spec = cast(DTensor, self.sharded_param).layout 

815 

816 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[torch.Tensor, Optional[dist.Work]]: 

817 """ 

818 Perform all-gather to get unsharded parameter data. 

819 

820 Args: 

821 async_op: Whether to execute asynchronously. 

822 

823 Returns: 

824 (unsharded_param, handle): Unsharded parameter data and communication handle. 

825 """ 

826 # If parameter is not sharded (below threshold), no communication needed 

827 if not self.is_sharded: 

828 all_gather_input = self.all_gather_inputs[0] 

829 self.init_all_gather_outputs( 

830 all_gather_input_numels=[all_gather_input.numel()], 

831 all_gather_input_dtypes=[all_gather_input.dtype], 

832 world_size=1, 

833 device=self.device, 

834 ) 

835 self.alloc_all_gather_outputs() 

836 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input) 

837 return self.all_gather_outputs[0], None 

838 

839 # Get input data 

840 all_gather_input = self.all_gather_inputs[0] 

841 

842 # Initialize output buffer 

843 self.init_all_gather_outputs( 

844 all_gather_input_numels=[all_gather_input.numel()], 

845 all_gather_input_dtypes=[all_gather_input.dtype], 

846 world_size=self.shard_world_size, 

847 device=self.device, 

848 ) 

849 self.alloc_all_gather_outputs() 

850 

851 if self.sharded_group_info.group is None or self.shard_world_size <= 1: 

852 # No communication needed, just copy 

853 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input) 

854 return self.all_gather_outputs[0], None 

855 

856 # Execute all_gather_into_tensor 

857 handle = dist.all_gather_into_tensor( 

858 self.all_gather_outputs[0], 

859 all_gather_input, 

860 group=self.sharded_group_info.group, 

861 async_op=async_op, 

862 ) 

863 

864 return self.all_gather_outputs[0], handle 

865 

866 def unshard(self, async_op: bool = False) -> None: 

867 if self.prefetch_handle is not None: 

868 # Already triggered by HSDPState.prefetch(), so return directly. 

869 return # no-op 

870 

871 _, handle = self._get_unsharded_param_data(async_op=async_op) 

872 self.prefetch_handle = handle 

873 

874 def wait_for_unshard(self) -> None: 

875 self._assert_in_states(ShardedState.SHARDED) 

876 

877 if self.prefetch_handle is not None: 

878 self.prefetch_handle.wait() 

879 self.prefetch_handle = None 

880 

881 self.init_unsharded_param() 

882 self.to_unsharded() 

883 

884 def shard(self) -> None: 

885 """ 

886 Transition parameter from unsharded back to sharded state. 

887 """ 

888 self._assert_in_states(ShardedState.UNSHARDED) 

889 self.to_sharded() 

890 

891 def reduce_scatter_grad( 

892 self, 

893 async_op: bool = True, 

894 dtype: Optional[torch.dtype] = None, 

895 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG, 

896 output_buffer: Optional[torch.Tensor] = None, 

897 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]: 

898 """ 

899 Perform reduce-scatter on gradient to reduce and shard the full gradient. 

900 

901 Args: 

902 async_op: Whether to execute asynchronously. 

903 dtype: reduce dtype. 

904 reduce_op: do reduce-scatter avg or sum. 

905 output_buffer: Optional pre-allocated output buffer for fused all-reduce. 

906 When provided, reduce_scatter writes directly into this buffer, 

907 enabling zero-copy fusion with subsequent all_reduce operations. 

908 The buffer must have the correct size (sharded_size.numel()) and dtype. 

909 

910 Returns: 

911 (sharded_grad, handle): Sharded gradient and communication handle. 

912 """ 

913 self._assert_in_states(ShardedState.UNSHARDED) 

914 

915 # Choose gradient source based on use_accumulated_grad flag 

916 if self.unsharded_accumulated_grad is not None: 

917 grad = self.unsharded_accumulated_grad_data 

918 else: 

919 grad = self.unsharded_grad_data 

920 reduce_dtype = dtype or grad.dtype 

921 self._grad = grad.to(reduce_dtype) 

922 plan_world_size = ( 

923 self.shard_world_size 

924 if self.is_sharded 

925 and self.sharded_group_info.group is not None 

926 and self.shard_world_size > 1 

927 else 1 

928 ) 

929 plan = build_rs_plan(self, self._grad, plan_world_size) 

930 grad_flat = pack_for_reduce_scatter(self._grad, plan).reshape(-1) 

931 

932 # If parameter is not sharded (below threshold), no reduce-scatter needed 

933 if not self.is_sharded: 

934 if output_buffer is not None: 

935 output_buffer.copy_(grad_flat) 

936 self._reduce_scatter_output = output_buffer 

937 else: 

938 self._reduce_scatter_output = grad_flat 

939 self.reduce_scatter_handle = None 

940 return grad_flat, None 

941 

942 if self.sharded_group_info.group is None or self.shard_world_size <= 1: 

943 if output_buffer is not None: 

944 output_buffer.copy_(grad_flat) 

945 self._reduce_scatter_output = output_buffer 

946 else: 

947 self._reduce_scatter_output = grad_flat 

948 self.reduce_scatter_handle = None 

949 # No communication needed 

950 return grad_flat, None 

951 

952 # Calculate output size 

953 output_numel = grad_flat.numel() // self.shard_world_size 

954 # Use provided output buffer or allocate a new one 

955 if output_buffer is not None: 

956 if output_buffer.numel() != output_numel: 

957 raise ValueError( 

958 f"output_buffer size mismatch: expected {output_numel}, got {output_buffer.numel()}" 

959 ) 

960 if output_buffer.dtype != reduce_dtype: 

961 raise ValueError( 

962 f"output_buffer dtype mismatch: expected {reduce_dtype}, got {output_buffer.dtype}" 

963 ) 

964 self._reduce_scatter_output = output_buffer 

965 else: 

966 self._reduce_scatter_output = torch.empty(output_numel, dtype=reduce_dtype, device=self._grad.device) 

967 

968 # Execute reduce_scatter_tensor 

969 self.reduce_scatter_handle = dist.reduce_scatter_tensor( 

970 self._reduce_scatter_output, 

971 grad_flat, 

972 op=reduce_op, 

973 group=self.sharded_group_info.group, 

974 async_op=async_op, 

975 ) 

976 return self._reduce_scatter_output, self.reduce_scatter_handle 

977 

978 def all_reduce_grad( 

979 self, 

980 grad: Optional[torch.Tensor] = None, 

981 dtype: Optional[torch.dtype] = None, 

982 async_op: bool = True, 

983 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG 

984 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]: 

985 """ 

986 Perform all-reduce on gradient (across replicate dimension in HSDP mode). 

987 

988 Args: 

989 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad 

990 or unsharded_accumulated_grad based on use_accumulated_grad flag. 

991 async_op: Whether to execute asynchronously. 

992 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG. 

993 

994 Returns: 

995 (reduced_grad, handle): Reduced gradient and communication handle. 

996 """ 

997 # If grad is not provided, get from parameter 

998 if grad is None: 

999 if self.unsharded_accumulated_grad is not None: 

1000 grad = self.unsharded_accumulated_grad_data 

1001 else: 

1002 grad = self.unsharded_grad_data 

1003 

1004 if dtype is not None and dtype != grad.dtype: 

1005 grad = grad.to(dtype) 

1006 

1007 if self.unsharded_group_info.group is None or self.replicate_world_size <= 1: 

1008 return grad, None 

1009 

1010 self.all_reduce_handle = dist.all_reduce(grad, op=reduce_op, 

1011 group=self.unsharded_group_info.group, async_op=async_op) 

1012 self._all_reduce_output = grad 

1013 return grad, self.all_reduce_handle 

1014 

1015 

1016def set_requires_grad_if_needed( 

1017 src_tensor: torch.Tensor, dst_tensor: torch.Tensor 

1018) -> None: 

1019 """set dst_tensor requires_grads from src_tensor if needed.""" 

1020 if src_tensor.requires_grad != dst_tensor.requires_grad: 

1021 dst_tensor.requires_grad_(src_tensor.requires_grad)