Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / api.py: 43%

294 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""hybrid shard data parallel interface""" 

16import warnings 

17from collections import namedtuple 

18from typing import Any, List, Mapping, cast, Optional, Union 

19 

20from hyper_parallel.platform.platform import PlatformType 

21from hyper_parallel.core.fully_shard.utils import MixedPrecisionPolicy, OffloadPolicy 

22from hyper_parallel import DeviceMesh, init_device_mesh 

23from hyper_parallel.platform import get_platform 

24from hyper_parallel.core.dtensor.dtensor import DTensor, distribute_tensor 

25from hyper_parallel.core.fully_shard.hsdp_utils import ( 

26 get_managed_modules_parameters, 

27 is_dtensor_managed_param, 

28 get_dtensor_managed_mesh, 

29) 

30 

31platform = get_platform() 

32 

33origin_class_to_extend_class = {} 

34 

35 

36def _resolve_comm_fusion_zero_copy_default( 

37 platform_type: PlatformType, 

38 comm_fusion: bool, 

39 comm_fusion_zero_copy: Optional[bool], 

40) -> bool: 

41 """Resolve backend-specific default for the comm_fusion zero-copy path.""" 

42 if comm_fusion_zero_copy is not None: 

43 return comm_fusion_zero_copy 

44 if not comm_fusion: 

45 return False 

46 if platform_type == PlatformType.PYTORCH: 

47 return True 

48 if platform_type == PlatformType.MINDSPORE: 

49 return False 

50 return False 

51 

52 

53def _check_strict_keys( 

54 module: platform.Module, state_dict: Mapping[str, Any], 

55) -> None: 

56 """Raise ``RuntimeError`` if *state_dict* keys do not match *module*.""" 

57 expected_keys = set(module.state_dict().keys()) 

58 missing = expected_keys - set(state_dict.keys()) 

59 unexpected = set(state_dict.keys()) - expected_keys 

60 error_msgs: list[str] = [] 

61 if missing: 

62 error_msgs.append( 

63 "Missing key(s): " + ", ".join(repr(k) for k in sorted(missing)) 

64 ) 

65 if unexpected: 

66 error_msgs.append( 

67 "Unexpected key(s): " + ", ".join(repr(k) for k in sorted(unexpected)) 

68 ) 

69 if error_msgs: 

70 raise RuntimeError( 

71 f"Error(s) in loading state_dict for " 

72 f"{module.__class__.__name__}:\n\t" 

73 + "\n\t".join(error_msgs) 

74 ) 

75 

76 

77def _resolve_local_tensor( 

78 key: str, val: platform.Tensor, target: DTensor, 

79) -> platform.Tensor: 

80 """Return the local shard tensor to be loaded into *target*.""" 

81 if isinstance(val, DTensor): 

82 return val.to_local() 

83 local_shape = tuple(target.local_shape) 

84 global_shape = tuple(target.shape) 

85 val_shape = tuple(val.shape) 

86 if val_shape == local_shape: 

87 return val 

88 if val_shape == global_shape: 

89 wrapped = distribute_tensor( 

90 val, target.device_mesh, 

91 target.layout.alias_placements if target.layout else target.placements, 

92 ) 

93 return wrapped.to_local() 

94 

95 raise ValueError( 

96 f"load '{key}': plain tensor shape {val_shape} " 

97 f"matches neither local shard {local_shape} " 

98 f"nor global {global_shape}." 

99 ) 

100 

101 

102class _UnshardHandle: 

103 """Unshard handle for user call HSDPModule.unshard(async_op=True)""" 

104 def __init__(self, hsdp_state=None): 

105 """ 

106 Initialize an async unshard handle. 

107 

108 Args: 

109 hsdp_state (HSDPState, optional): The state to wait on. None means a no-op handle. 

