Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / state.py: 61%

312 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP cell state""" 

16from typing import Optional 

17import mindspore as ms 

18from mindspore import ops 

19import mindspore.mint.distributed as dist 

20from hyper_parallel.core.fully_shard.hsdp_state import HSDPState 

21from hyper_parallel.core.fully_shard.hsdp_utils import ( 

22 _get_param_module_infos, 

23 FullyShardParamMode, 

24 infer_fully_shard_param_mode, 

25) 

26from hyper_parallel.platform.mindspore.fully_shard.pack_utils import build_rs_plan 

27from hyper_parallel.platform.mindspore.fully_shard.param import MindSporeHSDPParamV2 

28from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version 

29from hyper_parallel.platform.mindspore.fully_shard.param_group import HSDPParamGroup, get_comm_ctx 

30from hyper_parallel.platform.mindspore.utils import normalize_runtime_device 

31from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy 

32 

33 

34def _to_dtype_if_needed( 

35 tensor: ms.Tensor, dtype: Optional[ms.Type] 

36) -> ms.Tensor: 

37 """Cast tensor to the given dtype if it differs from current dtype. 

38 

39 Args: 

40 tensor: The input tensor to potentially cast. 

41 dtype: Target dtype. If None or same as tensor dtype, no-op. 

42 """ 

43 if dtype is not None and tensor.dtype != dtype: 

44 return tensor.to(dtype) 

45 return tensor 

46 

47 

48class MindSporeHSDPStateV2(HSDPState): 

49 """MindSpore HSDP cell state""" 

50 # DTensor compat parameters in pure-TP mode can accumulate gradients 

51 # directly on ``sharded_param.grad`` without materializing an 

52 # ``_unsharded_param``. Track those async all-reduces separately from the 

53 # standard unsharded-gradient queues. 

54 pre_direct_all_reduce_grads = [] 

55 

56 @staticmethod 

57 def _get_pending_unsharded_grad(hsdp_param): 

58 """Return the pending unsharded gradient tensor for reduction paths.""" 

59 if hsdp_param.unsharded_accumulated_grad is not None: 

60 return hsdp_param.unsharded_accumulated_grad_data 

61 return hsdp_param.unsharded_grad_data 

62 

63 @staticmethod 

64 def _has_pending_unsharded_grad(hsdp_param): 

65 """Whether the parameter currently has a gradient waiting for reduction.""" 

66 if hsdp_param.unsharded_accumulated_grad is not None: 

67 return True 

68 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

69 return False 

70 return hsdp_param.unsharded_param.grad is not None 

71 

72 @staticmethod 

73 def _get_local_sharded_grad(hsdp_param): 

74 """Return the local gradient tensor currently stored on ``sharded_param``.""" 

75 grad = hsdp_param.sharded_param.grad 

76 if grad is None: 

77 return None 

78 to_local = getattr(grad, "to_local", None) 

79 if callable(to_local): 

80 return to_local() 

81 return grad 

82 

83 @staticmethod 

84 def _synchronize_current_stream_if_needed(need_synchronize: bool) -> None: 

85 """Synchronize the current device stream after non-blocking CPU offload.""" 

86 if not need_synchronize: 

87 return 

88 ms.runtime.current_stream().synchronize() 

89 

90 def __init__(self, cell, mesh_info, config, platform, device=None): 

91 super().__init__(cell, mesh_info, config, platform, device) 

92 self.comm_fusion = config.comm_fusion 

93 # Do ReduceScatter/AllReduce for grad 

94 self.mp_policy = config.mp_policy 

95 self.offload_policy = config.offload_policy 

96 self.reduce_grads = True 

97 # Reshard parameter after backward 

98 self.reshard_after_backward = True 

99 # Requires AllReduce for grad When HSDP 

100 self.requires_all_reduce = True 

101 # Keep historical AVG behavior for local parameters while DTensor-aware 

102 # paths default to SUM semantics without extra division. 

103 self.reduce_op_type = ops.ReduceOp.SUM 

104 self._need_div = not any( 

105 getattr(param, "param_mode", FullyShardParamMode.LOCAL_PARAM) 

106 != FullyShardParamMode.LOCAL_PARAM 

107 for param in self._iter_managed_params() 

108 ) 

109 self._ignored_allreduce_works = [] 

110 self._reset_sharded_params = False 

111 self._init_param_group() 

112 

113 def _iter_managed_params(self): 

114 """Return all fully_shard-managed parameters, including replicate_params.""" 

115 return [*self.hsdp_params, *self.replicate_params] 

116 

117 @staticmethod 

118 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]: 

119 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion.""" 

