Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / activation_checkpoint / swap.py: 16%

382 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Swap tensor and swap manager implementation for activation checkpointing""" 

16# pylint: disable=W0212 

17 

18import functools 

19import threading 

20import warnings 

21 

22from collections import defaultdict 

23from typing import Any, Dict, List, Optional 

24 

25from hyper_parallel.platform import get_platform 

26 

27platform = get_platform() 

28 

29 

30class SwapTensor: 

31 """A tensor that can be swapped between device and host memory asynchronously.""" 

32 STATE_DEVICE = "device" 

33 STATE_HOST = "host" 

34 STATE_D2H = "d2h" 

35 STATE_H2D = "h2d" 

36 STATE_NON_TENSOR = "non_tensor" 

37 

38 def __init__(self, val: Any, funcname: Any) -> None: 

39 self.val = val 

40 self.ver = val._version 

41 self.funcname = funcname 

42 self._keep_on_device = False 

43 self._duplicate_swap = False 

44 if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu': 

45 self._state = self.STATE_DEVICE 

46 self.is_slice_tensor = val.untyped_storage().size() != val.numel() * platform.get_element_size(val) 

47 self.val_cpu = platform.empty_like( 

48 val, device="cpu", pin_memory=True 

49 ) 

50 self.storage_size = val.untyped_storage().size() 

51 else: 

52 self._state = self.STATE_NON_TENSOR 

53 self.val_cpu = None 

54 

55 def dedup_key(self): 

56 """Return a stable identity key for duplicate-swap detection.""" 

57 if self._state == self.STATE_NON_TENSOR: 

58 return None 

59 return ( 

60 str(self.val.device), 

61 self.val.untyped_storage().data_ptr(), 

62 self.val.storage_offset(), 

63 self.val.untyped_storage().size(), 

64 tuple(self.val.stride()), 

65 ) 

66 

67 def mark_duplicate_swap(self) -> None: 

68 """Mark this wrapper as a duplicate registration in the same swap group.""" 

69 self._duplicate_swap = True 

70 

71 def protect_if_aliases(self, output_tensors: List[Any]) -> None: 

72 """Keep tensors that alias the wrapped module output on device.""" 

73 if self._state == self.STATE_NON_TENSOR: 

74 return 

75 self_storage_ptr = self.val.untyped_storage().data_ptr() 

76 for out in output_tensors: 

77 if not isinstance(out, platform.Tensor): 

78 continue 

79 if str(out.device).lower() == "cpu": 

80 continue 

81 if out.untyped_storage().data_ptr() == self_storage_ptr: 

82 self._keep_on_device = True 

83 return 

84 

85 def get_val(self) -> Any: 

86 if self._state == self.STATE_NON_TENSOR: 

87 return self.val 

88 if self._state != self.STATE_DEVICE: 

89 raise RuntimeError( 

90 f"Cannot call get_val(): tensor is in '{self._state}' state. " 

91 f"Must be in 'device' state." 

92 ) 

93 return self.val 

94 

95 def resize_device_storage(self): 

96 """Reallocate device memory on compute stream.""" 

97 if self._state == self.STATE_NON_TENSOR or self._duplicate_swap: 

98 return 

99 

100 if self._state != self.STATE_HOST: 

101 return 

102 storage = self.val.untyped_storage() 

103 if storage.size() == self.storage_size: 

104 return 

105 storage.resize_(self.storage_size) 

106 

107 def async_load(self): 

108 """async load tensor from host to device""" 

109 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap: 

110 return 

111 

112 if self._state != self.STATE_HOST: 

113 warnings.warn( 

114 f"[SwapTensor.async_load] Invalid state: current={self._state}, " 

115 f"expected 'host'. Operation skipped." 

116 ) 

117 return 

118 

119 if self.val_cpu is None: 

120 raise ValueError("val_cpu must not be None during async_load") 

121 if self.is_slice_tensor: 

122 self.val.data.copy_(self.val_cpu, non_blocking=True) 

123 else: 

124 self.val.untyped_storage().copy_(self.val_cpu.untyped_storage(), non_blocking=True) 

125 self._state = self.STATE_H2D 

126 

127 def wait_load(self): 

128 """change state to device after async load is done""" 

129 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap: 

130 return 

131 

132 if self._state == self.STATE_DEVICE: 

133 return # already loaded 

134 if self._state != self.STATE_H2D: 

135 warnings.warn( 

136 f"[SwapTensor.wait_load] Called in invalid state: {self._state}. " 

137 f"Expected 'h2d'. Skipped." 

138 ) 

139 return 

140 self._state = self.STATE_DEVICE 

141 

142 def async_offload(self): 

143 """async offload tensor from device to host""" 

144 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap: 

145 return 

146 

147 if self._state != self.STATE_DEVICE: 

148 warnings.warn( 

149 f"[SwapTensor.async_offload] Invalid state: current={self._state}, " 

150 f"expected 'device'. Operation skipped." 

151 ) 

152 return 

153 

154 if self.storage_size != self.val.untyped_storage().size(): 

155 raise RuntimeError( 

156 f"There is a tensor from {self.funcname} cannot be SWAPPED! Its storage has been resized " 

157 f"presize:{self.storage_size}, current size:{self.val.untyped_storage().size()}" 

158 ) 

159 if self.ver != self.val._version: 

160 raise RuntimeError( 

161 f"There is a tensor from {self.funcname} cannot be SWAPPED! In-place modification happened " 

162 f"preversion:{self.ver}, current version:{self.val._version}" 

163 ) 

164 

165 if self.is_slice_tensor: 

166 self.val_cpu.copy_(self.val, non_blocking=True) 

167 else: 

168 self.val_cpu.untyped_storage().copy_(self.val.untyped_storage(), non_blocking=True) 

169 self._state = self.STATE_D2H 

170 

171 def wait_offload(self): 

172 """wait offload to host and free device memory""" 

173 if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap: 

174 return 

175 

176 if self._state == self.STATE_HOST: 

177 return 

178 if self._state != self.STATE_D2H: 

179 warnings.warn( 

180 f"[SwapTensor.wait_offload] Called in invalid state: {self._state}. " 

181 f"Expected 'd2h'. Skipped." 

182 ) 

183 return 

184 storage = self.val.untyped_storage() 

185 if storage.size() != 0: 

186 storage.resize_(0) 

187 self._state = self.STATE_HOST 

188 

189 @property 

190 def state(self) -> str: 

191 return self._state 

192 

193 def __repr__(self): 

194 if self._state == self.STATE_NON_TENSOR: 

195 return f"<SwapTensor state=non_tensor, val_type={type(self.val).__name__}>" 

196 return ( 

197 f"<SwapTensor state={self._state}, duplicate={self._duplicate_swap}, " 

198 f"device_val={'exists' if self.val is not None else 'None'}>" 

199 ) 

200 

201 

202class Storage: 

203 """Manage a collection of tensors for swapping operations.""" 

204 

205 def __init__(self): 

206 self.save_storage: Dict[Any, List[Any]] = defaultdict(list) 

207 self.swap_storage: Dict[Any, List[Any]] = defaultdict(list) 

208 

209 def iter_swap_tensors(self): 

210 """Iterate all SwapTensor objects stored in this storage.""" 

211 collected = [] 

212 

213 def _collect(x): 

214 if isinstance(x, SwapTensor): 

215 collected.append(x) 

216 return x 

217 

218 for storage_list in self.swap_storage.values(): 

219 for item in storage_list: 

220 platform.tree_map(_collect, item) 

221 return collected 

222 

223 def mark_duplicate_swaps(self, seen_keys) -> int: 

224 """Mark tensors already registered in the same swap group as duplicates.""" 

225 duplicate_count = 0 

226 for swap_tensor in self.iter_swap_tensors(): 

227 dedup_key = swap_tensor.dedup_key() 

228 if dedup_key is None: 

229 continue 

230 if dedup_key in seen_keys: 

231 swap_tensor.mark_duplicate_swap() 

232 duplicate_count += 1 

233 continue 

234 seen_keys.add(dedup_key) 

235 return duplicate_count 

236 

237 def protect_output_tensors(self, outputs: Any): 

238 """Avoid offloading tensors that alias the wrapped module outputs.""" 

239 output_tensors = [] 

240 

241 def _collect_outputs(x): 

242 if isinstance(x, platform.Tensor): 

243 output_tensors.append(x) 

244 return x 

245 

246 platform.tree_map(_collect_outputs, outputs) 

247 if not output_tensors: 

248 return 

249 

250 def _protect_tensor(x): 

251 if isinstance(x, SwapTensor): 

252 x.protect_if_aliases(output_tensors) 

253 return x 

254 

255 for storage_list in self.swap_storage.values(): 

256 for item in storage_list: 

257 platform.tree_map(_protect_tensor, item) 

258 

259 def launch_load(self): 

260 """launch async load for all tensors in swap storage""" 

261 def _async_load(x): 

262 if isinstance(x, SwapTensor): 

263 x.async_load() 

264 return x 

265 

266 for storage_list in self.swap_storage.values(): 

267 for item in storage_list: 

268 platform.tree_map(_async_load, item) 

269 

270 def resize_device_storage(self): 

271 """Resize device storage for all swap tensors (runs on compute stream).""" 

272 def _resize(x): 

273 if isinstance(x, SwapTensor): 

274 x.resize_device_storage() 

275 return x 

276 for storage_list in self.swap_storage.values(): 

277 for item in storage_list: 

278 platform.tree_map(_resize, item) 

279 

280 def wait_load(self): 

281 """wait load for all tensors in swap storage""" 

282 def _wait_load(x): 

283 if isinstance(x, SwapTensor): 

284 x.wait_load() 

285 return x 

286 

287 for storage_list in self.swap_storage.values(): 

288 for item in storage_list: 

289 platform.tree_map(_wait_load, item) 

290 

291 def wait_offload(self): 

292 """wait offload for all tensors in swap storage""" 

293 def _wait_offload(x): 

294 if isinstance(x, SwapTensor): 

295 x.wait_offload() 

296 return x 

297 

298 for storage_list in self.swap_storage.values(): 

299 for item in storage_list: 

300 platform.tree_map(_wait_offload, item) 

301 

302 def launch_offload(self): 

303 """launch async offload for all tensors in swap storage""" 

304 def _async_offload(x): 

305 if isinstance(x, SwapTensor): 

306 x.async_offload() 

307 return x 

308 

309 for storage_list in self.swap_storage.values(): 

310 for item in storage_list: 

311 platform.tree_map(_async_offload, item) 

312 

313 

314class SwapGroup: 

315 """Manager for a group of storages to coordinate swap operations.""" 

316 

317 def __init__(self, group_name: str): 

318 self.group_name = group_name 

319 self.is_last_group = False 

320 self._live_storages = [] 

321 self._load_event = None 

322 self._offload_event = None 

323 

324 def add(self, storage): 

325 """Add a storage to the swap group.""" 

326 seen_keys = set() 

327 for existing_storage in self._live_storages: 

328 for swap_tensor in existing_storage.iter_swap_tensors(): 

329 dedup_key = swap_tensor.dedup_key() 

330 if dedup_key is not None: 

331 seen_keys.add(dedup_key) 

332 duplicate_count = storage.mark_duplicate_swaps(seen_keys) 

333 if duplicate_count > 0: 

334 warnings.warn( 

335 f"SwapGroup '{self.group_name}' skipped {duplicate_count} duplicate tensor swap registration(s)." 

336 ) 

337 self._live_storages.append(storage) 

338 

339 def protect_output_tensors(self, outputs: Any): 

340 """Protect current module outputs from premature offload.""" 

341 for storage in self._live_storages: 

342 storage.protect_output_tensors(outputs) 

343 

344 def launch_offload(self, copy_stream): 

345 """Launch async offload for all storages in the group.""" 

346 compute_event = platform.new_event() 

347 compute_event.record(platform.get_current_stream()) 

348 self._offload_event = platform.new_event() 

349 stream_context = platform.get_stream_context() 

350 with platform.no_grad(), stream_context(copy_stream): 

351 compute_event.wait(copy_stream) 

352 for storage in self._live_storages: 

353 storage.launch_offload() 

354 self._offload_event.record(copy_stream) 

355 

356 def wait_offload(self): 

357 """Wait for offload to complete for all storages in the group.""" 

358 if self._offload_event is None: 

359 raise RuntimeError( 

360 f"SwapGroup '{self.group_name}' wait_offload() called before launch_offload()." 

361 ) 

362 compute_stream = platform.get_current_stream() 

363 stream_context = platform.get_stream_context() 

364 with platform.no_grad(), stream_context(compute_stream): 

365 self._offload_event.wait(compute_stream) 

366 self._offload_event = None 

367 for storage in self._live_storages: 

368 storage.wait_offload() 

369 

370 def launch_load(self, copy_stream): 

371 """Prepare storage and launch async load for all storages in the group.""" 

372 with platform.no_grad(): 

373 for storage in self._live_storages: 

374 storage.resize_device_storage() 

375 

376 compute_event = platform.new_event() 

377 compute_event.record(platform.get_current_stream()) 

378 self._load_event = platform.new_event() 

379 stream_context = platform.get_stream_context() 

380 with platform.no_grad(), stream_context(copy_stream): 

381 compute_event.wait(copy_stream) 

382 for storage in self._live_storages: 

383 storage.launch_load() # Only copy, no resize 

384 self._load_event.record(copy_stream) 

385 

386 def wait_load(self): 

387 """Wait for load to complete for all storages in the group.""" 

388 if self._load_event is None: 

389 raise RuntimeError( 

390 f"SwapGroup '{self.group_name}' wait_load() called before launch_load()." 

391 ) 

392 try: 

393 compute_stream = platform.get_current_stream() 

394 stream_context = platform.get_stream_context() 

395 with platform.no_grad(), stream_context(compute_stream): 

396 self._load_event.wait(compute_stream) 

397 self._load_event = None 

398 for storage in self._live_storages: 

399 storage.wait_load() 

400 finally: 

401 self._live_storages.clear() 

402 

403 

404class SwapManager: 

405 """Singleton manager for swap groups and their operations.""" 

406 _instance: Optional["SwapManager"] = None 

407 _lock = threading.Lock() 

408 

409 def __init__(self): 

410 if hasattr(self, '_groups'): 

411 return 

412 self._groups = {} 

413 self._current_group_name = "" 

414 self._counter_lock = threading.Lock() 

415 self._layer_count = 0 

416 self._copy_stream = None 

417 

418 def __new__(cls): 

419 if cls._instance is None: 

420 with cls._lock: 

421 if cls._instance is None: 

422 cls._instance = super().__new__(cls) 

423 return cls._instance 

424 

425 def add_storage(self, group_name: str, storage: Storage) -> None: 

426 """Add a storage to a specified swap group.""" 

427 if group_name not in self._groups: 

428 self._groups[group_name] = SwapGroup(group_name) 

429 self._groups[group_name].add(storage) 

430 

431 def launch_offload(self, group_name: str, copy_stream=None): 

432 """Launch async offload for a specified swap group.""" 

433 group = self._groups.get(group_name) 

434 if group is None: 

435 raise RuntimeError(f"Group {group_name} does not exist.") 

436 if copy_stream is None: 

437 copy_stream = self._get_copy_stream() 

438 group.launch_offload(copy_stream) 

439 

440 def protect_output_tensors(self, group_name: str, outputs: Any): 

441 """Keep tensors that alias the module output on device.""" 

442 group = self._groups.get(group_name) 

443 if group is None: 

444 raise RuntimeError(f"Group {group_name} does not exist.") 

445 group.protect_output_tensors(outputs) 

446 

447 def wait_offload(self, group_name: str): 

448 """Wait for offload to complete for a specified swap group.""" 

449 group = self._groups.get(group_name) 

450 if group is None: 

451 raise RuntimeError(f"Group {group_name} does not exist.") 

452 group.wait_offload() 

453 

454 def launch_load(self, group_name: str, copy_stream=None): 

455 """Launch async load for a specified swap group.""" 

456 group = self._groups.get(group_name) 

457 if group is None: 

458 raise RuntimeError(f"Group {group_name} does not exist.") 

459 if copy_stream is None: 

460 copy_stream = self._get_copy_stream() 

461 group.launch_load(copy_stream) 

462 

463 def wait_load(self, group_name: str): 

464 """Wait for load to complete for a specified swap group.""" 

465 group = self._groups.get(group_name) 

466 if group is None: 

467 raise RuntimeError(f"Group {group_name} does not exist.") 

468 group.wait_load() 

469 

470 def release_group_storage(self, group_name: str) -> None: 

471 """Release live storage references held by the swap group. 

