Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / activation_checkpoint / sac.py: 0%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/utils/checkpoint.py 

16# enhanced with selective checkpoint support swap 

17# ============================================================================ 

18"""enhanced with selective checkpoint support swap""" 

19# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705 

20from typing import Any, Optional, Union 

21 

22import torch 

23import torch.fx.traceback as fx_traceback 

24from torch._functorch._aot_autograd.functional_utils import is_fun 

25from torch.utils._pytree import tree_map 

26from torch.utils._python_dispatch import TorchDispatchMode 

27from hyper_parallel.core.activation_checkpoint import CheckpointPolicy # patch code 

28from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage # patch code 

29 

30 

31def _is_compiling(func, args, kwargs): 

32 # Check if we are under AOTAutograd tracing 

33 # There should probably be a better way to do this... 

34 # NOTE: unify _is_compiling across all compile stacks 

35 for arg in args: 

36 if isinstance(arg, torch.Tensor) and is_fun(arg): 

37 return True 

38 return False 

39 

40 

41class _VersionWrapper: 

42 # Check that cached tensors are not mutated. 

43 def __init__(self, val): 

44 self.val: Union[torch.Tensor, Any] = val 

45 self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None 

46 

47 def get_val(self, allow_cache_entry_mutation): 

48 if self.version is not None and not allow_cache_entry_mutation: 

49 if self.val._version != self.version: 

50 # Can we give user a stack trace of where the mutation happened? 

51 raise RuntimeError( 

52 "Tensor cached during selective activation checkpoint has been mutated" 

53 ) 

54 return self.val 

55 

56 

57def _maybe_detach(x, any_ret_has_alias_info): 

58 # We detach for two separate reasons: 

59 # - For view ops, we need to ensure that when the tensor is returned from 

60 # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr 

61 # - Avoid reference cycles 

62 # For case 1, it is not enough to check whether x has differentiable dtype 

63 # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. 

64 # when the tensor is a view. 

65 if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): 

66 with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): 

67 # Ensure that view performed beneath autograd properly propagates 

68 # version counter. TODO: Use reentrant_dispatch instead of 

69 # manually manipulating dispatch keys. Using reentrant_dispatch 

70 # would respect inference_mode, though that is not relevant for 

71 # this case. 

72 x = x.detach() 

73 return x 

74 

75 

76class SelectiveCheckpointContext: 

77 """ 

78 Context passed to policy function during selective checkpointing. 

79 

80 This class is used to pass relevant metadata to the policy function during 

81 selective checkpointing. The metadata includes whether the current invocation 

82 of the policy function is during recomputation or not. 

83 

84 Example: 

85 >>> # xdoctest: +SKIP(stub) 

86 >>> 

87 >>> def policy_fn(ctx, op, *args, **kwargs): 

88 >>> print(ctx.is_recompute) 

89 >>> 

90 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 

91 >>> 

92 >>> out = torch.utils.checkpoint.checkpoint( 

93 >>> fn, x, y, 

94 >>> use_reentrant=False, 

95 >>> context_fn=context_fn, 

96 >>> ) 

97 """ 

98 def __init__(self, *, is_recompute): 

99 self.is_recompute = is_recompute 

100 

101 

102def _policy_from_bool(b): 

103 # For backward compatibility 

104 return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE 

105 

106 

107SAC_IGNORED_OPS = { 

108 # AC inserts different number of detach during forward and recompute. 

109 torch.ops.aten.detach.default, 

110 # AC's determinism check invokes additional metadata ops during forward. 

111 # With subclasses involved, these metadata ops become dispatchable, this 

112 # can result in incorrectness if these ops are selected cached. 

113 torch.ops.prim.device.default, 

114} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) 

115 

116 

117class _CachingTorchDispatchMode(TorchDispatchMode): 

118 # Used together with _CachedTorchDispatchMode to implement SAC. 

119 def __init__(self, policy_fn, storage): 

120 self.policy_fn = policy_fn 

121 self.storage = storage 

122 self.add_to_storage = False 

123 

124 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 

125 if func in SAC_IGNORED_OPS: 

126 return func(*args, **kwargs) 

127 

128 kwargs = {} if kwargs is None else kwargs 

129 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), 

130 func, *args, **kwargs) 

131 if isinstance(policy, bool): 

132 policy = _policy_from_bool(policy) 

133 

134 is_compiling = _is_compiling(func, args, kwargs) 

135 

136 if is_compiling: 

137 # Overwrite each node's "recompute" tag to add in the user annotation. 

138 fx_traceback.current_meta["recompute"] = policy 

139 

140 out = func(*args, **kwargs) 

141 

142 any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) 

143 

144 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE): 

145 storage = self.storage.save_storage[func] # patch code 

146 storage.append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) 

147 elif policy == CheckpointPolicy.MUST_SWAP: # patch code 

148 group_name = SwapManager().get_current_group_name() 

149 if not self.add_to_storage: 

150 SwapManager().add_storage(group_name, self.storage) 

151 self.add_to_storage = True 

152 storage = self.storage.swap_storage[func] 

153 funcname = f"{group_name}::{func}" 

