Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / platform.py: 56%

619 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch platform api""" 

16from datetime import timedelta 

17from typing import Optional, Any, Union 

18import dataclasses 

19from collections import OrderedDict 

20 

21import numpy as np 

22from safetensors.torch import save_file, load_file 

23import torch 

24from torch import nn 

25from torch import Tensor 

26from torch._C._distributed_c10d import Store, ProcessGroup 

27from torch.distributed import Backend 

28from torch.distributed.distributed_c10d import _get_default_group 

29from torch.nn import Parameter, Module 

30from torch.nn.utils.rnn import PackedSequence 

31from torch._ops import OpOverload, OpOverloadPacket 

32from torch.utils.checkpoint import noop_context_fn 

33from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper 

34import torch.distributed.nn.functional as dist_func 

35import torch.distributed as dist 

36from hyper_parallel.platform.torch.dtensor import DTensorBase 

37from hyper_parallel.platform.torch.pipeline_parallel.stage import PipelineStageBase 

38from hyper_parallel.platform.torch.group_utils import create_sub_groups 

39from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS 

40from hyper_parallel.platform.torch.function_override import override_functions 

41from hyper_parallel.platform.torch.init_weights import init_on_device as _init_on_device 

42 

43override_functions() 

44 

45 

46# --------------------------------------------------------------------------- 

47# Module-level A2A reshape helpers 

48# --------------------------------------------------------------------------- 

49 

50def _a2a_reconstruct(out_perm: torch.Tensor, concat_dim: int) -> torch.Tensor: 

51 """Reconstruct A2A result from raw out_perm buffer. 

52 

53 ``out_perm`` has shape ``[ws, *rest_dims]``, chunk at ``concat_dim + 1``. 

54 Returns tensor with merged chunk dimension. 

55 """ 

56 new_ndim = out_perm.dim() 

57 chunk_in_perm = concat_dim + 1 

58 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim)) 

59 x_recon = out_perm.permute(recon_perm).contiguous() 

60 shape = list(x_recon.shape) 

61 merged = shape[concat_dim] * shape[concat_dim + 1] 

62 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:]) 

63 

64 

65class _TorchAsyncA2AFunction(torch.autograd.Function): 

66 """Differentiable wrapper for pre-launched async all-to-all. 

67 

68 Forward: wait async handle, reconstruct A2A result. 

69 Backward: launch async head→seq A2A and store handle in ``handle_box`` 

70 for the projection pre-hook to wait, achieving GEMM–A2A overlap. 

71 """ 

72 

73 @staticmethod 

74 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=arguments-differ 

75 handle_box): 

76 """Wait for pre-launched async A2A and return reconstructed output.""" 

77 ctx.group = group 

78 ctx.world_size = world_size 

79 ctx.concat_dim = concat_dim 

80 ctx.split_dim = split_dim 

81 ctx.handle_box = handle_box 

82 ctx.x_shape = x.shape 

83 work.wait() 

84 return _a2a_reconstruct(out_perm, concat_dim) 

85 

86 @staticmethod 

87 def backward(ctx, grad_output): 

88 """Launch async head→seq A2A for backward overlap, or return zero grad.""" 

89 if ctx.handle_box is not None: 

90 # Launch async head→seq A2A (reverse of forward seq→head) 

91 g = grad_output.contiguous() 

92 shape = list(g.shape) 

93 seq_dim = ctx.concat_dim 

94 s_full = shape[seq_dim] 

95 ndim = len(shape) + 1 

96 x_perm = g.reshape( 

97 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:] 

98 ).permute( 

99 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim)) 

100 ).contiguous() 

101 out_perm = torch.empty_like(x_perm) 

102 work = dist.all_to_all_single(out_perm, x_perm, group=ctx.group, async_op=True) 

103 ctx.handle_box.append((work, out_perm)) 

104 return grad_output.new_zeros(ctx.x_shape), None, None, None, None, None, None, None 

105 

106 

107class _AsyncA2ALazyBwd(torch.autograd.Function): 

108 """All-to-all whose forward AND backward return ``AsyncCollectiveTensor``. 

109 

110 PyTorch's stock ``all_to_all_single_autograd`` calls ``wait_tensor`` in 

111 its backward eagerly, and the autograd engine binds backward stream 

112 context to the forward stream — so even if the BWD thread is wrapped 

113 in a side-stream context, that wait still lands on the FWD main 

114 stream and blocks Attention launches. 

115 

116 This Function bypasses the engine's binding by calling the 

117 non-autograd functional op in both directions and returning ACT. 

118 The wait is deferred to the next consumer's first non-view access 

119 (e.g. the indexing backward of ``_unpermute``), giving the FWD 

120 thread a small Python window to enqueue its Attention kernels onto 

121 the main stream **before** the wait lands there. 

122 """ 

123 

124 @staticmethod 

125 def forward(ctx, input_tensor, output_splits, input_splits, group): # pylint: disable=arguments-differ 

126 ctx.input_splits = input_splits 

127 ctx.output_splits = output_splits 

128 ctx.group = group 

129 # pylint: disable=C0415 

130 from torch.distributed._functional_collectives import all_to_all_single 

131 return all_to_all_single( 

132 input_tensor, output_splits, input_splits, group, 

133 ) 

134 

135 @staticmethod 

136 def backward(ctx, grad_output): 

137 # pylint: disable=C0415 

138 from torch.distributed._functional_collectives import all_to_all_single 

139 grad_input = all_to_all_single( 

140 grad_output, ctx.input_splits, ctx.output_splits, ctx.group, 

141 ) 

142 return grad_input, None, None, None 

143 

144 

145class _TorchSyncHookFunction(torch.autograd.Function): 

146 """Autograd identity that fires HookCoordinator rendezvous on fwd/bwd. 

147 

148 Uses a **4-hook** design (``A``, ``B``, ``C``, ``D``) with pure 

149 COMM / COMPUTE roles — no NONE role. Every rendezvous is a strict 

150 COMM + COMPUTE pair, guaranteeing NCCL-first dispatch ordering at 

151 **all** points including layer boundaries. 

152 

153 Hook placement per MoE layer:: 

154 

155 [A] → dispatch → [B] → module → [C] → combine → [D] → (Attention) → [A_next] 

156 

157 At layer boundaries (D / A hooks), the Attention that runs between 

158 layers is treated as COMPUTE, and the combine / combine.bwd is treated 