120 if not hsdp_param.enable_fsdp_shard: 

121 return "non-sharded parameters such as replicate_params are not supported" 

122 if hsdp_param.param_mode not in ( 

123 FullyShardParamMode.LOCAL_PARAM, 

124 FullyShardParamMode.DTENSOR_UNIFIED, 

125 ): 

126 return f"param_mode {hsdp_param.param_mode} is not supported" 

127 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None) 

128 if local_shard is None: 

129 return "missing local shard tensor for comm_fusion plan validation" 

130 plan_world_size = getattr(hsdp_param, "shard_world_size", None) 

131 if plan_world_size is None: 

132 plan_world_size = getattr(hsdp_param, "shard_size", 1) 

133 try: 

134 build_rs_plan(hsdp_param, local_shard, plan_world_size) 

135 except NotImplementedError as exc: 

136 return str(exc) 

137 except (AssertionError, ValueError) as exc: 

138 return f"cannot build comm_fusion pack plan: {exc}" 

139 return None 

140 

141 def _init_param_group(self): 

142 """Initialize fused parameter group when comm_fusion is enabled.""" 

143 if self.config.comm_fusion: 

144 unsupported_param = next( 

145 ( 

146 hsdp_param 

147 for hsdp_param in self.hsdp_params 

148 if self._comm_fusion_unsupported_reason(hsdp_param) is not None 

149 ), 

150 None, 

151 ) 

152 if unsupported_param is not None: 

153 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>") 

154 reason = self._comm_fusion_unsupported_reason(unsupported_param) 

155 raise NotImplementedError( 

156 f"comm_fusion does not support parameter {param_fqn}: {reason}." 

157 ) 

158 self.param_group = None 

159 if self.hsdp_params: 

160 self.param_group = HSDPParamGroup( 

161 self.hsdp_params, 

162 self.mesh_info, 

163 self.device, 

164 self.mp_policy, 

165 self.config.comm_fusion_zero_copy, 

166 ) 

167 

168 def zero_grad(self): 

169 """zero grad""" 

170 for hsdp_param in self.hsdp_params: 

171 hsdp_param.zero_grad() 

172 for hsdp_param in self.replicate_params: 

173 hsdp_param.zero_grad() 

174 

175 @staticmethod 

176 def _div_if_needed(x, divisor, need_div: bool): 

177 """Apply gradient averaging only when the caller-provided policy requires it. 

178 

179 ``need_div`` may come from the current state or from metadata captured when 

180 async reduce work was queued, so this helper is safe for both immediate and 

181 deferred gradient materialization paths. 

182 """ 

183 if not need_div: 

184 return 

185 if divisor == 1: 

186 return 

187 x.div_(divisor) 

188 

189 def _move_states_to_device(self): 

190 """move states to device""" 

191 for mod in self.modules: 

192 for param in mod.get_parameters(): 

193 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

194 continue 

195 param_device = normalize_runtime_device(param.device) 

196 if param_device in (self.device, "meta"): 

197 continue 

198 param.data = param.to(self.device) 

199 for buffer in mod.buffers(): 

200 if buffer.device in (self.device, "meta"): 

201 continue 

202 buffer.data = buffer.to(self.device) 

203 

204 def _init_hsdp_params(self): 

205 """init hsdp parameters for cell and replicate parameters for cell.""" 

206 # all parameters in the module tree(s), deduplicated 

207 visited_params = set() 

208 replicate_params = set(self.config.replicate_params or ()) 

209 ignored_params = set(self.config.ignored_params or ()) 

210 filtered_params = [] 

211 for mod in self.modules: 

212 for _, param in mod.parameters_and_names(): 

213 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

214 continue 

215 if param in ignored_params: 

216 continue 

217 if param in visited_params: 

218 continue 

219 visited_params.add(param) 

220 filtered_params.append(param) 

221 

222 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules)) 

223 for param, module_info in zip(filtered_params, module_infos): 

224 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param]) 

225 enable_fsdp_shard = param not in replicate_params 

226 hsdp_param = MindSporeHSDPParamV2( 

227 param, 

228 module_info, 

229 self.mesh_info, 

230 shard_placement_fn=self.config.shard_placement_fn, 

231 mp_policy=self.mp_policy, 

232 offload_policy=self.offload_policy, 

233 device=self.device, 

234 param_mode=param_mode, 

235 enable_fsdp_shard=enable_fsdp_shard, 

236 ) 

237 if param in replicate_params: 

238 self.replicate_params.append(hsdp_param) 

239 else: 

240 self.hsdp_params.append(hsdp_param) 