110 """ 

111 self._hsdp_state = hsdp_state 

112 

113 def wait(self): 

114 """Block until the async unshard operation completes.""" 

115 if self._hsdp_state is not None: 

116 self._hsdp_state.wait_for_unshard() 

117 self._hsdp_state = None 

118 

119 

120class HSDPModule: 

121 """ 

122 The hsdp block of neural networks with hsdp interface. 

123 

124 Supported Platforms: 

125 ``MindSpore`` ``torch`` 

126 """ 

127 

128 def __init__(self): 

129 """Initialize HSDPModule.""" 

130 self.hsdp_scheduler = None # Initialized in hsdp_init() 

131 

132 # pylint: disable=C0415 

133 def hsdp_init(self, platform_type, module, mesh, reshard_after_forward, 

134 shard_placement_fn, mp_policy, offload_policy, ignored_params, replicate_params, device, 

135 comm_fusion, comm_fusion_zero_copy: Optional[bool] = None): 

136 """init hsdp2 scheduler.""" 

137 scheduler_class = None 

138 if platform_type == PlatformType.MINDSPORE: 

139 from hyper_parallel.platform.mindspore.fully_shard.scheduler import MindSporeHSDPSchedulerV2 

140 scheduler_class = MindSporeHSDPSchedulerV2 

141 else: 

142 from hyper_parallel.platform.torch.fully_shard.scheduler import TorchHSDPSchedulerV2 

143 scheduler_class = TorchHSDPSchedulerV2 

144 

145 resolved_comm_fusion_zero_copy = _resolve_comm_fusion_zero_copy_default( 

146 platform_type, 

147 comm_fusion, 

148 comm_fusion_zero_copy, 

149 ) 

150 

151 self.hsdp_scheduler = scheduler_class(module, 

152 mesh, 

153 reshard_after_forward, 

154 shard_placement_fn, 

155 mp_policy, 

156 offload_policy, 

157 ignored_params, 

158 replicate_params, 

159 device, 

160 comm_fusion, 

161 resolved_comm_fusion_zero_copy, 

162 ) 

163 

164 def set_requires_gradient_sync(self, requires_grad_sync): 

165 r""" 

166 set requires grad sync flag. 

167 Args: 

168 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process. 

169 Raises: 

170 ValueError: If `requires_grad_sync` is not bool. 

171 """ 

172 if not isinstance(requires_grad_sync, bool): 

173 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.") 

174 if not hasattr(self, "hsdp_scheduler"): 

175 raise ValueError("call hsdp interface first.") 

176 

177 for _, module in platform.get_cells_and_names(self): 

178 if isinstance(module, HSDPModule): 

179 module.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync) 

180 

181 def zero_grad(self): 

182 """zero accumunication grads""" 

183 if not hasattr(self, "hsdp_scheduler"): 

184 raise ValueError("call hsdp interface first.") 

185 if platform.platform_type == PlatformType.PYTORCH: 

186 return super().zero_grad() 

187 for _, module in platform.get_cells_and_names(self): 

188 if isinstance(module, HSDPModule): 

189 module.hsdp_scheduler.zero_grad() 

190 

191 def set_modules_to_forward_prefetch(self, modules): 

192 """set forward prefetch module list to prefetch all gather for unsharded parameters""" 

193 if not isinstance(modules, (tuple, list)): 

194 raise ValueError("modules must be HSDPModule list") 

195 for module in modules: 

196 if not isinstance(module, HSDPModule): 

197 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.") 

198 if not hasattr(self, "hsdp_scheduler"): 

199 raise ValueError("call hsdp interface first.") 

200 self.hsdp_scheduler.set_forward_prefetch_cells(modules) 

201 

202 def set_modules_to_backward_prefetch(self, modules): 

203 """set backward prefetch module list to prefetch all gather for unsharded parameters""" 

204 if not isinstance(modules, (tuple, list)): 

205 raise ValueError("modules must be HSDPModule list") 

206 for module in modules: 

207 if not isinstance(module, HSDPModule): 

208 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.") 

209 if not hasattr(self, "hsdp_scheduler"): 

210 raise ValueError("call fully_shard interface first.") 

211 self.hsdp_scheduler.set_backward_prefetch_cells(modules) 

212 

213 def reshard(self) -> None: 

214 """reshard all sharded parameters""" 

215 if not self.hsdp_scheduler: 

216 raise ValueError("hsdp_scheduler is None") 

217 hsdp_state = self.hsdp_scheduler.hsdp_state 

218 if hsdp_state: 

219 hsdp_state.shard() 

220 

221 def unshard(self, async_op: bool = False): 

222 """unshard all sharded parameters""" 

223 if not isinstance(async_op, bool): 

224 raise ValueError(f"async_op should be a bool, got {type(async_op)}") 

225 if not self.hsdp_scheduler: 

226 raise ValueError("hsdp_scheduler is None") 

227 hsdp_state = self.hsdp_scheduler.hsdp_state 

228 if hsdp_state: 

229 hsdp_state.unshard(async_op) # pylint: disable=too-many-function-args 

230 if async_op: 

231 return _UnshardHandle(hsdp_state=hsdp_state) 

232 return None 

233 

234 def load_state_dict( 

235 self, 

236 state_dict: Mapping[str, Any], 

237 strict: bool = True, 

238 assign: bool = False, 

239 ): 

240 """ 

