Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / platform.py: 56%
619 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch platform api"""
16from datetime import timedelta
17from typing import Optional, Any, Union
18import dataclasses
19from collections import OrderedDict
21import numpy as np
22from safetensors.torch import save_file, load_file
23import torch
24from torch import nn
25from torch import Tensor
26from torch._C._distributed_c10d import Store, ProcessGroup
27from torch.distributed import Backend
28from torch.distributed.distributed_c10d import _get_default_group
29from torch.nn import Parameter, Module
30from torch.nn.utils.rnn import PackedSequence
31from torch._ops import OpOverload, OpOverloadPacket
32from torch.utils.checkpoint import noop_context_fn
33from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
34import torch.distributed.nn.functional as dist_func
35import torch.distributed as dist
36from hyper_parallel.platform.torch.dtensor import DTensorBase
37from hyper_parallel.platform.torch.pipeline_parallel.stage import PipelineStageBase
38from hyper_parallel.platform.torch.group_utils import create_sub_groups
39from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS
40from hyper_parallel.platform.torch.function_override import override_functions
41from hyper_parallel.platform.torch.init_weights import init_on_device as _init_on_device
43override_functions()
46# ---------------------------------------------------------------------------
47# Module-level A2A reshape helpers
48# ---------------------------------------------------------------------------
50def _a2a_reconstruct(out_perm: torch.Tensor, concat_dim: int) -> torch.Tensor:
51 """Reconstruct A2A result from raw out_perm buffer.
53 ``out_perm`` has shape ``[ws, *rest_dims]``, chunk at ``concat_dim + 1``.
54 Returns tensor with merged chunk dimension.
55 """
56 new_ndim = out_perm.dim()
57 chunk_in_perm = concat_dim + 1
58 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim))
59 x_recon = out_perm.permute(recon_perm).contiguous()
60 shape = list(x_recon.shape)
61 merged = shape[concat_dim] * shape[concat_dim + 1]
62 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:])
65class _TorchAsyncA2AFunction(torch.autograd.Function):
66 """Differentiable wrapper for pre-launched async all-to-all.
68 Forward: wait async handle, reconstruct A2A result.
69 Backward: launch async head→seq A2A and store handle in ``handle_box``
70 for the projection pre-hook to wait, achieving GEMM–A2A overlap.
71 """
73 @staticmethod
74 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=arguments-differ
75 handle_box):
76 """Wait for pre-launched async A2A and return reconstructed output."""
77 ctx.group = group
78 ctx.world_size = world_size
79 ctx.concat_dim = concat_dim
80 ctx.split_dim = split_dim
81 ctx.handle_box = handle_box
82 ctx.x_shape = x.shape
83 work.wait()
84 return _a2a_reconstruct(out_perm, concat_dim)
86 @staticmethod
87 def backward(ctx, grad_output):
88 """Launch async head→seq A2A for backward overlap, or return zero grad."""
89 if ctx.handle_box is not None:
90 # Launch async head→seq A2A (reverse of forward seq→head)
91 g = grad_output.contiguous()
92 shape = list(g.shape)
93 seq_dim = ctx.concat_dim
94 s_full = shape[seq_dim]
95 ndim = len(shape) + 1
96 x_perm = g.reshape(
97 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:]
98 ).permute(
99 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim))
100 ).contiguous()
101 out_perm = torch.empty_like(x_perm)
102 work = dist.all_to_all_single(out_perm, x_perm, group=ctx.group, async_op=True)
103 ctx.handle_box.append((work, out_perm))
104 return grad_output.new_zeros(ctx.x_shape), None, None, None, None, None, None, None
107class _AsyncA2ALazyBwd(torch.autograd.Function):
108 """All-to-all whose forward AND backward return ``AsyncCollectiveTensor``.
110 PyTorch's stock ``all_to_all_single_autograd`` calls ``wait_tensor`` in
111 its backward eagerly, and the autograd engine binds backward stream
112 context to the forward stream — so even if the BWD thread is wrapped
113 in a side-stream context, that wait still lands on the FWD main
114 stream and blocks Attention launches.
116 This Function bypasses the engine's binding by calling the
117 non-autograd functional op in both directions and returning ACT.
118 The wait is deferred to the next consumer's first non-view access
119 (e.g. the indexing backward of ``_unpermute``), giving the FWD
120 thread a small Python window to enqueue its Attention kernels onto
121 the main stream **before** the wait lands there.
122 """
124 @staticmethod
125 def forward(ctx, input_tensor, output_splits, input_splits, group): # pylint: disable=arguments-differ
126 ctx.input_splits = input_splits
127 ctx.output_splits = output_splits
128 ctx.group = group
129 # pylint: disable=C0415
130 from torch.distributed._functional_collectives import all_to_all_single
131 return all_to_all_single(
132 input_tensor, output_splits, input_splits, group,
133 )
135 @staticmethod
136 def backward(ctx, grad_output):
137 # pylint: disable=C0415
138 from torch.distributed._functional_collectives import all_to_all_single
139 grad_input = all_to_all_single(
140 grad_output, ctx.input_splits, ctx.output_splits, ctx.group,
141 )
142 return grad_input, None, None, None
145class _TorchSyncHookFunction(torch.autograd.Function):
146 """Autograd identity that fires HookCoordinator rendezvous on fwd/bwd.
148 Uses a **4-hook** design (``A``, ``B``, ``C``, ``D``) with pure
149 COMM / COMPUTE roles — no NONE role. Every rendezvous is a strict
150 COMM + COMPUTE pair, guaranteeing NCCL-first dispatch ordering at
151 **all** points including layer boundaries.
153 Hook placement per MoE layer::
155 [A] → dispatch → [B] → module → [C] → combine → [D] → (Attention) → [A_next]
157 At layer boundaries (D / A hooks), the Attention that runs between
158 layers is treated as COMPUTE, and the combine / combine.bwd is treated
159 as COMM, so the coordinator enforces comm-first ordering even across
160 layer transitions.
161 """
163 # 4-hook role tables: (prev_role_idx, next_role_idx).
164 # Index encoding: 1 = COMM, 2 = COMPUTE.
165 _FWD_ROLES = {
166 # (prev, next) prev op next op
167 "A": (2, 1), # COMPUTE, COMM Attention | dispatch
168 "B": (1, 2), # COMM, COMPUTE dispatch | module
169 "C": (2, 1), # COMPUTE, COMM module | combine
170 "D": (1, 2), # COMM, COMPUTE combine | Attention
171 }
172 _BWD_ROLES = {
173 "D": (2, 1), # COMPUTE, COMM Attn.bwd | combine.bwd
174 "C": (1, 2), # COMM, COMPUTE combine.bwd | module.bwd
175 "B": (2, 1), # COMPUTE, COMM module.bwd | dispatch.bwd
176 "A": (1, 2), # COMM, COMPUTE dispatch.bwd| Attn.bwd
177 }
179 _ROLE_CACHE = None
181 @staticmethod
182 def _role_enum(idx: int):
183 if _TorchSyncHookFunction._ROLE_CACHE is None:
184 from hyper_parallel.core.pipeline_parallel.hook_coordinator import HookRole # pylint: disable=C0415
185 _TorchSyncHookFunction._ROLE_CACHE = (None, HookRole.COMM, HookRole.COMPUTE)
186 return _TorchSyncHookFunction._ROLE_CACHE[idx]
188 @staticmethod
189 def forward(ctx, x, hook_name, coordinator): # pylint: disable=arguments-differ
190 """Identity forward that fires a HookCoordinator rendezvous.
192 Notifies the previous op's role and rendezvouses for the next op's
193 role per the ``_FWD_ROLES`` table. ``"D_LAST"`` is a sentinel
194 meaning "skip this rendezvous" (last layer's closing D — no
195 Attention follows).
197 Args:
198 ctx: Autograd context, stores ``hook_name`` and
199 ``coordinator`` for the backward pass.
200 x: Input tensor, returned unchanged.
201 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``,
202 ``"D_LAST"``.
203 coordinator: The :class:`HookCoordinator` driving the rendezvous.
205 Returns:
206 ``x`` unchanged.
207 """
208 ctx.hook_name = hook_name
209 ctx.coordinator = coordinator
211 if not coordinator.is_enabled():
212 return x
214 # ``D_LAST`` marks the last layer's D hook. The "next op" after
215 # this hook is the chunk's output (no Attention follows), so the
216 # rendezvous is meaningless — skip it. In backward this same
217 # hook is the very first BWD hook to fire, where ``combine.bwd``
218 # has already free-run before any rendezvous is possible — also
219 # skip. Tagging at wrap time replaces the old runtime
220 # ``increment_cycle`` / ``bwd_d_should_skip`` mechanisms.
221 if hook_name == "D_LAST":
222 return x
224 prev_idx, next_idx = _TorchSyncHookFunction._FWD_ROLES[hook_name]
225 role_of = _TorchSyncHookFunction._role_enum
226 coordinator.notify_dispatched(role_of(prev_idx))
227 coordinator.rendezvous(role_of(next_idx))
228 return x
230 @staticmethod
231 def backward(ctx, grad_output):
232 """Identity backward that fires a HookCoordinator rendezvous.
234 Mirror of :meth:`forward` using the ``_BWD_ROLES`` table.
235 ``"D_LAST"`` skips the rendezvous because this is the first BWD
236 hook to fire and ``combine.bwd`` has already dispatched freely
237 before any rendezvous can happen.
239 Args:
240 ctx: Autograd context with ``hook_name`` and
241 ``coordinator`` saved during forward.
242 grad_output: Gradient w.r.t. the forward output, returned
243 unchanged.
245 Returns:
246 ``(grad_output, None, None)`` — gradients only flow back to
247 the tensor input, ``hook_name`` and ``coordinator`` are
248 non-tensor inputs.
249 """
250 hook_name = ctx.hook_name
251 coordinator = ctx.coordinator
253 if not coordinator.is_enabled():
254 return grad_output, None, None
256 # Same ``D_LAST`` semantics as forward: this is the first BWD
257 # hook to fire and combine.bwd has already dispatched freely
258 # before any rendezvous can happen, so skip the rendezvous.
259 if hook_name == "D_LAST":
260 return grad_output, None, None
262 prev_idx, next_idx = _TorchSyncHookFunction._BWD_ROLES[hook_name]
263 role_of = _TorchSyncHookFunction._role_enum
264 coordinator.notify_dispatched(role_of(prev_idx))
265 coordinator.rendezvous(role_of(next_idx))
266 return grad_output, None, None
269class _TorchP2PExchangeFunction(torch.autograd.Function):
270 """Symmetric bidirectional P2P: send local tensor to peer, receive peer's tensor."""
272 @staticmethod
273 def forward(ctx, tensor: torch.Tensor, peer_rank: int, group) -> torch.Tensor: # pylint: disable=arguments-differ
274 """Perform symmetric bidirectional P2P exchange with peer_rank."""
275 ctx.peer_rank = peer_rank
276 ctx.group = group
277 send_buf = tensor.contiguous()
278 recv_buf = torch.empty_like(send_buf)
279 reqs = dist.batch_isend_irecv([
280 dist.P2POp(dist.isend, send_buf, peer_rank, group),
281 dist.P2POp(dist.irecv, recv_buf, peer_rank, group),
282 ])
283 for req in reqs:
284 req.wait()
285 return recv_buf
287 @staticmethod
288 def backward(ctx, grad_output: torch.Tensor):
289 """Perform symmetric P2P exchange for the backward gradient pass."""
290 send_buf = grad_output.contiguous()
291 recv_buf = torch.empty_like(send_buf)
292 reqs = dist.batch_isend_irecv([
293 dist.P2POp(dist.isend, send_buf, ctx.peer_rank, ctx.group),
294 dist.P2POp(dist.irecv, recv_buf, ctx.peer_rank, ctx.group),
295 ])
296 for req in reqs:
297 req.wait()
298 return recv_buf, None, None
301# Mapping from string op names to torch.distributed.ReduceOp
302_OP_MAP = {
303 'sum': dist.ReduceOp.SUM,
304 'prod': dist.ReduceOp.PRODUCT,
305 'max': dist.ReduceOp.MAX,
306 'min': dist.ReduceOp.MIN,
307 # convert tensor elements to int32 and use MIN
308 'all': dist.ReduceOp.MIN,
309 # 'avg' is typically handled by SUM followed by division in current implementation logic
310 'avg': dist.ReduceOp.SUM,
311}
313# Try to add AVG for 'mean' if supported by current torch version
314if hasattr(dist.ReduceOp, "AVG"):
315 _OP_MAP['mean'] = dist.ReduceOp.AVG
316else:
317 # Fallback for older torch versions if necessary, though this might require manual division upstream
318 # Assuming standard behavior where 'mean' implies native AVG support or upstream handling
319 _OP_MAP['mean'] = dist.ReduceOp.SUM
322# pylint: disable=C0103
323class TorchPlatform(Platform):
324 """Torch platform api"""
325 Tensor = Tensor
326 tensor = torch.tensor
327 Parameter = Parameter
328 Module = Module
329 DTensorBase = DTensorBase
330 PipelineStageBase = PipelineStageBase
331 platform_type = PlatformType.PYTORCH
332 tensor_dtype = torch
333 dtype = torch.dtype
334 Function = torch.autograd.Function
336 _custom_ops_cls = None
338 @property
339 def custom_ops(self):
340 """Return the Torch platform custom ops instance.
342 .. warning::
343 This is an experimental API that subject to change or deletion.
345 Returns:
346 TorchCustomOps: Custom ops class that raises NotImplementedError
347 for all operators (MindSpore-only at this time).
348 """
349 if self._custom_ops_cls is None:
350 from hyper_parallel.platform.torch.custom_ops import TorchCustomOps # pylint: disable=import-outside-toplevel
351 self._custom_ops_cls = TorchCustomOps
352 return self._custom_ops_cls
354 @staticmethod
355 def is_linear_module(module) -> bool:
356 """Check whether *module* is a ``torch.nn.Linear`` instance."""
357 return isinstance(module, nn.Linear)
359 @staticmethod
360 def is_embedding_module(module) -> bool:
361 """Check whether *module* is a ``torch.nn.Embedding`` instance."""
362 return isinstance(module, nn.Embedding)
364 @staticmethod
365 def device_count(device_handle):
366 """
367 Get the number of available devices.
369 Args:
370 device_handle: The device handle (e.g., torch.cuda, torch.npu).
372 Returns:
373 int: The number of available devices.
374 """
375 return device_handle.device_count()
377 def device_type(self):
378 """
379 Get the current device type.
381 Returns:
382 str: The device type string ("npu" for NPU, "cuda" for GPU).
383 """
384 device_handle = self.get_device_handle()
385 if device_handle == torch.npu:
386 return "npu"
387 return "cuda"
389 def device(self, device_idx=None):
390 """
391 Get a torch.device object for the specified device index.
393 Args:
394 device_idx (Optional[int]): The device index. If None, returns device without index.
396 Returns:
397 torch.device: A torch device object.
398 """
399 device_type = self.device_type()
400 if device_idx is None:
401 return torch.device(device_type)
402 return torch.device(f"{device_type}:{device_idx:d}")
404 @staticmethod
405 def get_rng_state(device=None, device_handle=None):
406 """
407 Get the random number generator state.
409 Args:
410 device (Optional): The device to get RNG state from.
411 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.).
413 Returns:
414 Tensor: The RNG state as a byte tensor.
415 """
416 if device_handle is None:
417 return torch.get_rng_state()
418 if device is None:
419 return device_handle.get_rng_state()
420 return device_handle.get_rng_state(device)
422 @staticmethod
423 def set_rng_state(state, device=None, device_handle=None):
424 """
425 Set the random number generator state.
427 Args:
428 state (Tensor): The RNG state to set.
429 device (Optional): The device to set RNG state for.
430 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.).
431 """
432 if device_handle is None:
433 return torch.set_rng_state(state)
434 if device is None:
435 return device_handle.set_rng_state(state)
436 return device_handle.set_rng_state(state, device)
438 @staticmethod
439 def manual_seed(seed):
440 """
441 Set the random seed for reproducibility.
443 Args:
444 seed (int): The random seed value.
446 Returns:
447 torch.Generator: The random number generator.
448 """
449 return torch.manual_seed(seed)
451 @staticmethod
452 def ones(size, dtype=None):
453 """
454 Create a tensor filled with ones.
456 Args:
457 size (tuple): The shape of the output tensor.
458 dtype (Optional[torch.dtype]): The desired data type.
460 Returns:
461 Tensor: A tensor filled with ones.
462 """
463 return torch.ones(size, dtype=dtype)
465 @staticmethod
466 def zeros(size, dtype=None, device=None):
467 """
468 Create a tensor filled with zeros.
470 Args:
471 size (tuple): The shape of the output tensor.
472 dtype (Optional[torch.dtype]): The desired data type.
473 device (Optional[torch.device]): The device to create the tensor on.
475 Returns:
476 Tensor: A tensor filled with zeros.
477 """
478 return torch.zeros(size, dtype=dtype, device=device)
480 @staticmethod
481 def full(size, fill_value, dtype=None):
482 """
483 Create a tensor filled with a scalar value.
485 Args:
486 size (tuple): The shape of the output tensor.
487 fill_value (scalar): The value to fill the tensor with.
488 dtype (Optional[torch.dtype]): The desired data type.
490 Returns:
491 Tensor: A tensor filled with the specified value.
492 """
493 return torch.full(size, fill_value, dtype=dtype)
495 @staticmethod
496 def empty(size, dtype=None):
497 """
498 Create an uninitialized tensor.
500 Args:
501 size (tuple): The shape of the output tensor.
502 dtype (Optional[torch.dtype]): The desired data type.
504 Returns:
505 Tensor: An uninitialized tensor.
506 """
507 return torch.empty(size, dtype=dtype)
509 @staticmethod
510 def get_rank():
511 """
512 Get the rank of the current process in the distributed group.
514 Returns:
515 int: The rank of the current process.
516 """
517 return dist.get_rank()
519 @staticmethod
520 def get_global_rank(group, group_rank):
521 """
522 Get the global rank from a group rank.
524 Args:
525 group (ProcessGroup): The process group.
526 group_rank (int): The rank within the group.
528 Returns:
529 int: The global rank.
530 """
531 return dist.get_global_rank(group, group_rank)
533 @staticmethod
534 def get_world_size():
535 """
536 Get the total number of processes in the distributed group.
538 Returns:
539 int: The world size.
540 """
541 return dist.get_world_size()
543 @staticmethod
544 def get_param_local_shape(param):
545 """
546 Get the local shape of a parameter, handling both regular and distributed tensors.
548 Args:
549 param (Union[Tensor, DTensorBase]): The parameter tensor.
551 Returns:
552 torch.Size: The local shape of the parameter.
553 """
554 if isinstance(param, DTensorBase):
555 return param.local_shape
556 return param.shape
558 @staticmethod
559 def get_param_local_data(param):
560 """
561 Get the local data of a parameter, handling both regular and distributed tensors.
563 Args:
564 param (Union[Tensor, DTensorBase]): The parameter tensor.
566 Returns:
567 Tensor: The local tensor data.
568 """
569 if isinstance(param, DTensorBase):
570 return param.to_local()
571 return param
573 @staticmethod
574 def update_param_data(param, data):
575 """
576 Update the data of a parameter.
578 Args:
579 param (Parameter): The parameter to update.
580 data (Tensor): The new data tensor.
581 """
582 param.data = data
584 @staticmethod
585 def load_into_param(param, data):
586 """Load tensor *data* into *param* (plain tensor or DTensor)."""
587 if isinstance(param, DTensorBase):
588 local = param._local_tensor # pylint: disable=W0212
589 if local.is_meta:
590 # Meta tensor materialisation: replace the placeholder.
591 orig_requires_grad = param.requires_grad
592 param._local_tensor = data # pylint: disable=W0212
593 if data.requires_grad != orig_requires_grad:
594 param.requires_grad_(orig_requires_grad)
595 else:
596 local.copy_(data)
597 else:
598 param.copy_(data)
600 @staticmethod
601 def get_op_name(func):
602 """
603 Extract the operation name from various function types.
605 Args:
606 func: The function or operation to extract the name from.
608 Returns:
609 str: The operation name.
610 """
611 if hasattr(func, "__name__"):
612 return func.__name__
613 if isinstance(func, OpOverload):
614 full_name = func.name
615 core_name = full_name.split("::")[-1].split(".")[0]
616 return core_name
617 if isinstance(func, OpOverloadPacket):
618 return func.name.split("::")[-1]
619 func_str = str(func)
620 if "built-in function" in func_str:
621 return func_str.split()[-1].strip(">")
622 if "function" in func_str:
623 return func_str.split()[1]
624 return "unknown_op"
626 @staticmethod
627 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
628 output = dist_func.all_gather(data, group=group)
629 return torch.cat(output, dim=concat_dim)
631 @staticmethod
632 def chunk(data, split_dim, split_size, index):
633 return torch.chunk(data, split_size, dim=split_dim)[index]
635 @staticmethod
636 def differentiable_all_to_all(input_data, output_shape, group):
637 output_tensor = torch.empty(output_shape, device=input_data.device, dtype=input_data.dtype)
638 output_tensor = dist_func.all_to_all_single(
639 output_tensor,
640 input_data,
641 group=group
642 )
643 return output_tensor
645 @staticmethod
646 def tensor_type_cast(input_data, cast_type):
647 """Cast tensor to specified data type."""
648 type_mapping = {
649 'float32': torch.float32,
650 'float16': torch.float16,
651 'int64': torch.int64,
652 'int32': torch.int32
653 }
654 if cast_type not in type_mapping:
655 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}")
656 return input_data.to(type_mapping[cast_type])
658 @staticmethod
659 def differentiable_all_reduce(data, op, group):
660 # Resolve the op from string to ReduceOp enum if necessary
661 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op
662 return dist_func.all_reduce(data, op=reduce_op, group=group)
664 @staticmethod
665 def get_cell_construct(cell):
666 return cell.forward
668 @staticmethod
669 def get_cells_and_names(cell):
670 return cell.named_modules()
672 @staticmethod
673 def get_modules(module):
674 return module.modules()
676 @staticmethod
677 def search_parameter_by_name(cell, param_name: str):
678 """
679 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter.
680 Return value: (parent Module instance, parameter's name in parent Module, parameter object).
681 Returns None if not found.
682 """
683 # Remove the "self." prefix from param_name
684 param_name = param_name.replace("self.", "")
685 # Case 1: The parameter is a direct parameter of the current Module
686 if param_name in cell._parameters: # pylint: disable=protected-access
687 return (cell, param_name, cell._parameters[param_name]) # pylint: disable=protected-access
689 # Case 2: The parameter is in a sub-Module
690 if "." in param_name:
691 cell_path, param_key = param_name.rsplit(".", 1)
692 try:
693 # Locate the sub-Module where the parameter resides (supports multi-level paths)
694 target_cell = cell.get_submodule(cell_path)
695 # Check if the sub-Module directly contains this parameter
696 if param_key in target_cell._parameters: # pylint: disable=protected-access
697 return target_cell, param_key, target_cell._parameters[param_key] # pylint: disable=protected-access
698 except AttributeError:
699 pass
701 # Traverse all sub-Modules (recursively) to search for the parameter
702 for _, child_cell in cell.named_children():
703 if isinstance(child_cell, Module):
704 result = TorchPlatform.search_parameter_by_name(child_cell, param_name)
705 if result is not None:
706 return result
708 return None
710 @staticmethod
711 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
712 """
713 Modify the original parameter in a Module or sub-Module using the search result
714 """
715 parent_cell, param_key, _ = result
716 # Key operation: directly modify the _parameters dictionary.
717 if param_key in parent_cell._parameters: # pylint: disable=protected-access
718 parent_cell._parameters[param_key] = new_param # pylint: disable=protected-access
719 else:
720 parent_cell.register_parameter(param_key, new_param)
721 return True
723 @staticmethod
724 def set_layout_into_parameter(param, layout):
725 """Set layout into parameter"""
726 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel
727 from hyper_parallel.core.dtensor.layout import _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel
728 if isinstance(param, DTensor):
729 raise ValueError(f"Parameter {param} has been configured layout, cannot be set repeatedly.")
730 requires_grad = param.requires_grad
731 param_dtensor = DTensor.from_local(
732 _get_slice_tensor_by_layout(param, layout),
733 layout.mesh, layout.alias_placements)
734 new_param = Parameter(param_dtensor, requires_grad=requires_grad)
735 return new_param
737 @staticmethod
738 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
739 input_tuple = torch.chunk(data, dev_num, dim=axis)
740 output_tensor = torch.empty(input_tuple[0].shape, device=data.device, dtype=data.dtype)
742 # Resolve the op from string to ReduceOp enum
743 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op
745 output_tensor = dist_func.reduce_scatter(output_tensor, input_tuple, op=reduce_op, group=group)
747 # Keep manual handling for 'avg' string as it maps to SUM in _OP_MAP
748 if op == 'avg':
749 output_tensor = output_tensor / dev_num
750 return output_tensor
752 @staticmethod
753 def get_device_handle(device_type: str = "npu"):
754 try:
755 handle = getattr(torch, device_type)
756 except AttributeError as e:
757 raise RuntimeError(f"TorchPlatform expect got device handle: 'torch.{device_type}' failed.") from e
758 return handle
760 @staticmethod
761 def get_param_type_size(param):
762 # pylint: disable=W0212
763 return torch._utils._element_size(param.dtype)
765 @staticmethod
766 def is_tensor(obj: Any) -> bool:
767 """Return True if ``obj`` is a ``torch.Tensor``."""
768 return isinstance(obj, Tensor)
770 @staticmethod
771 def get_tensor_storage_size(tensor: Any) -> int:
772 """Return serialized byte size (numel * element size) for a PyTorch tensor."""
773 if not TorchPlatform.is_tensor(tensor):
774 raise TypeError(
775 f"TorchPlatform.get_tensor_storage_size expects torch.Tensor, got {type(tensor)!r}"
776 )
777 return int(tensor.numel()) * int(tensor.element_size())
779 @staticmethod
780 def parameters_dict(cell: Module):
781 return cell.named_parameters()
783 @staticmethod
784 def get_model_state_dict(model, *, options=None):
785 # pylint: disable=C0415
786 from hyper_parallel.platform.torch.fully_shard.state_dict_utils import (
787 get_model_state_dict as _get_model_state_dict,
788 )
789 return _get_model_state_dict(model, options=options)
791 @staticmethod
792 def save_checkpoint(cell: Module, file_path: str, ckpt_format: str = "safetensors") -> None:
793 if ckpt_format == "safetensors":
794 save_file(tensors=cell, filename=file_path)
795 else:
796 torch.save(obj=cell, f=file_path)
798 @staticmethod
799 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict:
800 if ckpt_format == "safetensors":
801 return load_file(filename=file_path)
802 return torch.load(f=file_path)
804 @staticmethod
805 def new_zero_parameter(param_shape, param_type, requires_grad, device):
806 return nn.Parameter(torch.zeros(param_shape, dtype=param_type, device=device), requires_grad=requires_grad)
808 @staticmethod
809 def new_tensor(tensor_shape, tensor_type, device):
810 return torch.empty(size=tensor_shape, dtype=tensor_type, device=device)
812 @staticmethod
813 def full_like(tensor, fill_value, dtype=None):
814 return torch.full_like(tensor, fill_value, dtype=dtype)
816 @staticmethod
817 def set_tensor_requires_grad(input_tensor):
818 """
819 set requires grad flag for input tensor, only effective for leaf node
820 """
821 if input_tensor.is_leaf:
822 input_tensor.requires_grad = True
824 def _create_group(self, rank_list):
825 group_dict = create_sub_groups(rank_list)
826 return group_dict[tuple(rank_list)]
828 @staticmethod
829 def all_gather_into_tensor(data, group_info, async_op=False):
830 output_shape = list(data.shape)
831 output_shape[0] = output_shape[0] * group_info.rank_size
832 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
833 handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op)
834 return output, handle
836 @staticmethod
837 def all_reduce(data, group_info, async_op=False):
838 if not data.is_contiguous():
839 data = data.contiguous()
840 handle = dist.all_reduce(data, group=group_info.group, async_op=async_op)
841 return data, handle
843 @staticmethod
844 def broadcast(data, src, group=None, async_op=False):
845 handle = dist.broadcast(data, src, group, async_op)
846 if async_op:
847 handle.wait()
849 @staticmethod
850 def isend(tensor, dst=None, group=None, tag=0):
851 return dist.isend(tensor, dst, group, tag)
853 @staticmethod
854 def irecv(tensor, src=None, group=None, tag=0):
855 return dist.irecv(tensor, src, group, tag)
857 @staticmethod
858 def p2p_exchange(tensor, peer_rank: int, group=None):
859 if peer_rank == dist.get_rank(group):
860 return tensor
861 return _TorchP2PExchangeFunction.apply(tensor, peer_rank, group)
863 @staticmethod
864 def send_object_list(obj_list, dst=None, group=None):
865 dist.send_object_list(obj_list, dst, group)
867 @staticmethod
868 def recv_object_list(obj_list, src=None, group=None):
869 dist.recv_object_list(obj_list, src, group)
871 @staticmethod
872 def reduce_scatter_tensor(data, group_info, async_op=False):
873 output_shape = list(data.shape)
874 output_shape[0] = output_shape[0] // group_info.rank_size
875 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
876 handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op)
877 return output, handle
879 @staticmethod
880 def all_to_all_single(input_tensor, output_shape, group, async_op=False):
881 output = torch.empty(output_shape, device=input_tensor.device, dtype=input_tensor.dtype)
882 work = dist.all_to_all_single(output, input_tensor, group=group, async_op=async_op)
883 return output, work
885 @staticmethod
886 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group):
887 """Variable-split all-to-all with autograd support for EP token dispatch/combine."""
888 out_total = sum(output_splits)
889 output = torch.empty(
890 out_total, *input_tensor.shape[1:],
891 dtype=input_tensor.dtype, device=input_tensor.device,
892 )
893 output = dist_func.all_to_all_single(
894 output, input_tensor,
895 output_split_sizes=output_splits,
896 input_split_sizes=input_splits,
897 group=group,
898 )
899 return output
901 @staticmethod
902 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group):
903 """Truly-async variant of :meth:`differentiable_all_to_all_single`.
905 Both forward AND backward return :class:`AsyncCollectiveTensor`,
906 so the ``wait_tensor`` op is queued lazily — only when a downstream
907 kernel actually reads the result.
909 Why both directions need lazy wait:
911 * FWD: ACT lazy wait lets host return immediately and the paired
912 BWD thread's compute kernel slip into the queue before the wait.
913 * BWD: PyTorch's stock backward issues ``wait_tensor`` eagerly,
914 and the autograd engine binds backward stream to the forward
915 stream — so even running BWD inside a ``with torch.npu.stream
916 (side_stream)`` context does not move that wait off the main
917 stream. Returning ACT from backward defers the wait to the
918 next backward op's first consumption, opening a small window
919 during which FWD's Attention kernels can be queued onto the
920 main stream **before** the wait lands.
922 Args:
923 input_tensor: Input tensor, split along dim 0 by ``input_splits``.
924 input_splits: ``list[int]`` — rows sent to each rank.
925 output_splits: ``list[int]`` — rows received from each rank.
926 group: Process group.
928 Returns:
929 ``AsyncCollectiveTensor`` of shape
930 ``[sum(output_splits), *input_tensor.shape[1:]]``.
931 """
932 return _AsyncA2ALazyBwd.apply(input_tensor, output_splits, input_splits, group)
934 @staticmethod
935 def arange(start, end=None, step=1, dtype=None, device=None):
936 """Create a 1-D tensor with evenly spaced values."""
937 if end is None:
938 return torch.arange(start, dtype=dtype, device=device)
939 return torch.arange(start, end, step, dtype=dtype, device=device)
941 @staticmethod
942 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim,
943 handle_box=None):
944 """Wait async A2A handle and reconstruct result (differentiable).
946 Args:
947 x: Input tensor.
948 work: Async work handle from all_to_all.
949 out_perm: Output buffer from all_to_all.
950 group: Process group.
951 world_size: World size.
952 concat_dim: Dimension for concatenation.
953 split_dim: Dimension for split.
954 handle_box: Optional mutable list; backward appends (work, out_perm) here.
955 """
956 return _TorchAsyncA2AFunction.apply(
957 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box
958 )
960 @staticmethod
961 def differentiable_sync_hook(x, hook_name: str, coordinator):
962 """Identity op that fires coordinator rendezvous on forward and backward.
964 Always goes through ``_TorchSyncHookFunction.apply`` so that the
965 autograd graph **records a SyncHook node regardless of whether the
966 coordinator is currently enabled**. Skipping ``apply`` when
967 disabled would leave warmup-forwarded graphs without the hook
968 nodes, and a later ``overlap.run`` — whose BWD thread back-props
969 such a graph — would then traverse zero hooks while the paired FWD
970 thread (whose current forward DOES record hooks) waits at a
971 barrier for a partner that never arrives.
973 Args:
974 x: Input tensor.
975 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``.
976 coordinator: A :class:`HookCoordinator` instance.
977 """
978 return _TorchSyncHookFunction.apply(x, hook_name, coordinator)
980 @staticmethod
981 def get_tensor_transform():
982 raise NotImplementedError("Unsupported get_tensor_transform for torch platform")
984 @staticmethod
985 def construct_strided_slice(x, begin, end, stride):
986 raise NotImplementedError("Unsupported construct_strided_slice for torch platform")
988 @staticmethod
989 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
990 # pylint: disable=C0415
991 from hyper_parallel.platform.torch.pipeline_parallel._utils import _MicroBatch
992 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim)
994 @staticmethod
995 def get_symmetric_memory_handler():
996 # pylint: disable=C0415
997 from hyper_parallel.platform.torch.symmetric_memory import TorchSymmetricMemoryHandler
998 symmetric_memory = TorchSymmetricMemoryHandler()
999 return symmetric_memory
1001 @staticmethod
1002 def get_multicore_handler():
1003 # pylint: disable=C0415
1004 from hyper_parallel.platform.torch.multicore import TorchMulticoreHandler
1005 return TorchMulticoreHandler()
1007 def new_stream(self):
1008 device = self.get_device_handle()
1009 return device.Stream()
1011 def get_stream_context(self):
1012 device = self.get_device_handle()
1013 return device.stream
1015 @staticmethod
1016 def all_gather_object(object_list, obj, group=None) -> None:
1017 """
1018 Gathers objects from the given group into object list.
1020 Args:
1021 object_list (list[Any]): Define the output list, which size equal to the size of group.
1022 obj (Any): The object on current rank and in given process group.
1023 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means
1024 global group.
1026 Returns:
1027 None. Objs are gathered into ``object_list``.
1028 """
1029 dist.all_gather_object(object_list, obj, group)
1031 @staticmethod
1032 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any:
1033 """
1034 Synchronize all processes in the given process group.
1036 Args:
1037 group (ProcessGroup, optional): The process group to work on. Default is ``None``,
1038 meaning the default process group.
1039 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``.
1040 device_ids (list[int], optional): Device ids for backends that require a device for
1041 barrier (e.g. NCCL). Default: ``None``.
1043 Returns:
1044 Async work handle if ``async_op`` is True; otherwise ``None``.
1045 """
1046 return dist.barrier(group, async_op, device_ids)
1048 @staticmethod
1049 def init_process_group(
1050 backend: Optional[str] = None,
1051 *,
1052 init_method: Optional[str] = None,
1053 timeout: Optional[timedelta] = None,
1054 world_size: int = -1,
1055 rank: int = -1,
1056 store: Optional[Store] = None,
1057 pg_options: Optional[Any] = None,
1058 device_id: Optional[Union[torch.device, int]] = None,
1059 ) -> None:
1060 """
1061 Initialize global process group.
1063 Args:
1064 backend (str or Backend, optional): The backend to use for distributed communication.
1065 init_method (str, optional): URL specifying how to initialize the process group. Default is "env://",
1066 can not be specified at the same time with ``store``.
1067 timeout (timedelta, optional): Timeout for process group. Default 10 minutes for NCCL and for other
1068 backends 30 minutes.
1069 world_size (int, optional): Number of processes. If ``store`` is specified, world_size is required.
1070 rank (int, optional): Rank of the current process, which value must between 0 and ``world_size``-1. If
1071 ``store`` is specified, rank is required.
1072 store (Store, optional): Key/value store accessible to all workers, used to exchange connection/address
1073 information. Can not be specified at the same time with ``init_method``.
1074 pg_options (ProcessGroupOptions, optional): Extra options to pass during constructing process groups.
1075 device_id (torch.device | int, optional): Specific device this process will work on.
1076 """
1077 try:
1078 _get_default_group()
1079 # except multi version error
1080 except (ValueError, RuntimeError):
1081 if backend is None:
1082 backend = "hccl"
1083 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size,
1084 rank=rank, store=store, pg_options=pg_options, device_id=device_id)
1086 @staticmethod
1087 def destroy_process_group(group: Optional[ProcessGroup] = None) -> None:
1088 """
1089 Destroy given process group.
1091 Args:
1092 group (ProcessGroup, optional): Given process group will be destroyed, if not given, all process groups
1093 will be destroyed.
1094 """
1095 group = group or _get_default_group()
1096 if group in EXISTING_COMM_GROUPS.values():
1097 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group]
1098 for k in keys_to_destroy:
1099 del EXISTING_COMM_GROUPS[k]
1100 dist.destroy_process_group(group)
1102 @staticmethod
1103 def get_process_group_ranks(group: Optional[ProcessGroup] = None) -> list[int]:
1104 """
1105 Get all ranks relative to given process group.
1107 Args:
1108 group (Optional[ProcessGroup]): Process group worked on. Default is ``None``, and ``None`` means global
1109 group.
1111 Returns:
1112 Rank list.
1113 """
1114 group = group or _get_default_group()
1115 return dist.get_process_group_ranks(group)
1117 @staticmethod
1118 def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
1119 """
1120 Get the backend of the given process group.
1122 Args:
1123 group (ProcessGroup, optional): Process group worked on. Default is ``None``, and ``None`` means global
1124 group.
1126 Returns:
1127 The backend object of the given process group.
1128 """
1129 group = group or _get_default_group()
1130 return dist.get_backend(group)
1132 @staticmethod
1133 def split_group(parent_pg: Optional[ProcessGroup] = None,
1134 split_ranks: Optional[list] = None,
1135 timeout: Optional[timedelta] = None,
1136 pg_options: Optional[Any] = None,
1137 group_desc: Optional[str] = None,
1138 ) -> Optional[ProcessGroup]:
1139 """
1140 Create split groups for every group rank in split_ranks, and return the split process group which relative to
1141 current rank id.
1143 Args:
1144 parent_pg (Optional[ProcessGroup]): A process group which the goal group split from.
1145 split_ranks (Optional[list]): A list like ``list[list[int]]``.
1146 timeout (Optional[timedelta]): Timeout for process group. Default 10 minutes for NCCL and for other
1147 backend 30 minutes.
1148 pg_options (Optional[Any]): Extra options to pass during constructing process groups.
1149 group_desc (Optional[str]): Description of process group.
1151 Return:
1152 Optional[ProcessGroup]: One of split process group which relative to current rank id
1153 """
1154 if split_ranks is None or len(split_ranks) == 0:
1155 raise ValueError("split_ranks cannot be None or empty")
1157 split_group = None
1158 for split_rank in split_ranks:
1159 dist_group = TorchPlatform.get_created_group(split_rank)
1160 if dist_group is None:
1161 dist_group = dist.new_group(ranks=split_rank)
1162 EXISTING_COMM_GROUPS[str(tuple(sorted(split_rank)))] = dist_group
1163 if TorchPlatform.get_rank() in split_rank:
1164 split_group = dist_group
1166 return split_group
1168 @staticmethod
1169 def get_group_local_rank(group: ProcessGroup = None) -> int:
1170 """get group local rank id."""
1171 group = group or _get_default_group()
1172 return group.rank()
1174 @staticmethod
1175 def no_grad():
1176 return torch.no_grad()
1178 @staticmethod
1179 def relu(tensor):
1180 return torch.relu(tensor)
1182 @staticmethod
1183 def cat(tensors, dim=0):
1184 return torch.cat(tensors, dim=dim)
1186 @staticmethod
1187 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
1188 return torch.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory)
1190 def get_current_stream(self):
1191 device = self.get_device_handle()
1192 return device.current_stream()
1194 def new_event(self):
1195 device = self.get_device_handle()
1196 return device.Event()
1198 def tree_map(self, fn, tree):
1199 return torch.utils._pytree.tree_map(fn, tree) # pylint: disable=protected-access
1201 @property
1202 def checkpoint(self):
1203 return torch.utils.checkpoint.checkpoint
1205 @staticmethod
1206 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
1207 # pylint: disable=C0415
1208 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import FuncModule
1209 if callable(module) and not isinstance(module, torch.nn.Module):
1210 module = FuncModule(module)
1211 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)
1213 @staticmethod
1214 def swap_wrapper(module, policy_fn=None):
1215 # pylint: disable=C0415
1216 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_wrapper
1217 return swap_wrapper(module, policy_fn=policy_fn)
1219 @staticmethod
1220 def swap_tensor_wrapper(target, tag=None):
1221 # pylint: disable=C0415
1222 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_tensor_wrapper
1223 return swap_tensor_wrapper(target, tag=tag)
1225 @property
1226 def noop_context_fn(self):
1227 return noop_context_fn
1229 @staticmethod
1230 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
1231 # pylint: disable=C0415
1232 from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts
1233 return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation)
1235 @staticmethod
1236 def async_save_on_cpu(policy_fn=None):
1237 # pylint: disable=C0415
1238 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import AsyncSaveOnCpu
1239 return AsyncSaveOnCpu(policy_fn)
1241 @staticmethod
1242 def get_element_size(tensor):
1243 """Get Tensor Element Size"""
1244 return tensor.element_size()
1246 @staticmethod
1247 def tensor_to_numpy(tensor) -> np.ndarray:
1248 """Convert PyTorch tensor to numpy array."""
1249 return tensor.cpu().numpy()
1251 @staticmethod
1252 def clip_grad_norm_(
1253 parameters, max_norm, norm_type=2.0,
1254 error_if_nonfinite=False, foreach=None,
1255 ):
1256 # pylint: disable=C0415
1257 from hyper_parallel.platform.torch.clip_grad import (
1258 clip_grad_norm_ as _clip_grad_norm,
1259 )
1260 return _clip_grad_norm(
1261 parameters, max_norm, norm_type,
1262 error_if_nonfinite=error_if_nonfinite, foreach=foreach,
1263 )
1265 @staticmethod
1266 def profiler_record(name):
1267 """Profiler context manager for recording operations using torch.profiler."""
1268 return torch.profiler.record_function(name)
1270 def cast_fp_tensor(self, dtype, x):
1271 """
1272 Cast floating-point tensor to target dtype if applicable.
1273 """
1274 if (
1275 not isinstance(x, torch.Tensor)
1276 or not torch.is_floating_point(x)
1277 or x.dtype == dtype
1278 ):
1279 return x
1280 return x.to(dtype)
1282 def apply_to_tensors(self, fn, container):
1283 """Recursively apply to all tensor in different kinds of container types."""
1285 def apply(x):
1287 if isinstance(x, torch.Tensor):
1288 return fn(x)
1289 if hasattr(x, "__dataclass_fields__"):
1290 dc = dataclasses.replace(x)
1291 changes = {
1292 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
1293 }
1294 return dataclasses.replace(dc, **changes)
1295 if isinstance(x, OrderedDict):
1296 od = x.__class__()
1297 for key, value in x.items():
1298 od[key] = apply(value)
1299 return od
1300 if isinstance(x, PackedSequence):
1301 apply(x.data)
1302 return x
1303 if isinstance(x, dict):
1304 return {key: apply(value) for key, value in x.items()}
1305 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"):
1306 res = (apply(el) for el in x)
1307 return type(x)(*res)
1308 if isinstance(x, (list, tuple, set)):
1309 return type(x)(apply(el) for el in x)
1310 return x
1312 return apply(container)
1315 @property
1316 def meta_device(self):
1317 return torch.device("meta")
1319 def init_on_device(self, device, include_buffers=False):
1320 return _init_on_device(device, include_buffers=include_buffers)
1322 def str_to_dtype(self, dtype_str: str) -> torch.dtype:
1323 """Map ``torch.<type>`` strings from checkpoint metadata to ``torch.dtype``."""
1324 parts = dtype_str.split(".", 1)
1325 if len(parts) != 2:
1326 raise ValueError(
1327 f"Expected dtype string like 'torch.float32', got {dtype_str!r}."
1328 )
1329 prefix, name = parts
1330 if prefix != "torch":
1331 raise ValueError(
1332 f"Expected PyTorch dtype string with prefix 'torch', got {dtype_str!r}."
1333 )
1334 dtype = getattr(torch, name)
1335 if isinstance(dtype, torch.dtype):
1336 return dtype
1337 raise ValueError(f"{dtype_str!r} does not resolve to a torch.dtype.")
1339 def list_to_size(self, size_list: list[int]) -> torch.Size:
1340 return torch.Size(size_list)