159 as COMM, so the coordinator enforces comm-first ordering even across 

160 layer transitions. 

161 """ 

162 

163 # 4-hook role tables: (prev_role_idx, next_role_idx). 

164 # Index encoding: 1 = COMM, 2 = COMPUTE. 

165 _FWD_ROLES = { 

166 # (prev, next) prev op next op 

167 "A": (2, 1), # COMPUTE, COMM Attention | dispatch 

168 "B": (1, 2), # COMM, COMPUTE dispatch | module 

169 "C": (2, 1), # COMPUTE, COMM module | combine 

170 "D": (1, 2), # COMM, COMPUTE combine | Attention 

171 } 

172 _BWD_ROLES = { 

173 "D": (2, 1), # COMPUTE, COMM Attn.bwd | combine.bwd 

174 "C": (1, 2), # COMM, COMPUTE combine.bwd | module.bwd 

175 "B": (2, 1), # COMPUTE, COMM module.bwd | dispatch.bwd 

176 "A": (1, 2), # COMM, COMPUTE dispatch.bwd| Attn.bwd 

177 } 

178 

179 _ROLE_CACHE = None 

180 

181 @staticmethod 

182 def _role_enum(idx: int): 

183 if _TorchSyncHookFunction._ROLE_CACHE is None: 

184 from hyper_parallel.core.pipeline_parallel.hook_coordinator import HookRole # pylint: disable=C0415 

185 _TorchSyncHookFunction._ROLE_CACHE = (None, HookRole.COMM, HookRole.COMPUTE) 

186 return _TorchSyncHookFunction._ROLE_CACHE[idx] 

187 

188 @staticmethod 

189 def forward(ctx, x, hook_name, coordinator): # pylint: disable=arguments-differ 

190 """Identity forward that fires a HookCoordinator rendezvous. 

191 

192 Notifies the previous op's role and rendezvouses for the next op's 

193 role per the ``_FWD_ROLES`` table. ``"D_LAST"`` is a sentinel 

194 meaning "skip this rendezvous" (last layer's closing D — no 

195 Attention follows). 

196 

197 Args: 

198 ctx: Autograd context, stores ``hook_name`` and 

199 ``coordinator`` for the backward pass. 

200 x: Input tensor, returned unchanged. 

201 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``, 

202 ``"D_LAST"``. 

203 coordinator: The :class:`HookCoordinator` driving the rendezvous. 

204 

205 Returns: 

206 ``x`` unchanged. 

207 """ 

208 ctx.hook_name = hook_name 

209 ctx.coordinator = coordinator 

210 

211 if not coordinator.is_enabled(): 

212 return x 

213 

214 # ``D_LAST`` marks the last layer's D hook. The "next op" after 

215 # this hook is the chunk's output (no Attention follows), so the 

216 # rendezvous is meaningless — skip it. In backward this same 

217 # hook is the very first BWD hook to fire, where ``combine.bwd`` 

218 # has already free-run before any rendezvous is possible — also 

219 # skip. Tagging at wrap time replaces the old runtime 

220 # ``increment_cycle`` / ``bwd_d_should_skip`` mechanisms. 

221 if hook_name == "D_LAST": 

222 return x 

223 

224 prev_idx, next_idx = _TorchSyncHookFunction._FWD_ROLES[hook_name] 

225 role_of = _TorchSyncHookFunction._role_enum 

226 coordinator.notify_dispatched(role_of(prev_idx)) 

227 coordinator.rendezvous(role_of(next_idx)) 

228 return x 

229 

230 @staticmethod 

231 def backward(ctx, grad_output): 

232 """Identity backward that fires a HookCoordinator rendezvous. 

233 

234 Mirror of :meth:`forward` using the ``_BWD_ROLES`` table. 

235 ``"D_LAST"`` skips the rendezvous because this is the first BWD 

236 hook to fire and ``combine.bwd`` has already dispatched freely 

237 before any rendezvous can happen. 

238 

239 Args: 

240 ctx: Autograd context with ``hook_name`` and 

241 ``coordinator`` saved during forward. 

242 grad_output: Gradient w.r.t. the forward output, returned 

243 unchanged. 

244 

245 Returns: 

246 ``(grad_output, None, None)`` — gradients only flow back to 

247 the tensor input, ``hook_name`` and ``coordinator`` are 

248 non-tensor inputs. 

249 """ 

250 hook_name = ctx.hook_name 

251 coordinator = ctx.coordinator 

252 

253 if not coordinator.is_enabled(): 

254 return grad_output, None, None 

255 

256 # Same ``D_LAST`` semantics as forward: this is the first BWD 

257 # hook to fire and combine.bwd has already dispatched freely 

258 # before any rendezvous can happen, so skip the rendezvous. 

259 if hook_name == "D_LAST": 

260 return grad_output, None, None 

261 

262 prev_idx, next_idx = _TorchSyncHookFunction._BWD_ROLES[hook_name] 

263 role_of = _TorchSyncHookFunction._role_enum 

264 coordinator.notify_dispatched(role_of(prev_idx)) 

265 coordinator.rendezvous(role_of(next_idx)) 

266 return grad_output, None, None 

267 

268 

269class _TorchP2PExchangeFunction(torch.autograd.Function): 

270 """Symmetric bidirectional P2P: send local tensor to peer, receive peer's tensor.""" 

271 

272 @staticmethod 

273 def forward(ctx, tensor: torch.Tensor, peer_rank: int, group) -> torch.Tensor: # pylint: disable=arguments-differ 

274 """Perform symmetric bidirectional P2P exchange with peer_rank.""" 

275 ctx.peer_rank = peer_rank 

276 ctx.group = group 

277 send_buf = tensor.contiguous() 

278 recv_buf = torch.empty_like(send_buf) 

279 reqs = dist.batch_isend_irecv([ 

280 dist.P2POp(dist.isend, send_buf, peer_rank, group), 

281 dist.P2POp(dist.irecv, recv_buf, peer_rank, group), 

282 ]) 

283 for req in reqs: 

284 req.wait() 

285 return recv_buf 

286 

287 @staticmethod 

288 def backward(ctx, grad_output: torch.Tensor): 

289 """Perform symmetric P2P exchange for the backward gradient pass.""" 

290 send_buf = grad_output.contiguous() 

291 recv_buf = torch.empty_like(send_buf) 

