Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / state.py: 59%
352 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP cell state"""
16# pylint: disable=protected-access
18from typing import Optional, List
19from collections import defaultdict
20import torch
22from hyper_parallel.core.fully_shard.hsdp_state import HSDPState
23from hyper_parallel.core.fully_shard.hsdp_utils import (
24 FullyShardParamMode,
25 _get_param_module_infos,
26 infer_fully_shard_param_mode,
27)
28from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy
29from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2
30from hyper_parallel.platform.torch.fully_shard.pack_utils import build_rs_plan
31from hyper_parallel.platform.torch.fully_shard.param_group import get_comm_ctx, HSDPParamGroup, AllReduceParamGroup
34def _to_dtype_if_needed(
35 tensor: torch.Tensor, dtype: Optional[torch.dtype]
36) -> torch.Tensor:
37 """Cast tensor to the given dtype if it differs from current dtype.
39 Args:
40 tensor: The input tensor to potentially cast.
41 dtype: Target dtype. If None or same as tensor dtype, no-op.
42 """
43 if dtype is not None and tensor.dtype != dtype:
44 return tensor.to(dtype)
45 return tensor
48class TorchHSDPStateV2(HSDPState):
49 """Torch HSDP cell state"""
50 # DTensor compat parameters in pure-TP mode can accumulate gradients
51 # directly on ``sharded_param.grad`` without ever materializing an
52 # ``_unsharded_param``. Track their async all-reduce work separately from
53 # the standard unsharded-grad queues.
54 pre_direct_all_reduce_grads = []
55 # Record AllReduceParamGroup that has reduce_scatter issued, waiting for next post_backward to process
56 pre_all_reduce_groups: List[AllReduceParamGroup] = []
57 # Record AllReduceParamGroup that has all_reduce issued, waiting for root_backward_hook to apply
58 pending_all_reduce_groups: List[AllReduceParamGroup] = []
59 @staticmethod
60 def _get_pending_unsharded_grad(hsdp_param):
61 """Return the pending unsharded gradient tensor for all-reduce-based paths."""
62 if hsdp_param.unsharded_accumulated_grad is not None:
63 return hsdp_param.unsharded_accumulated_grad_data
64 return hsdp_param.unsharded_grad_data
66 @staticmethod
67 def _has_pending_unsharded_grad(hsdp_param):
68 """Whether the parameter currently has a gradient waiting for reduction."""
69 if hsdp_param.unsharded_accumulated_grad is not None:
70 return True
71 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
72 return False
73 return hsdp_param.unsharded_param.grad is not None
75 @staticmethod
76 def _get_local_sharded_grad(hsdp_param):
77 """Return the local gradient tensor currently stored on ``sharded_param``."""
78 grad = hsdp_param.sharded_param.grad
79 if grad is None:
80 return None
81 to_local = getattr(grad, "to_local", None)
82 if callable(to_local):
83 return to_local()
84 return grad
86 def __init__(self, cell, mesh_info, config, platform, device):
87 """
88 Initialize TorchHSDPStateV2.
90 Args:
91 cell (nn.Module): The module whose parameters are managed by this state.
92 mesh_info: Mesh topology for shard/replicate dimensions.
93 config (HSDPConfigV2): HSDP configuration.
94 platform (TorchPlatform): Torch platform abstraction.
95 device (torch.device): Target device.
96 """
97 super().__init__(cell, mesh_info, config, platform, device)
98 self.comm_fusion = config.comm_fusion
99 # Do ReduceScatter/AllReduce for grad
100 self.device = device
101 self.mp_policy = config.mp_policy
102 self.offload_policy = config.offload_policy
103 self.reduce_grads = True
104 # Reshard parameter after backward
105 self.reshard_after_backward = True
106 # Requires AllReduce for grad When HSDP
107 self.requires_all_reduce = True
108 # Default reduce op is decided at the fully_shard-state level:
109 # if any managed parameter is DTensor-backed, use SUM; otherwise AVG.
110 self._user_reduce_op_type = None
111 self.reduce_op_type = self._resolve_default_reduce_op()
112 self._reset_sharded_params = False
113 self._init_param_group()
115 @staticmethod
116 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]:
117 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion."""
118 if not hsdp_param.enable_fsdp_shard:
119 return "non-sharded parameters such as replicate_params are not supported"
120 if hsdp_param.param_mode not in (
121 FullyShardParamMode.LOCAL_PARAM,
122 FullyShardParamMode.DTENSOR_UNIFIED,
123 ):
124 return (
125 "param_mode "
126 f"{hsdp_param.param_mode} is not supported"
127 )
128 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None)
129 if local_shard is None:
130 return "missing local shard tensor for comm_fusion plan validation"
131 plan_world_size = getattr(hsdp_param, "shard_world_size", None)
132 if plan_world_size is None:
133 plan_world_size = getattr(hsdp_param, "shard_size", 1)
134 try:
135 build_rs_plan(hsdp_param, local_shard, plan_world_size)
136 except NotImplementedError as exc:
137 return str(exc)
138 except (AssertionError, ValueError) as exc:
139 return f"cannot build comm_fusion pack plan: {exc}"
140 return None
142 def _init_param_group(self):
143 """Initialize fused parameter group for communication fusion.
145 When ``comm_fusion`` is enabled, creates an ``HSDPParamGroup`` that packs all
146 parameters into a single buffer for fused all-gather and reduce-scatter,
147 replacing the per-parameter communication pattern.
148 """
149 if self.config.comm_fusion:
150 unsupported_param = next(
151 (
152 hsdp_param
153 for hsdp_param in self.hsdp_params
154 if self._comm_fusion_unsupported_reason(hsdp_param) is not None
155 ),
156 None,
157 )
158 if unsupported_param is not None:
159 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>")
160 reason = self._comm_fusion_unsupported_reason(unsupported_param)
161 raise NotImplementedError(
162 f"comm_fusion does not support parameter {param_fqn}: {reason}."
163 )
164 self.param_group = None
165 if self.hsdp_params:
166 # pylint: disable=E1128
167 self.param_group = HSDPParamGroup(
168 self.hsdp_params,
169 self.mesh_info,
170 self.device,
171 self.mp_policy,
172 self.config.comm_fusion_zero_copy,
173 )
175 def _move_states_to_device(self):
176 """move states to device"""
177 for mod in self.modules:
178 for param in mod.parameters():
179 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
180 continue
181 if param.device == self.device or param.device.type == "meta":
182 continue
183 param.data = param.to(self.device)
184 for buffer in mod.buffers():
185 if buffer.device == self.device or buffer.device.type == "meta":
186 continue
187 buffer.data = buffer.to(self.device)
189 def _init_hsdp_params(self):
190 """init hsdp parameters and replicate parameters for cell."""
191 replicate_params = set(self.config.replicate_params or ())
192 # all parameters in the module tree(s), deduplicated
193 ignored_params = set(self.config.ignored_params or ())
194 visited_params = set()
195 filtered_params = []
196 for mod in self.modules:
197 for _, param in mod.named_parameters():
198 if param in ignored_params:
199 continue
200 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
201 continue
202 if param in visited_params:
203 continue
204 visited_params.add(param)
205 filtered_params.append(param)
207 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules))
208 for param, module_info in zip(filtered_params, module_infos):
209 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param])
210 enable_fsdp_shard = param not in replicate_params
211 hsdp_param = TorchHSDPParamV2(param,
212 module_info,
213 self.mesh_info,
214 shard_placement_fn=self.config.shard_placement_fn,
215 mp_policy=self.mp_policy,
216 offload_policy=self.offload_policy,
217 device=self.device,
218 param_mode=param_mode,
219 enable_fsdp_shard=enable_fsdp_shard,
220 )
221 if param in replicate_params:
222 self.replicate_params.append(hsdp_param)
223 else:
224 self.hsdp_params.append(hsdp_param)
225 if hsdp_param.is_sharded:
226 self.sharded_hsdp_params.append(hsdp_param)
228 def _init_mp_dtypes(self):
229 """init mp dtypes for hsdp parameters and replicate parameters"""
230 for hsdp_param in self.hsdp_params:
231 hsdp_param.init_dtype_attrs(self.mp_policy)
232 for replicate_param in self.replicate_params:
233 replicate_param.init_dtype_attrs(self.mp_policy)
234 trainable_params: list[TorchHSDPParamV2] = [
235 p for p in self._iter_managed_params() if p.sharded_param.requires_grad
236 ]
237 orig_dtypes = {p.orig_dtype for p in trainable_params}
238 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
239 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
240 raise AssertionError(
241 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
242 )
243 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
244 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
245 raise AssertionError(
246 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
247 )
248 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
250 def _validate_cpu_offload_params(self):
251 """Validate that all parameters are on CPU when CPU offload policy is enabled."""
252 if not isinstance(self.offload_policy, CPUOffloadPolicy):
253 return
254 hsdp_params_not_on_cpu = [
255 hsdp_param
256 for hsdp_param in self._iter_managed_params()
257 if hsdp_param.sharded_param.device.type != "cpu"
258 ]
259 if hsdp_params_not_on_cpu:
260 raise RuntimeError(
261 "HSDP parameters should be materialized on CPU when enabling CPU offloading. "
262 'For example, load a CPU state dict or call module.to_empty(device="cpu"). '
263 "Found following parameters on non-CPU device: "
264 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n"
265 )
267 def lazy_init(self):
268 if self.is_shard and not self._reset_sharded_params:
269 for hsdp_param in self.hsdp_params:
270 if hsdp_param.is_sharded:
271 hsdp_param.reset_sharded_param()
272 self._reset_sharded_params = True
273 self._validate_no_meta_params()
274 self._validate_cpu_offload_params()
275 self._init_mp_dtypes()
277 def _validate_no_meta_params(self):
278 param_names_on_meta = [
279 hsdp_param._param_fqn
280 for hsdp_param in self._iter_managed_params()
281 if hsdp_param.sharded_param.device.type == "meta"
282 ]
283 if param_names_on_meta:
284 raise RuntimeError(
285 "HSDP parameters should be materialized from meta device before training, "
286 f"but the following were still on meta device: {param_names_on_meta}\n"
287 "For example, call module.to_empty(device) to materialize to device and "
288 "call module.reset_parameters() on each module to initialize values."
289 )
291 def post_backward_for_comm_fusion(self):
292 """post_backward_for_comm_fusion."""
293 # Replicate-only params still use the non-fused compat all-reduce path.
294 # Drain any pending side-path reductions before advancing the fused
295 # param-group pipeline for sharded params.
296 self.reduce_params()
297 # Fused gradient reduction path: first apply any pending async reduction
298 # from the previous module's backward (pipelined overlap), then issue
299 # this module's fused reduce-scatter (+ all-reduce for HSDP).
300 comm_ctx = get_comm_ctx()
301 # Phase 2: apply grads for the param group whose all_reduce is done
302 if comm_ctx.all_reduce_param_group is not None:
303 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad()
304 comm_ctx.all_reduce_param_group = None
305 # Phase 1: wait reduce_scatter, issue async all_reduce for previous layer
306 if comm_ctx.pre_param_group is not None:
307 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce()
308 comm_ctx.pre_param_group = None
309 if self.param_group is not None:
310 self.param_group.foreach_reduce(
311 reduce_scatter_reduce_op=self.reduce_op_type
312 )
313 for hsdp_param in self.replicate_params:
314 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
315 continue
316 if not hsdp_param.sharded_param.requires_grad:
317 continue
318 if not self._has_pending_unsharded_grad(hsdp_param):
319 continue
320 reduce_op = self._resolve_reduce_op(hsdp_param)
321 self._queue_compat_all_reduce(hsdp_param, reduce_op)
323 def _resolve_default_reduce_op(self):
324 """Resolve the default reduce op for the whole fully_shard state."""
325 for hsdp_param in self._iter_managed_params():
326 if hsdp_param.param_mode in (
327 FullyShardParamMode.DTENSOR_COMPAT,
328 FullyShardParamMode.DTENSOR_UNIFIED,
329 ):
330 return torch.distributed.ReduceOp.SUM
331 return torch.distributed.ReduceOp.AVG
333 def _resolve_reduce_op(self, hsdp_param=None):
334 """Resolve the gradient reduction op for the current fully_shard state."""
335 if self._user_reduce_op_type is not None:
336 return self._user_reduce_op_type
337 return self.reduce_op_type
339 def _should_run_all_reduce(self, hsdp_param) -> bool:
340 """Whether the current parameter should issue an all-reduce in this backward pass."""
341 return self.requires_all_reduce and hsdp_param.dp_size > 1
343 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param, reduce_op):
344 """Queue the standard FSDP/HSDP reduction path."""
345 hsdp_param.reduce_scatter_grad(
346 dtype=self._reduce_dtype,
347 reduce_op=reduce_op,
348 )
349 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype))
350 if not self._should_run_all_reduce(hsdp_param):
351 return
352 reduced_grad = hsdp_param.reduce_scatter_output()
353 if (
354 HSDPState.pre_reduce_scatter_params
355 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param
356 ):
357 HSDPState.pre_reduce_scatter_params.pop()
358 hsdp_param.all_reduce_grad(
359 grad=reduced_grad,
360 dtype=self._reduce_dtype,
361 reduce_op=reduce_op,
362 )
363 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype))
365 def _queue_compat_all_reduce(self, hsdp_param, reduce_op):
366 """Queue the compatibility all-reduce path without FSDP sharding."""
367 if not self._should_run_all_reduce(hsdp_param):
368 return
369 hsdp_param.all_reduce_grad(
370 grad=self._get_pending_unsharded_grad(hsdp_param),
371 dtype=self._reduce_dtype,
372 reduce_op=reduce_op,
373 )
374 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype))
376 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool:
377 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly."""
378 return (
379 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT
380 and hsdp_param.enable_fsdp_shard
381 and not hsdp_param.is_sharded
382 and hsdp_param.shard_size == 1
383 and hsdp_param.sharded_param.requires_grad
384 and self._should_run_all_reduce(hsdp_param)
385 and self._get_local_sharded_grad(hsdp_param) is not None
386 )
388 def _queue_direct_compat_all_reduce(self, hsdp_param, reduce_op):
389 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``."""
390 grad = self._get_local_sharded_grad(hsdp_param)
391 if grad is None:
392 return
393 reduced_grad = grad
394 if self._reduce_dtype is not None and reduced_grad.dtype != self._reduce_dtype:
395 reduced_grad = reduced_grad.to(self._reduce_dtype)
396 handle = None
397 if hsdp_param.unsharded_group_info.group is not None and hsdp_param.dp_size > 1:
398 handle = torch.distributed.all_reduce(
399 reduced_grad,
400 op=reduce_op,
401 group=hsdp_param.unsharded_group_info.group,
402 async_op=True,
403 )
404 TorchHSDPStateV2.pre_direct_all_reduce_grads.append((handle, reduced_grad, grad))
406 def post_backward(self, *unused): # pylint: disable=unused-argument
407 """Reduce gradients and reshard parameters after backward."""
408 for hsdp_param in self._iter_managed_params():
409 hsdp_param.accumulate_unsharded_grad_if_needed()
410 if not self.reduce_grads:
411 if self.reshard_after_backward:
412 self.shard()
413 for hsdp_param in self._iter_managed_params():
414 hsdp_param.to_accumulated_grad_if_needed()
415 return
416 if not self.comm_fusion:
417 # Handle user config replicate params and mirror params.
418 self.reduce_params()
419 for hsdp_param in self._iter_managed_params():
420 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
421 if self._can_direct_all_reduce_compat_grad(hsdp_param):
422 reduce_op = self._resolve_reduce_op(hsdp_param)
423 self._queue_direct_compat_all_reduce(hsdp_param, reduce_op)
425 # Step 1: wait prev reduce_scatter (for params needing allreduce)
426 prev_group = self._wait_prev_reduce_scatter()
428 # Step 2: wait and apply prev reduce_scatter (for params NOT needing allreduce)
429 self._wait_and_apply_prev_no_allreduce_params()
431 # Step 3: issue current reduce_scatter
432 self._issue_reduce_scatter_for_current_module()
434 # Step 4: issue prev fused allreduce (async) - using saved prev_group
435 self._issue_prev_fused_allreduce(prev_group)
436 else:
437 self.post_backward_for_comm_fusion()
438 if self.reshard_after_backward:
439 self.shard()
441 def _issue_reduce_scatter_for_current_module(self):
442 """Issue reduce_scatter for current module's parameters with fused all-reduce support.
444 This method groups parameters by their replicate_process_group and:
445 1. For params without all_reduce needs: issue reduce_scatter directly
446 2. For params with all_reduce needs: allocate fused buffer and issue reduce_scatter
447 into aligned views, enabling zero-copy fused all_reduce later.
448 """
449 # Collect parameters that need gradient reduction
450 params_to_reduce = []
451 for hsdp_param in self._iter_managed_params():
452 skip_param = (not hasattr(hsdp_param, "_unsharded_param")
453 or hsdp_param.unsharded_param is None
454 or not hsdp_param.sharded_param.requires_grad
455 or self._can_direct_all_reduce_compat_grad(hsdp_param)
456 or (hsdp_param.unsharded_param.grad is None
457 and hsdp_param.unsharded_accumulated_grad_data is None))
458 if skip_param:
459 continue
460 params_to_reduce.append(hsdp_param)
462 if not params_to_reduce:
463 return
465 # Group by replicate_process_group for fused all-reduce
466 # Key: id of process group, or None for params that don't need all_reduce
467 groups_by_comm = defaultdict(list)
468 for hsdp_param in params_to_reduce:
469 if self._should_run_all_reduce(hsdp_param):
470 key = id(hsdp_param.unsharded_group_info.group)
471 groups_by_comm[key].append(hsdp_param)
472 else:
473 groups_by_comm[None].append(hsdp_param)
475 # Handle params that don't need all_reduce (FSDP or single replica)
476 if None in groups_by_comm:
477 for hsdp_param in groups_by_comm[None]:
478 hsdp_param.reduce_scatter_grad(
479 dtype=self._reduce_dtype,
480 reduce_op=self._resolve_reduce_op()
481 )
482 HSDPState.pre_reduce_scatter_params.append(
483 (hsdp_param, self._orig_dtype))
485 # Handle params that need all_reduce (HSDP with multiple replicas)
486 for key, hsdp_params in groups_by_comm.items():
487 if key is None:
488 continue
490 # Create AllReduceParamGroup for fused all-reduce
491 group = AllReduceParamGroup(
492 replicate_group=hsdp_params[0].unsharded_group_info.group,
493 hsdp_params=hsdp_params,
494 orig_dtypes=[self._orig_dtype] * len(hsdp_params),
495 reduce_dtype=self._reduce_dtype,
496 reduce_op=self._resolve_reduce_op(),
497 mp_policy=self.mp_policy,
498 )
500 # Allocate fused buffer with 512-byte alignment
501 group.allocate_fused_buffer(self.device)
503 # Issue reduce_scatter with output directly into fused buffer views
504 for idx, hsdp_param in enumerate(hsdp_params):
505 buffer_view = group.get_param_buffer_view(idx)
506 hsdp_param.reduce_scatter_grad(
507 dtype=self._reduce_dtype,
508 reduce_op=self._resolve_reduce_op(),
509 output_buffer=buffer_view,
510 )
512 # Save group for later all_reduce in reduce_params()
513 TorchHSDPStateV2.pre_all_reduce_groups.append(group)
515 def _wait_prev_reduce_scatter(self) -> List[AllReduceParamGroup]:
516 """Step 1: wait prev reduce_scatter.
518 This enables overlapping:
519 - Layer N-1's reduce_scatter wait with Layer N's backward compute
521 Returns:
522 List of previous AllReduceParamGroups (one per communication group).
523 """
524 if TorchHSDPStateV2.pre_all_reduce_groups:
525 prev_groups = list(TorchHSDPStateV2.pre_all_reduce_groups)
526 TorchHSDPStateV2.pre_all_reduce_groups.clear()
527 for prev_group in prev_groups:
528 for hsdp_param in prev_group.hsdp_params:
529 hsdp_param.reduce_scatter_output()
530 hsdp_param.clear_reduce_scatter_output()
531 if hsdp_param.unsharded_accumulated_grad_data is not None:
532 hsdp_param.unsharded_accumulated_grad = None
533 elif hsdp_param.unsharded_param.grad is not None:
534 hsdp_param.unsharded_param.grad = None
535 return prev_groups
536 return []
538 def _issue_prev_fused_allreduce(self, prev_groups: List[AllReduceParamGroup]):
539 """Step 4: issue previous module's fused allreduce (async).
541 The allreduce handle is collected in pending_all_reduce_groups,
542 and will be processed in root_backward_hook's delay_apply_reduce_grads().
544 Args:
545 prev_groups: List of previous AllReduceParamGroups to issue allreduce for.
546 """
547 for prev_group in prev_groups:
548 prev_group.accumulate_existing_grads_to_buffer()
549 prev_group.issue_async_allreduce()
550 # Move to pending queue for root_backward_hook to process
551 TorchHSDPStateV2.pending_all_reduce_groups.append(prev_group)
553 def _wait_and_apply_prev_no_allreduce_params(self):
554 """Step 2: wait and apply previous reduce_scatter for params NOT needing allreduce.
556 These are FSDP params or single-replica HSDP params that don't need
557 cross-replica allreduce. Their reduce_scatter was issued by the previous
558 module's _issue_reduce_scatter_for_current_module(), and we wait and apply here.
559 """
560 need_synchronize = False
561 while HSDPState.pre_reduce_scatter_params:
562 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_reduce_scatter_params.pop(0)
563 reduced_grad = pre_hsdp_param.reduce_scatter_output()
564 pre_hsdp_param.clear_reduce_scatter_output()
565 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize
566 pre_hsdp_param.accumulated_allreduced_grad = False
568 if need_synchronize:
569 if self.device.type == "npu":
570 torch.npu.current_stream().synchronize()
571 elif self.device.type == "cuda":
572 torch.cuda.current_stream().synchronize()
573 else:
574 raise NotImplementedError(
575 f"Unsupported device type {self.device.type} for synchronization after CPU offload."
576 )
578 @classmethod
579 def delay_apply_reduce_grads(cls, device: torch.device):
580 """Apply all pending allreduce gradients in root_backward_hook.
582 This is called at the end of root_backward_hook to wait for all
583 async allreduce operations and apply gradients to sharded parameters.
585 Args:
586 device: Device for CPU offload synchronization.
587 """
588 need_synchronize = False
590 for group in cls.pending_all_reduce_groups:
591 need_synchronize = group.wait_and_apply_grads() or need_synchronize
593 cls.pending_all_reduce_groups.clear()
595 if need_synchronize:
596 if device.type == "npu":
597 torch.npu.current_stream().synchronize()
598 elif device.type == "cuda":
599 torch.cuda.current_stream().synchronize()
600 else:
601 raise NotImplementedError(
602 f"Unsupported device type {device.type} for synchronization after CPU offload."
603 )
606 def reduce_scattered_params(self):
607 """
608 reduce_scattered_params
609 """
610 need_synchronize = False
611 while HSDPState.pre_reduce_scatter_params:
612 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_reduce_scatter_params.pop(0)
613 reduced_grad = pre_hsdp_param.reduce_scatter_output()
614 pre_hsdp_param.clear_reduce_scatter_output()
615 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize
616 pre_hsdp_param.accumulated_allreduced_grad = False
617 if need_synchronize:
618 if self.device.type == "npu":
619 torch.npu.current_stream().synchronize()
620 elif self.device.type == "cuda":
621 torch.cuda.current_stream().synchronize()
622 else:
623 raise NotImplementedError(
624 f"Unsupported device type {self.device.type} for synchronization after CPU offload."
625 )
627 def reduce_params(self):
628 """Apply reduced gradients from pre-staged HSDP parameters to sharded parameters.
630 This function processes two lists of pre-queued HSDP parameters (`pre_reduce_scatter_params`
631 and `pre_all_reduce_params`), retrieves the reduced gradients from asynchronous
632 reduce-scatter/all-reduce operations, clears cached communication outputs, and applies
633 the reduced gradients to the corresponding sharded parameters (including reshaping,
634 dtype conversion, optional CPU offloading, and gradient accumulation/assignment).
636 Note:
637 - Parameters are processed in **FIFO (First-In-First-Out)** order (via `pop(0)`), ensuring
638 gradient application order matches the order of gradient reduction operations.
639 - After retrieving the reduced gradient, the cached communication output (reduce_scatter_output
640 or all_reduce_output) is cleared to free memory and avoid stale data.
641 - Gradient application logic (in `apply_reduced_grad`) includes:
642 1. Reshaping the flat reduced gradient to match the local shard shape
643 2. Optional dtype conversion to `param_type`
644 3. Optional CPU offloading (per the HSDP parameter's offload policy)
645 4. Assigning or accumulating the gradient to `sharded_param.grad`
646 """
647 need_synchronize = False
648 while HSDPState.pre_all_reduce_params:
649 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_all_reduce_params.pop(0)
650 reduced_grad = pre_hsdp_param.all_reduce_output()
651 pre_hsdp_param.clear_all_reduce_output()
652 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize
654 while TorchHSDPStateV2.pre_direct_all_reduce_grads:
655 handle, reduced_grad, target_grad = TorchHSDPStateV2.pre_direct_all_reduce_grads.pop(0)
656 if handle is not None:
657 handle.wait()
658 if reduced_grad is not target_grad:
659 if reduced_grad.dtype != target_grad.dtype:
660 reduced_grad = reduced_grad.to(target_grad.dtype)
661 target_grad.copy_(reduced_grad)
662 if need_synchronize:
663 if self.device.type == "npu":
664 torch.npu.current_stream().synchronize()
665 elif self.device.type == "cuda":
666 torch.cuda.current_stream().synchronize()
667 else:
668 raise NotImplementedError(
669 f"Unsupported device type {self.device.type} for synchronization after CPU offload."
670 )
672 def set_requires_grad_sync(self, requires_grad_sync):
673 """set requires grad sync flag to control gradient sync."""
674 self.reduce_grads = requires_grad_sync
676 @property
677 def _is_hsdp(self) -> bool:
678 return isinstance(self.mesh_info, HSDPMeshInfo)
680 def set_reduce_op_type(self, reduce_op_type: str):
681 """set reduce op type for gradient reduction."""
682 fsdp_support_reduce_op = {
683 "sum": torch.distributed.ReduceOp.SUM,
684 "avg": torch.distributed.ReduceOp.AVG,
685 }
686 if reduce_op_type not in fsdp_support_reduce_op:
687 raise ValueError(
688 f"Unsupported reduce op type {reduce_op_type}, "
689 f"supported types are {list(fsdp_support_reduce_op.keys())}"
690 )
691 reduce_op: str = reduce_op_type.lower().strip()
692 self._user_reduce_op_type = fsdp_support_reduce_op[reduce_op]
693 self.reduce_op_type = self._user_reduce_op_type