241 if hsdp_param.is_sharded: 

242 self.sharded_hsdp_params.append(hsdp_param) 

243 

244 def _init_mp_dtypes(self): 

245 """init mp dtypes for hsdp parameters and replicate parameters""" 

246 for hsdp_param in self.hsdp_params: 

247 hsdp_param.init_dtype_attrs(self.mp_policy) 

248 for replicate_param in self.replicate_params: 

249 replicate_param.init_dtype_attrs(self.mp_policy) 

250 trainable_params: list[MindSporeHSDPParamV2] = [ 

251 p for p in self._iter_managed_params() if p.sharded_param.requires_grad 

252 ] 

253 orig_dtypes = {p.orig_dtype for p in trainable_params} 

254 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

255 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

256 raise AssertionError( 

257 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

258 ) 

259 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

260 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

261 raise AssertionError( 

262 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

263 ) 

264 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

265 

266 def lazy_init(self): 

267 """Refresh parameter views and validate runtime state before first execution.""" 

268 if self.is_shard and not self._reset_sharded_params: 

269 for hsdp_param in self.hsdp_params: 

270 if hsdp_param.is_sharded: 

271 hsdp_param.reset_sharded_param() 

272 self._reset_sharded_params = True 

273 self._validate_no_meta_params() 

274 self._validate_cpu_offload_params() 

275 self._init_mp_dtypes() 

276 

277 def _validate_cpu_offload_params(self): 

278 """Validate that all parameters are on CPU when CPU offload policy is enabled.""" 

279 if not isinstance(self.offload_policy, CPUOffloadPolicy): 

280 return 

281 hsdp_params_not_on_cpu = [ 

282 hsdp_param 

283 for hsdp_param in self._iter_managed_params() 

284 if not str(hsdp_param.sharded_param.device).lower().startswith("cpu") 

285 ] 

286 if hsdp_params_not_on_cpu: 

287 raise RuntimeError( 

288 "HSDP parameters should be materialized on CPU when enabling CPU offloading. " 

289 "For example, load a CPU state dict before training. " 

290 "Found following parameters on non-CPU device: " 

291 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n" 

292 ) 

293 

294 def _validate_no_meta_params(self): 

295 """Validate that all parameters have been materialized from meta device.""" 

296 param_names_on_meta = [ 

297 hsdp_param._param_fqn 

298 for hsdp_param in self._iter_managed_params() 

299 if hsdp_param.sharded_param.device == "meta" 

300 ] 

301 if param_names_on_meta: 

302 raise RuntimeError( 

303 "HSDP parameters should be materialized from meta device before training, " 

304 f"but the following were still on meta device: {param_names_on_meta}\n" 

305 "For example, initialize the module weights on a real device before running training." 

306 ) 

307 

308 def _allreduce_replicate_params(self, async_op=True) -> None: 

309 """ 

310 DDP-style all-reduce for parameters in config.replicate_params. 

311 

312 Use the parameter's layout-driven unsharded group so DTensor-aware 

313 compatibility and unified modes reduce over the correct axes. 

314 """ 

315 for param in self.replicate_params: 

316 if not hasattr(param, "_unsharded_param") or param.unsharded_param is None: 

317 continue 

318 if ( 

319 param.unsharded_accumulated_grad is None 

320 and param.unsharded_param.grad is None 

321 ): 

322 continue 

323 

324 reduced_grad = param.unsharded_accumulated_grad_data 

325 if reduced_grad is None: 

326 reduced_grad = param.unsharded_grad_data 

327 reduced_grad = _to_dtype_if_needed(reduced_grad, self._reduce_dtype) 

328 reduce_group_info = getattr(param, "unsharded_group_info", None) 

329 reduce_group = reduce_group_info.group if reduce_group_info is not None else None 

330 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1 

331 

332 if reduce_group is not None and reduce_group_size > 1: 

333 # Ascend HCCL DistCommAllReduce rejects non-contiguous tensors; 

334 # reduced_grad here may still be a view from the no-reduce path 

335 # of ``unsharded_grad_data`` / ``_to_local_unsharded_grad``. 

336 # ``Tensor.contiguous()`` is a no-op when storage is already 

337 # contiguous, so the unconditional call is safe. 

338 reduced_grad = reduced_grad.contiguous() 

339 param.all_reduce_handle = dist.all_reduce( 

340 reduced_grad, group=reduce_group, op=self.reduce_op_type, async_op=async_op 

341 ) 

342 self._ignored_allreduce_works.append((param, reduced_grad, reduce_group_size)) 

343 

344 def _finish_ignored_allreduce(self) -> None: 