292 reqs = dist.batch_isend_irecv([ 

293 dist.P2POp(dist.isend, send_buf, ctx.peer_rank, ctx.group), 

294 dist.P2POp(dist.irecv, recv_buf, ctx.peer_rank, ctx.group), 

295 ]) 

296 for req in reqs: 

297 req.wait() 

298 return recv_buf, None, None 

299 

300 

301# Mapping from string op names to torch.distributed.ReduceOp 

302_OP_MAP = { 

303 'sum': dist.ReduceOp.SUM, 

304 'prod': dist.ReduceOp.PRODUCT, 

305 'max': dist.ReduceOp.MAX, 

306 'min': dist.ReduceOp.MIN, 

307 # convert tensor elements to int32 and use MIN 

308 'all': dist.ReduceOp.MIN, 

309 # 'avg' is typically handled by SUM followed by division in current implementation logic 

310 'avg': dist.ReduceOp.SUM, 

311} 

312 

313# Try to add AVG for 'mean' if supported by current torch version 

314if hasattr(dist.ReduceOp, "AVG"): 

315 _OP_MAP['mean'] = dist.ReduceOp.AVG 

316else: 

317 # Fallback for older torch versions if necessary, though this might require manual division upstream 

318 # Assuming standard behavior where 'mean' implies native AVG support or upstream handling 

319 _OP_MAP['mean'] = dist.ReduceOp.SUM 

320 

321 

322# pylint: disable=C0103 

323class TorchPlatform(Platform): 

324 """Torch platform api""" 

325 Tensor = Tensor 

326 tensor = torch.tensor 

327 Parameter = Parameter 

328 Module = Module 

329 DTensorBase = DTensorBase 

330 PipelineStageBase = PipelineStageBase 

331 platform_type = PlatformType.PYTORCH 

332 tensor_dtype = torch 

333 dtype = torch.dtype 

334 Function = torch.autograd.Function 

335 

336 _custom_ops_cls = None 

337 

338 @property 

339 def custom_ops(self): 

340 """Return the Torch platform custom ops instance. 

341 

342 .. warning:: 

343 This is an experimental API that subject to change or deletion. 

344 

345 Returns: 

346 TorchCustomOps: Custom ops class that raises NotImplementedError 

347 for all operators (MindSpore-only at this time). 

348 """ 

349 if self._custom_ops_cls is None: 

350 from hyper_parallel.platform.torch.custom_ops import TorchCustomOps # pylint: disable=import-outside-toplevel 

351 self._custom_ops_cls = TorchCustomOps 

352 return self._custom_ops_cls 

353 

354 @staticmethod 

355 def is_linear_module(module) -> bool: 

356 """Check whether *module* is a ``torch.nn.Linear`` instance.""" 

357 return isinstance(module, nn.Linear) 

358 

359 @staticmethod 

360 def is_embedding_module(module) -> bool: 

361 """Check whether *module* is a ``torch.nn.Embedding`` instance.""" 

362 return isinstance(module, nn.Embedding) 

363 

364 @staticmethod 

365 def device_count(device_handle): 

366 """ 

367 Get the number of available devices. 

368 

369 Args: 

370 device_handle: The device handle (e.g., torch.cuda, torch.npu). 

371 

372 Returns: 

373 int: The number of available devices. 

374 """ 

375 return device_handle.device_count() 

376 

377 def device_type(self): 

378 """ 

379 Get the current device type. 

380 

381 Returns: 

382 str: The device type string ("npu" for NPU, "cuda" for GPU). 

383 """ 

384 device_handle = self.get_device_handle() 

385 if device_handle == torch.npu: 

386 return "npu" 

387 return "cuda" 

388 

389 def device(self, device_idx=None): 

390 """ 

391 Get a torch.device object for the specified device index. 

392 

393 Args: 

394 device_idx (Optional[int]): The device index. If None, returns device without index. 

395 

396 Returns: 

397 torch.device: A torch device object. 

398 """ 

399 device_type = self.device_type() 

400 if device_idx is None: 

401 return torch.device(device_type) 

402 return torch.device(f"{device_type}:{device_idx:d}") 

403 

404 @staticmethod 

405 def get_rng_state(device=None, device_handle=None): 

406 """ 

407 Get the random number generator state. 

408 

409 Args: 

410 device (Optional): The device to get RNG state from. 

411 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.). 

412 

413 Returns: 

414 Tensor: The RNG state as a byte tensor. 

415 """ 

416 if device_handle is None: 

417 return torch.get_rng_state() 

418 if device is None: 

419 return device_handle.get_rng_state() 

420 return device_handle.get_rng_state(device) 

421 

422 @staticmethod 

423 def set_rng_state(state, device=None, device_handle=None): 

424 """ 

425 Set the random number generator state. 

426 

427 Args: 

428 state (Tensor): The RNG state to set. 

429 device (Optional): The device to set RNG state for. 

430 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.). 

431 """ 

432 if device_handle is None: 

433 return torch.set_rng_state(state) 

434 if device is None: 

435 return device_handle.set_rng_state(state) 

436 return device_handle.set_rng_state(state, device) 

437 

438 @staticmethod 

439 def manual_seed(seed): 

440 """ 

441 Set the random seed for reproducibility. 

442 

443 Args: 

444 seed (int): The random seed value. 

445 

446 Returns: 

447 torch.Generator: The random number generator. 

448 """ 

449 return torch.manual_seed(seed) 

450 

451 @staticmethod 

452 def ones(size, dtype=None): 

453 """ 

454 Create a tensor filled with ones. 

455 

456 Args: 

457 size (tuple): The shape of the output tensor. 

458 dtype (Optional[torch.dtype]): The desired data type. 

459 

460 Returns: 

461 Tensor: A tensor filled with ones. 

462 """ 

463 return torch.ones(size, dtype=dtype) 

464 

465 @staticmethod 

466 def zeros(size, dtype=None, device=None): 

467 """ 

468 Create a tensor filled with zeros. 

469 

470 Args: 

471 size (tuple): The shape of the output tensor. 

472 dtype (Optional[torch.dtype]): The desired data type. 

473 device (Optional[torch.device]): The device to create the tensor on. 

474 

475 Returns: 

476 Tensor: A tensor filled with zeros. 

477 """ 

478 return torch.zeros(size, dtype=dtype, device=device) 

479 

480 @staticmethod 

481 def full(size, fill_value, dtype=None): 