241 Load state dict by copying directly into local shards. 

242 

243 Bypasses ``super().load_state_dict()`` because the standard PyTorch 

244 implementation triggers ``copy_`` through the DTensor dispatcher, which 

245 is not registered in the hyper-parallel layout system. 

246 

247 Each value in ``state_dict`` is dispatched by type: 

248 - hyper DTensor: extract local shard and copy directly. 

249 - plain Tensor whose shape == local shard shape: copy as-is. 

250 - plain Tensor whose shape == global shape: distribute via 

251 ``distribute_tensor``, then copy the local shard. 

252 

253 Args: 

254 state_dict (Mapping[str, Any]): Fully-qualified parameter/buffer 

255 names mapped to tensors (DTensor or plain Tensor). 

256 strict (bool): If ``True`` (default), missing or unexpected keys 

257 raise ``RuntimeError``, matching ``nn.Module.load_state_dict`` 

258 semantics. 

259 assign (bool): Accepted for API compatibility with 

260 ``nn.Module.load_state_dict(assign=True)`` but currently 

261 ignored; HSDP always copies into existing DTensor storage. 

262 

263 Raises: 

264 RuntimeError: When ``strict`` is ``True`` and keys do not match. 

265 ValueError: When a plain tensor shape matches neither the local 

266 shard shape nor the global shape of the target DTensor. 