472 

473 Called at the end of backward to free Storage objects that were never 

474 released via wait_load (e.g. the last layer, which has no next layer 

475 and therefore never goes through the offload-load cycle). 

476 """ 

477 group = self._groups.get(group_name) 

478 if group is not None: 

479 group._live_storages.clear() 

480 

481 def get_current_group_name(self): 

482 return self._current_group_name 

483 

484 def set_current_group_name(self, group_name): 

485 self._current_group_name = group_name 

486 

487 def is_last_group(self, group_name: Optional[str] = None) -> bool: 

488 """Return whether the specified swap group is the terminal group in the chain.""" 

489 group_name = self._current_group_name if group_name is None else group_name 

490 group = self._groups.get(group_name) 

491 if group is None: 

492 return False 

493 return group.is_last_group 

494 

495 def set_forward_prefetch_layer(self, first_layer, second_layer): 

496 """ 

497 Configure prefetching and offloading order between two consecutive layers. 

498 

499 Usage: 

500 for i in range(len(model.layers) - 1): 

501 set_forward_prefetch_layer(model.layers[i], model.layers[i + 1]) 

502 

503 Ensures idempotency: safe to call multiple times on the same layer pair. 

504 """ 

505 

506 def _ensure_group_name(module): 

507 """Assign a unique swap group name to the module if not already assigned.""" 

508 if not hasattr(module, "_swap_group_name"): 

509 name = f"swap_group_{self._layer_count}" 

510 self._layer_count += 1 

511 module._swap_group_name = name 

512 module._swap_group_order = {"prev": None, "next": None} 

513 return module._swap_group_name 

514 first_name = _ensure_group_name(first_layer) 

515 second_name = _ensure_group_name(second_layer) 

516 

517 if first_name not in self._groups: 

518 self._groups[first_name] = SwapGroup(first_name) 

519 if second_name not in self._groups: 

520 self._groups[second_name] = SwapGroup(second_name) 

521 

522 if first_layer._swap_group_order["next"] is None: 

523 first_layer._swap_group_order["next"] = second_name 

524 if second_layer._swap_group_order["prev"] is None: 

525 second_layer._swap_group_order["prev"] = first_name 

526 

527 self._groups[first_name].is_last_group = first_layer._swap_group_order["next"] is None 

528 self._groups[second_name].is_last_group = second_layer._swap_group_order["next"] is None 

529 

530 def _forward_pre_hook(group_name, module, _): # pylint: disable=W0613 

531 if getattr(module, "_swap_state", None) == "pre_backward": 

532 return 

533 SwapManager().set_current_group_name(group_name) 

534 

535 def _forward_hook(group_name, module, args, output): # pylint: disable=W0613 

536 """ 