345 """ 

346 Wait for async all-reduce of replicate_params and materialize param.grad. 

347 

348 For each pending work, this: 

349 Waits on all associated handles to complete; 

350 Casts reduced_grad back to _orig_dtype if needed; 

351 Assigns the final tensor to param.grad. 

352 """ 

353 if not self._ignored_allreduce_works: 

354 return 

355 

356 need_synchronize = False 

357 for param, reduced_grad, reduce_group_size in self._ignored_allreduce_works: 

358 if param.all_reduce_handle: 

359 param.all_reduce_handle.wait() 

360 self._div_if_needed(reduced_grad, reduce_group_size, self._need_div) 

361 need_synchronize = ( 

362 param.apply_reduced_grad(reduced_grad, self._orig_dtype) 

363 or need_synchronize 

364 ) 

365 

366 self._synchronize_current_stream_if_needed(need_synchronize) 

367 self._ignored_allreduce_works.clear() 

368 

369 def reduce_params(self): 

370 """Drain pending sharded parameter reductions and materialize sharded grads.""" 

371 need_synchronize = False 

372 while HSDPState.pre_reduce_scatter_params: 

373 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_reduce_scatter_params.pop(0) 

374 reduced_grad = hsdp_param.reduce_scatter_output() 

375 self._div_if_needed(reduced_grad, hsdp_param.shard_world_size, need_div) 

376 hsdp_param.clear_reduce_scatter_output() 

377 need_synchronize = ( 

378 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) 

379 or need_synchronize 

380 ) 

381 

382 while HSDPState.pre_all_reduce_params: 

383 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_all_reduce_params.pop(0) 

384 reduced_grad = hsdp_param.all_reduce_output() 

385 self._div_if_needed(reduced_grad, hsdp_param.replicate_world_size, need_div) 

386 hsdp_param.clear_all_reduce_output() 

387 need_synchronize = ( 

388 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) 

389 or need_synchronize 

390 ) 

391 while MindSporeHSDPStateV2.pre_direct_all_reduce_grads: 

392 handle, reduced_grad, target_grad, reduce_group_size, need_div = ( 

393 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.pop(0) 

394 ) 

395 if handle is not None: 

396 handle.wait() 

397 self._div_if_needed(reduced_grad, reduce_group_size, need_div) 

398 if reduced_grad is not target_grad: 

399 if reduced_grad.dtype != target_grad.dtype: 

400 reduced_grad = reduced_grad.to(target_grad.dtype) 

401 copy_without_bumping_version(target_grad, reduced_grad) 

402 self._synchronize_current_stream_if_needed(need_synchronize) 

403 

404 def post_backward_for_comm_fusion(self): 

405 """Drive the fused gradient-reduction pipeline for sharded params.""" 

406 self.reduce_params() 

407 comm_ctx = get_comm_ctx() 

408 if comm_ctx.all_reduce_param_group is not None: 

409 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad() 

410 comm_ctx.all_reduce_param_group = None 

411 if comm_ctx.pre_param_group is not None: 

412 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce() 

413 comm_ctx.pre_param_group = None 

414 if self.param_group is not None: 

415 self.param_group.foreach_reduce( 

416 reduce_scatter_reduce_op=self.reduce_op_type, 

417 needs_avg_div=self._need_div, 

418 ) 

419 self._allreduce_replicate_params() 

420 

421 def _post_backward_without_reduce(self): 

422 """Finish backward when gradient communication is disabled.""" 

423 if self.reshard_after_backward: 

424 self.shard() 

425 for hsdp_param in self._iter_managed_params(): 

426 hsdp_param.to_accumulated_grad_if_needed() 

427 

428 def _should_run_all_reduce(self, hsdp_param) -> bool: 

429 """Whether the current parameter should issue an all-reduce in this backward pass.""" 

430 return self.requires_all_reduce and hsdp_param.dp_size > 1 

431 

432 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param): 

433 """Queue the standard FSDP/HSDP reduction path.""" 

434 hsdp_param.reduce_scatter_grad( 

435 async_op=True, 

436 dtype=self._reduce_dtype, 

437 reduce_op=self.reduce_op_type 

438 ) 

439 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype, self._need_div)) 

440 if not self._should_run_all_reduce(hsdp_param): 

441 return 

442 reduced_grad = hsdp_param.reduce_scatter_output() 

443 if ( 

444 HSDPState.pre_reduce_scatter_params 

445 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param 

446 ): 

447 HSDPState.pre_reduce_scatter_params.pop() 

448 hsdp_param.clear_reduce_scatter_output() 