154 storage.append(tree_map(lambda x: SwapTensor(_maybe_detach(x, any_ret_has_alias_info), funcname), out)) 

155 return out 

156 

157 

158class _CachedTorchDispatchMode(TorchDispatchMode): 

159 # Used together with _CachedTorchDispatchMode to implement SAC. 

160 def __init__(self, policy_fn, storage, allow_cache_entry_mutation): 

161 self.policy_fn = policy_fn 

162 self.storage = storage 

163 self.allow_cache_entry_mutation = allow_cache_entry_mutation 

164 

165 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 

166 if func in SAC_IGNORED_OPS: 

167 return func(*args, **kwargs) 

168 

169 kwargs = {} if kwargs is None else kwargs 

170 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), 

171 func, *args, **kwargs) 

172 if isinstance(policy, bool): 

173 policy = _policy_from_bool(policy) 

174 

175 is_compiling = _is_compiling(func, args, kwargs) 

176 

177 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: 

178 storage = self.storage.save_storage.get(func) # patch code 

179 if storage is None: 

180 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 

181 if len(storage) == 0: 

182 raise RuntimeError( 

183 "Trying to backward an extra time. You are only allowed to backward once " 

184 "on any region computed under selective activation checkpoint." 

185 ) 

186 out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) 

187 elif policy == CheckpointPolicy.MUST_SWAP: # patch code 

188 storage = self.storage.swap_storage.get(func) 

189 if storage is None: 

190 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 

191 if len(storage) == 0: 

192 raise RuntimeError( 

193 "Trying to backward an extra time. You are only allowed to backward once " 

194 "on any region computed under selective activation checkpoint." 

195 ) 

196 out = tree_map(lambda x: x.get_val(), storage.pop(0)) 

197 else: 

198 out = func(*args, **kwargs) 

199 return out 

200 

201 

202def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

203 """ 

204 Helper to avoid recomputing certain ops during activation checkpointing. 

205 

206 Use this with `torch.utils.checkpoint.checkpoint` to control which 

207 operations are recomputed during the backward pass. 

208 

209 Args: 

210 policy_fn_or_list (Callable or List): 

211 - If a policy function is provided, it should accept a 

212 :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and 

213 kwargs to the op, and return a :class:`CheckpointPolicy` enum value 

214 indicating whether the execution of the op should be recomputed or not. 

215 - If a list of operations is provided, it is equivalent to a policy 

216 returning `CheckpointPolicy.MUST_SAVE` for the specified 

217 operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other 

218 operations. 

219 allow_cache_entry_mutation (bool, optional): By default, an error is 

220 raised if any tensors cached by selective activation checkpoint are 

221 mutated in order to ensure correctness. If set to `True`, this check 

222 is disabled. 

223 Returns: 

224 A tuple of two context managers. 

225 

226 Example: 

227 >>> # xdoctest: +REQUIRES(LINUX) 

228 >>> import functools 

229 >>> 

230 >>> x = torch.rand(10, 10, requires_grad=True) 

231 >>> y = torch.rand(10, 10, requires_grad=True) 

232 >>> 

233 >>> ops_to_save = [ 

234 >>> torch.ops.aten.mm.default, 

235 >>> ] 

236 >>> 

237 >>> def policy_fn(ctx, op, *args, **kwargs): 

238 >>> if op in ops_to_save: 

239 >>> return CheckpointPolicy.MUST_SAVE 

240 >>> else: 

241 >>> return CheckpointPolicy.PREFER_RECOMPUTE 

242 >>> 

243 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 

244 >>> 

245 >>> # or equivalently 

246 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) 

247 >>> 

248 >>> def fn(x, y): 

249 >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 

250 >>> 

251 >>> out = torch.utils.checkpoint.checkpoint( 

252 >>> fn, x, y, 

253 >>> use_reentrant=False, 

254 >>> context_fn=context_fn, 

255 >>> ) 

256 """ 

257 # NB: If grad_mode is disabled, checkpoint would not run forward under 

258 # context_fn anyway, so proceed as usual. 

259 if isinstance(policy_fn_or_list, list): 

260 for op in policy_fn_or_list: 

261 if not isinstance(op, torch._ops.OpOverload): 

262 _extra_msg = ( 

263 "Please update the OpOverloadPacket to a specific OpOverload." 

264 "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." 

265 ) if isinstance(op, torch._ops.OpOverloadPacket) else "" 

266 raise ValueError( 

267 f"Expected op in `op_list` to be an OpOverload but got: {op} " 

268 f"of type {type(op)}. {_extra_msg}" 

269 ) 

270 

271 def policy_fn(ctx, op, *args, **kwargs): 

272 if op in policy_fn_or_list: 

273 return CheckpointPolicy.MUST_SAVE 

274 else: 

275 return CheckpointPolicy.PREFER_RECOMPUTE 

276 elif callable(policy_fn_or_list): 

277 policy_fn = policy_fn_or_list 

278 else: 

279 raise TypeError("policy_fn_or_list must be either a function or a list of ops.") 

280 

281 storage = Storage() # patch code 

282 return ( 

283 _CachingTorchDispatchMode(policy_fn, storage), 

284 _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), 

285 )