267 """ 

268 if assign: 

269 warnings.warn( 

270 "HSDPModule.load_state_dict: assign=True is ignored; " 

271 "HSDP always copies into existing DTensor parameters.", 

272 stacklevel=2, 

273 ) 

274 self_module = cast(platform.Module, self) 

275 

276 target_map: dict[str, platform.Tensor] = {} 

277 for name, p in platform.parameters_dict(self_module): 

278 target_map[name] = p 

279 for name, b in self_module.named_buffers(): 

280 target_map[name] = b 

281 

282 if strict: 

283 _check_strict_keys(self_module, state_dict) 

284 

285 with platform.no_grad(): 

286 for key, val in state_dict.items(): 

287 target = target_map.get(key) 

288 if target is None: 

289 continue 

290 

291 if isinstance(target, DTensor): 

292 val = _resolve_local_tensor(key, val, target) 

293 platform.load_into_param(target, val) 

294 

295 # Trigger load_state_dict post-hooks so that HSDP internal 

296 # bookkeeping (e.g. _sharded_param_data) stays in sync. 

297 # Pass an IncompatibleKeys with the same attribute names as PyTorch 

298 # so external hooks can safely read .missing_keys/.unexpected_keys. 

299 _IK = namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]) 

300 incompatible_keys = _IK([], []) 

301 for _, module in platform.get_cells_and_names(self_module): 

302 hooks = module._load_state_dict_post_hooks # pylint: disable=protected-access 

303 for hook in hooks.values(): 

304 hook(module, incompatible_keys) 

305 

306 def set_is_last_backward(self, is_last_backward: bool): 

307 """set is_last_backward flag""" 

308 self.hsdp_scheduler.scheduler_ctx.is_last_backward = is_last_backward 

309 

310 def set_requires_all_reduce(self, requires_all_reduce: bool, *, recurse: bool = True) -> None: 

311 """set requires_all_reduce flag""" 

312 if not isinstance(requires_all_reduce, bool): 

313 raise ValueError( 

314 f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}" 

315 ) 

316 if not recurse: 

317 raise NotImplementedError( 

318 "Currently impl is equal to recurse=True, " 

319 "need support module_param mapping." 

320 ) 

321 self_module = cast(platform.Module, self) 

322 for _, module in platform.get_cells_and_names(self_module): 

323 if isinstance(module, HSDPModule): 

324 module.hsdp_scheduler.set_requires_all_reduce(requires_all_reduce) 

325 

326 def set_reshard_after_forward(self, reshard_after_forward: bool, recurse: bool = True) -> None: 

327 """set reshard_after_forward flag""" 

328 if not isinstance(reshard_after_forward, bool): 

329 raise ValueError( 

330 f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}" 

331 ) 

332 if not recurse: 

333 raise NotImplementedError( 

334 "Currently impl is equal to recurse=True, " 

335 "need support module_param mapping." 

336 ) 

337 self_module = cast(platform.Module, self) 

338 for _, module in platform.get_cells_and_names(self_module): 

339 if isinstance(module, HSDPModule): 

340 module.hsdp_scheduler.set_reshard_after_forward(reshard_after_forward) 

341 

342 def set_reshard_after_backward(self, reshard_after_backward: bool, recurse: bool = True) -> None: 

343 """set reshard_after_backward flag""" 

344 if not isinstance(reshard_after_backward, bool): 

345 raise ValueError( 

346 f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}" 

347 ) 

348 if not recurse: 

349 raise NotImplementedError( 

350 "Currently impl is equal to recurse=True, " 

351 "need support module_param mapping." 

352 ) 

353 self_module = cast(platform.Module, self) 

354 for _, module in platform.get_cells_and_names(self_module): 

355 if isinstance(module, HSDPModule): 

356 module.hsdp_scheduler.set_reshard_after_backward(reshard_after_backward) 

357 

358 def set_reduce_op_type(self, reduce_op_type) -> None: 

359 """ 

360 Set reduce_op_type for all gradient reductions in fully_shard. 

361 

362 Supports ``"avg"`` and ``"sum"``. Local-parameter FSDP/HSDP keeps the 

363 historical ``"avg"`` default, while DTensor-based paths default to ``"sum"``. 

364 """ 

365 if hsdp_state := self.hsdp_scheduler.hsdp_state: 

366 hsdp_state.set_reduce_op_type(reduce_op_type) 

367 

368 

369def _extend_module_with_hsdp_interface(module): 

370 """Dynamically extend module's class to inherit from HSDPModule, adding HSDP capabilities.""" 

371 origin_class = module.__class__ 

372 extend_class = origin_class_to_extend_class.get(origin_class, None) 

373 if extend_class is None: 

374 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPModule, origin_class), {}) 

375 origin_class_to_extend_class[origin_class] = extend_class 

376 module.__class__ = extend_class 

377 

378 

379def _get_root_modules(modules: List[platform.Module]) -> List[platform.Module]: 

380 """ 

381 Returns the modules in ``modules`` that are root modules (i.e. parent-less) 

382 with respect to the set ``modules``. In other words, these are the modules 

383 in ``modules`` that are not the child of any other module in ``modules``. 

384 

385 Aligned with PyTorch torch.distributed.utils._get_root_modules. 

386 """ 