449 self._div_if_needed(reduced_grad, hsdp_param.shard_size, self._need_div) 

450 hsdp_param.all_reduce_grad( 

451 grad=reduced_grad, 

452 dtype=self._reduce_dtype, 

453 async_op=True, 

454 reduce_op=self.reduce_op_type, 

455 ) 

456 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div)) 

457 

458 def _queue_compat_all_reduce(self, hsdp_param): 

459 """Queue the compatibility all-reduce path without FSDP sharding.""" 

460 if not self._should_run_all_reduce(hsdp_param): 

461 return 

462 hsdp_param.all_reduce_grad( 

463 grad=self._get_pending_unsharded_grad(hsdp_param), 

464 dtype=self._reduce_dtype, 

465 async_op=True, 

466 reduce_op=self.reduce_op_type, 

467 ) 

468 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div)) 

469 

470 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool: 

471 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly.""" 

472 return ( 

473 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

474 and hsdp_param.enable_fsdp_shard 

475 and not hsdp_param.is_sharded 

476 and hsdp_param.shard_size == 1 

477 and hsdp_param.sharded_param.requires_grad 

478 and self._should_run_all_reduce(hsdp_param) 

479 and self._get_local_sharded_grad(hsdp_param) is not None 

480 ) 

481 

482 def _queue_direct_compat_all_reduce(self, hsdp_param): 

483 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``.""" 

484 grad = self._get_local_sharded_grad(hsdp_param) 

485 if grad is None: 

486 return 

487 reduced_grad = _to_dtype_if_needed(grad, self._reduce_dtype) 

488 reduce_group_info = getattr(hsdp_param, "unsharded_group_info", None) 

489 reduce_group = reduce_group_info.group if reduce_group_info is not None else None 

490 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1 

491 handle = None 

492 if reduce_group_size > 1: 

493 if reduce_group is None: 

494 raise RuntimeError("Expected a valid unsharded all-reduce group when rank_size > 1") 

495 handle = dist.all_reduce( 

496 reduced_grad, 

497 group=reduce_group, 

498 op=self.reduce_op_type, 

499 async_op=True, 

500 ) 

501 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.append( 

502 (handle, reduced_grad, grad, reduce_group_size, self._need_div) 

503 ) 

504 

505 def post_backward(self, *_): 

506 for hsdp_param in self._iter_managed_params(): 

507 hsdp_param.accumulate_unsharded_grad_if_needed() 

508 if not self.reduce_grads: 

509 self._post_backward_without_reduce() 

510 return 

511 if not self.comm_fusion: 

512 self.reduce_params() 

513 self._allreduce_replicate_params() 

514 for hsdp_param in self.hsdp_params: 

515 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

516 if self._can_direct_all_reduce_compat_grad(hsdp_param): 

517 self._queue_direct_compat_all_reduce(hsdp_param) 

518 continue 

519 if not hsdp_param.sharded_param.requires_grad: 

520 continue 

521 if not self._has_pending_unsharded_grad(hsdp_param): 

522 continue 

523 if hsdp_param.shard_size > 1: 

524 self._queue_reduce_scatter_then_all_reduce(hsdp_param) 

525 elif self._should_run_all_reduce(hsdp_param): 

526 self._queue_compat_all_reduce(hsdp_param) 

527 else: 

528 need_synchronize = hsdp_param.apply_reduced_grad( 

529 self._get_pending_unsharded_grad(hsdp_param), 

530 self._orig_dtype, 

531 ) 

532 self._synchronize_current_stream_if_needed(need_synchronize) 

533 self._finish_ignored_allreduce() 

534 else: 

535 self.post_backward_for_comm_fusion() 

536 if self.reshard_after_backward: 

537 self.shard() 

538 

539 def set_requires_grad_sync(self, requires_grad_sync): 

540 """set requires grad sync flag to control gradient sync.""" 

541 self.reduce_grads = requires_grad_sync 

542 

543 def set_reduce_op_type(self, reduce_op_type: str): 

544 """set reduce op type for gradient reduction.""" 

545 fsdp_support_reduce_op = { 

546 "sum": ops.ReduceOp.SUM, 

547 "avg": ops.ReduceOp.SUM, 

548 } 

549 if reduce_op_type not in fsdp_support_reduce_op: 

550 raise ValueError( 

551 f"Unsupported reduce op type {reduce_op_type}, " 

552 f"supported types are {list(fsdp_support_reduce_op.keys())}") 

553 self._need_div = reduce_op_type == "avg" 

554 reduce_op: str = reduce_op_type.lower().strip() 

555 self.reduce_op_type = fsdp_support_reduce_op[reduce_op]