Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / activation_checkpoint / swap.py: 16%
382 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Swap tensor and swap manager implementation for activation checkpointing"""
16# pylint: disable=W0212
18import functools
19import threading
20import warnings
22from collections import defaultdict
23from typing import Any, Dict, List, Optional
25from hyper_parallel.platform import get_platform
27platform = get_platform()
30class SwapTensor:
31 """A tensor that can be swapped between device and host memory asynchronously."""
32 STATE_DEVICE = "device"
33 STATE_HOST = "host"
34 STATE_D2H = "d2h"
35 STATE_H2D = "h2d"
36 STATE_NON_TENSOR = "non_tensor"
38 def __init__(self, val: Any, funcname: Any) -> None:
39 self.val = val
40 self.ver = val._version
41 self.funcname = funcname
42 self._keep_on_device = False
43 self._duplicate_swap = False
44 if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu':
45 self._state = self.STATE_DEVICE
46 self.is_slice_tensor = val.untyped_storage().size() != val.numel() * platform.get_element_size(val)
47 self.val_cpu = platform.empty_like(
48 val, device="cpu", pin_memory=True
49 )
50 self.storage_size = val.untyped_storage().size()
51 else:
52 self._state = self.STATE_NON_TENSOR
53 self.val_cpu = None
55 def dedup_key(self):
56 """Return a stable identity key for duplicate-swap detection."""
57 if self._state == self.STATE_NON_TENSOR:
58 return None
59 return (
60 str(self.val.device),
61 self.val.untyped_storage().data_ptr(),
62 self.val.storage_offset(),
63 self.val.untyped_storage().size(),
64 tuple(self.val.stride()),
65 )
67 def mark_duplicate_swap(self) -> None:
68 """Mark this wrapper as a duplicate registration in the same swap group."""
69 self._duplicate_swap = True
71 def protect_if_aliases(self, output_tensors: List[Any]) -> None:
72 """Keep tensors that alias the wrapped module output on device."""
73 if self._state == self.STATE_NON_TENSOR:
74 return
75 self_storage_ptr = self.val.untyped_storage().data_ptr()
76 for out in output_tensors:
77 if not isinstance(out, platform.Tensor):
78 continue
79 if str(out.device).lower() == "cpu":
80 continue
81 if out.untyped_storage().data_ptr() == self_storage_ptr:
82 self._keep_on_device = True
83 return
85 def get_val(self) -> Any:
86 if self._state == self.STATE_NON_TENSOR:
87 return self.val
88 if self._state != self.STATE_DEVICE:
89 raise RuntimeError(
90 f"Cannot call get_val(): tensor is in '{self._state}' state. "
91 f"Must be in 'device' state."
92 )
93 return self.val
95 def resize_device_storage(self):
96 """Reallocate device memory on compute stream."""
97 if self._state == self.STATE_NON_TENSOR or self._duplicate_swap:
98 return
100 if self._state != self.STATE_HOST:
101 return
102 storage = self.val.untyped_storage()
103 if storage.size() == self.storage_size:
104 return
105 storage.resize_(self.storage_size)
107 def async_load(self):
108 """async load tensor from host to device"""
109 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap:
110 return
112 if self._state != self.STATE_HOST:
113 warnings.warn(
114 f"[SwapTensor.async_load] Invalid state: current={self._state}, "
115 f"expected 'host'. Operation skipped."
116 )
117 return
119 if self.val_cpu is None:
120 raise ValueError("val_cpu must not be None during async_load")
121 if self.is_slice_tensor:
122 self.val.data.copy_(self.val_cpu, non_blocking=True)
123 else:
124 self.val.untyped_storage().copy_(self.val_cpu.untyped_storage(), non_blocking=True)
125 self._state = self.STATE_H2D
127 def wait_load(self):
128 """change state to device after async load is done"""
129 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap:
130 return
132 if self._state == self.STATE_DEVICE:
133 return # already loaded
134 if self._state != self.STATE_H2D:
135 warnings.warn(
136 f"[SwapTensor.wait_load] Called in invalid state: {self._state}. "
137 f"Expected 'h2d'. Skipped."
138 )
139 return
140 self._state = self.STATE_DEVICE
142 def async_offload(self):
143 """async offload tensor from device to host"""
144 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap:
145 return
147 if self._state != self.STATE_DEVICE:
148 warnings.warn(
149 f"[SwapTensor.async_offload] Invalid state: current={self._state}, "
150 f"expected 'device'. Operation skipped."
151 )
152 return
154 if self.storage_size != self.val.untyped_storage().size():
155 raise RuntimeError(
156 f"There is a tensor from {self.funcname} cannot be SWAPPED! Its storage has been resized "
157 f"presize:{self.storage_size}, current size:{self.val.untyped_storage().size()}"
158 )
159 if self.ver != self.val._version:
160 raise RuntimeError(
161 f"There is a tensor from {self.funcname} cannot be SWAPPED! In-place modification happened "
162 f"preversion:{self.ver}, current version:{self.val._version}"
163 )
165 if self.is_slice_tensor:
166 self.val_cpu.copy_(self.val, non_blocking=True)
167 else:
168 self.val_cpu.untyped_storage().copy_(self.val.untyped_storage(), non_blocking=True)
169 self._state = self.STATE_D2H
171 def wait_offload(self):
172 """wait offload to host and free device memory"""
173 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap:
174 return
176 if self._state == self.STATE_HOST:
177 return
178 if self._state != self.STATE_D2H:
179 warnings.warn(
180 f"[SwapTensor.wait_offload] Called in invalid state: {self._state}. "
181 f"Expected 'd2h'. Skipped."
182 )
183 return
184 storage = self.val.untyped_storage()
185 if storage.size() != 0:
186 storage.resize_(0)
187 self._state = self.STATE_HOST
189 @property
190 def state(self) -> str:
191 return self._state
193 def __repr__(self):
194 if self._state == self.STATE_NON_TENSOR:
195 return f"<SwapTensor state=non_tensor, val_type={type(self.val).__name__}>"
196 return (
197 f"<SwapTensor state={self._state}, duplicate={self._duplicate_swap}, "
198 f"device_val={'exists' if self.val is not None else 'None'}>"
199 )
202class Storage:
203 """Manage a collection of tensors for swapping operations."""
205 def __init__(self):
206 self.save_storage: Dict[Any, List[Any]] = defaultdict(list)
207 self.swap_storage: Dict[Any, List[Any]] = defaultdict(list)
209 def iter_swap_tensors(self):
210 """Iterate all SwapTensor objects stored in this storage."""
211 collected = []
213 def _collect(x):
214 if isinstance(x, SwapTensor):
215 collected.append(x)
216 return x
218 for storage_list in self.swap_storage.values():
219 for item in storage_list:
220 platform.tree_map(_collect, item)
221 return collected
223 def mark_duplicate_swaps(self, seen_keys) -> int:
224 """Mark tensors already registered in the same swap group as duplicates."""
225 duplicate_count = 0
226 for swap_tensor in self.iter_swap_tensors():
227 dedup_key = swap_tensor.dedup_key()
228 if dedup_key is None:
229 continue
230 if dedup_key in seen_keys:
231 swap_tensor.mark_duplicate_swap()
232 duplicate_count += 1
233 continue
234 seen_keys.add(dedup_key)
235 return duplicate_count
237 def protect_output_tensors(self, outputs: Any):
238 """Avoid offloading tensors that alias the wrapped module outputs."""
239 output_tensors = []
241 def _collect_outputs(x):
242 if isinstance(x, platform.Tensor):
243 output_tensors.append(x)
244 return x
246 platform.tree_map(_collect_outputs, outputs)
247 if not output_tensors:
248 return
250 def _protect_tensor(x):
251 if isinstance(x, SwapTensor):
252 x.protect_if_aliases(output_tensors)
253 return x
255 for storage_list in self.swap_storage.values():
256 for item in storage_list:
257 platform.tree_map(_protect_tensor, item)
259 def launch_load(self):
260 """launch async load for all tensors in swap storage"""
261 def _async_load(x):
262 if isinstance(x, SwapTensor):
263 x.async_load()
264 return x
266 for storage_list in self.swap_storage.values():
267 for item in storage_list:
268 platform.tree_map(_async_load, item)
270 def resize_device_storage(self):
271 """Resize device storage for all swap tensors (runs on compute stream)."""
272 def _resize(x):
273 if isinstance(x, SwapTensor):
274 x.resize_device_storage()
275 return x
276 for storage_list in self.swap_storage.values():
277 for item in storage_list:
278 platform.tree_map(_resize, item)
280 def wait_load(self):
281 """wait load for all tensors in swap storage"""
282 def _wait_load(x):
283 if isinstance(x, SwapTensor):
284 x.wait_load()
285 return x
287 for storage_list in self.swap_storage.values():
288 for item in storage_list:
289 platform.tree_map(_wait_load, item)
291 def wait_offload(self):
292 """wait offload for all tensors in swap storage"""
293 def _wait_offload(x):
294 if isinstance(x, SwapTensor):
295 x.wait_offload()
296 return x
298 for storage_list in self.swap_storage.values():
299 for item in storage_list:
300 platform.tree_map(_wait_offload, item)
302 def launch_offload(self):
303 """launch async offload for all tensors in swap storage"""
304 def _async_offload(x):
305 if isinstance(x, SwapTensor):
306 x.async_offload()
307 return x
309 for storage_list in self.swap_storage.values():
310 for item in storage_list:
311 platform.tree_map(_async_offload, item)
314class SwapGroup:
315 """Manager for a group of storages to coordinate swap operations."""
317 def __init__(self, group_name: str):
318 self.group_name = group_name
319 self.is_last_group = False
320 self._live_storages = []
321 self._load_event = None
322 self._offload_event = None
324 def add(self, storage):
325 """Add a storage to the swap group."""
326 seen_keys = set()
327 for existing_storage in self._live_storages:
328 for swap_tensor in existing_storage.iter_swap_tensors():
329 dedup_key = swap_tensor.dedup_key()
330 if dedup_key is not None:
331 seen_keys.add(dedup_key)
332 duplicate_count = storage.mark_duplicate_swaps(seen_keys)
333 if duplicate_count > 0:
334 warnings.warn(
335 f"SwapGroup '{self.group_name}' skipped {duplicate_count} duplicate tensor swap registration(s)."
336 )
337 self._live_storages.append(storage)
339 def protect_output_tensors(self, outputs: Any):
340 """Protect current module outputs from premature offload."""
341 for storage in self._live_storages:
342 storage.protect_output_tensors(outputs)
344 def launch_offload(self, copy_stream):
345 """Launch async offload for all storages in the group."""
346 compute_event = platform.new_event()
347 compute_event.record(platform.get_current_stream())
348 self._offload_event = platform.new_event()
349 stream_context = platform.get_stream_context()
350 with platform.no_grad(), stream_context(copy_stream):
351 compute_event.wait(copy_stream)
352 for storage in self._live_storages:
353 storage.launch_offload()
354 self._offload_event.record(copy_stream)
356 def wait_offload(self):
357 """Wait for offload to complete for all storages in the group."""
358 if self._offload_event is None:
359 raise RuntimeError(
360 f"SwapGroup '{self.group_name}' wait_offload() called before launch_offload()."
361 )
362 compute_stream = platform.get_current_stream()
363 stream_context = platform.get_stream_context()
364 with platform.no_grad(), stream_context(compute_stream):
365 self._offload_event.wait(compute_stream)
366 self._offload_event = None
367 for storage in self._live_storages:
368 storage.wait_offload()
370 def launch_load(self, copy_stream):
371 """Prepare storage and launch async load for all storages in the group."""
372 with platform.no_grad():
373 for storage in self._live_storages:
374 storage.resize_device_storage()
376 compute_event = platform.new_event()
377 compute_event.record(platform.get_current_stream())
378 self._load_event = platform.new_event()
379 stream_context = platform.get_stream_context()
380 with platform.no_grad(), stream_context(copy_stream):
381 compute_event.wait(copy_stream)
382 for storage in self._live_storages:
383 storage.launch_load() # Only copy, no resize
384 self._load_event.record(copy_stream)
386 def wait_load(self):
387 """Wait for load to complete for all storages in the group."""
388 if self._load_event is None:
389 raise RuntimeError(
390 f"SwapGroup '{self.group_name}' wait_load() called before launch_load()."
391 )
392 try:
393 compute_stream = platform.get_current_stream()
394 stream_context = platform.get_stream_context()
395 with platform.no_grad(), stream_context(compute_stream):
396 self._load_event.wait(compute_stream)
397 self._load_event = None
398 for storage in self._live_storages:
399 storage.wait_load()
400 finally:
401 self._live_storages.clear()
404class SwapManager:
405 """Singleton manager for swap groups and their operations."""
406 _instance: Optional["SwapManager"] = None
407 _lock = threading.Lock()
409 def __init__(self):
410 if hasattr(self, '_groups'):
411 return
412 self._groups = {}
413 self._current_group_name = ""
414 self._counter_lock = threading.Lock()
415 self._layer_count = 0
416 self._copy_stream = None
418 def __new__(cls):
419 if cls._instance is None:
420 with cls._lock:
421 if cls._instance is None:
422 cls._instance = super().__new__(cls)
423 return cls._instance
425 def add_storage(self, group_name: str, storage: Storage) -> None:
426 """Add a storage to a specified swap group."""
427 if group_name not in self._groups:
428 self._groups[group_name] = SwapGroup(group_name)
429 self._groups[group_name].add(storage)
431 def launch_offload(self, group_name: str, copy_stream=None):
432 """Launch async offload for a specified swap group."""
433 group = self._groups.get(group_name)
434 if group is None:
435 raise RuntimeError(f"Group {group_name} does not exist.")
436 if copy_stream is None:
437 copy_stream = self._get_copy_stream()
438 group.launch_offload(copy_stream)
440 def protect_output_tensors(self, group_name: str, outputs: Any):
441 """Keep tensors that alias the module output on device."""
442 group = self._groups.get(group_name)
443 if group is None:
444 raise RuntimeError(f"Group {group_name} does not exist.")
445 group.protect_output_tensors(outputs)
447 def wait_offload(self, group_name: str):
448 """Wait for offload to complete for a specified swap group."""
449 group = self._groups.get(group_name)
450 if group is None:
451 raise RuntimeError(f"Group {group_name} does not exist.")
452 group.wait_offload()
454 def launch_load(self, group_name: str, copy_stream=None):
455 """Launch async load for a specified swap group."""
456 group = self._groups.get(group_name)
457 if group is None:
458 raise RuntimeError(f"Group {group_name} does not exist.")
459 if copy_stream is None:
460 copy_stream = self._get_copy_stream()
461 group.launch_load(copy_stream)
463 def wait_load(self, group_name: str):
464 """Wait for load to complete for a specified swap group."""
465 group = self._groups.get(group_name)
466 if group is None:
467 raise RuntimeError(f"Group {group_name} does not exist.")
468 group.wait_load()
470 def release_group_storage(self, group_name: str) -> None:
471 """Release live storage references held by the swap group.
473 Called at the end of backward to free Storage objects that were never
474 released via wait_load (e.g. the last layer, which has no next layer
475 and therefore never goes through the offload-load cycle).
476 """
477 group = self._groups.get(group_name)
478 if group is not None:
479 group._live_storages.clear()
481 def get_current_group_name(self):
482 return self._current_group_name
484 def set_current_group_name(self, group_name):
485 self._current_group_name = group_name
487 def is_last_group(self, group_name: Optional[str] = None) -> bool:
488 """Return whether the specified swap group is the terminal group in the chain."""
489 group_name = self._current_group_name if group_name is None else group_name
490 group = self._groups.get(group_name)
491 if group is None:
492 return False
493 return group.is_last_group
495 def set_forward_prefetch_layer(self, first_layer, second_layer):
496 """
497 Configure prefetching and offloading order between two consecutive layers.
499 Usage:
500 for i in range(len(model.layers) - 1):
501 set_forward_prefetch_layer(model.layers[i], model.layers[i + 1])
503 Ensures idempotency: safe to call multiple times on the same layer pair.
504 """
506 def _ensure_group_name(module):
507 """Assign a unique swap group name to the module if not already assigned."""
508 if not hasattr(module, "_swap_group_name"):
509 name = f"swap_group_{self._layer_count}"
510 self._layer_count += 1
511 module._swap_group_name = name
512 module._swap_group_order = {"prev": None, "next": None}
513 return module._swap_group_name
514 first_name = _ensure_group_name(first_layer)
515 second_name = _ensure_group_name(second_layer)
517 if first_name not in self._groups:
518 self._groups[first_name] = SwapGroup(first_name)
519 if second_name not in self._groups:
520 self._groups[second_name] = SwapGroup(second_name)
522 if first_layer._swap_group_order["next"] is None:
523 first_layer._swap_group_order["next"] = second_name
524 if second_layer._swap_group_order["prev"] is None:
525 second_layer._swap_group_order["prev"] = first_name
527 self._groups[first_name].is_last_group = first_layer._swap_group_order["next"] is None
528 self._groups[second_name].is_last_group = second_layer._swap_group_order["next"] is None
530 def _forward_pre_hook(group_name, module, _): # pylint: disable=W0613
531 if getattr(module, "_swap_state", None) == "pre_backward":
532 return
533 SwapManager().set_current_group_name(group_name)
535 def _forward_hook(group_name, module, args, output): # pylint: disable=W0613
536 """
537 Forward post-hook executed immediately after forward computation
538 of the current layer finishes.
540 Execution timeline (example with 3 layers, forward order: L0 → L1 → L2):
542 Time →
543 Forward Compute Stream:
544 | Fwd L0 | post(L0) | Fwd L1 | post(L1) | Fwd L2 |
546 Copy Stream (offload):
547 | Offload L0 | - | Offload L1 |
548 ↑ ↑
549 offload at post(L0) offload at post(L1)
551 Swap rules:
552 1. After forward computation of the current layer completes:
553 - If a next layer exists, asynchronously offload the activations
554 of the current layer (launch_offload).
556 Example:
557 - At post-forward of L0, offload activations of L0.
558 - At post-forward of L1, offload activations of L1.
560 2. To limit device memory peak:
561 - If a previous layer exists, wait until its offload operation
562 has completed (wait_offload).
564 Notes:
565 - Offload operations are issued on the copy stream to overlap data transfer
566 with forward computation of subsequent layers.
567 - If the module is already in 'pre_backward' state, this hook is skipped
568 to avoid triggering offload during backward phase.
569 """
570 if getattr(module, "_swap_state", None) == "pre_backward":
571 return
572 next_name = module._swap_group_order.get('next', None)
573 if next_name:
574 SwapManager().protect_output_tensors(group_name, output)
575 SwapManager().launch_offload(group_name)
576 prev_name = module._swap_group_order.get('prev', None)
577 if prev_name:
578 SwapManager().wait_offload(prev_name)
580 def _backward_pre_hook(group_name, module, grad_input): # pylint: disable=W0613
581 """
582 Pre-backward hook executed immediately before backward computation
583 of the current layer starts.
585 Execution timeline (example with 3 layers, backward order: L2 → L1 → L0):
587 Time →
588 Backward Compute Stream:
589 | pre(L2) | Grad L2 | pre(L1) | Grad L1 | pre(L0) | Grad L0 |
591 Copy Stream (load):
592 | Load L1 | - | Load L0 |
593 ↑ ↑
594 prefetch at pre(L2) prefetch at pre(L1)
596 Swap rules:
597 1. At the beginning of backward for the current layer:
598 - If a previous layer exists in backward order, asynchronously
599 prefetch its activations (launch_load).
601 Example:
602 - At pre-backward of L2, prefetch activations of L1.
603 - At pre-backward of L1, prefetch activations of L0.
605 2. Before starting backward computation of the current layer:
606 - Ensure that the activations of the current layer have already
607 been loaded back to device memory (wait_load).
609 Notes:
610 - Load operations are issued on the copy stream to overlap data transfer
611 with backward computation of the current layer.
612 - The swap state is marked as 'pre_backward' to prevent forward hooks
613 from issuing offload operations during backward phase.
614 """
615 module._swap_state = "pre_backward"
616 prev_name = module._swap_group_order.get('prev', None)
617 if prev_name:
618 SwapManager().launch_load(prev_name)
620 next_name = module._swap_group_order.get('next', None)
621 if next_name:
622 SwapManager().wait_load(group_name)
624 def _backward_hook(group_name, module, grad_input, grad_output): # pylint: disable=W0613
625 module._swap_state = "backward"
626 SwapManager().release_group_storage(group_name)
628 def _register_hooks_once(module, group_name):
629 hooks = [
630 ("_swap_forward_pre_hook_handle",
631 lambda h: platform.register_forward_pre_hook(module, h, prepend=True),
632 functools.partial(_forward_pre_hook, group_name)),
634 ("_swap_forward_hook_handle",
635 module.register_forward_hook,
636 functools.partial(_forward_hook, group_name)),
638 ("_swap_backward_pre_hook_handle",
639 lambda h: platform.register_full_backward_pre_hook(module, h, prepend=True),
640 functools.partial(_backward_pre_hook, group_name)),
642 ("_swap_backward_hook_handle",
643 lambda h: platform.register_full_backward_hook(module, h),
644 functools.partial(_backward_hook, group_name)),
645 ]
647 for attr_name, register_func, hook in hooks:
648 if not hasattr(module, attr_name):
649 handle = register_func(hook)
650 setattr(module, attr_name, handle)
651 # Register for both layers
652 _register_hooks_once(first_layer, first_name)
653 _register_hooks_once(second_layer, second_name)
655 def _get_copy_stream(self):
656 """Return a singleton copy stream, created on first access."""
657 if self._copy_stream is None:
658 self._copy_stream = platform.new_stream()
659 return self._copy_stream