387 root_modules: List[platform.Module] = [] 

388 

389 def _get_submodules(mod): 

390 if platform.platform_type == PlatformType.MINDSPORE: 

391 return set(c for _, c in mod.cells_and_names()) 

392 return set(mod.modules()) 

393 

394 module_to_modules: dict[platform.Module, set] = { 

395 m: _get_submodules(m) for m in modules 

396 } 

397 for candidate in modules: 

398 is_root = True 

399 for mod, submodules in module_to_modules.items(): 

400 if candidate is not mod and candidate in submodules: 

401 is_root = False 

402 break 

403 if is_root: 

404 root_modules.append(candidate) 

405 return root_modules 

406 

407 

408def _check_module_valid(platform_type, module): 

409 """check module valid""" 

410 if platform_type == PlatformType.MINDSPORE: 

411 from mindspore.nn.cell import Cell 

412 if not isinstance(module, Cell): 

413 raise ValueError(f"module's type must be nn.cell but got {type(module)}.") 

414 else: 

415 from torch.nn import Module 

416 if not isinstance(module, Module): 

417 raise ValueError(f"module's type must be nn.Module but got {type(module)}.") 

418 

419 

420def _validate_module_for_fully_shard( 

421 module: Union[platform.Module, List[platform.Module]], platform_type 

422) -> None: 

423 """Validate module(s) for fully_shard. Platform-aware for single module.""" 

424 if isinstance(module, list): 

425 if len(module) == 0: 

426 raise ValueError("fully_shard does not support empty list of modules.") 

427 for i, m in enumerate(module): 

428 try: 

429 _check_module_valid(platform_type, m) 

430 except ValueError: 

431 raise ValueError( 

432 f"fully_shard expects nn.Module or list[nn.Module], " 

433 f"but got list with {type(m).__name__} at index {i}." 

434 ) from None 

435 else: 

436 _check_module_valid(platform_type, module) 

437 

438 

439def _check_hsdp_input_valid(platform_type, module, shard_size, threshold, optimizer_level, enable_grad_accumulation, 

440 grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): 

441 """check hsdp input valid""" 

442 _check_module_valid(platform_type, module) 

443 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1): 

444 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.") 

445 if not isinstance(threshold, int) or threshold < 0: 

446 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.") 

447 if optimizer_level not in ["level1", "level2", "level3"]: 

448 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.") 

449 if not isinstance(enable_grad_accumulation, bool): 

450 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.") 

451 if not isinstance(grad_scale, float): 

452 raise ValueError(f"grad_scale must be float but got {grad_scale}.") 

453 if platform_type == PlatformType.MINDSPORE: 

454 from mindspore._c_expression.typing import Type 

455 if reduce_dtype is not None and not isinstance(reduce_dtype, Type): 

456 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.") 

457 else: 

458 import torch 

459 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype): 

460 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.") 

461 if not isinstance(comm_async, bool): 

462 raise ValueError(f"comm_async must be bool but got {comm_async}.") 

463 if not isinstance(comm_fusion, bool): 

464 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.") 

465 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1): 

466 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.") 

467 

468 

469def _get_device_from_mesh(mesh: DeviceMesh): 

470 """Extract and validate the torch device from the device mesh.""" 

471 device = None 

472 device_type = mesh.device_type 

473 if device_type not in ("npu", "cuda"): 

474 raise AssertionError( 

475 f"hyper_parallel.fully_shard support device in [torch.npu, torch.cuda], " 

476 f"but got '{device_type}'" 

477 ) 

478 if platform.platform_type == PlatformType.PYTORCH: 

479 device_handle = platform.get_device_handle(device_type) 

480 if device_handle is None: 

481 raise ValueError( 

482 f"hyper_parallel.fully_shard can't find device_handle of " 

483 f"'torch.{device_type}', check the environment." 

484 ) 

485 if device_handle.is_available(): 

486 import torch 

487 device = torch.device(device_handle.current_device()) 