482 """ 

483 Create a tensor filled with a scalar value. 

484 

485 Args: 

486 size (tuple): The shape of the output tensor. 

487 fill_value (scalar): The value to fill the tensor with. 

488 dtype (Optional[torch.dtype]): The desired data type. 

489 

490 Returns: 

491 Tensor: A tensor filled with the specified value. 

492 """ 

493 return torch.full(size, fill_value, dtype=dtype) 

494 

495 @staticmethod 

496 def empty(size, dtype=None): 

497 """ 

498 Create an uninitialized tensor. 

499 

500 Args: 

501 size (tuple): The shape of the output tensor. 

502 dtype (Optional[torch.dtype]): The desired data type. 

503 

504 Returns: 

505 Tensor: An uninitialized tensor. 

506 """ 

507 return torch.empty(size, dtype=dtype) 

508 

509 @staticmethod 

510 def get_rank(): 

511 """ 

512 Get the rank of the current process in the distributed group. 

513 

514 Returns: 

515 int: The rank of the current process. 

516 """ 

517 return dist.get_rank() 

518 

519 @staticmethod 

520 def get_global_rank(group, group_rank): 

521 """ 

522 Get the global rank from a group rank. 

523 

524 Args: 

525 group (ProcessGroup): The process group. 

526 group_rank (int): The rank within the group. 

527 

528 Returns: 

529 int: The global rank. 

530 """ 

531 return dist.get_global_rank(group, group_rank) 

532 

533 @staticmethod 

534 def get_world_size(): 

535 """ 

536 Get the total number of processes in the distributed group. 

537 

538 Returns: 

539 int: The world size. 

540 """ 

541 return dist.get_world_size() 

542 

543 @staticmethod 

544 def get_param_local_shape(param): 

545 """ 

546 Get the local shape of a parameter, handling both regular and distributed tensors. 

547 

548 Args: 

549 param (Union[Tensor, DTensorBase]): The parameter tensor. 

550 

551 Returns: 

552 torch.Size: The local shape of the parameter. 

553 """ 

554 if isinstance(param, DTensorBase): 

555 return param.local_shape 

556 return param.shape 

557 

558 @staticmethod 

559 def get_param_local_data(param): 

560 """ 

561 Get the local data of a parameter, handling both regular and distributed tensors. 

562 

563 Args: 

564 param (Union[Tensor, DTensorBase]): The parameter tensor. 

565 

566 Returns: 

567 Tensor: The local tensor data. 

568 """ 

569 if isinstance(param, DTensorBase): 

570 return param.to_local() 

571 return param 

572 

573 @staticmethod 

574 def update_param_data(param, data): 

575 """ 

576 Update the data of a parameter. 

577 

578 Args: 

579 param (Parameter): The parameter to update. 

580 data (Tensor): The new data tensor. 

581 """ 

582 param.data = data 

583 

584 @staticmethod 

585 def load_into_param(param, data): 

586 """Load tensor *data* into *param* (plain tensor or DTensor).""" 

587 if isinstance(param, DTensorBase): 

588 local = param._local_tensor # pylint: disable=W0212 

589 if local.is_meta: 

590 # Meta tensor materialisation: replace the placeholder. 

591 orig_requires_grad = param.requires_grad 

592 param._local_tensor = data # pylint: disable=W0212 

593 if data.requires_grad != orig_requires_grad: 

594 param.requires_grad_(orig_requires_grad) 

595 else: 

596 local.copy_(data) 

597 else: 

598 param.copy_(data) 

599 

600 @staticmethod 

601 def get_op_name(func): 

602 """ 

603 Extract the operation name from various function types. 

604 

605 Args: 

606 func: The function or operation to extract the name from. 

607 

608 Returns: 

609 str: The operation name. 

610 """ 

611 if hasattr(func, "__name__"): 

612 return func.__name__ 

613 if isinstance(func, OpOverload): 

614 full_name = func.name 

615 core_name = full_name.split("::")[-1].split(".")[0] 

616 return core_name 

617 if isinstance(func, OpOverloadPacket): 

618 return func.name.split("::")[-1] 

619 func_str = str(func) 

620 if "built-in function" in func_str: 

621 return func_str.split()[-1].strip(">") 

622 if "function" in func_str: 

623 return func_str.split()[1] 

624 return "unknown_op" 

625 

626 @staticmethod 

627 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

628 output = dist_func.all_gather(data, group=group) 

629 return torch.cat(output, dim=concat_dim) 

630 

631 @staticmethod 

632 def chunk(data, split_dim, split_size, index): 

633 return torch.chunk(data, split_size, dim=split_dim)[index] 

634 

635 @staticmethod 

636 def differentiable_all_to_all(input_data, output_shape, group): 

637 output_tensor = torch.empty(output_shape, device=input_data.device, dtype=input_data.dtype) 

638 output_tensor = dist_func.all_to_all_single( 

639 output_tensor, 

640 input_data, 

641 group=group 

642 ) 

643 return output_tensor 

644 

645 @staticmethod 

646 def tensor_type_cast(input_data, cast_type): 

647 """Cast tensor to specified data type.""" 

648 type_mapping = { 

649 'float32': torch.float32, 

650 'float16': torch.float16, 

651 'int64': torch.int64, 

652 'int32': torch.int32 

653 } 

654 if cast_type not in type_mapping: 

655 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}") 

656 return input_data.to(type_mapping[cast_type]) 

657 

658 @staticmethod 

659 def differentiable_all_reduce(data, op, group): 

660 # Resolve the op from string to ReduceOp enum if necessary 

661 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op 

662 return dist_func.all_reduce(data, op=reduce_op, group=group) 

663 

664 @staticmethod 

665 def get_cell_construct(cell): 

666 return cell.forward 

667 

668 @staticmethod 

669 def get_cells_and_names(cell): 

670 return cell.named_modules() 

671 

672 @staticmethod 

673 def get_modules(module): 

674 return module.modules() 

675 

676 @staticmethod 

677 def search_parameter_by_name(cell, param_name: str): 

678 """ 

679 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter. 

680 Return value: (parent Module instance, parameter's name in parent Module, parameter object). 

681 Returns None if not found. 

682 """ 

683 # Remove the "self." prefix from param_name 

684 param_name = param_name.replace("self.", "") 

685 # Case 1: The parameter is a direct parameter of the current Module 

686 if param_name in cell._parameters: # pylint: disable=protected-access 

687 return (cell, param_name, cell._parameters[param_name]) # pylint: disable=protected-access 

