1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
16# enhanced with activation swap functionality.
17# ============================================================================
18"""Activation Swap implementation for PyTorch."""
19# pylint: disable=W0212, W0613
20
21from abc import ABC, abstractmethod
22from collections.abc import Iterator
23from typing import Optional, Callable, Any, Union
24import warnings
25import torch
26from torch import nn
27from torch.distributed.utils import _replace_by_prefix
28from hyper_parallel.core.activation_checkpoint.activation_checkpoint import CheckpointPolicy
29from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage
30
31
32_SWAP_WRAPPED_MODULE = "_swap_wrapped_module"
33_SWAP_PREFIX = _SWAP_WRAPPED_MODULE + "."
34
35
36class FuncModule(nn.Module):
37 """
38 Thin :class:`~torch.nn.Module` adapter that wraps a plain callable.
39
40 Allows ordinary Python functions (or any callable without Module
41 parameters) to be passed to :func:`swap_wrapper` and
42 :func:`~hyper_parallel.platform.torch.platform.TorchPlatform.ckpt_wrapper`
43 in place of an :class:`~torch.nn.Module`.
44 The wrapped function is stored as ``_fn`` and invoked in
45 :meth:`forward`; the module has no trainable parameters.
46
47 Args:
48 fn (callable): The function to wrap.
49
50 Example:
51 >>> wrapped = swap_wrapper(lambda x: x * 2)
52 """
53
54 def __init__(self, fn: Callable):
55 super().__init__()
56 self._fn = fn
57
58 def forward(self, *args, **kwargs):
59 return self._fn(*args, **kwargs)
60
61
62def base_check_fn(tensor) -> bool:
63 """
64 Basic check to determine if a tensor is eligible for offloading.
65 - Skip Parameters and their views.
66 - Skip empty storage tensors.
67 """
68 if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): # pylint: disable=W0212
69 return False
70 if tensor.storage().size() <= 0:
71 return False
72 return True
73
74
75class AsyncSaveOnCpu(torch.autograd.graph.saved_tensors_hooks):
76 """
77 Context manager to offload tensors to CPU during forward pass.
78 """
79 def __init__(self, policy_fn=None) -> None:
80 self.add_to_storage = False
81 self.storage = Storage()
82 self.count_idx = 0
83 self.pack_count = 0
84 self.unpack_count = 0
85 self.policy_fn = policy_fn
86
87 def pack_to_cpu(tensor: torch.Tensor):
88 # skip ineligible tensors
89 if not base_check_fn(tensor):
90 return tensor
91
92 if (policy_fn is not None) and (policy_fn(tensor)==CheckpointPolicy.MUST_SAVE):
93 return tensor
94
95 group_name = SwapManager().get_current_group_name()
96 if not self.add_to_storage:
97 SwapManager().add_storage(group_name, self.storage)
98 self.add_to_storage = True
99 funcname = f"{group_name}::{tensor.shape}"
100 self.storage.swap_storage[self.count_idx].append(SwapTensor(tensor, funcname))
101 idx = self.count_idx
102 self.count_idx += 1
103 self.pack_count += 1
104 return idx
105
106 def unpack_from_cpu(idx) -> torch.Tensor:
107 if isinstance(idx, torch.Tensor):
108 return idx
109
110 swap_tensor = self.storage.swap_storage[idx].pop(0)
111 tensor = swap_tensor.get_val()
112 self.unpack_count += 1
113 if self.unpack_count == self.pack_count:
114 self.storage = None
115 return tensor
116
117 super().__init__(pack_to_cpu, unpack_from_cpu)
118
119
120class ActivationWrapper(torch.nn.Module, ABC):
121 """
122 Base class for Activation Swap.
123
124 Not meant to be instantiated directly.
125 """
126
127 def __init__(self, module: Union[nn.Module, Callable]):
128 if callable(module) and not isinstance(module, nn.Module):
129 module = FuncModule(module)
130 super().__init__()
131 self._swap_wrapped_module = module
132 # state_dict post hook to remove prefix to allow loading into a
133 # non-swap wrapped module.
134 self._register_state_dict_hook(self._post_state_dict_hook)
135 # load_state_dict pre-hook to allow loading back into
136 # swap-wrapped module.
137 self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
138
139 @abstractmethod
140 def forward(self, *args, **kwargs):
141 raise ValueError("Subclasses should implement forward().")
142
143 def __getattr__(self, name: str) -> Any:
144 """Forward missing attributes to wrapped module."""
145 try:
146 return super().__getattr__(name) # defer to nn.Module's logic
147 except AttributeError:
148 return getattr(self._swap_wrapped_module, name)
149
150 def __getitem__(self, key: int) -> Any:
151 """Forward indexing calls in case the module is a nn.Sequential."""
152 return self._swap_wrapped_module.__getitem__(key) # type: ignore[operator]
153
154 def named_parameters(
155 self,
156 *args,
157 **kwargs,
158 ) -> Iterator[tuple[str, torch.nn.Parameter]]:
159 """
160 Override :meth:`named_parameters()` to intercept parameter names.
161
162 remove all occurrences of ``_SWAP_PREFIX``.
163 """
164 for param_name, param in super().named_parameters(*args, **kwargs):
165 yield param_name.replace(_SWAP_PREFIX, ""), param
166
167 @staticmethod
168 def _post_state_dict_hook(
169 module: nn.Module, # pylint: disable=W0613
170 state_dict: dict[str, Any],
171 prefix: str,
172 *args: Any, # pylint: disable=W0613
173 ) -> dict[str, Any]:
174 """
175 _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
176
177 For ``swap_wrapper``, it will strip swap-wrapped module prefix,
178 so that this module can be loaded into non-swapped modules.
179 It would still be able to be loaded into swap-wrapped modules as this class,
180 adds the prefix back before loading the state_dict.
181 """
182 _replace_by_prefix(state_dict, f"{prefix}{_SWAP_PREFIX}", prefix)
183 return state_dict
184
185 @staticmethod
186 def _pre_load_state_dict_hook(
187 module: nn.Module,
188 state_dict: dict[str, Any],
189 prefix: str,
190 *args: Any,
191 ) -> None:
192 """
193 ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called.
194
195 For ``swap_wrapper``, it will add back the module
196 prefix so that non-swapped modules can be loaded into
197 swap_wrapper modules properly.
198 """
199 _replace_by_prefix(state_dict, prefix, prefix + f"{_SWAP_PREFIX}")
200
201
202class SwapWrapper(ActivationWrapper):
203 """
204 Customize an nn.Module wrapper class to add an AsyncSaveOnCpu context manager for the target model.
205 """
206 def __init__(self, mod: Union[nn.Module, Callable], policy_fn: Optional[Callable] = None):
207 super().__init__(mod)
208 self.policy_fn = policy_fn
209
210 def forward(self, *args, **kwargs):
211 with AsyncSaveOnCpu(policy_fn=self.policy_fn):
212 return self._swap_wrapped_module(*args, **kwargs)
213
214
215def swap_wrapper(module: Union[nn.Module, Callable], policy_fn: Optional[Callable] = None) -> SwapWrapper:
216 return SwapWrapper(module, policy_fn)
217
218
219def swap_tensor_wrapper(target, tag: Optional[str] = None):
220 """Register selected tensors into the current swap group.
221
222 This helper is intended to be used inside a forward path that already
223 participates in the existing swap scheduling managed by ``SwapManager``.
224 It preserves the input structure and returns the original tensors.
225 """
226 group_name = SwapManager().get_current_group_name()
227 if not group_name:
228 warnings.warn(
229 f"Tensor {tag} cannot be swapped, for its group is unregistered."
230 )
231 return target
232 if SwapManager().is_last_group(group_name):
233 return target
234
235 storage = Storage()
236 count_idx = 0
237
238 def _register_tensor(tensor):
239 nonlocal count_idx
240 if not base_check_fn(tensor):
241 return tensor
242
243 tensor_tag = tag or f"{group_name}_swap_tensor"
244 funcname = f"{tensor_tag}::{tuple(tensor.shape)}"
245 storage.swap_storage[count_idx].append(SwapTensor(tensor, funcname))
246 count_idx += 1
247 return tensor
248
249 wrapped = torch.utils._pytree.tree_map( # pylint: disable=protected-access
250 lambda x: _register_tensor(x) if isinstance(x, torch.Tensor) else x,
251 target,
252 )
253 if count_idx > 0:
254 SwapManager().add_storage(group_name, storage)
255 return wrapped