488 else: 

489 device = device_type 

490 return device 

491 

492 

493def _normalize_replicate_params( 

494 replicate_params: Optional[set[platform.Parameter]], 

495) -> set[platform.Parameter]: 

496 """ 

497 Normalize replicate_params for fully_shard 

498 Args: 

499 replicate_params (Optional[set[nn.Parameter]]): Set of parameters to exclude from sharding. 

500 Returns: 

501 set[nn.Parameter]: Set of parameters to exclude from sharding. 

502 """ 

503 if replicate_params is None: 

504 return set() 

505 out = set(replicate_params) 

506 for p in out: 

507 if not isinstance(p, (platform.Parameter, DTensor)): 

508 raise TypeError( 

509 "replicate_params must contain only nn.Parameter or DTensor, " 

510 f"got {type(p).__name__}." 

511 ) 

512 return out 

513 

514 

515def _get_modules_parameters(modules, ignored_params=None): 

516 """Collect deduplicated parameters from module roots.""" 

517 return get_managed_modules_parameters(modules, ignored_params) 

518 

519 

520def fully_shard( 

521 module: Union[platform.Module, List[platform.Module]], 

522 *, 

523 mesh: Optional[DeviceMesh] = None, 

524 reshard_after_forward: bool = True, 

525 shard_placement_fn: None = None, 

526 mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), 

527 offload_policy: OffloadPolicy = OffloadPolicy(), 

528 ignored_params: Optional[set[platform.Parameter]] = None, 

529 replicate_params: Optional[set[platform.Parameter]] = None, 

530 comm_fusion: bool = False, 

531 comm_fusion_zero_copy: Optional[bool] = None, 

532) -> Union[platform.Module, List[platform.Module]]: 

533 

534 """ 

535 Apply fully_shard to a module (or list of modules) for distributed training with parameter sharding. 

536 

537 This interface provides PyTorch-compatible HSDP (Hybrid Sharded Data Parallelism) 

538 functionality, enabling efficient training of large models by sharding parameters 

539 across multiple devices. The module is automatically enhanced with distributed 

540 capabilities including parameter sharding, gradient synchronization, and memory 

541 management. 

542 

543 When a list of modules is passed, they are treated as one FSDP unit (parameters 

544 grouped together). Both PyTorch and MindSpore platforms support list input. 

545 

546 Parameters: 

547 module (nn.Module or List[nn.Module]): 

548 The module(s) to apply fully_shard to. Modified in-place. When a list 

549 is passed, parameters from all modules are grouped as one FSDP unit. 

550 

551 mesh (Optional[DeviceMesh], default=None): 

552 The device mesh defining the process topology for distributed training. 

553 If None, fully_shard keeps pure-DTensor modules on their original 

554 distributed layout and only creates a default 1D mesh when local 

555 parameters need explicit data-parallel/FSDP management. 

556 

557 reshard_after_forward (bool, default=True): 

558 Whether to automatically reshard parameters after forward. When True, 

559 parameters are resharded immediately after they are no longer needed, 

560 freeing memory for subsequent operations. Set to False if you want to 

561 keep parameters unsharded for backward pass or manual control. 

562 

563 shard_placement_fn (Callable, default=None): 

564 A callable that determines how to shard each parameter. The function 

565 should accept a parameter and return a Shard object specifying the 

566 sharding dimension, or None to use default sharding (dimension 0) 

567 

568 mp_policy (MixedPrecisionPolicy, default=MixedPrecisionPolicy()): 

569 Mixed precision training policy controlling data type conversions. 

570 offload_policy (OffloadPolicy, default=OffloadPolicy()): 

571 Memory offload policy for reducing device memory usage. 

572 

573 ignored_params (Optional[set[nn.Parameter]], default=None): 

574 Set of parameters to exclude from fully_shard management entirely. 

575 These parameters are left on the original module as regular parameters, 

576 are not sharded, and do not participate in fully_shard gradient 

577 synchronization. Use this for parameters that should remain outside 

578 the fully_shard lifecycle. 

579 

580 comm_fusion (bool, default=False): 

581 Whether enable all_gather fusion and reduce_scatter fusion. 

582 

583 replicate_params (Optional[set[nn.Parameter]], default=None): 

584 Set of parameters to keep replicated while still managing them under 

585 fully_shard. These parameters are not sharded, but their gradients 

586 are still synchronized with DDP-style all-reduce over the current 

587 fully_shard communication domain. This differs from ``ignored_params``, 

588 which skips fully_shard management and gradient synchronization 

589 entirely for the selected parameters. 

590 

591 comm_fusion_zero_copy (Optional[bool], default=None): 

592 Whether allow the experimental zero-copy path for 

593 ``comm_fusion``. When set to ``None``, fully_shard uses a backend-specific 

594 default: 

595 - PyTorch: enabled automatically when ``comm_fusion=True`` 

596 - MindSpore: disabled automatically even when ``comm_fusion=True`` 

597 When enabled, fully_shard may rebase sharded local parameter storage 

598 into one shared flat buffer so fused all-gather can read directly from 

599 contiguous memory. This path depends on optimizer compatibility with 

600 view-backed parameters. 

601 

602 Returns: 

603 nn.Module or List[nn.Module]: The input module(s) with HSDP capabilities added. 

604 """ 