688 

689 # Case 2: The parameter is in a sub-Module 

690 if "." in param_name: 

691 cell_path, param_key = param_name.rsplit(".", 1) 

692 try: 

693 # Locate the sub-Module where the parameter resides (supports multi-level paths) 

694 target_cell = cell.get_submodule(cell_path) 

695 # Check if the sub-Module directly contains this parameter 

696 if param_key in target_cell._parameters: # pylint: disable=protected-access 

697 return target_cell, param_key, target_cell._parameters[param_key] # pylint: disable=protected-access 

698 except AttributeError: 

699 pass 

700 

701 # Traverse all sub-Modules (recursively) to search for the parameter 

702 for _, child_cell in cell.named_children(): 

703 if isinstance(child_cell, Module): 

704 result = TorchPlatform.search_parameter_by_name(child_cell, param_name) 

705 if result is not None: 

706 return result 

707 

708 return None 

709 

710 @staticmethod 

711 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

712 """ 

713 Modify the original parameter in a Module or sub-Module using the search result 

714 """ 

715 parent_cell, param_key, _ = result 

716 # Key operation: directly modify the _parameters dictionary. 

717 if param_key in parent_cell._parameters: # pylint: disable=protected-access 

718 parent_cell._parameters[param_key] = new_param # pylint: disable=protected-access 

719 else: 

720 parent_cell.register_parameter(param_key, new_param) 

721 return True 

722 

723 @staticmethod 

724 def set_layout_into_parameter(param, layout): 

725 """Set layout into parameter""" 

726 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel 

727 from hyper_parallel.core.dtensor.layout import _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel 

728 if isinstance(param, DTensor): 

729 raise ValueError(f"Parameter {param} has been configured layout, cannot be set repeatedly.") 

730 requires_grad = param.requires_grad 

731 param_dtensor = DTensor.from_local( 

732 _get_slice_tensor_by_layout(param, layout), 

733 layout.mesh, layout.alias_placements) 

734 new_param = Parameter(param_dtensor, requires_grad=requires_grad) 

735 return new_param 

736 

737 @staticmethod 

738 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

739 input_tuple = torch.chunk(data, dev_num, dim=axis) 

740 output_tensor = torch.empty(input_tuple[0].shape, device=data.device, dtype=data.dtype) 

741 

742 # Resolve the op from string to ReduceOp enum 

743 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op 

744 

745 output_tensor = dist_func.reduce_scatter(output_tensor, input_tuple, op=reduce_op, group=group) 

746 

747 # Keep manual handling for 'avg' string as it maps to SUM in _OP_MAP 

748 if op == 'avg': 

749 output_tensor = output_tensor / dev_num 

750 return output_tensor 

751 

752 @staticmethod 

753 def get_device_handle(device_type: str = "npu"): 

754 try: 

755 handle = getattr(torch, device_type) 

756 except AttributeError as e: 

757 raise RuntimeError(f"TorchPlatform expect got device handle: 'torch.{device_type}' failed.") from e 

758 return handle 

759 

760 @staticmethod 

761 def get_param_type_size(param): 

762 # pylint: disable=W0212 

763 return torch._utils._element_size(param.dtype) 

764 

765 @staticmethod 

766 def is_tensor(obj: Any) -> bool: 

767 """Return True if ``obj`` is a ``torch.Tensor``.""" 

768 return isinstance(obj, Tensor) 

769 

770 @staticmethod 

771 def get_tensor_storage_size(tensor: Any) -> int: 

772 """Return serialized byte size (numel * element size) for a PyTorch tensor.""" 

773 if not TorchPlatform.is_tensor(tensor): 

774 raise TypeError( 

775 f"TorchPlatform.get_tensor_storage_size expects torch.Tensor, got {type(tensor)!r}" 

776 ) 

777 return int(tensor.numel()) * int(tensor.element_size()) 

778 

779 @staticmethod 

780 def parameters_dict(cell: Module): 

781 return cell.named_parameters() 

782 

783 @staticmethod 

784 def get_model_state_dict(model, *, options=None): 

785 # pylint: disable=C0415 

786 from hyper_parallel.platform.torch.fully_shard.state_dict_utils import ( 

787 get_model_state_dict as _get_model_state_dict, 

788 ) 

789 return _get_model_state_dict(model, options=options) 

790 

791 @staticmethod 

792 def save_checkpoint(cell: Module, file_path: str, ckpt_format: str = "safetensors") -> None: 

793 if ckpt_format == "safetensors": 

794 save_file(tensors=cell, filename=file_path) 

795 else: 

796 torch.save(obj=cell, f=file_path) 

797 

798 @staticmethod 

799 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict: 

800 if ckpt_format == "safetensors": 

801 return load_file(filename=file_path) 

802 return torch.load(f=file_path) 

803 

804 @staticmethod 

805 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

806 return nn.Parameter(torch.zeros(param_shape, dtype=param_type, device=device), requires_grad=requires_grad) 

807 

808 @staticmethod 

809 def new_tensor(tensor_shape, tensor_type, device): 

810 return torch.empty(size=tensor_shape, dtype=tensor_type, device=device) 

811 

812 @staticmethod 

813 def full_like(tensor, fill_value, dtype=None): 

814 return torch.full_like(tensor, fill_value, dtype=dtype) 

815 

816 @staticmethod 

817 def set_tensor_requires_grad(input_tensor): 

818 """ 

819 set requires grad flag for input tensor, only effective for leaf node 

820 """ 

821 if input_tensor.is_leaf: 

822 input_tensor.requires_grad = True 

823 

824 def _create_group(self, rank_list): 

825 group_dict = create_sub_groups(rank_list) 

826 return group_dict[tuple(rank_list)] 

827 

828 @staticmethod 

829 def all_gather_into_tensor(data, group_info, async_op=False): 

830 output_shape = list(data.shape) 

831 output_shape[0] = output_shape[0] * group_info.rank_size 

832 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

833 handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op) 

834 return output, handle 

835 

836 @staticmethod 

837 def all_reduce(data, group_info, async_op=False): 

838 if not data.is_contiguous(): 

839 data = data.contiguous() 

840 handle = dist.all_reduce(data, group=group_info.group, async_op=async_op) 

841 return data, handle 

842 

843 @staticmethod 

844 def broadcast(data, src, group=None, async_op=False): 

845 handle = dist.broadcast(data, src, group, async_op) 

846 if async_op: 

