Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / activation_checkpoint / sac.py: 0%
106 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/utils/checkpoint.py
16# enhanced with selective checkpoint support swap
17# ============================================================================
18"""enhanced with selective checkpoint support swap"""
19# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705
20from typing import Any, Optional, Union
22import torch
23import torch.fx.traceback as fx_traceback
24from torch._functorch._aot_autograd.functional_utils import is_fun
25from torch.utils._pytree import tree_map
26from torch.utils._python_dispatch import TorchDispatchMode
27from hyper_parallel.core.activation_checkpoint import CheckpointPolicy # patch code
28from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage # patch code
31def _is_compiling(func, args, kwargs):
32 # Check if we are under AOTAutograd tracing
33 # There should probably be a better way to do this...
34 # NOTE: unify _is_compiling across all compile stacks
35 for arg in args:
36 if isinstance(arg, torch.Tensor) and is_fun(arg):
37 return True
38 return False
41class _VersionWrapper:
42 # Check that cached tensors are not mutated.
43 def __init__(self, val):
44 self.val: Union[torch.Tensor, Any] = val
45 self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
47 def get_val(self, allow_cache_entry_mutation):
48 if self.version is not None and not allow_cache_entry_mutation:
49 if self.val._version != self.version:
50 # Can we give user a stack trace of where the mutation happened?
51 raise RuntimeError(
52 "Tensor cached during selective activation checkpoint has been mutated"
53 )
54 return self.val
57def _maybe_detach(x, any_ret_has_alias_info):
58 # We detach for two separate reasons:
59 # - For view ops, we need to ensure that when the tensor is returned from
60 # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr
61 # - Avoid reference cycles
62 # For case 1, it is not enough to check whether x has differentiable dtype
63 # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g.
64 # when the tensor is a view.
65 if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info):
66 with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False):
67 # Ensure that view performed beneath autograd properly propagates
68 # version counter. TODO: Use reentrant_dispatch instead of
69 # manually manipulating dispatch keys. Using reentrant_dispatch
70 # would respect inference_mode, though that is not relevant for
71 # this case.
72 x = x.detach()
73 return x
76class SelectiveCheckpointContext:
77 """
78 Context passed to policy function during selective checkpointing.
80 This class is used to pass relevant metadata to the policy function during
81 selective checkpointing. The metadata includes whether the current invocation
82 of the policy function is during recomputation or not.
84 Example:
85 >>> # xdoctest: +SKIP(stub)
86 >>>
87 >>> def policy_fn(ctx, op, *args, **kwargs):
88 >>> print(ctx.is_recompute)
89 >>>
90 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
91 >>>
92 >>> out = torch.utils.checkpoint.checkpoint(
93 >>> fn, x, y,
94 >>> use_reentrant=False,
95 >>> context_fn=context_fn,
96 >>> )
97 """
98 def __init__(self, *, is_recompute):
99 self.is_recompute = is_recompute
102def _policy_from_bool(b):
103 # For backward compatibility
104 return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE
107SAC_IGNORED_OPS = {
108 # AC inserts different number of detach during forward and recompute.
109 torch.ops.aten.detach.default,
110 # AC's determinism check invokes additional metadata ops during forward.
111 # With subclasses involved, these metadata ops become dispatchable, this
112 # can result in incorrectness if these ops are selected cached.
113 torch.ops.prim.device.default,
114} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
117class _CachingTorchDispatchMode(TorchDispatchMode):
118 # Used together with _CachedTorchDispatchMode to implement SAC.
119 def __init__(self, policy_fn, storage):
120 self.policy_fn = policy_fn
121 self.storage = storage
122 self.add_to_storage = False
124 def __torch_dispatch__(self, func, types, args=(), kwargs=None):
125 if func in SAC_IGNORED_OPS:
126 return func(*args, **kwargs)
128 kwargs = {} if kwargs is None else kwargs
129 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
130 func, *args, **kwargs)
131 if isinstance(policy, bool):
132 policy = _policy_from_bool(policy)
134 is_compiling = _is_compiling(func, args, kwargs)
136 if is_compiling:
137 # Overwrite each node's "recompute" tag to add in the user annotation.
138 fx_traceback.current_meta["recompute"] = policy
140 out = func(*args, **kwargs)
142 any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
144 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
145 storage = self.storage.save_storage[func] # patch code
146 storage.append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
147 elif policy == CheckpointPolicy.MUST_SWAP: # patch code
148 group_name = SwapManager().get_current_group_name()
149 if not self.add_to_storage:
150 SwapManager().add_storage(group_name, self.storage)
151 self.add_to_storage = True
152 storage = self.storage.swap_storage[func]
153 funcname = f"{group_name}::{func}"
154 storage.append(tree_map(lambda x: SwapTensor(_maybe_detach(x, any_ret_has_alias_info), funcname), out))
155 return out
158class _CachedTorchDispatchMode(TorchDispatchMode):
159 # Used together with _CachedTorchDispatchMode to implement SAC.
160 def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
161 self.policy_fn = policy_fn
162 self.storage = storage
163 self.allow_cache_entry_mutation = allow_cache_entry_mutation
165 def __torch_dispatch__(self, func, types, args=(), kwargs=None):
166 if func in SAC_IGNORED_OPS:
167 return func(*args, **kwargs)
169 kwargs = {} if kwargs is None else kwargs
170 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
171 func, *args, **kwargs)
172 if isinstance(policy, bool):
173 policy = _policy_from_bool(policy)
175 is_compiling = _is_compiling(func, args, kwargs)
177 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
178 storage = self.storage.save_storage.get(func) # patch code
179 if storage is None:
180 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
181 if len(storage) == 0:
182 raise RuntimeError(
183 "Trying to backward an extra time. You are only allowed to backward once "
184 "on any region computed under selective activation checkpoint."
185 )
186 out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
187 elif policy == CheckpointPolicy.MUST_SWAP: # patch code
188 storage = self.storage.swap_storage.get(func)
189 if storage is None:
190 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
191 if len(storage) == 0:
192 raise RuntimeError(
193 "Trying to backward an extra time. You are only allowed to backward once "
194 "on any region computed under selective activation checkpoint."
195 )
196 out = tree_map(lambda x: x.get_val(), storage.pop(0))
197 else:
198 out = func(*args, **kwargs)
199 return out
202def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
203 """
204 Helper to avoid recomputing certain ops during activation checkpointing.
206 Use this with `torch.utils.checkpoint.checkpoint` to control which
207 operations are recomputed during the backward pass.
209 Args:
210 policy_fn_or_list (Callable or List):
211 - If a policy function is provided, it should accept a
212 :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and
213 kwargs to the op, and return a :class:`CheckpointPolicy` enum value
214 indicating whether the execution of the op should be recomputed or not.
215 - If a list of operations is provided, it is equivalent to a policy
216 returning `CheckpointPolicy.MUST_SAVE` for the specified
217 operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other
218 operations.
219 allow_cache_entry_mutation (bool, optional): By default, an error is
220 raised if any tensors cached by selective activation checkpoint are
221 mutated in order to ensure correctness. If set to `True`, this check
222 is disabled.
223 Returns:
224 A tuple of two context managers.
226 Example:
227 >>> # xdoctest: +REQUIRES(LINUX)
228 >>> import functools
229 >>>
230 >>> x = torch.rand(10, 10, requires_grad=True)
231 >>> y = torch.rand(10, 10, requires_grad=True)
232 >>>
233 >>> ops_to_save = [
234 >>> torch.ops.aten.mm.default,
235 >>> ]
236 >>>
237 >>> def policy_fn(ctx, op, *args, **kwargs):
238 >>> if op in ops_to_save:
239 >>> return CheckpointPolicy.MUST_SAVE
240 >>> else:
241 >>> return CheckpointPolicy.PREFER_RECOMPUTE
242 >>>
243 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
244 >>>
245 >>> # or equivalently
246 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
247 >>>
248 >>> def fn(x, y):
249 >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
250 >>>
251 >>> out = torch.utils.checkpoint.checkpoint(
252 >>> fn, x, y,
253 >>> use_reentrant=False,
254 >>> context_fn=context_fn,
255 >>> )
256 """
257 # NB: If grad_mode is disabled, checkpoint would not run forward under
258 # context_fn anyway, so proceed as usual.
259 if isinstance(policy_fn_or_list, list):
260 for op in policy_fn_or_list:
261 if not isinstance(op, torch._ops.OpOverload):
262 _extra_msg = (
263 "Please update the OpOverloadPacket to a specific OpOverload."
264 "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."
265 ) if isinstance(op, torch._ops.OpOverloadPacket) else ""
266 raise ValueError(
267 f"Expected op in `op_list` to be an OpOverload but got: {op} "
268 f"of type {type(op)}. {_extra_msg}"
269 )
271 def policy_fn(ctx, op, *args, **kwargs):
272 if op in policy_fn_or_list:
273 return CheckpointPolicy.MUST_SAVE
274 else:
275 return CheckpointPolicy.PREFER_RECOMPUTE
276 elif callable(policy_fn_or_list):
277 policy_fn = policy_fn_or_list
278 else:
279 raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
281 storage = Storage() # patch code
282 return (
283 _CachingTorchDispatchMode(policy_fn, storage),
284 _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation),
285 )