Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / param.py: 66%
502 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py
16# enhanced with fully_shard parameter management
17# ============================================================================
18"""HSDP parameter"""
19# pylint: disable=W0212
20import itertools
21from typing import Callable, List, Optional, Tuple, Union, cast
23import torch
24import torch.distributed as dist
25from torch import nn
26from torch._prims_common import make_contiguous_strides_for
28from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
29from hyper_parallel.core.dtensor.dtensor import DTensor, SkipDTensorDispatch
30from hyper_parallel.core.dtensor.layout import Layout
31from hyper_parallel.core.dtensor.placement_types import Replicate, Shard, StridedShard
32from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2
33from hyper_parallel.core.fully_shard.hsdp_utils import (
34 FullyShardParamMode,
35 GroupInfo,
36 ParamModuleInfo,
37 ShardedState,
38 get_rank_list_for_axes,
39 get_split_rank_lists_for_axes,
40)
41from hyper_parallel.core.fully_shard.utils import (
42 CPUOffloadPolicy,
43 DDPMeshInfo,
44 FSDPMeshInfo,
45 MixedPrecisionPolicy,
46 OffloadPolicy,
47)
48from hyper_parallel.platform import get_platform
49from hyper_parallel.platform.torch.fully_shard.pack_utils import (
50 build_rs_plan,
51 pack_for_reduce_scatter,
52 unpack_from_all_gather,
53)
55_GROUP_INFO_CACHE = {}
56platform = get_platform()
59def _copy_without_bumping_version(dst: torch.Tensor, src: torch.Tensor) -> None:
60 """Copy into ``dst`` while preserving its autograd version counter."""
61 # pylint: disable=W0212
62 with torch.autograd._unsafe_preserve_version_counter(dst):
63 dst.copy_(src)
66def _build_group_info_from_rank_list(
67 group_name: str,
68 rank_list,
69) -> GroupInfo:
70 """Create group metadata from an explicit rank list."""
71 normalized_rank_list = tuple(sorted(int(rank) for rank in rank_list))
72 if len(normalized_rank_list) <= 1:
73 return GroupInfo(f"{group_name}_invalid", None, 1)
74 if normalized_rank_list in _GROUP_INFO_CACHE:
75 cached_group = _GROUP_INFO_CACHE[normalized_rank_list]
76 return GroupInfo(str(normalized_rank_list), cached_group, len(normalized_rank_list))
77 try:
78 group = platform.create_group(list(normalized_rank_list))
79 except (RuntimeError, ValueError): # pragma: no cover - UT may run without dist init
80 group = None
81 _GROUP_INFO_CACHE[normalized_rank_list] = group
82 return GroupInfo(str(normalized_rank_list), group, len(normalized_rank_list))
85def _build_group_info_from_process_group(
86 group_name: str,
87 process_group,
88 rank_size: int,
89) -> GroupInfo:
90 """Create group metadata from an existing process group."""
91 if process_group is None or rank_size <= 1:
92 return GroupInfo(f"{group_name}_invalid", None, 1)
93 try:
94 rank_list = dist.get_process_group_ranks(process_group)
95 resolved_group_name = str(tuple(sorted(rank_list)))
96 except (AssertionError, AttributeError, KeyError, RuntimeError, TypeError, ValueError):
97 # pragma: no cover - best-effort naming / mocked process groups in UT
98 resolved_group_name = group_name
99 return GroupInfo(resolved_group_name, process_group, rank_size)
102class TorchHSDPParamV2(HSDPParamV2):
103 """
104 Torch HSDP parameter.
105 """
107 def __init__(
108 self,
109 param: nn.Parameter,
110 module_info: ParamModuleInfo,
111 mesh_info: FSDPMeshInfo,
112 shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
113 mp_policy: Optional[MixedPrecisionPolicy] = None,
114 offload_policy: Optional[OffloadPolicy] = None,
115 device: Optional[torch.device] = None,
116 param_mode: Optional[FullyShardParamMode] = None,
117 enable_fsdp_shard: bool = True,
118 ):
119 """
120 Initialize TorchHSDPParamV2 and shard the parameter.
122 Args:
123 param (nn.Parameter): The original full parameter to shard.
124 module_info (ParamModuleInfo): Ownership and shared-weight metadata.
125 mesh_info (FSDPMeshInfo): Mesh topology for shard/replicate dimensions.
126 shard_placement_fn (Callable, optional): Returns a Shard placement for the parameter,
127 or None to use default (Shard(0)).
128 mp_policy (MixedPrecisionPolicy, optional): Mixed precision dtype policy.
129 offload_policy (OffloadPolicy, optional): CPU offload policy.
130 device (torch.device, optional): Target device for the sharded parameter.
131 """
132 self._module_info: ParamModuleInfo = module_info
133 self.mesh_info = mesh_info
134 self.mp_policy = mp_policy
135 self.device = device
136 if param_mode is None:
137 raise AssertionError("param_mode must be resolved before TorchHSDPParamV2 initialization.")
138 self.param_mode = param_mode
139 self.enable_fsdp_shard = enable_fsdp_shard
140 self.orig_dtype = None
141 self.param_dtype = None
142 self.reduce_dtype = None
143 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
144 self.pin_memory = (
145 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
146 )
147 self._orig_param_hooks: List[Callable] = []
148 self.grad_offload_event: Optional[torch.Event] = None
149 self._orig_param_is_dtensor = isinstance(param, DTensor)
150 self._orig_dtensor_mesh = param.device_mesh if self._orig_param_is_dtensor else None
151 self._orig_dtensor_placements = tuple(param.placements) if self._orig_param_is_dtensor else None
152 self._spmd_shard_mesh_dim = self.mesh_info.shard_mesh_dim
153 self._spmd_replicate_mesh_dim = self.mesh_info.replicate_mesh_dim
154 self._init_sharded_param(param, shard_placement_fn)
155 self._init_group_infos()
156 self.all_gather_outputs: List[torch.Tensor] = []
157 self.unsharded_accumulated_grad = None
158 self._param_fqn: Optional[str] = None
159 # Communication attributes for prefetch pattern
160 self.prefetch_handle: Optional[dist.Work] = None
161 self._post_load_hook_handle = (
162 module_info.module.register_load_state_dict_post_hook(
163 lambda *args, **kwargs: self.reset_sharded_param()
164 )
165 )
166 self._reduce_scatter_output = None
167 self.reduce_scatter_handle = None
168 self._all_reduce_output = None
169 self.all_reduce_handle = None
170 self._save_backward_hooks(param)
171 self._grad = None
172 self._accumulated_allreduced_grad = True
174 @property
175 def uses_param_shard(self) -> bool:
176 """Whether fully_shard should physically shard parameter storage for this param."""
177 return self.enable_fsdp_shard
179 @property
180 def is_dtensor_compat_mode(self) -> bool:
181 """Whether the parameter is managed through the DTensor compatibility path only."""
182 return self.param_mode == FullyShardParamMode.DTENSOR_COMPAT
184 def _get_base_spmd_placements(self) -> tuple:
185 if self.param_mode == FullyShardParamMode.DTENSOR_UNIFIED and self._orig_param_is_dtensor:
186 # DTENSOR_UNIFIED keeps the original distributed layout and prefixes
187 # explicit DP/FSDP mesh dimensions ahead of it on the unified mesh.
188 self._spmd_mesh = DeviceMesh.concatenate([self.mesh_info.mesh, self._orig_dtensor_mesh])
189 dp_prefix_placements = tuple(Replicate() for _ in range(self.mesh_info.mesh.ndim))
190 return dp_prefix_placements + tuple(self._orig_dtensor_placements)
192 if self.is_dtensor_compat_mode and self._orig_param_is_dtensor:
193 self._spmd_mesh = self._orig_dtensor_mesh
194 return tuple(self._orig_dtensor_placements)
196 self._spmd_mesh = self.mesh_info.mesh
197 return tuple(Replicate() for _ in range(self._spmd_mesh.ndim))
199 def _apply_data_parallel_placements(self, placements: list, shard_placement: Shard) -> tuple:
200 if len(placements) != self._spmd_mesh.ndim:
201 raise AssertionError(
202 f"Expected {self._spmd_mesh.ndim} unified placements, got {len(placements)}: {placements}"
203 )
204 if (
205 isinstance(self.mesh_info, DDPMeshInfo)
206 and self._spmd_replicate_mesh_dim is not None
207 and not self._orig_param_is_dtensor
208 ):
209 placements[self._spmd_replicate_mesh_dim] = Replicate()
210 if (
211 self.uses_param_shard
212 and isinstance(self.mesh_info, FSDPMeshInfo)
213 and self._spmd_shard_mesh_dim is not None
214 ):
215 # If TP/EP already shards the same tensor dimension, fully_shard must
216 # use StridedShard so the unified placement preserves the intended
217 # shard order on the concatenated mesh.
218 split_factor = 1
219 for mesh_idx, placement in enumerate(placements):
220 if mesh_idx == self._spmd_shard_mesh_dim:
221 continue
222 if placement.is_shard(shard_placement.dim):
223 split_factor *= self._spmd_mesh.mesh_shape[mesh_idx]
224 placements[self._spmd_shard_mesh_dim] = (
225 StridedShard(shard_placement.dim, split_factor=split_factor)
226 if split_factor > 1
227 else shard_placement
228 )
229 return tuple(placements)
231 def _init_group_infos(self) -> None:
232 if self.uses_param_shard and self.is_sharded and isinstance(self.mesh_info, FSDPMeshInfo):
233 self.sharded_group_info = _build_group_info_from_process_group(
234 "fully_shard_sharded_group",
235 self.mesh_info.shard_process_group,
236 self.mesh_info.shard_mesh_size,
237 )
238 else:
239 self.sharded_group_info = GroupInfo("fully_shard_sharded_group_invalid", None, 1)
241 # The all-reduce group is always derived from the final materialized layout.
242 # This keeps replicate_params, DTensor compat, and unified multi-dim layouts
243 # on a single source of truth.
244 self.unsharded_group_info = self._build_layout_driven_group_info()
246 self.shard_size = self.sharded_group_info.rank_size
247 self.dp_size = self.unsharded_group_info.rank_size
248 self.rank_size = max(1, self.shard_size * self.dp_size)
250 def _build_layout_driven_group_info(self):
251 group_axes = [
252 axis
253 for axis, placement in enumerate(self._spmd_placements)
254 if placement.is_replicate()
255 ]
256 if self.uses_param_shard and self._spmd_shard_mesh_dim is not None:
257 group_axes = [axis for axis in group_axes if axis != self._spmd_shard_mesh_dim]
258 if not group_axes:
259 return GroupInfo("fully_shard_unsharded_group_invalid", None, 1)
260 group_dim_names = getattr(self._spmd_mesh, "mesh_dim_names", None)
261 if group_dim_names:
262 try:
263 mesh_axis_names = tuple(group_dim_names[axis] for axis in group_axes)
264 if len(mesh_axis_names) == 1:
265 axis_name = mesh_axis_names[0]
266 process_group = self._spmd_mesh.get_group(axis_name)
267 if process_group is not None:
268 rank_size = self._spmd_mesh.mesh_shape[group_dim_names.index(axis_name)]
269 return _build_group_info_from_process_group(
270 "fully_shard_unsharded_group",
271 process_group,
272 rank_size,
273 )
275 split_rank_lists = get_split_rank_lists_for_axes(self._spmd_mesh, group_axes)
276 process_group = platform.split_group(split_ranks=split_rank_lists)
277 if process_group is not None:
278 rank_size = 1
279 for axis in group_axes:
280 rank_size *= self._spmd_mesh.mesh_shape[axis]
281 return _build_group_info_from_process_group(
282 "fully_shard_unsharded_group",
283 process_group,
284 rank_size,
285 )
286 except (
287 AssertionError,
288 AttributeError,
289 KeyError,
290 RuntimeError,
291 TypeError,
292 ValueError,
293 ):
294 # Fall back to the explicit rank-list path for mocked meshes in UT
295 # or when a mesh implementation cannot materialize a reusable group.
296 pass
298 rank_list = get_rank_list_for_axes(self._spmd_mesh, group_axes)
299 return _build_group_info_from_rank_list("fully_shard_unsharded_group", rank_list)
301 def _to_local_unsharded_grad(self, grad):
302 """Normalize a pending gradient to a local tensor expected by fully_shard collectives."""
303 if not isinstance(grad, DTensor):
304 return grad
306 if any(placement.is_partial() for placement in grad.placements):
307 grad = grad.reduce_partial()
309 if (
310 self._orig_dtensor_mesh is not None
311 and grad.device_mesh.to_hash() != self._orig_dtensor_mesh.to_hash()
312 ) or (
313 self._orig_dtensor_placements is not None
314 and tuple(grad.placements) != tuple(self._orig_dtensor_placements)
315 ):
316 grad = grad.redistribute(self._orig_dtensor_mesh, self._orig_dtensor_placements)
317 return grad.to_local()
319 @property
320 def accumulated_allreduced_grad(self) -> bool:
321 """Whether the parameter has accumulated all-reduced gradient."""
322 return self._accumulated_allreduced_grad
324 @accumulated_allreduced_grad.setter
325 def accumulated_allreduced_grad(self, value: bool) -> None:
326 self._accumulated_allreduced_grad = value
328 def _save_backward_hooks(self, param: nn.Parameter) -> None:
329 """Save the backward hooks of the original parameter"""
330 if not hasattr(param, '_backward_hooks') or param._backward_hooks is None:
331 return
333 # Get the set of saved hook function IDs for deduplication
334 if not hasattr(self, '_saved_hook_ids'):
335 object.__setattr__(self, '_saved_hook_ids', set())
337 for _, hook_func in param._backward_hooks.items():
338 # Use the id of hook_func to avoid adding the same function object repeatedly
339 hook_func_id = id(hook_func)
340 if hook_func_id not in self._saved_hook_ids:
341 self._orig_param_hooks.append(hook_func)
342 self._saved_hook_ids.add(hook_func_id)
344 def _migrate_backward_hooks(self, new_param: nn.Parameter) -> None:
345 """Migrate backward hooks from the original parameter to the new parameter"""
346 if not self._orig_param_hooks or hasattr(new_param, "migrate_backward_hooks_run_once"):
347 return
349 # Properly register each hook using the register_hook method
350 for hook_func in self._orig_param_hooks:
351 try:
352 if new_param.requires_grad:
353 new_param.register_hook(hook_func)
354 except RuntimeError:
355 # Skip hook registration if the parameter does not require gradients
356 pass
357 new_param.migrate_backward_hooks_run_once = True
359 def reduce_scatter_output(self):
360 """
361 Get the reduce-scatter output tensor and wait for asynchronous operation to complete.
363 Returns:
364 torch.Tensor: The sharded gradient tensor after reduce-scatter operation.
365 """
366 if self.reduce_scatter_handle is not None:
367 self.reduce_scatter_handle.wait()
368 self._grad.untyped_storage().resize_(0)
369 self._grad = None
370 self.reduce_scatter_handle = None
371 return self._reduce_scatter_output
373 def clear_reduce_scatter_output(self):
374 """Clear the reduce-scatter output tensor to free memory."""
375 self._reduce_scatter_output = None
377 def all_reduce_output(self):
378 """
379 Get the all-reduce output tensor and wait for asynchronous operation to complete.
381 Returns:
382 torch.Tensor: The reduced gradient tensor after all-reduce operation.
383 """
384 if self.all_reduce_handle is not None:
385 self.all_reduce_handle.wait()
386 self.all_reduce_handle = None
387 return self._all_reduce_output
389 def clear_all_reduce_output(self):
390 """Clear the all-reduce output tensor to free memory."""
391 self._all_reduce_output = None
393 def apply_reduced_grad(self, reduced_grad, param_type):
394 """
395 Apply reduced gradient to the sharded parameter.
397 Reshapes ``reduced_grad`` to match the local shard, optionally
398 offloads to CPU, then accumulates or assigns onto
399 ``hsdp_param.sharded_param.grad``.
401 Args:
402 reduced_grad (torch.Tensor): Gradient after reduce-scatter
403 and/or all-reduce.
404 param_type (Optional[torch.dtype]): Target dtype for the gradient (if conversion is needed).
405 """
406 sharded_grad = None
407 if not self.mp_policy.apply_grad_on_fp32_main_grad:
408 sharded_grad = self.sharded_param.grad
409 else:
410 if not hasattr(self.sharded_param, "main_grad"):
411 self.sharded_param.main_grad = None
412 sharded_grad = self.sharded_param.main_grad
413 sharded_param_local_shape = (
414 self.sharded_param.local_shape
415 if isinstance(self.sharded_param, DTensor)
416 else self.sharded_param.shape
417 )
418 reduced_grad = reduced_grad.view(sharded_param_local_shape)
419 if (not self.mp_policy.apply_grad_on_fp32_main_grad and param_type is not None
420 and reduced_grad.dtype != param_type):
421 reduced_grad = reduced_grad.to(param_type)
422 to_accumulate_grad = sharded_grad is not None
423 need_synchronize = False
424 if self.offload_to_cpu:
425 non_blocking = self.pin_memory and not to_accumulate_grad
426 reduced_grad = reduced_grad.to(
427 torch.device("cpu"), non_blocking=non_blocking
428 )
429 need_synchronize = True
430 if sharded_grad is None:
431 if not self.mp_policy.apply_grad_on_fp32_main_grad:
432 self.sharded_param.grad = self.to_sharded_dtensor(reduced_grad)
433 else:
434 self.sharded_param.main_grad = self.to_sharded_dtensor(reduced_grad)
435 self.sharded_param.grad = None
436 else:
437 with SkipDTensorDispatch():
438 if not self.mp_policy.apply_grad_on_fp32_main_grad:
439 self.sharded_param.grad._local_tensor += reduced_grad
440 else:
441 self.sharded_param.main_grad._local_tensor += reduced_grad
442 self.sharded_param.grad = None
443 if self.unsharded_accumulated_grad_data is not None:
444 self.unsharded_accumulated_grad = None
445 elif self.unsharded_param.grad is not None:
446 self.unsharded_param.grad = None
447 return need_synchronize
449 @torch.no_grad()
450 def _init_sharded_param(
451 self,
452 param: nn.Parameter,
453 shard_placement_fn: Optional[Callable],
454 ) -> None:
455 if param.device != self.device and param.device.type != "meta":
456 raise AssertionError(
457 f"Expects the parameter to already be moved to device {self.device} but got {param.device}"
458 )
460 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None
461 if hsdp_placement is None:
462 hsdp_placement = Shard(0)
463 elif hsdp_placement.dim < 0:
464 # if dim is negative, add the number of dimensions of the parameter
465 hsdp_placement = Shard(hsdp_placement.dim + param.ndim)
467 if not isinstance(hsdp_placement, Shard):
468 raise AssertionError(
469 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}"
470 )
472 self.hsdp_placement = hsdp_placement
473 base_placements = list(self._get_base_spmd_placements())
474 self._spmd_placements = self._apply_data_parallel_placements(base_placements, hsdp_placement)
475 param_data = param.to_local() if self._orig_param_is_dtensor else param
477 shard_dim = hsdp_placement.dim
478 self._orig_size = param_data.size()
479 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
481 if self.uses_param_shard and isinstance(self.mesh_info, FSDPMeshInfo):
482 shard_rank = self.mesh_info.shard_mesh_rank
483 shard_world_size = self.mesh_info.shard_mesh_size
484 else:
485 shard_rank = 0
486 shard_world_size = 1
488 if isinstance(param_data, DTensor) and isinstance(self.mesh_info, DDPMeshInfo):
489 param_data.data = param_data.full_tensor()
491 self.is_sharded = bool(self.uses_param_shard and shard_world_size > 1)
493 if param_data.size(shard_dim) % shard_world_size != 0:
494 raise NotImplementedError(
495 f"Uneven sharding on dim {shard_dim} not supported: "
496 f"shape={param_data.shape}, world_size={shard_world_size}"
497 )
498 chunks = torch.chunk(param_data, shard_world_size, dim=shard_dim)
499 sharded_param = chunks[shard_rank].clone().contiguous()
500 self.sharded_size = sharded_param.size()
501 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
502 if self.offload_to_cpu and not sharded_param.is_meta:
503 sharded_param = sharded_param.cpu()
504 if self.pin_memory:
505 sharded_param = sharded_param.pin_memory()
506 self._sharded_param_data = sharded_param.view(-1)
508 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh)
509 self._sharding_spec.set_placements(self._spmd_placements)
510 self._sharding_spec.placement_to_tensor_map(param.ndim)
512 self.sharded_param = nn.Parameter(DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements))
513 self.sharded_param.requires_grad_(param.requires_grad)
514 self._setattr_on_modules(self.sharded_param)
515 # after init, self.sharded_param replaces original param, gradients must accumulate to this Parameter's grad
516 self.sharded_param._hsdp_param_initialized = True
517 self.sharded_state = ShardedState.SHARDED
518 self.param_dtype = None
520 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
521 """Initialize param_dtype and reduce_dtype from the mixed precision policy."""
522 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
523 self.orig_dtype = self.sharded_param.dtype
524 if reduce_dtype == param_dtype:
525 reduce_dtype = None
526 if param_dtype == self.orig_dtype:
527 param_dtype = None
528 self.param_dtype = param_dtype
529 self.reduce_dtype = reduce_dtype
531 def init_all_gather_outputs(
532 self,
533 all_gather_input_numels: list[int],
534 all_gather_input_dtypes: list[torch.dtype],
535 world_size: int,
536 device: torch.device,
537 force_recreate: bool = False,
538 ):
539 """
540 Allocate output buffers for all-gather communication.
542 Args:
543 all_gather_input_numels: Number of elements per input shard.
544 all_gather_input_dtypes: Dtype of each input shard.
545 world_size: Number of ranks in the shard process group.
546 device: Device on which to allocate the output buffers.
547 force_recreate: If True, always recreate buffers even if already initialized.
548 """
549 if not force_recreate and len(self.all_gather_outputs) > 0:
550 return # already initialized
551 self.all_gather_outputs = [
552 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
553 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
554 ]
556 def init_unsharded_param(self):
557 """
558 Initialize unsharded parameter from all-gather outputs.
560 This reconstructs the full parameter after all-gather by unpacking the
561 gathered flat buffer back to the original tensor layout.
562 """
563 unsharded_param = self._get_unsharded_param_from_all_gather_output()
564 # Always refresh the unsharded Parameter from the latest all-gather output.
565 # Non-dim0 unpack currently materializes a contiguous tensor copy, so
566 # keeping stale .data would otherwise reuse old weights after optimizer.step()
567 # mutates only the sharded local shard. Preserve the Parameter object identity
568 # so autograd-facing module state stays stable across unshard cycles.
569 if hasattr(self, "_unsharded_param"):
570 # pylint: disable=access-member-before-definition
571 self._unsharded_param.data = unsharded_param
572 self._unsharded_param.requires_grad_(self.sharded_param.requires_grad)
573 self._unsharded_param.grad = None
574 return
575 self._unsharded_param = nn.Parameter(
576 unsharded_param,
577 requires_grad=self.sharded_param.requires_grad,
578 )
580 def _get_unsharded_param_from_all_gather_output(self) -> torch.Tensor:
581 """Reconstruct the full local parameter view from the packed all-gather output."""
582 if len(self.all_gather_outputs) != 1:
583 raise AssertionError(
584 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
585 )
586 unsharded_tensor = self.all_gather_outputs[0]
587 plan = build_rs_plan(
588 self,
589 self._sharded_local_tensor,
590 self.shard_world_size if self.is_sharded else 1,
591 )
592 unsharded_param = unpack_from_all_gather(unsharded_tensor, plan)
593 if self._orig_param_is_dtensor:
594 # Rebuild the original DTensor view after all-gather so gradient
595 # consumers keep seeing the source DTensor layout.
596 unsharded_param = DTensor.from_local(
597 unsharded_param,
598 self._orig_dtensor_mesh,
599 self._orig_dtensor_placements,
600 )
601 return unsharded_param
603 def to_sharded(self) -> None:
604 if not self.uses_param_shard and self._unsharded_param is not None:
605 # Replicate params keep the same local shape across shard/unshard,
606 # so persist forward-time state updates before switching objects.
607 src = self._unsharded_param.to_local() if isinstance(self._unsharded_param, DTensor) \
608 else self._unsharded_param
609 dst = self.sharded_param.to_local() if isinstance(self.sharded_param, DTensor) else self.sharded_param
610 _copy_without_bumping_version(dst, src)
611 self._setattr_on_modules(self.sharded_param)
612 self.free_unsharded_param()
613 self.sharded_state = ShardedState.SHARDED
615 def to_unsharded(self) -> None:
616 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
617 self._setattr_on_modules(self._unsharded_param)
618 self.sharded_state = ShardedState.UNSHARDED
620 def _setattr_on_modules(self, param: nn.Parameter) -> None:
621 """Set parameter on module and shared modules, preserving pointer consistency."""
622 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Module.__setattr__:
623 # fast path
624 self._module_info.module._parameters[self._module_info.param_name] = param
625 else:
626 # slow path
627 setattr(self._module_info.module, self._module_info.param_name, param)
628 self._save_backward_hooks(self.sharded_param)
629 self._migrate_backward_hooks(param)
630 # Iterate through all modules that share this parameter to prevent pointer desync.
631 for shared_module, shared_param_name in zip(
632 self._module_info.shared_modules, self._module_info.shared_param_names
633 ):
634 if getattr(shared_module.__setattr__, "__func__", None) is nn.Module.__setattr__:
635 shared_module._parameters[shared_param_name] = param
636 else:
637 setattr(shared_module, shared_param_name, param)
639 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
640 """
641 Converts a local tensor representing either the sharded parameter or
642 sharded gradient to DTensor.
643 """
644 return DTensor.from_local(
645 tensor,
646 self._sharding_spec.mesh,
647 self._sharding_spec.placements
648 )
650 def to_accumulated_grad_if_needed(self) -> None:
651 if self._unsharded_param.grad is None:
652 return
653 # Keep local gradients alive across no-sync / delayed-sync steps even
654 # after the parameter transitions back to the sharded view.
655 unsharded_grad = self._unsharded_param.grad
656 self._unsharded_param.grad = None
657 if self.reduce_dtype is not None and unsharded_grad.dtype != self.reduce_dtype:
658 unsharded_grad = unsharded_grad.to(self.reduce_dtype)
659 if self.unsharded_accumulated_grad is None:
660 self.unsharded_accumulated_grad = unsharded_grad
661 else:
662 self.unsharded_accumulated_grad += unsharded_grad
664 def accumulate_unsharded_grad_if_needed(self) -> None:
665 if (
666 self.unsharded_accumulated_grad is not None
667 and self.unsharded_param.grad is not None
668 ):
669 grad = self.unsharded_param.grad
670 if self.reduce_dtype is not None and grad.dtype != self.reduce_dtype:
671 grad = grad.to(self.reduce_dtype)
672 self.unsharded_accumulated_grad += grad
673 self.unsharded_param.grad = None
675 def alloc_all_gather_outputs(self) -> None:
676 """Resize all-gather output buffers to their full capacity for communication."""
677 for tensor in self.all_gather_outputs:
678 expected_size = tensor.numel() * tensor.itemsize
679 storage = tensor.untyped_storage()
680 if storage.size() != expected_size:
681 storage.resize_(expected_size)
683 def free_unsharded_param(self) -> None:
684 """Release storage of all-gather outputs to free device memory."""
685 for tensor in self.all_gather_outputs:
686 storage = tensor.untyped_storage()
687 if storage.size() != 0:
688 storage.resize_(0)
690 @property
691 def all_gather_inputs(self) -> list[torch.Tensor]:
692 """Return the local sharded tensor to use as input for all-gather, applying dtype cast if needed."""
693 self._assert_in_states(ShardedState.SHARDED)
694 sharded_param_data = self._sharded_param_data
695 if self.offload_to_cpu:
696 sharded_param_data = sharded_param_data.to(
697 self.device, non_blocking=True
698 )
699 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype:
700 return [sharded_param_data.to(self.param_dtype)]
701 return [sharded_param_data]
703 @property
704 def unsharded_param(self) -> nn.Parameter:
705 """Return the full unsharded parameter after all-gather."""
706 return self._unsharded_param
708 @property
709 def unsharded_grad_data(self) -> torch.Tensor:
710 """
711 Get the unsharded gradient data as a local tensor.
712 """
713 grad = self.unsharded_param.grad
714 if grad is None:
715 raise AssertionError("Expects unsharded_param.grad to not be None")
716 return self._to_local_unsharded_grad(grad)
718 @property
719 def unsharded_accumulated_grad_data(self) -> torch.Tensor:
720 """
721 Get the unsharded accumulated gradient data as a local tensor.
722 """
723 grad = self.unsharded_accumulated_grad
724 return self._to_local_unsharded_grad(grad)
726 @property
727 def _sharded_local_tensor(self) -> torch.Tensor:
728 """Return the underlying local tensor of the sharded DTensor parameter."""
729 return cast(DTensor, self.sharded_param)._local_tensor
731 @property
732 def shard_world_size(self) -> int:
733 """Get the world size for shard dimension."""
734 return self.shard_size
736 @property
737 def replicate_world_size(self) -> int:
738 """Get the world size for replicate dimension (HSDP only)."""
739 return self.dp_size
741 def _assert_in_states(self, *states: ShardedState) -> None:
742 """Assert current state is one of expected states."""
743 if self.sharded_state not in states:
744 raise AssertionError(
745 f"Expected sharded_state in {states}, got {self.sharded_state}"
746 )
748 def reset_sharded_param(self) -> None:
749 """Reset sharded param after load_state_dict."""
750 module_info = self._module_info
751 new_param = getattr(module_info.module, module_info.param_name)
752 if new_param is not self.sharded_param:
753 # Ensure object identity is preserved after parameter conversion.
754 if torch.__future__.get_swap_module_params_on_conversion():
755 raise AssertionError(
756 f"Expects swap_tensors to preserve object but got {new_param} "
757 f"instead of {self.sharded_param}"
758 )
759 if isinstance(new_param, DTensor):
760 self.sharded_param = new_param
761 if not getattr(self.sharded_param, "_hsdp_param_initialized", None):
762 # reset _hsdp_param_initialized flag.
763 self.sharded_param._hsdp_param_initialized = True
764 elif isinstance(new_param, torch.Tensor):
765 # if new_param is Tensor, don't change 'self.sharded_param' ref
766 # just update self.sharded_param._local_tensor and self.sharded_param_data.
767 pass
769 local_tensor = new_param._local_tensor if isinstance(new_param, DTensor) else new_param
770 if local_tensor.is_meta:
771 return
772 updated_local_tensor = False
773 # local_tensor can be padded twice
774 # 1st time in fully_shard(model)
775 # 2nd time in model(input) lazy_init
776 # 2nd time should be no-op if parameters remain unchanged
777 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init
778 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop
779 # and use `sd` without calling .state_dict() per iteration
780 same_local_tensor = False
781 if isinstance(self._sharded_param_data, torch.Tensor):
782 same_local_tensor = (
783 # when sharding param with shape (1, ...) over 2 ranks
784 # local_tensor on rank 1 can be size 0, data_ptr() can be 0
785 self._sharded_param_data.untyped_storage().data_ptr() > 0
786 and self._sharded_param_data.untyped_storage().data_ptr()
787 == local_tensor.untyped_storage().data_ptr()
788 )
789 sharded_size = self.sharded_size
790 shard_dim = self.hsdp_placement.dim
791 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
792 if not same_local_tensor:
793 if local_tensor.size() != sharded_size:
794 raise AssertionError(
795 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}"
796 )
797 updated_local_tensor = True
798 if self.pin_memory and not local_tensor.is_pinned():
799 local_tensor = local_tensor.cpu().pin_memory()
800 updated_local_tensor = True
801 if not same_local_tensor:
802 self._sharded_param_data = local_tensor.view(-1)
803 if not isinstance(self.sharded_param, DTensor):
804 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
805 if updated_local_tensor:
806 # Only change the local tensor object if needed
807 self.sharded_param._local_tensor = local_tensor.narrow(
808 dim=shard_dim, start=0, length=length
809 )
810 if not self.sharded_param._local_tensor.is_contiguous():
811 raise AssertionError(
812 "Expected sharded_param._local_tensor to be contiguous"
813 )
814 self._sharding_spec = cast(DTensor, self.sharded_param).layout
816 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[torch.Tensor, Optional[dist.Work]]:
817 """
818 Perform all-gather to get unsharded parameter data.
820 Args:
821 async_op: Whether to execute asynchronously.
823 Returns:
824 (unsharded_param, handle): Unsharded parameter data and communication handle.
825 """
826 # If parameter is not sharded (below threshold), no communication needed
827 if not self.is_sharded:
828 all_gather_input = self.all_gather_inputs[0]
829 self.init_all_gather_outputs(
830 all_gather_input_numels=[all_gather_input.numel()],
831 all_gather_input_dtypes=[all_gather_input.dtype],
832 world_size=1,
833 device=self.device,
834 )
835 self.alloc_all_gather_outputs()
836 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input)
837 return self.all_gather_outputs[0], None
839 # Get input data
840 all_gather_input = self.all_gather_inputs[0]
842 # Initialize output buffer
843 self.init_all_gather_outputs(
844 all_gather_input_numels=[all_gather_input.numel()],
845 all_gather_input_dtypes=[all_gather_input.dtype],
846 world_size=self.shard_world_size,
847 device=self.device,
848 )
849 self.alloc_all_gather_outputs()
851 if self.sharded_group_info.group is None or self.shard_world_size <= 1:
852 # No communication needed, just copy
853 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input)
854 return self.all_gather_outputs[0], None
856 # Execute all_gather_into_tensor
857 handle = dist.all_gather_into_tensor(
858 self.all_gather_outputs[0],
859 all_gather_input,
860 group=self.sharded_group_info.group,
861 async_op=async_op,
862 )
864 return self.all_gather_outputs[0], handle
866 def unshard(self, async_op: bool = False) -> None:
867 if self.prefetch_handle is not None:
868 # Already triggered by HSDPState.prefetch(), so return directly.
869 return # no-op
871 _, handle = self._get_unsharded_param_data(async_op=async_op)
872 self.prefetch_handle = handle
874 def wait_for_unshard(self) -> None:
875 self._assert_in_states(ShardedState.SHARDED)
877 if self.prefetch_handle is not None:
878 self.prefetch_handle.wait()
879 self.prefetch_handle = None
881 self.init_unsharded_param()
882 self.to_unsharded()
884 def shard(self) -> None:
885 """
886 Transition parameter from unsharded back to sharded state.
887 """
888 self._assert_in_states(ShardedState.UNSHARDED)
889 self.to_sharded()
891 def reduce_scatter_grad(
892 self,
893 async_op: bool = True,
894 dtype: Optional[torch.dtype] = None,
895 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG,
896 output_buffer: Optional[torch.Tensor] = None,
897 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]:
898 """
899 Perform reduce-scatter on gradient to reduce and shard the full gradient.
901 Args:
902 async_op: Whether to execute asynchronously.
903 dtype: reduce dtype.
904 reduce_op: do reduce-scatter avg or sum.
905 output_buffer: Optional pre-allocated output buffer for fused all-reduce.
906 When provided, reduce_scatter writes directly into this buffer,
907 enabling zero-copy fusion with subsequent all_reduce operations.
908 The buffer must have the correct size (sharded_size.numel()) and dtype.
910 Returns:
911 (sharded_grad, handle): Sharded gradient and communication handle.
912 """
913 self._assert_in_states(ShardedState.UNSHARDED)
915 # Choose gradient source based on use_accumulated_grad flag
916 if self.unsharded_accumulated_grad is not None:
917 grad = self.unsharded_accumulated_grad_data
918 else:
919 grad = self.unsharded_grad_data
920 reduce_dtype = dtype or grad.dtype
921 self._grad = grad.to(reduce_dtype)
922 plan_world_size = (
923 self.shard_world_size
924 if self.is_sharded
925 and self.sharded_group_info.group is not None
926 and self.shard_world_size > 1
927 else 1
928 )
929 plan = build_rs_plan(self, self._grad, plan_world_size)
930 grad_flat = pack_for_reduce_scatter(self._grad, plan).reshape(-1)
932 # If parameter is not sharded (below threshold), no reduce-scatter needed
933 if not self.is_sharded:
934 if output_buffer is not None:
935 output_buffer.copy_(grad_flat)
936 self._reduce_scatter_output = output_buffer
937 else:
938 self._reduce_scatter_output = grad_flat
939 self.reduce_scatter_handle = None
940 return grad_flat, None
942 if self.sharded_group_info.group is None or self.shard_world_size <= 1:
943 if output_buffer is not None:
944 output_buffer.copy_(grad_flat)
945 self._reduce_scatter_output = output_buffer
946 else:
947 self._reduce_scatter_output = grad_flat
948 self.reduce_scatter_handle = None
949 # No communication needed
950 return grad_flat, None
952 # Calculate output size
953 output_numel = grad_flat.numel() // self.shard_world_size
954 # Use provided output buffer or allocate a new one
955 if output_buffer is not None:
956 if output_buffer.numel() != output_numel:
957 raise ValueError(
958 f"output_buffer size mismatch: expected {output_numel}, got {output_buffer.numel()}"
959 )
960 if output_buffer.dtype != reduce_dtype:
961 raise ValueError(
962 f"output_buffer dtype mismatch: expected {reduce_dtype}, got {output_buffer.dtype}"
963 )
964 self._reduce_scatter_output = output_buffer
965 else:
966 self._reduce_scatter_output = torch.empty(output_numel, dtype=reduce_dtype, device=self._grad.device)
968 # Execute reduce_scatter_tensor
969 self.reduce_scatter_handle = dist.reduce_scatter_tensor(
970 self._reduce_scatter_output,
971 grad_flat,
972 op=reduce_op,
973 group=self.sharded_group_info.group,
974 async_op=async_op,
975 )
976 return self._reduce_scatter_output, self.reduce_scatter_handle
978 def all_reduce_grad(
979 self,
980 grad: Optional[torch.Tensor] = None,
981 dtype: Optional[torch.dtype] = None,
982 async_op: bool = True,
983 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG
984 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]:
985 """
986 Perform all-reduce on gradient (across replicate dimension in HSDP mode).
988 Args:
989 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad
990 or unsharded_accumulated_grad based on use_accumulated_grad flag.
991 async_op: Whether to execute asynchronously.
992 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG.
994 Returns:
995 (reduced_grad, handle): Reduced gradient and communication handle.
996 """
997 # If grad is not provided, get from parameter
998 if grad is None:
999 if self.unsharded_accumulated_grad is not None:
1000 grad = self.unsharded_accumulated_grad_data
1001 else:
1002 grad = self.unsharded_grad_data
1004 if dtype is not None and dtype != grad.dtype:
1005 grad = grad.to(dtype)
1007 if self.unsharded_group_info.group is None or self.replicate_world_size <= 1:
1008 return grad, None
1010 self.all_reduce_handle = dist.all_reduce(grad, op=reduce_op,
1011 group=self.unsharded_group_info.group, async_op=async_op)
1012 self._all_reduce_output = grad
1013 return grad, self.all_reduce_handle
1016def set_requires_grad_if_needed(
1017 src_tensor: torch.Tensor, dst_tensor: torch.Tensor
1018) -> None:
1019 """set dst_tensor requires_grads from src_tensor if needed."""
1020 if src_tensor.requires_grad != dst_tensor.requires_grad:
1021 dst_tensor.requires_grad_(src_tensor.requires_grad)