537 Forward post-hook executed immediately after forward computation 

538 of the current layer finishes. 

539 

540 Execution timeline (example with 3 layers, forward order: L0 → L1 → L2): 

541 

542 Time → 

543 Forward Compute Stream: 

544 | Fwd L0 | post(L0) | Fwd L1 | post(L1) | Fwd L2 | 

545 

546 Copy Stream (offload): 

547 | Offload L0 | - | Offload L1 | 

548 ↑ ↑ 

549 offload at post(L0) offload at post(L1) 

550 

551 Swap rules: 

552 1. After forward computation of the current layer completes: 

553 - If a next layer exists, asynchronously offload the activations 

554 of the current layer (launch_offload). 

555 

556 Example: 

557 - At post-forward of L0, offload activations of L0. 

558 - At post-forward of L1, offload activations of L1. 

559 

560 2. To limit device memory peak: 

561 - If a previous layer exists, wait until its offload operation 

562 has completed (wait_offload). 

563 

564 Notes: 

565 - Offload operations are issued on the copy stream to overlap data transfer 

566 with forward computation of subsequent layers. 

567 - If the module is already in 'pre_backward' state, this hook is skipped 

568 to avoid triggering offload during backward phase. 

569 """ 

570 if getattr(module, "_swap_state", None) == "pre_backward": 

571 return 

572 next_name = module._swap_group_order.get('next', None) 

573 if next_name: 

574 SwapManager().protect_output_tensors(group_name, output) 

575 SwapManager().launch_offload(group_name) 

576 prev_name = module._swap_group_order.get('prev', None) 

577 if prev_name: 

578 SwapManager().wait_offload(prev_name) 

579 

580 def _backward_pre_hook(group_name, module, grad_input): # pylint: disable=W0613 

581 """ 

