Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / activation_checkpoint / activation_swap.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py 

16# enhanced with activation swap functionality. 

17# ============================================================================ 

18"""Activation Swap implementation for PyTorch.""" 

19# pylint: disable=W0212, W0613 

20 

21from abc import ABC, abstractmethod 

22from collections.abc import Iterator 

23from typing import Optional, Callable, Any, Union 

24import warnings 

25import torch 

26from torch import nn 

27from torch.distributed.utils import _replace_by_prefix 

28from hyper_parallel.core.activation_checkpoint.activation_checkpoint import CheckpointPolicy 

29from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage 

30 

31 

32_SWAP_WRAPPED_MODULE = "_swap_wrapped_module" 

33_SWAP_PREFIX = _SWAP_WRAPPED_MODULE + "." 

34 

35 

36class FuncModule(nn.Module): 

37 """ 

38 Thin :class:`~torch.nn.Module` adapter that wraps a plain callable. 

39 

40 Allows ordinary Python functions (or any callable without Module 

41 parameters) to be passed to :func:`swap_wrapper` and 

42 :func:`~hyper_parallel.platform.torch.platform.TorchPlatform.ckpt_wrapper` 

43 in place of an :class:`~torch.nn.Module`. 

44 The wrapped function is stored as ``_fn`` and invoked in 

45 :meth:`forward`; the module has no trainable parameters. 

46 

47 Args: 

48 fn (callable): The function to wrap. 

49 

50 Example: 

51 >>> wrapped = swap_wrapper(lambda x: x * 2) 

52 """ 

53 

54 def __init__(self, fn: Callable): 

55 super().__init__() 

56 self._fn = fn 

57 

58 def forward(self, *args, **kwargs): 

59 return self._fn(*args, **kwargs) 

60 

61 

62def base_check_fn(tensor) -> bool: 

63 """ 

64 Basic check to determine if a tensor is eligible for offloading. 

65 - Skip Parameters and their views. 

66 - Skip empty storage tensors. 

67 """ 

68 if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): # pylint: disable=W0212 

69 return False 

70 if tensor.storage().size() <= 0: 

71 return False 

72 return True 

73 

74 

75class AsyncSaveOnCpu(torch.autograd.graph.saved_tensors_hooks): 

76 """ 

77 Context manager to offload tensors to CPU during forward pass. 

78 """ 

79 def __init__(self, policy_fn=None) -> None: 

80 self.add_to_storage = False 

81 self.storage = Storage() 

82 self.count_idx = 0 

83 self.pack_count = 0 

84 self.unpack_count = 0 

85 self.policy_fn = policy_fn 

86 

87 def pack_to_cpu(tensor: torch.Tensor): 

88 # skip ineligible tensors 

89 if not base_check_fn(tensor): 

90 return tensor 

91 

92 if (policy_fn is not None) and (policy_fn(tensor)==CheckpointPolicy.MUST_SAVE): 

93 return tensor 

94 

95 group_name = SwapManager().get_current_group_name() 

96 if not self.add_to_storage: 

97 SwapManager().add_storage(group_name, self.storage) 

98 self.add_to_storage = True 

99 funcname = f"{group_name}::{tensor.shape}" 

100 self.storage.swap_storage[self.count_idx].append(SwapTensor(tensor, funcname)) 

101 idx = self.count_idx 

102 self.count_idx += 1 

103 self.pack_count += 1 

104 return idx 

105 

106 def unpack_from_cpu(idx) -> torch.Tensor: 

107 if isinstance(idx, torch.Tensor): 

108 return idx 

109 

110 swap_tensor = self.storage.swap_storage[idx].pop(0) 

111 tensor = swap_tensor.get_val() 

112 self.unpack_count += 1 

113 if self.unpack_count == self.pack_count: 

114 self.storage = None 

115 return tensor 

116 

117 super().__init__(pack_to_cpu, unpack_from_cpu) 

118 

119 

120class ActivationWrapper(torch.nn.Module, ABC): 

121 """ 

122 Base class for Activation Swap. 

123 

124 Not meant to be instantiated directly. 

125 """ 

126 

127 def __init__(self, module: Union[nn.Module, Callable]): 

128 if callable(module) and not isinstance(module, nn.Module): 

129 module = FuncModule(module) 

130 super().__init__() 

131 self._swap_wrapped_module = module 

132 # state_dict post hook to remove prefix to allow loading into a 

133 # non-swap wrapped module. 

134 self._register_state_dict_hook(self._post_state_dict_hook) 

135 # load_state_dict pre-hook to allow loading back into 

136 # swap-wrapped module. 

137 self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) 

138 

139 @abstractmethod 