605 platform_type = platform.platform_type 

606 _validate_module_for_fully_shard(module, platform_type) 

607 if platform_type == PlatformType.MINDSPORE: 

608 from hyper_parallel.platform.mindspore.autograd_compat import enable_mindspore_backward_compat 

609 

610 enable_mindspore_backward_compat() 

611 

612 arg_module = module 

613 if isinstance(module, list): 

614 modules = tuple(_get_root_modules(module)) 

615 else: 

616 modules = (module,) 

617 

618 for mod in modules: 

619 _extend_module_with_hsdp_interface(mod) 

620 

621 params = _get_modules_parameters(modules, ignored_params) 

622 has_dtensor_param = any(is_dtensor_managed_param(param) for param in params) 

623 replicate_params = _normalize_replicate_params(replicate_params) 

624 

625 if mesh is None and not has_dtensor_param: 

626 mesh = init_device_mesh(device_type="npu", mesh_shape=(platform.get_world_size(),)) 

627 if mesh is not None: 

628 device = _get_device_from_mesh(mesh) 

629 else: 

630 compat_mesh = next( 

631 (dtensor_mesh for param in params if (dtensor_mesh := get_dtensor_managed_mesh(param)) is not None), 

632 None, 

633 ) 

634 if compat_mesh is None: 

635 raise ValueError("fully_shard could not resolve a DTensor mesh for compatibility mode.") 

636 device = _get_device_from_mesh(compat_mesh) 

637 

638 init_modules = modules 

639 modules[0].hsdp_init( 

640 platform_type, 

641 init_modules, 

642 mesh, 

643 reshard_after_forward, 

644 shard_placement_fn, 

645 mp_policy, 

646 offload_policy, 

647 ignored_params, 

648 replicate_params, 

649 device, 

650 comm_fusion, 

651 comm_fusion_zero_copy, 

652 ) 

653 # Share the same scheduler handle with other roots so mods[i].unshard()/prefetch work 

654 if len(modules) > 1: 

655 for mod in modules[1:]: 

656 mod.hsdp_scheduler = modules[0].hsdp_scheduler 

657 return arg_module 

658 

659 

660def get_model_state_dict(model, *, options=None): 

661 """Get model state dict with platform-specific implementation. 

662 

663 Delegates to the platform-specific implementation at runtime. 

664 Users import from here instead of platform internals. 

665 """ 

666 return platform.get_model_state_dict(model, options=options) 

667 

668 

669def hsdp_sync_stream(): 

670 """Wait for hsdp gradient handle to be completed.""" 

671 platform.wait_grad_handle()