847 handle.wait() 

848 

849 @staticmethod 

850 def isend(tensor, dst=None, group=None, tag=0): 

851 return dist.isend(tensor, dst, group, tag) 

852 

853 @staticmethod 

854 def irecv(tensor, src=None, group=None, tag=0): 

855 return dist.irecv(tensor, src, group, tag) 

856 

857 @staticmethod 

858 def p2p_exchange(tensor, peer_rank: int, group=None): 

859 if peer_rank == dist.get_rank(group): 

860 return tensor 

861 return _TorchP2PExchangeFunction.apply(tensor, peer_rank, group) 

862 

863 @staticmethod 

864 def send_object_list(obj_list, dst=None, group=None): 

865 dist.send_object_list(obj_list, dst, group) 

866 

867 @staticmethod 

868 def recv_object_list(obj_list, src=None, group=None): 

869 dist.recv_object_list(obj_list, src, group) 

870 

871 @staticmethod 

872 def reduce_scatter_tensor(data, group_info, async_op=False): 

873 output_shape = list(data.shape) 

874 output_shape[0] = output_shape[0] // group_info.rank_size 

875 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

876 handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op) 

877 return output, handle 

878 

879 @staticmethod 

880 def all_to_all_single(input_tensor, output_shape, group, async_op=False): 

881 output = torch.empty(output_shape, device=input_tensor.device, dtype=input_tensor.dtype) 

882 work = dist.all_to_all_single(output, input_tensor, group=group, async_op=async_op) 

883 return output, work 

884 

885 @staticmethod 

886 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group): 

887 """Variable-split all-to-all with autograd support for EP token dispatch/combine.""" 

888 out_total = sum(output_splits) 

889 output = torch.empty( 

890 out_total, *input_tensor.shape[1:], 

891 dtype=input_tensor.dtype, device=input_tensor.device, 

892 ) 

893 output = dist_func.all_to_all_single( 

894 output, input_tensor, 

895 output_split_sizes=output_splits, 

896 input_split_sizes=input_splits, 

897 group=group, 

898 ) 

899 return output 

900 

901 @staticmethod 

902 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group): 

903 """Truly-async variant of :meth:`differentiable_all_to_all_single`. 

904 

905 Both forward AND backward return :class:`AsyncCollectiveTensor`, 

906 so the ``wait_tensor`` op is queued lazily — only when a downstream 

907 kernel actually reads the result. 

908 

909 Why both directions need lazy wait: 

910 

911 * FWD: ACT lazy wait lets host return immediately and the paired 

912 BWD thread's compute kernel slip into the queue before the wait. 

913 * BWD: PyTorch's stock backward issues ``wait_tensor`` eagerly, 

914 and the autograd engine binds backward stream to the forward 

915 stream — so even running BWD inside a ``with torch.npu.stream 

916 (side_stream)`` context does not move that wait off the main 

917 stream. Returning ACT from backward defers the wait to the 

918 next backward op's first consumption, opening a small window 

919 during which FWD's Attention kernels can be queued onto the 

920 main stream **before** the wait lands. 

921 

922 Args: 

923 input_tensor: Input tensor, split along dim 0 by ``input_splits``. 

924 input_splits: ``list[int]`` — rows sent to each rank. 

925 output_splits: ``list[int]`` — rows received from each rank. 

926 group: Process group. 

927 

928 Returns: 

929 ``AsyncCollectiveTensor`` of shape 

930 ``[sum(output_splits), *input_tensor.shape[1:]]``. 

931 """ 

932 return _AsyncA2ALazyBwd.apply(input_tensor, output_splits, input_splits, group) 

933 

934 @staticmethod 

935 def arange(start, end=None, step=1, dtype=None, device=None): 

936 """Create a 1-D tensor with evenly spaced values.""" 

937 if end is None: 

938 return torch.arange(start, dtype=dtype, device=device) 

939 return torch.arange(start, end, step, dtype=dtype, device=device) 

940 

941 @staticmethod 

942 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, 

943 handle_box=None): 

944 """Wait async A2A handle and reconstruct result (differentiable). 

945 

946 Args: 

947 x: Input tensor. 

948 work: Async work handle from all_to_all. 

949 out_perm: Output buffer from all_to_all. 

950 group: Process group. 

951 world_size: World size. 

952 concat_dim: Dimension for concatenation. 

953 split_dim: Dimension for split. 

954 handle_box: Optional mutable list; backward appends (work, out_perm) here. 

955 """ 

956 return _TorchAsyncA2AFunction.apply( 

957 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box 

958 ) 

959 

960 @staticmethod 

961 def differentiable_sync_hook(x, hook_name: str, coordinator): 

962 """Identity op that fires coordinator rendezvous on forward and backward. 

963 

964 Always goes through ``_TorchSyncHookFunction.apply`` so that the 

965 autograd graph **records a SyncHook node regardless of whether the 

966 coordinator is currently enabled**. Skipping ``apply`` when 

967 disabled would leave warmup-forwarded graphs without the hook 

968 nodes, and a later ``overlap.run`` — whose BWD thread back-props 

969 such a graph — would then traverse zero hooks while the paired FWD 

970 thread (whose current forward DOES record hooks) waits at a 

971 barrier for a partner that never arrives. 

972 

973 Args: 

974 x: Input tensor. 

975 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``. 

976 coordinator: A :class:`HookCoordinator` instance. 

977 """ 

978 return _TorchSyncHookFunction.apply(x, hook_name, coordinator) 

979 

980 @staticmethod 

981 def get_tensor_transform(): 

982 raise NotImplementedError("Unsupported get_tensor_transform for torch platform") 

983 

984 @staticmethod 

985 def construct_strided_slice(x, begin, end, stride): 

986 raise NotImplementedError("Unsupported construct_strided_slice for torch platform") 

987 

988 @staticmethod 

989 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

990 # pylint: disable=C0415 

991 from hyper_parallel.platform.torch.pipeline_parallel._utils import _MicroBatch 

992 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim) 

993 

994 @staticmethod 

995 def get_symmetric_memory_handler(): 

996 # pylint: disable=C0415 

997 from hyper_parallel.platform.torch.symmetric_memory import TorchSymmetricMemoryHandler 

998 symmetric_memory = TorchSymmetricMemoryHandler() 

999 return symmetric_memory 

1000 

1001 @staticmethod 

1002 def get_multicore_handler(): 

1003 # pylint: disable=C0415 

