Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / state.py: 61%
312 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP cell state"""
16from typing import Optional
17import mindspore as ms
18from mindspore import ops
19import mindspore.mint.distributed as dist
20from hyper_parallel.core.fully_shard.hsdp_state import HSDPState
21from hyper_parallel.core.fully_shard.hsdp_utils import (
22 _get_param_module_infos,
23 FullyShardParamMode,
24 infer_fully_shard_param_mode,
25)
26from hyper_parallel.platform.mindspore.fully_shard.pack_utils import build_rs_plan
27from hyper_parallel.platform.mindspore.fully_shard.param import MindSporeHSDPParamV2
28from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version
29from hyper_parallel.platform.mindspore.fully_shard.param_group import HSDPParamGroup, get_comm_ctx
30from hyper_parallel.platform.mindspore.utils import normalize_runtime_device
31from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy
34def _to_dtype_if_needed(
35 tensor: ms.Tensor, dtype: Optional[ms.Type]
36) -> ms.Tensor:
37 """Cast tensor to the given dtype if it differs from current dtype.
39 Args:
40 tensor: The input tensor to potentially cast.
41 dtype: Target dtype. If None or same as tensor dtype, no-op.
42 """
43 if dtype is not None and tensor.dtype != dtype:
44 return tensor.to(dtype)
45 return tensor
48class MindSporeHSDPStateV2(HSDPState):
49 """MindSpore HSDP cell state"""
50 # DTensor compat parameters in pure-TP mode can accumulate gradients
51 # directly on ``sharded_param.grad`` without materializing an
52 # ``_unsharded_param``. Track those async all-reduces separately from the
53 # standard unsharded-gradient queues.
54 pre_direct_all_reduce_grads = []
56 @staticmethod
57 def _get_pending_unsharded_grad(hsdp_param):
58 """Return the pending unsharded gradient tensor for reduction paths."""
59 if hsdp_param.unsharded_accumulated_grad is not None:
60 return hsdp_param.unsharded_accumulated_grad_data
61 return hsdp_param.unsharded_grad_data
63 @staticmethod
64 def _has_pending_unsharded_grad(hsdp_param):
65 """Whether the parameter currently has a gradient waiting for reduction."""
66 if hsdp_param.unsharded_accumulated_grad is not None:
67 return True
68 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
69 return False
70 return hsdp_param.unsharded_param.grad is not None
72 @staticmethod
73 def _get_local_sharded_grad(hsdp_param):
74 """Return the local gradient tensor currently stored on ``sharded_param``."""
75 grad = hsdp_param.sharded_param.grad
76 if grad is None:
77 return None
78 to_local = getattr(grad, "to_local", None)
79 if callable(to_local):
80 return to_local()
81 return grad
83 @staticmethod
84 def _synchronize_current_stream_if_needed(need_synchronize: bool) -> None:
85 """Synchronize the current device stream after non-blocking CPU offload."""
86 if not need_synchronize:
87 return
88 ms.runtime.current_stream().synchronize()
90 def __init__(self, cell, mesh_info, config, platform, device=None):
91 super().__init__(cell, mesh_info, config, platform, device)
92 self.comm_fusion = config.comm_fusion
93 # Do ReduceScatter/AllReduce for grad
94 self.mp_policy = config.mp_policy
95 self.offload_policy = config.offload_policy
96 self.reduce_grads = True
97 # Reshard parameter after backward
98 self.reshard_after_backward = True
99 # Requires AllReduce for grad When HSDP
100 self.requires_all_reduce = True
101 # Keep historical AVG behavior for local parameters while DTensor-aware
102 # paths default to SUM semantics without extra division.
103 self.reduce_op_type = ops.ReduceOp.SUM
104 self._need_div = not any(
105 getattr(param, "param_mode", FullyShardParamMode.LOCAL_PARAM)
106 != FullyShardParamMode.LOCAL_PARAM
107 for param in self._iter_managed_params()
108 )
109 self._ignored_allreduce_works = []
110 self._reset_sharded_params = False
111 self._init_param_group()
113 def _iter_managed_params(self):
114 """Return all fully_shard-managed parameters, including replicate_params."""
115 return [*self.hsdp_params, *self.replicate_params]
117 @staticmethod
118 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]:
119 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion."""
120 if not hsdp_param.enable_fsdp_shard:
121 return "non-sharded parameters such as replicate_params are not supported"
122 if hsdp_param.param_mode not in (
123 FullyShardParamMode.LOCAL_PARAM,
124 FullyShardParamMode.DTENSOR_UNIFIED,
125 ):
126 return f"param_mode {hsdp_param.param_mode} is not supported"
127 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None)
128 if local_shard is None:
129 return "missing local shard tensor for comm_fusion plan validation"
130 plan_world_size = getattr(hsdp_param, "shard_world_size", None)
131 if plan_world_size is None:
132 plan_world_size = getattr(hsdp_param, "shard_size", 1)
133 try:
134 build_rs_plan(hsdp_param, local_shard, plan_world_size)
135 except NotImplementedError as exc:
136 return str(exc)
137 except (AssertionError, ValueError) as exc:
138 return f"cannot build comm_fusion pack plan: {exc}"
139 return None
141 def _init_param_group(self):
142 """Initialize fused parameter group when comm_fusion is enabled."""
143 if self.config.comm_fusion:
144 unsupported_param = next(
145 (
146 hsdp_param
147 for hsdp_param in self.hsdp_params
148 if self._comm_fusion_unsupported_reason(hsdp_param) is not None
149 ),
150 None,
151 )
152 if unsupported_param is not None:
153 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>")
154 reason = self._comm_fusion_unsupported_reason(unsupported_param)
155 raise NotImplementedError(
156 f"comm_fusion does not support parameter {param_fqn}: {reason}."
157 )
158 self.param_group = None
159 if self.hsdp_params:
160 self.param_group = HSDPParamGroup(
161 self.hsdp_params,
162 self.mesh_info,
163 self.device,
164 self.mp_policy,
165 self.config.comm_fusion_zero_copy,
166 )
168 def zero_grad(self):
169 """zero grad"""
170 for hsdp_param in self.hsdp_params:
171 hsdp_param.zero_grad()
172 for hsdp_param in self.replicate_params:
173 hsdp_param.zero_grad()
175 @staticmethod
176 def _div_if_needed(x, divisor, need_div: bool):
177 """Apply gradient averaging only when the caller-provided policy requires it.
179 ``need_div`` may come from the current state or from metadata captured when
180 async reduce work was queued, so this helper is safe for both immediate and
181 deferred gradient materialization paths.
182 """
183 if not need_div:
184 return
185 if divisor == 1:
186 return
187 x.div_(divisor)
189 def _move_states_to_device(self):
190 """move states to device"""
191 for mod in self.modules:
192 for param in mod.get_parameters():
193 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
194 continue
195 param_device = normalize_runtime_device(param.device)
196 if param_device in (self.device, "meta"):
197 continue
198 param.data = param.to(self.device)
199 for buffer in mod.buffers():
200 if buffer.device in (self.device, "meta"):
201 continue
202 buffer.data = buffer.to(self.device)
204 def _init_hsdp_params(self):
205 """init hsdp parameters for cell and replicate parameters for cell."""
206 # all parameters in the module tree(s), deduplicated
207 visited_params = set()
208 replicate_params = set(self.config.replicate_params or ())
209 ignored_params = set(self.config.ignored_params or ())
210 filtered_params = []
211 for mod in self.modules:
212 for _, param in mod.parameters_and_names():
213 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
214 continue
215 if param in ignored_params:
216 continue
217 if param in visited_params:
218 continue
219 visited_params.add(param)
220 filtered_params.append(param)
222 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules))
223 for param, module_info in zip(filtered_params, module_infos):
224 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param])
225 enable_fsdp_shard = param not in replicate_params
226 hsdp_param = MindSporeHSDPParamV2(
227 param,
228 module_info,
229 self.mesh_info,
230 shard_placement_fn=self.config.shard_placement_fn,
231 mp_policy=self.mp_policy,
232 offload_policy=self.offload_policy,
233 device=self.device,
234 param_mode=param_mode,
235 enable_fsdp_shard=enable_fsdp_shard,
236 )
237 if param in replicate_params:
238 self.replicate_params.append(hsdp_param)
239 else:
240 self.hsdp_params.append(hsdp_param)
241 if hsdp_param.is_sharded:
242 self.sharded_hsdp_params.append(hsdp_param)
244 def _init_mp_dtypes(self):
245 """init mp dtypes for hsdp parameters and replicate parameters"""
246 for hsdp_param in self.hsdp_params:
247 hsdp_param.init_dtype_attrs(self.mp_policy)
248 for replicate_param in self.replicate_params:
249 replicate_param.init_dtype_attrs(self.mp_policy)
250 trainable_params: list[MindSporeHSDPParamV2] = [
251 p for p in self._iter_managed_params() if p.sharded_param.requires_grad
252 ]
253 orig_dtypes = {p.orig_dtype for p in trainable_params}
254 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
255 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
256 raise AssertionError(
257 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
258 )
259 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
260 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
261 raise AssertionError(
262 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
263 )
264 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
266 def lazy_init(self):
267 """Refresh parameter views and validate runtime state before first execution."""
268 if self.is_shard and not self._reset_sharded_params:
269 for hsdp_param in self.hsdp_params:
270 if hsdp_param.is_sharded:
271 hsdp_param.reset_sharded_param()
272 self._reset_sharded_params = True
273 self._validate_no_meta_params()
274 self._validate_cpu_offload_params()
275 self._init_mp_dtypes()
277 def _validate_cpu_offload_params(self):
278 """Validate that all parameters are on CPU when CPU offload policy is enabled."""
279 if not isinstance(self.offload_policy, CPUOffloadPolicy):
280 return
281 hsdp_params_not_on_cpu = [
282 hsdp_param
283 for hsdp_param in self._iter_managed_params()
284 if not str(hsdp_param.sharded_param.device).lower().startswith("cpu")
285 ]
286 if hsdp_params_not_on_cpu:
287 raise RuntimeError(
288 "HSDP parameters should be materialized on CPU when enabling CPU offloading. "
289 "For example, load a CPU state dict before training. "
290 "Found following parameters on non-CPU device: "
291 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n"
292 )
294 def _validate_no_meta_params(self):
295 """Validate that all parameters have been materialized from meta device."""
296 param_names_on_meta = [
297 hsdp_param._param_fqn
298 for hsdp_param in self._iter_managed_params()
299 if hsdp_param.sharded_param.device == "meta"
300 ]
301 if param_names_on_meta:
302 raise RuntimeError(
303 "HSDP parameters should be materialized from meta device before training, "
304 f"but the following were still on meta device: {param_names_on_meta}\n"
305 "For example, initialize the module weights on a real device before running training."
306 )
308 def _allreduce_replicate_params(self, async_op=True) -> None:
309 """
310 DDP-style all-reduce for parameters in config.replicate_params.
312 Use the parameter's layout-driven unsharded group so DTensor-aware
313 compatibility and unified modes reduce over the correct axes.
314 """
315 for param in self.replicate_params:
316 if not hasattr(param, "_unsharded_param") or param.unsharded_param is None:
317 continue
318 if (
319 param.unsharded_accumulated_grad is None
320 and param.unsharded_param.grad is None
321 ):
322 continue
324 reduced_grad = param.unsharded_accumulated_grad_data
325 if reduced_grad is None:
326 reduced_grad = param.unsharded_grad_data
327 reduced_grad = _to_dtype_if_needed(reduced_grad, self._reduce_dtype)
328 reduce_group_info = getattr(param, "unsharded_group_info", None)
329 reduce_group = reduce_group_info.group if reduce_group_info is not None else None
330 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1
332 if reduce_group is not None and reduce_group_size > 1:
333 # Ascend HCCL DistCommAllReduce rejects non-contiguous tensors;
334 # reduced_grad here may still be a view from the no-reduce path
335 # of ``unsharded_grad_data`` / ``_to_local_unsharded_grad``.
336 # ``Tensor.contiguous()`` is a no-op when storage is already
337 # contiguous, so the unconditional call is safe.
338 reduced_grad = reduced_grad.contiguous()
339 param.all_reduce_handle = dist.all_reduce(
340 reduced_grad, group=reduce_group, op=self.reduce_op_type, async_op=async_op
341 )
342 self._ignored_allreduce_works.append((param, reduced_grad, reduce_group_size))
344 def _finish_ignored_allreduce(self) -> None:
345 """
346 Wait for async all-reduce of replicate_params and materialize param.grad.
348 For each pending work, this:
349 Waits on all associated handles to complete;
350 Casts reduced_grad back to _orig_dtype if needed;
351 Assigns the final tensor to param.grad.
352 """
353 if not self._ignored_allreduce_works:
354 return
356 need_synchronize = False
357 for param, reduced_grad, reduce_group_size in self._ignored_allreduce_works:
358 if param.all_reduce_handle:
359 param.all_reduce_handle.wait()
360 self._div_if_needed(reduced_grad, reduce_group_size, self._need_div)
361 need_synchronize = (
362 param.apply_reduced_grad(reduced_grad, self._orig_dtype)
363 or need_synchronize
364 )
366 self._synchronize_current_stream_if_needed(need_synchronize)
367 self._ignored_allreduce_works.clear()
369 def reduce_params(self):
370 """Drain pending sharded parameter reductions and materialize sharded grads."""
371 need_synchronize = False
372 while HSDPState.pre_reduce_scatter_params:
373 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_reduce_scatter_params.pop(0)
374 reduced_grad = hsdp_param.reduce_scatter_output()
375 self._div_if_needed(reduced_grad, hsdp_param.shard_world_size, need_div)
376 hsdp_param.clear_reduce_scatter_output()
377 need_synchronize = (
378 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype)
379 or need_synchronize
380 )
382 while HSDPState.pre_all_reduce_params:
383 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_all_reduce_params.pop(0)
384 reduced_grad = hsdp_param.all_reduce_output()
385 self._div_if_needed(reduced_grad, hsdp_param.replicate_world_size, need_div)
386 hsdp_param.clear_all_reduce_output()
387 need_synchronize = (
388 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype)
389 or need_synchronize
390 )
391 while MindSporeHSDPStateV2.pre_direct_all_reduce_grads:
392 handle, reduced_grad, target_grad, reduce_group_size, need_div = (
393 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.pop(0)
394 )
395 if handle is not None:
396 handle.wait()
397 self._div_if_needed(reduced_grad, reduce_group_size, need_div)
398 if reduced_grad is not target_grad:
399 if reduced_grad.dtype != target_grad.dtype:
400 reduced_grad = reduced_grad.to(target_grad.dtype)
401 copy_without_bumping_version(target_grad, reduced_grad)
402 self._synchronize_current_stream_if_needed(need_synchronize)
404 def post_backward_for_comm_fusion(self):
405 """Drive the fused gradient-reduction pipeline for sharded params."""
406 self.reduce_params()
407 comm_ctx = get_comm_ctx()
408 if comm_ctx.all_reduce_param_group is not None:
409 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad()
410 comm_ctx.all_reduce_param_group = None
411 if comm_ctx.pre_param_group is not None:
412 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce()
413 comm_ctx.pre_param_group = None
414 if self.param_group is not None:
415 self.param_group.foreach_reduce(
416 reduce_scatter_reduce_op=self.reduce_op_type,
417 needs_avg_div=self._need_div,
418 )
419 self._allreduce_replicate_params()
421 def _post_backward_without_reduce(self):
422 """Finish backward when gradient communication is disabled."""
423 if self.reshard_after_backward:
424 self.shard()
425 for hsdp_param in self._iter_managed_params():
426 hsdp_param.to_accumulated_grad_if_needed()
428 def _should_run_all_reduce(self, hsdp_param) -> bool:
429 """Whether the current parameter should issue an all-reduce in this backward pass."""
430 return self.requires_all_reduce and hsdp_param.dp_size > 1
432 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param):
433 """Queue the standard FSDP/HSDP reduction path."""
434 hsdp_param.reduce_scatter_grad(
435 async_op=True,
436 dtype=self._reduce_dtype,
437 reduce_op=self.reduce_op_type
438 )
439 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype, self._need_div))
440 if not self._should_run_all_reduce(hsdp_param):
441 return
442 reduced_grad = hsdp_param.reduce_scatter_output()
443 if (
444 HSDPState.pre_reduce_scatter_params
445 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param
446 ):
447 HSDPState.pre_reduce_scatter_params.pop()
448 hsdp_param.clear_reduce_scatter_output()
449 self._div_if_needed(reduced_grad, hsdp_param.shard_size, self._need_div)
450 hsdp_param.all_reduce_grad(
451 grad=reduced_grad,
452 dtype=self._reduce_dtype,
453 async_op=True,
454 reduce_op=self.reduce_op_type,
455 )
456 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div))
458 def _queue_compat_all_reduce(self, hsdp_param):
459 """Queue the compatibility all-reduce path without FSDP sharding."""
460 if not self._should_run_all_reduce(hsdp_param):
461 return
462 hsdp_param.all_reduce_grad(
463 grad=self._get_pending_unsharded_grad(hsdp_param),
464 dtype=self._reduce_dtype,
465 async_op=True,
466 reduce_op=self.reduce_op_type,
467 )
468 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div))
470 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool:
471 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly."""
472 return (
473 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT
474 and hsdp_param.enable_fsdp_shard
475 and not hsdp_param.is_sharded
476 and hsdp_param.shard_size == 1
477 and hsdp_param.sharded_param.requires_grad
478 and self._should_run_all_reduce(hsdp_param)
479 and self._get_local_sharded_grad(hsdp_param) is not None
480 )
482 def _queue_direct_compat_all_reduce(self, hsdp_param):
483 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``."""
484 grad = self._get_local_sharded_grad(hsdp_param)
485 if grad is None:
486 return
487 reduced_grad = _to_dtype_if_needed(grad, self._reduce_dtype)
488 reduce_group_info = getattr(hsdp_param, "unsharded_group_info", None)
489 reduce_group = reduce_group_info.group if reduce_group_info is not None else None
490 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1
491 handle = None
492 if reduce_group_size > 1:
493 if reduce_group is None:
494 raise RuntimeError("Expected a valid unsharded all-reduce group when rank_size > 1")
495 handle = dist.all_reduce(
496 reduced_grad,
497 group=reduce_group,
498 op=self.reduce_op_type,
499 async_op=True,
500 )
501 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.append(
502 (handle, reduced_grad, grad, reduce_group_size, self._need_div)
503 )
505 def post_backward(self, *_):
506 for hsdp_param in self._iter_managed_params():
507 hsdp_param.accumulate_unsharded_grad_if_needed()
508 if not self.reduce_grads:
509 self._post_backward_without_reduce()
510 return
511 if not self.comm_fusion:
512 self.reduce_params()
513 self._allreduce_replicate_params()
514 for hsdp_param in self.hsdp_params:
515 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
516 if self._can_direct_all_reduce_compat_grad(hsdp_param):
517 self._queue_direct_compat_all_reduce(hsdp_param)
518 continue
519 if not hsdp_param.sharded_param.requires_grad:
520 continue
521 if not self._has_pending_unsharded_grad(hsdp_param):
522 continue
523 if hsdp_param.shard_size > 1:
524 self._queue_reduce_scatter_then_all_reduce(hsdp_param)
525 elif self._should_run_all_reduce(hsdp_param):
526 self._queue_compat_all_reduce(hsdp_param)
527 else:
528 need_synchronize = hsdp_param.apply_reduced_grad(
529 self._get_pending_unsharded_grad(hsdp_param),
530 self._orig_dtype,
531 )
532 self._synchronize_current_stream_if_needed(need_synchronize)
533 self._finish_ignored_allreduce()
534 else:
535 self.post_backward_for_comm_fusion()
536 if self.reshard_after_backward:
537 self.shard()
539 def set_requires_grad_sync(self, requires_grad_sync):
540 """set requires grad sync flag to control gradient sync."""
541 self.reduce_grads = requires_grad_sync
543 def set_reduce_op_type(self, reduce_op_type: str):
544 """set reduce op type for gradient reduction."""
545 fsdp_support_reduce_op = {
546 "sum": ops.ReduceOp.SUM,
547 "avg": ops.ReduceOp.SUM,
548 }
549 if reduce_op_type not in fsdp_support_reduce_op:
550 raise ValueError(
551 f"Unsupported reduce op type {reduce_op_type}, "
552 f"supported types are {list(fsdp_support_reduce_op.keys())}")
553 self._need_div = reduce_op_type == "avg"
554 reduce_op: str = reduce_op_type.lower().strip()
555 self.reduce_op_type = fsdp_support_reduce_op[reduce_op]