140 def forward(self, *args, **kwargs): 

141 raise ValueError("Subclasses should implement forward().") 

142 

143 def __getattr__(self, name: str) -> Any: 

144 """Forward missing attributes to wrapped module.""" 

145 try: 

146 return super().__getattr__(name) # defer to nn.Module's logic 

147 except AttributeError: 

148 return getattr(self._swap_wrapped_module, name) 

149 

150 def __getitem__(self, key: int) -> Any: 

151 """Forward indexing calls in case the module is a nn.Sequential.""" 

152 return self._swap_wrapped_module.__getitem__(key) # type: ignore[operator] 

153 

154 def named_parameters( 

155 self, 

156 *args, 

157 **kwargs, 

158 ) -> Iterator[tuple[str, torch.nn.Parameter]]: 

159 """ 

160 Override :meth:`named_parameters()` to intercept parameter names. 

161 

162 remove all occurrences of ``_SWAP_PREFIX``. 

163 """ 

164 for param_name, param in super().named_parameters(*args, **kwargs): 

165 yield param_name.replace(_SWAP_PREFIX, ""), param 

166 

167 @staticmethod 

168 def _post_state_dict_hook( 

169 module: nn.Module, # pylint: disable=W0613 

170 state_dict: dict[str, Any], 

171 prefix: str, 

172 *args: Any, # pylint: disable=W0613 

173 ) -> dict[str, Any]: 

174 """ 

175 _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. 

176 

177 For ``swap_wrapper``, it will strip swap-wrapped module prefix, 

178 so that this module can be loaded into non-swapped modules. 

179 It would still be able to be loaded into swap-wrapped modules as this class, 

180 adds the prefix back before loading the state_dict. 

181 """ 

182 _replace_by_prefix(state_dict, f"{prefix}{_SWAP_PREFIX}", prefix) 

183 return state_dict 

184 

185 @staticmethod 

186 def _pre_load_state_dict_hook( 

187 module: nn.Module, 

188 state_dict: dict[str, Any], 

189 prefix: str, 

190 *args: Any, 

191 ) -> None: 

192 """ 

193 ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. 

194 

195 For ``swap_wrapper``, it will add back the module 

196 prefix so that non-swapped modules can be loaded into 

197 swap_wrapper modules properly. 

198 """ 

199 _replace_by_prefix(state_dict, prefix, prefix + f"{_SWAP_PREFIX}") 

200 

201 

202class SwapWrapper(ActivationWrapper): 

203 """ 

204 Customize an nn.Module wrapper class to add an AsyncSaveOnCpu context manager for the target model. 

205 """ 

206 def __init__(self, mod: Union[nn.Module, Callable], policy_fn: Optional[Callable] = None): 

207 super().__init__(mod) 

208 self.policy_fn = policy_fn 

209 

210 def forward(self, *args, **kwargs): 

211 with AsyncSaveOnCpu(policy_fn=self.policy_fn): 

212 return self._swap_wrapped_module(*args, **kwargs) 

213 

214 

215def swap_wrapper(module: Union[nn.Module, Callable], policy_fn: Optional[Callable] = None) -> SwapWrapper: 

216 return SwapWrapper(module, policy_fn) 

217 

218 

219def swap_tensor_wrapper(target, tag: Optional[str] = None): 

220 """Register selected tensors into the current swap group. 

221 

222 This helper is intended to be used inside a forward path that already 

223 participates in the existing swap scheduling managed by ``SwapManager``. 

224 It preserves the input structure and returns the original tensors. 

225 """ 

226 group_name = SwapManager().get_current_group_name() 

227 if not group_name: 

228 warnings.warn( 

229 f"Tensor {tag} cannot be swapped, for its group is unregistered." 

230 ) 

231 return target 

232 if SwapManager().is_last_group(group_name): 

233 return target 

234 

235 storage = Storage() 

236 count_idx = 0 

237 

238 def _register_tensor(tensor): 

239 nonlocal count_idx 

240 if not base_check_fn(tensor): 

241 return tensor 

242 

243 tensor_tag = tag or f"{group_name}_swap_tensor" 

244 funcname = f"{tensor_tag}::{tuple(tensor.shape)}" 

245 storage.swap_storage[count_idx].append(SwapTensor(tensor, funcname)) 

246 count_idx += 1 

247 return tensor 

248 

249 wrapped = torch.utils._pytree.tree_map( # pylint: disable=protected-access 

250 lambda x: _register_tensor(x) if isinstance(x, torch.Tensor) else x, 

251 target, 

252 ) 

253 if count_idx > 0: 

254 SwapManager().add_storage(group_name, storage) 

255 return wrapped