1004 from hyper_parallel.platform.torch.multicore import TorchMulticoreHandler 

1005 return TorchMulticoreHandler() 

1006 

1007 def new_stream(self): 

1008 device = self.get_device_handle() 

1009 return device.Stream() 

1010 

1011 def get_stream_context(self): 

1012 device = self.get_device_handle() 

1013 return device.stream 

1014 

1015 @staticmethod 

1016 def all_gather_object(object_list, obj, group=None) -> None: 

1017 """ 

1018 Gathers objects from the given group into object list. 

1019 

1020 Args: 

1021 object_list (list[Any]): Define the output list, which size equal to the size of group. 

1022 obj (Any): The object on current rank and in given process group. 

1023 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means 

1024 global group. 

1025 

1026 Returns: 

1027 None. Objs are gathered into ``object_list``. 

1028 """ 

1029 dist.all_gather_object(object_list, obj, group) 

1030 

1031 @staticmethod 

1032 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any: 

1033 """ 

1034 Synchronize all processes in the given process group. 

1035 

1036 Args: 

1037 group (ProcessGroup, optional): The process group to work on. Default is ``None``, 

1038 meaning the default process group. 

1039 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``. 

1040 device_ids (list[int], optional): Device ids for backends that require a device for 

1041 barrier (e.g. NCCL). Default: ``None``. 

1042 

1043 Returns: 

1044 Async work handle if ``async_op`` is True; otherwise ``None``. 

1045 """ 

1046 return dist.barrier(group, async_op, device_ids) 

1047 

1048 @staticmethod 

1049 def init_process_group( 

1050 backend: Optional[str] = None, 

1051 *, 

1052 init_method: Optional[str] = None, 

1053 timeout: Optional[timedelta] = None, 

1054 world_size: int = -1, 

1055 rank: int = -1, 

1056 store: Optional[Store] = None, 

1057 pg_options: Optional[Any] = None, 

1058 device_id: Optional[Union[torch.device, int]] = None, 

1059 ) -> None: 

1060 """ 

1061 Initialize global process group. 

1062 

1063 Args: 

1064 backend (str or Backend, optional): The backend to use for distributed communication. 

1065 init_method (str, optional): URL specifying how to initialize the process group. Default is "env://", 

1066 can not be specified at the same time with ``store``. 

1067 timeout (timedelta, optional): Timeout for process group. Default 10 minutes for NCCL and for other 

1068 backends 30 minutes. 

1069 world_size (int, optional): Number of processes. If ``store`` is specified, world_size is required. 

1070 rank (int, optional): Rank of the current process, which value must between 0 and ``world_size``-1. If 

1071 ``store`` is specified, rank is required. 

1072 store (Store, optional): Key/value store accessible to all workers, used to exchange connection/address 

1073 information. Can not be specified at the same time with ``init_method``. 

1074 pg_options (ProcessGroupOptions, optional): Extra options to pass during constructing process groups. 

1075 device_id (torch.device | int, optional): Specific device this process will work on. 

1076 """ 

1077 try: 

1078 _get_default_group() 

1079 # except multi version error 

1080 except (ValueError, RuntimeError): 

1081 if backend is None: 

1082 backend = "hccl" 

1083 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

1084 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

1085 

1086 @staticmethod 

1087 def destroy_process_group(group: Optional[ProcessGroup] = None) -> None: 

1088 """ 

1089 Destroy given process group. 

1090 

1091 Args: 

1092 group (ProcessGroup, optional): Given process group will be destroyed, if not given, all process groups 

1093 will be destroyed. 

1094 """ 

1095 group = group or _get_default_group() 

1096 if group in EXISTING_COMM_GROUPS.values(): 

1097 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group] 

1098 for k in keys_to_destroy: 

1099 del EXISTING_COMM_GROUPS[k] 

1100 dist.destroy_process_group(group) 

1101 

1102 @staticmethod 

1103 def get_process_group_ranks(group: Optional[ProcessGroup] = None) -> list[int]: 

1104 """ 

1105 Get all ranks relative to given process group. 

1106 

1107 Args: 

1108 group (Optional[ProcessGroup]): Process group worked on. Default is ``None``, and ``None`` means global 

1109 group. 

1110 

1111 Returns: 

1112 Rank list. 

1113 """ 

1114 group = group or _get_default_group() 

1115 return dist.get_process_group_ranks(group) 

1116 

1117 @staticmethod 

1118 def get_backend(group: Optional[ProcessGroup] = None) -> Backend: 

1119 """ 

1120 Get the backend of the given process group. 

1121 

1122 Args: 

1123 group (ProcessGroup, optional): Process group worked on. Default is ``None``, and ``None`` means global 

1124 group. 

1125 

1126 Returns: 

1127 The backend object of the given process group. 

1128 """ 

1129 group = group or _get_default_group() 

1130 return dist.get_backend(group) 

1131 

1132 @staticmethod 

1133 def split_group(parent_pg: Optional[ProcessGroup] = None, 

1134 split_ranks: Optional[list] = None, 

1135 timeout: Optional[timedelta] = None, 

1136 pg_options: Optional[Any] = None, 

1137 group_desc: Optional[str] = None, 

1138 ) -> Optional[ProcessGroup]: 

1139 """ 

1140 Create split groups for every group rank in split_ranks, and return the split process group which relative to 

1141 current rank id. 

1142 

1143 Args: 

1144 parent_pg (Optional[ProcessGroup]): A process group which the goal group split from. 

1145 split_ranks (Optional[list]): A list like ``list[list[int]]``. 

1146 timeout (Optional[timedelta]): Timeout for process group. Default 10 minutes for NCCL and for other 

1147 backend 30 minutes. 

1148 pg_options (Optional[Any]): Extra options to pass during constructing process groups. 

1149 group_desc (Optional[str]): Description of process group. 

1150 

1151 Return: 

1152 Optional[ProcessGroup]: One of split process group which relative to current rank id 

1153 """ 

1154 if split_ranks is None or len(split_ranks) == 0: 

1155 raise ValueError("split_ranks cannot be None or empty") 

1156 

1157 split_group = None 

1158 for split_rank in split_ranks: 

1159 dist_group = TorchPlatform.get_created_group(split_rank) 

1160 if dist_group is None: 

1161 dist_group = dist.new_group(ranks=split_rank) 

1162 EXISTING_COMM_GROUPS[str(tuple(sorted(split_rank)))] = dist_group 