582 Pre-backward hook executed immediately before backward computation 

583 of the current layer starts. 

584 

585 Execution timeline (example with 3 layers, backward order: L2 → L1 → L0): 

586 

587 Time → 

588 Backward Compute Stream: 

589 | pre(L2) | Grad L2 | pre(L1) | Grad L1 | pre(L0) | Grad L0 | 

590 

591 Copy Stream (load): 

592 | Load L1 | - | Load L0 | 

593 ↑ ↑ 

594 prefetch at pre(L2) prefetch at pre(L1) 

595 

596 Swap rules: 

597 1. At the beginning of backward for the current layer: 

598 - If a previous layer exists in backward order, asynchronously 

599 prefetch its activations (launch_load). 

600 

601 Example: 

602 - At pre-backward of L2, prefetch activations of L1. 

603 - At pre-backward of L1, prefetch activations of L0. 

604 

605 2. Before starting backward computation of the current layer: 

606 - Ensure that the activations of the current layer have already 

607 been loaded back to device memory (wait_load). 

608 

609 Notes: 

610 - Load operations are issued on the copy stream to overlap data transfer 

611 with backward computation of the current layer. 

612 - The swap state is marked as 'pre_backward' to prevent forward hooks 

613 from issuing offload operations during backward phase. 

614 """ 

615 module._swap_state = "pre_backward" 

616 prev_name = module._swap_group_order.get('prev', None) 

617 if prev_name: 

618 SwapManager().launch_load(prev_name) 

619 

620 next_name = module._swap_group_order.get('next', None) 

621 if next_name: 

622 SwapManager().wait_load(group_name) 

623 

624 def _backward_hook(group_name, module, grad_input, grad_output): # pylint: disable=W0613 

625 module._swap_state = "backward" 

626 SwapManager().release_group_storage(group_name) 

627 

628 def _register_hooks_once(module, group_name): 

629 hooks = [ 

630 ("_swap_forward_pre_hook_handle", 

631 lambda h: platform.register_forward_pre_hook(module, h, prepend=True), 

632 functools.partial(_forward_pre_hook, group_name)), 

633 

634 ("_swap_forward_hook_handle", 

635 module.register_forward_hook, 

636 functools.partial(_forward_hook, group_name)), 

637 

638 ("_swap_backward_pre_hook_handle", 

639 lambda h: platform.register_full_backward_pre_hook(module, h, prepend=True), 

640 functools.partial(_backward_pre_hook, group_name)), 

641 

642 ("_swap_backward_hook_handle", 

643 lambda h: platform.register_full_backward_hook(module, h), 

644 functools.partial(_backward_hook, group_name)), 

645 ] 

646 

647 for attr_name, register_func, hook in hooks: 

648 if not hasattr(module, attr_name): 

649 handle = register_func(hook) 

650 setattr(module, attr_name, handle) 

651 # Register for both layers 

652 _register_hooks_once(first_layer, first_name) 

653 _register_hooks_once(second_layer, second_name) 

654 

655 def _get_copy_stream(self): 

656 """Return a singleton copy stream, created on first access.""" 

657 if self._copy_stream is None: 

658 self._copy_stream = platform.new_stream() 

659 return self._copy_stream