1163 if TorchPlatform.get_rank() in split_rank: 

1164 split_group = dist_group 

1165 

1166 return split_group 

1167 

1168 @staticmethod 

1169 def get_group_local_rank(group: ProcessGroup = None) -> int: 

1170 """get group local rank id.""" 

1171 group = group or _get_default_group() 

1172 return group.rank() 

1173 

1174 @staticmethod 

1175 def no_grad(): 

1176 return torch.no_grad() 

1177 

1178 @staticmethod 

1179 def relu(tensor): 

1180 return torch.relu(tensor) 

1181 

1182 @staticmethod 

1183 def cat(tensors, dim=0): 

1184 return torch.cat(tensors, dim=dim) 

1185 

1186 @staticmethod 

1187 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

1188 return torch.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory) 

1189 

1190 def get_current_stream(self): 

1191 device = self.get_device_handle() 

1192 return device.current_stream() 

1193 

1194 def new_event(self): 

1195 device = self.get_device_handle() 

1196 return device.Event() 

1197 

1198 def tree_map(self, fn, tree): 

1199 return torch.utils._pytree.tree_map(fn, tree) # pylint: disable=protected-access 

1200 

1201 @property 

1202 def checkpoint(self): 

1203 return torch.utils.checkpoint.checkpoint 

1204 

1205 @staticmethod 

1206 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

1207 # pylint: disable=C0415 

1208 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import FuncModule 

1209 if callable(module) and not isinstance(module, torch.nn.Module): 

1210 module = FuncModule(module) 

1211 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs) 

1212 

1213 @staticmethod 

1214 def swap_wrapper(module, policy_fn=None): 

1215 # pylint: disable=C0415 

1216 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_wrapper 

1217 return swap_wrapper(module, policy_fn=policy_fn) 

1218 

1219 @staticmethod 

1220 def swap_tensor_wrapper(target, tag=None): 

1221 # pylint: disable=C0415 

1222 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_tensor_wrapper 

1223 return swap_tensor_wrapper(target, tag=tag) 

1224 

1225 @property 

1226 def noop_context_fn(self): 

1227 return noop_context_fn 

1228 

1229 @staticmethod 

1230 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

1231 # pylint: disable=C0415 

1232 from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts 

1233 return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation) 

1234 

1235 @staticmethod 

1236 def async_save_on_cpu(policy_fn=None): 

1237 # pylint: disable=C0415 

1238 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import AsyncSaveOnCpu 

1239 return AsyncSaveOnCpu(policy_fn) 

1240 

1241 @staticmethod 

1242 def get_element_size(tensor): 

1243 """Get Tensor Element Size""" 

1244 return tensor.element_size() 

1245 

1246 @staticmethod 

1247 def tensor_to_numpy(tensor) -> np.ndarray: 

1248 """Convert PyTorch tensor to numpy array.""" 

1249 return tensor.cpu().numpy() 

1250 

1251 @staticmethod 

1252 def clip_grad_norm_( 

1253 parameters, max_norm, norm_type=2.0, 

1254 error_if_nonfinite=False, foreach=None, 

1255 ): 

1256 # pylint: disable=C0415 

1257 from hyper_parallel.platform.torch.clip_grad import ( 

1258 clip_grad_norm_ as _clip_grad_norm, 

1259 ) 

1260 return _clip_grad_norm( 

1261 parameters, max_norm, norm_type, 

1262 error_if_nonfinite=error_if_nonfinite, foreach=foreach, 

1263 ) 

1264 

1265 @staticmethod 

1266 def profiler_record(name): 

1267 """Profiler context manager for recording operations using torch.profiler.""" 

1268 return torch.profiler.record_function(name) 

1269 

1270 def cast_fp_tensor(self, dtype, x): 

1271 """ 

1272 Cast floating-point tensor to target dtype if applicable. 

1273 """ 

1274 if ( 

1275 not isinstance(x, torch.Tensor) 

1276 or not torch.is_floating_point(x) 

1277 or x.dtype == dtype 

1278 ): 

1279 return x 

1280 return x.to(dtype) 

1281 

1282 def apply_to_tensors(self, fn, container): 

1283 """Recursively apply to all tensor in different kinds of container types.""" 

1284 

1285 def apply(x): 

1286 

1287 if isinstance(x, torch.Tensor): 

1288 return fn(x) 

1289 if hasattr(x, "__dataclass_fields__"): 

1290 dc = dataclasses.replace(x) 

1291 changes = { 

1292 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 

1293 } 

1294 return dataclasses.replace(dc, **changes) 

1295 if isinstance(x, OrderedDict): 

1296 od = x.__class__() 

1297 for key, value in x.items(): 

1298 od[key] = apply(value) 

1299 return od 

1300 if isinstance(x, PackedSequence): 

1301 apply(x.data) 

1302 return x 

1303 if isinstance(x, dict): 

1304 return {key: apply(value) for key, value in x.items()} 

1305 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"): 

1306 res = (apply(el) for el in x) 

1307 return type(x)(*res) 

1308 if isinstance(x, (list, tuple, set)): 

1309 return type(x)(apply(el) for el in x) 

1310 return x 

1311 

1312 return apply(container) 

1313 

1314 

1315 @property 

1316 def meta_device(self): 

1317 return torch.device("meta") 

1318 

1319 def init_on_device(self, device, include_buffers=False): 

1320 return _init_on_device(device, include_buffers=include_buffers) 

1321 

1322 def str_to_dtype(self, dtype_str: str) -> torch.dtype: 

1323 """Map ``torch.<type>`` strings from checkpoint metadata to ``torch.dtype``.""" 

1324 parts = dtype_str.split(".", 1) 

1325 if len(parts) != 2: 

1326 raise ValueError( 

1327 f"Expected dtype string like 'torch.float32', got {dtype_str!r}." 

1328 ) 

1329 prefix, name = parts 

1330 if prefix != "torch": 

1331 raise ValueError( 

1332 f"Expected PyTorch dtype string with prefix 'torch', got {dtype_str!r}." 

1333 ) 

1334 dtype = getattr(torch, name) 

1335 if isinstance(dtype, torch.dtype): 

1336 return dtype 

1337 raise ValueError(f"{dtype_str!r} does not resolve to a torch.dtype.") 

1338 

1339 def list_to_size(self, size_list: list[int]) -> torch.Size: 

1340 return torch.Size(size_list)