Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / platform.py: 64%
359 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""framework platform api"""
16import os
17from datetime import timedelta
18from enum import auto, Enum
19from typing import Optional, Any, Union
21import numpy as np
23# Environment variable name used to specify the AI framework platform to use
24HYPER_PARALLEL_PLATFORM = "HYPER_PARALLEL_PLATFORM"
26# Identifier for the MindSpore framework
27HYPER_PARALLEL_PLATFORM_MINDSPORE = "mindspore"
29# Identifier for the PyTorch framework
30HYPER_PARALLEL_PLATFORM_TORCH = "torch"
33class PlatformType(Enum):
34 """Enumeration class for AI framework platform types.
36 Used to identify different deep learning framework platform types.
37 """
38 MINDSPORE = auto()
39 PYTORCH = auto()
42# Global platform instance, used to cache the created platform object
43platform = None
46def get_mindspore_platform():
47 """Create and return a MindSpore platform instance.
49 Returns:
50 MindSporePlatform: A MindSpore platform instance.
51 """
52 # pylint: disable=C0415
53 from hyper_parallel.platform.mindspore.platform import MindSporePlatform
54 global platform
55 platform = MindSporePlatform()
56 return platform
59def get_torch_platform():
60 """Create and return a PyTorch platform instance.
62 Returns:
63 TorchPlatform: A PyTorch platform instance.
64 """
65 # pylint: disable=C0415
66 from hyper_parallel.platform.torch.platform import TorchPlatform
67 global platform
68 platform = TorchPlatform()
69 return platform
72def get_platform():
73 """Obtain a framework platform instance.
75 Returns the appropriate AI framework platform instance based on environment variables or a default priority order.
76 The lookup priority is as follows:
77 1. Platform specified by environment variable
78 2. MindSpore platform (default preferred choice)
79 3. PyTorch platform (fallback option)
81 Returns:
82 Platform: An instance of the framework platform
84 Raises:
85 ImportError: Raised when none of the supported frameworks are available
86 """
87 if platform is not None:
88 return platform
89 platform_type = os.environ.get(HYPER_PARALLEL_PLATFORM)
90 if platform_type is not None and isinstance(platform_type, str):
91 platform_type = platform_type.lower()
92 if platform_type == HYPER_PARALLEL_PLATFORM_MINDSPORE:
93 return get_mindspore_platform()
94 if platform_type == HYPER_PARALLEL_PLATFORM_TORCH:
95 return get_torch_platform()
96 try:
97 return get_mindspore_platform()
98 except ImportError:
99 return get_torch_platform()
102EXISTING_COMM_GROUPS = {}
105class Platform:
106 """Platform api"""
107 current_grad_handle = None
108 post_grad_handle_process = None
109 grad_sync_stream = None
111 @property
112 def custom_ops(self):
113 """Return the platform-specific custom ops interface.
115 Subclasses MUST override this property to return an object that
116 exposes the platform-specific custom operator implementations.
118 Returns:
119 object: Platform-specific custom ops class instance.
120 """
121 raise NotImplementedError(
122 "Platform subclasses must implement custom_ops"
123 )
125 @staticmethod
126 def get_rank():
127 """Get the rank of the current process in the default process group.
129 Returns:
130 int: The rank of the current process.
131 """
132 raise NotImplementedError("Platform subclasses must implement get_rank")
134 @staticmethod
135 def get_global_rank(group, group_rank):
136 """Convert a group rank to its global rank.
138 Args:
139 group: The process group to query.
140 group_rank (int): The rank within the group.
142 Returns:
143 int: The global rank corresponding to the group rank.
144 """
145 raise NotImplementedError("Platform subclasses must implement get_global_rank")
147 @staticmethod
148 def get_world_size():
149 """Get the total number of processes in the default process group.
151 Returns:
152 int: The world size (total number of processes).
153 """
154 raise NotImplementedError("Platform subclasses must implement get_world_size")
156 @staticmethod
157 def get_op_name(func):
158 """Get the canonical name of an operator function.
160 Args:
161 func: The operator function to query.
163 Returns:
164 str: The canonical name of the operator.
165 """
166 raise NotImplementedError("Platform subclasses must implement get_op_name")
168 @staticmethod
169 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
170 """Perform differentiable all-gather and concatenate tensors along a dimension.
172 Args:
173 data: The input tensor to gather.
174 group: The process group for collective communication.
175 concat_size (int): The size to concatenate along concat_dim.
176 concat_dim (int): The dimension along which to concatenate.
178 Returns:
179 The concatenated tensor after all-gather operation.
180 """
181 raise NotImplementedError("Platform subclasses must implement differentiable_all_gather_concat")
183 @staticmethod
184 def chunk(data, split_dim, split_size, index):
185 """Split tensor along a dimension and return the chunk at the given index.
187 Args:
188 data: The input tensor to split.
189 split_dim (int): The dimension along which to split.
190 split_size (int): The size of each split chunk.
191 index (int): The index of the chunk to return.
193 Returns:
194 The tensor chunk at the specified index.
195 """
196 raise NotImplementedError("Platform subclasses must implement chunk")
198 @staticmethod
199 def differentiable_all_to_all(input_data, output_shape, group):
200 """Perform differentiable all-to-all communication.
202 Args:
203 input_data: The input tensor to redistribute.
204 output_shape: The shape of the output tensor.
205 group: The process group for collective communication.
207 Returns:
208 The output tensor after all-to-all operation.
209 """
210 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all")
212 @staticmethod
213 def tensor_type_cast(input_data, cast_type):
214 """Cast tensor to a specified dtype.
216 Args:
217 input_data: The input tensor to cast.
218 cast_type: The target dtype to cast to.
220 Returns:
221 The tensor cast to the specified dtype.
222 """
223 raise NotImplementedError("Platform subclasses must implement tensor_type_cast")
225 @staticmethod
226 def is_tensor(obj: Any) -> bool:
227 """Return True if ``obj`` is this framework's tensor type."""
228 raise NotImplementedError("Platform subclasses must implement is_tensor")
230 @staticmethod
231 def get_tensor_storage_size(tensor: Any) -> int:
232 """Return serialized byte size (numel * element size) for this framework's tensor."""
233 raise NotImplementedError("Platform subclasses must implement get_tensor_storage_size")
235 @staticmethod
236 def differentiable_all_reduce(data, op, group):
237 """Perform differentiable all-reduce operation.
239 Args:
240 data: The input tensor to reduce.
241 op: The reduction operation (e.g., sum, max, min).
242 group: The process group for collective communication.
244 Returns:
245 The reduced tensor with gradients supported.
246 """
247 raise NotImplementedError("Platform subclasses must implement differentiable_all_reduce")
249 @staticmethod
250 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
251 """Perform differentiable reduce-scatter operation.
253 Args:
254 data: The input tensor to reduce and scatter.
255 dev_num (int): The number of devices to scatter across.
256 axis (int): The axis along which to scatter.
257 op: The reduction operation (e.g., sum, max, min).
258 group: The process group for collective communication.
260 Returns:
261 The scattered tensor chunk with gradients supported.
262 """
263 raise NotImplementedError("Platform subclasses must implement differentiable_reduce_scatter")
265 @staticmethod
266 def init_parameters(module, stage_index):
267 """Initialize parameters for a module at a specific pipeline stage.
269 This method is primarily needed for MindSpore platform which requires
270 explicit parameter initialization interface.
272 Args:
273 module: The module whose parameters need to be initialized.
274 stage_index (int): The pipeline stage index for the module.
276 Raises:
277 ValueError: If module is None or stage_index is negative.
278 """
279 if module is None:
280 raise ValueError("input module must not be none.")
281 if stage_index < 0:
282 raise ValueError("input stage_index must be positive.")
284 @staticmethod
285 def get_cell_construct(cell):
286 """Get the construct (forward) function of a cell/module.
288 Args:
289 cell: The cell or module to get the construct function from.
291 Returns:
292 The construct/forward callable of the cell.
293 """
294 raise NotImplementedError("Platform subclasses must implement get_cell_construct")
296 @staticmethod
297 def get_cells_and_names(cell):
298 """Get all nested cells/modules and their names.
300 Args:
301 cell: The root cell or module to traverse.
303 Returns:
304 list: A list of tuples containing (name, cell) pairs.
305 """
306 raise NotImplementedError("Platform subclasses must implement get_cells_and_names")
308 @staticmethod
309 def get_modules(module):
310 raise NotImplementedError("Platform subclasses must implement get_modules")
312 @staticmethod
313 def search_parameter_by_name(cell, param_name: str):
314 """Search for a parameter by name within a cell/module.
316 Args:
317 cell: The cell or module to search in.
318 param_name (str): The name of the parameter to find.
320 Returns:
321 The parameter if found, otherwise None.
322 """
323 raise NotImplementedError("Platform subclasses must implement search_parameter_by_name")
325 @staticmethod
326 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
327 """Update a parameter by name within a cell/module.
329 Args:
330 cell: The cell or module containing the parameter.
331 result (tuple): A tuple containing (param_name, parameter) to update.
332 new_param: The new parameter value to set.
334 Returns:
335 bool: True if update was successful, False otherwise.
336 """
337 raise NotImplementedError("Platform subclasses must implement update_parameter_by_name")
339 @staticmethod
340 def set_layout_into_parameter(param, layout):
341 """Attach a DTensor layout to a parameter.
343 Args:
344 param: The parameter to attach the layout to.
345 layout: The DTensor layout describing tensor distribution.
346 """
347 raise NotImplementedError("Platform subclasses must implement set_layout_into_parameter")
349 @staticmethod
350 def get_param_local_shape(param):
351 """Get the local shape of a distributed parameter.
353 Args:
354 param: The parameter to query.
356 Returns:
357 tuple: The local shape of the parameter shard.
358 """
359 raise NotImplementedError("Platform subclasses must implement get_param_local_shape")
361 @staticmethod
362 def get_param_local_data(param):
363 """Get the local data tensor of a distributed parameter.
365 Args:
366 param: The parameter to query.
368 Returns:
369 The local tensor data of the parameter shard.
370 """
371 raise NotImplementedError("Platform subclasses must implement get_param_local_data")
373 @staticmethod
374 def update_param_data(param, data):
375 """Update the data of a parameter with new tensor data.
377 Args:
378 param: The parameter to update.
379 data: The new tensor data to assign.
380 """
381 raise NotImplementedError("Platform subclasses must implement update_param_data")
383 @staticmethod
384 def get_param_type_size(param):
385 """Get the size in bytes of a parameter's dtype.
387 Args:
388 param: The parameter to query.
390 Returns:
391 int: The size in bytes of the parameter's data type.
392 """
393 raise NotImplementedError("Platform subclasses must implement get_param_type_size")
395 @staticmethod
396 def new_zero_parameter(param_shape, param_type, requires_grad, device):
397 """Create a new parameter initialized with zeros.
399 Args:
400 param_shape (tuple): The shape of the parameter.
401 param_type: The dtype of the parameter.
402 requires_grad (bool): Whether the parameter requires gradients.
403 device: The device on which to create the parameter.
405 Returns:
406 A new parameter tensor filled with zeros.
407 """
408 raise NotImplementedError("Platform subclasses must implement new_zero_parameter")
410 @staticmethod
411 def new_tensor(tensor_shape, tensor_type, device):
412 """Create a new tensor with the specified shape, dtype, and device.
414 Args:
415 tensor_shape (tuple): The shape of the tensor.
416 tensor_type: The dtype of the tensor.
417 device: The device on which to create the tensor.
419 Returns:
420 A new tensor with uninitialized values.
421 """
422 raise NotImplementedError("Platform subclasses must implement new_tensor")
424 @staticmethod
425 def full_like(tensor, fill_value, dtype=None):
426 """Create a tensor filled with a value, with same shape as input.
428 Args:
429 tensor: The input tensor to copy shape from.
430 fill_value: The value to fill the new tensor with.
431 dtype: Optional dtype for the new tensor. If None, uses input tensor's dtype.
433 Returns:
434 A new tensor filled with the specified value.
435 """
436 raise NotImplementedError("Platform subclasses must implement full_like")
438 @staticmethod
439 def set_tensor_requires_grad(input_tensor):
440 """Enable gradient tracking for a tensor in-place.
442 Args:
443 input_tensor: The tensor to enable gradients for.
445 Returns:
446 The same tensor with requires_grad set to True.
447 """
448 raise NotImplementedError("Platform subclasses must implement set_tensor_requires_grad")
450 @staticmethod
451 def all_gather_into_tensor(data, group_info, async_op=False):
452 """Gather tensors from all ranks into a single output tensor.
454 Args:
455 data: The input tensor to gather.
456 group_info: The process group for collective communication.
457 async_op (bool): If True, returns a work handle for async operation.
459 Returns:
460 The gathered tensor, or a tuple of (tensor, handle) if async_op is True.
461 """
462 raise NotImplementedError("Platform subclasses must implement all_gather_into_tensor")
464 @staticmethod
465 def all_reduce(data, group_info, async_op=False):
466 """Reduce tensors across all ranks using specified operation.
468 Args:
469 data: The input tensor to reduce.
470 group_info: The process group for collective communication.
471 async_op (bool): If True, returns a work handle for async operation.
473 Returns:
474 The reduced tensor, or a tuple of (tensor, handle) if async_op is True.
475 """
476 raise NotImplementedError("Platform subclasses must implement all_reduce")
478 @staticmethod
479 def broadcast(data, src, group, async_op=False):
480 """Broadcast tensor from source rank to all ranks in group.
482 Args:
483 data: The tensor to broadcast (only valid on source rank).
484 src (int): The source rank to broadcast from.
485 group: The process group for collective communication.
486 async_op (bool): If True, returns a work handle for async operation.
488 Returns:
489 The broadcasted tensor, or a tuple of (tensor, handle) if async_op is True.
490 """
491 raise NotImplementedError("Platform subclasses must implement broadcast")
493 @staticmethod
494 def isend(tensor, dst=None, group=None, tag=0):
495 """Send tensor asynchronously to destination rank.
497 Args:
498 tensor: The tensor to send.
499 dst (int, optional): The destination rank. Defaults to None.
500 group: The process group for communication. Defaults to None.
501 tag (int): A tag to identify the send operation. Defaults to 0.
503 Returns:
504 A work handle that can be waited on.
505 """
506 raise NotImplementedError("Platform subclasses must implement isend")
508 @staticmethod
509 def irecv(tensor, src=None, group=None, tag=0):
510 """Receive tensor asynchronously from source rank.
512 Args:
513 tensor: The tensor buffer to receive data into.
514 src (int, optional): The source rank. Defaults to None.
515 group: The process group for communication. Defaults to None.
516 tag (int): A tag to identify the receive operation. Defaults to 0.
518 Returns:
519 A work handle that can be waited on.
520 """
521 raise NotImplementedError("Platform subclasses must implement irecv")
523 @staticmethod
524 def p2p_exchange(tensor, peer_rank: int, group=None):
525 """Differentiable symmetric P2P exchange (send local tensor, receive peer's tensor).
527 Sends ``tensor`` to ``peer_rank`` and simultaneously receives the peer's
528 tensor. The operation is differentiable: the backward pass performs the
529 same symmetric exchange on the upstream gradient.
531 Args:
532 tensor: Local tensor to send.
533 peer_rank (int): Global rank of the communication peer.
534 group: Process group. ``None`` uses the default group.
536 Returns:
537 Tensor received from ``peer_rank``, with the same shape and dtype as
538 the input ``tensor``.
539 """
540 raise NotImplementedError("Platform subclasses must implement p2p_exchange")
542 @staticmethod
543 def send_object_list(obj_list, dst=None, group=None):
544 """Send a list of Python objects to destination rank.
546 Args:
547 obj_list (list): The list of Python objects to send.
548 dst (int, optional): The destination rank. Defaults to None.
549 group: The process group for communication. Defaults to None.
550 """
551 raise NotImplementedError("Platform subclasses must implement send_object_list")
553 @staticmethod
554 def recv_object_list(obj_list, src=None, group=None):
555 """Receive a list of Python objects from source rank.
557 Args:
558 obj_list (list): The list buffer to receive objects into.
559 src (int, optional): The source rank. Defaults to None.
560 group: The process group for communication. Defaults to None.
561 """
562 raise NotImplementedError("Platform subclasses must implement recv_object_list")
564 @staticmethod
565 def reduce_scatter_tensor(data, group_info, async_op=False):
566 """Reduce and scatter tensor across all ranks in group.
568 Args:
569 data: The input tensor to reduce and scatter.
570 group_info: The process group for collective communication.
571 async_op (bool): If True, returns a work handle for async operation.
573 Returns:
574 The scattered tensor chunk, or a tuple of (tensor, handle) if async_op is True.
575 """
576 raise NotImplementedError("Platform subclasses must implement reduce_scatter_tensor")
578 @staticmethod
579 def all_to_all_single(input_tensor, output_shape, group, async_op=False):
580 """All-to-all single collective with optional async execution.
582 Args:
583 input_tensor: Input tensor to scatter.
584 output_shape: Shape of the pre-allocated output tensor.
585 group: Process group (ProcessGroup for torch, group name string for mindspore).
586 async_op: If True, returns a work handle; the output tensor is
587 filled only after ``work.wait()`` is called.
589 Returns:
590 Tuple ``(output, work)`` where *output* is the result tensor and
591 *work* is the async handle (``None`` when ``async_op=False``).
593 Raises:
594 NotImplementedError: Must be implemented by platform subclasses.
595 """
596 raise NotImplementedError("Platform subclasses must implement all_to_all_single")
598 @staticmethod
599 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim,
600 handle_box=None):
601 """Differentiable wrapper that waits for a pre-launched async A2A.
603 Wraps the wait-and-reconstruct step in the platform autograd mechanism
604 so gradients flow correctly through the all-to-all communication.
606 The A2A direction is seq→head (forward): the output gathers along
607 ``concat_dim`` (sequence grows from S/cp to S) and scatters along
608 ``split_dim`` (heads shrink from H to H/ws).
610 In backward, launches an async head→seq A2A on the incoming gradient
611 and appends ``(work, out_perm)`` to ``handle_box`` so the caller can
612 wait just before the projection GEMM, achieving GEMM–A2A overlap.
614 Args:
615 x: Original projection output tensor; anchors the op
616 in the autograd graph.
617 work: Async work handle from ``all_to_all_single(async_op=True)``.
618 out_perm: Output buffer filled once ``work.wait()`` completes
619 (shape ``[ws, ...]``).
620 group: Process group for the reverse A2A in backward.
621 world_size: CP/Ulysses degree.
622 concat_dim: Dimension that is gathered (concatenated) in forward;
623 typically the sequence dimension.
624 split_dim: Dimension that is scattered (split) in forward;
625 typically the head dimension.
626 handle_box: Optional mutable list ``[]``. In backward, ``(work, out_perm)``
627 for the reverse A2A is appended here so the pre-hook can wait.
629 Returns:
630 Result tensor with ``concat_dim`` gathered and ``split_dim`` split,
631 connected to the autograd graph through *x*.
633 Raises:
634 NotImplementedError: Must be implemented by platform subclasses.
635 """
636 raise NotImplementedError("Platform subclasses must implement differentiable_async_a2a_wait")
638 @staticmethod
639 def differentiable_sync_hook(x, hook_name: str, coordinator):
640 """Identity operation that intercepts both forward and backward to call
641 coordinator rendezvous, enabling deterministic comm/compute overlap.
643 This is the differentiable building block for dual-pipe schedules.
644 In the forward pass the coordinator is invoked with the forward-side
645 roles for ``hook_name``; in the backward pass it is invoked with the
646 backward-side roles. The tensor value and gradient flow through
647 unchanged.
649 Args:
650 x: Input tensor. Returned as-is; gradients flow through.
651 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"`` identifying
652 the position relative to MoE dispatch/combine.
653 coordinator: A :class:`HookCoordinator` instance shared between the
654 forward and backward threads.
656 Returns:
657 The same tensor *x*, attached to the autograd graph so that the
658 backward hook will fire.
659 """
660 raise NotImplementedError("Platform subclasses must implement differentiable_sync_hook")
662 @staticmethod
663 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group):
664 """Variable-split all-to-all single that supports gradient flow.
666 Unlike ``all_to_all_single`` (which is not differentiable), this method
667 wraps the collective in an autograd function so gradients are correctly
668 routed back through the reverse all-to-all in the backward pass.
669 Intended for Expert Parallelism token dispatch / combine.
671 Args:
672 input_tensor: Input tensor to scatter. Shape ``[sum(input_splits), *feature_dims]``.
673 input_splits: Per-rank sizes of data sent from this rank (list of ints,
674 length equal to ep_degree).
675 output_splits: Per-rank sizes of data received by this rank (list of ints,
676 length equal to ep_degree).
677 group: Process group (ProcessGroup for torch, group name str for mindspore).
679 Returns:
680 Output tensor of shape ``[sum(output_splits), *feature_dims]``.
682 Raises:
683 NotImplementedError: Must be implemented by platform subclasses.
684 """
685 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all_single")
687 @staticmethod
688 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group):
689 """Async variant of :meth:`differentiable_all_to_all_single`.
691 Same semantics but launches the collective with ``async_op=True`` and
692 only performs a stream-level ``wait`` — the host returns immediately
693 after dispatching the kernel. Intended for dual-pipe comm/compute
694 overlap paths where the paired COMPUTE side's rendezvous notify must
695 fire right after kernel launch (not after the collective actually
696 completes on device).
698 Args:
699 input_tensor: Input tensor to scatter. Shape ``[sum(input_splits), *feature_dims]``.
700 input_splits: Per-rank sizes of data sent from this rank.
701 output_splits: Per-rank sizes of data received by this rank.
702 group: Process group.
704 Returns:
705 Output tensor of shape ``[sum(output_splits), *feature_dims]``.
707 Raises:
708 NotImplementedError: Must be implemented by platform subclasses.
709 """
710 raise NotImplementedError(
711 "Platform subclasses must implement differentiable_all_to_all_single_async"
712 )
714 @staticmethod
715 def arange(start, end=None, step=1, dtype=None, device=None):
716 """Create a 1-D tensor with evenly spaced values.
718 Args:
719 start: Start of interval (inclusive). If *end* is ``None``,
720 treated as the stop value and *start* defaults to 0.
721 end: End of interval (exclusive). Defaults to ``None``.
722 step: Step size. Defaults to ``1``.
723 dtype: Data type. ``None`` uses the framework default (int64).
724 device: Target device.
726 Returns:
727 1-D tensor ``[start, start+step, ..., end)``.
729 Raises:
730 NotImplementedError: Must be implemented by platform subclasses.
731 """
732 raise NotImplementedError("Platform subclasses must implement arange")
734 @staticmethod
735 def zeros(size, dtype=None, device=None):
736 """Create a zero-filled tensor of the given shape.
738 Args:
739 size: Shape of the tensor (a single tuple/list).
740 dtype: Desired data type. ``None`` uses the framework default (float32).
741 device: Target device. ``None`` uses the framework default.
743 Returns:
744 Zero-filled tensor of the specified shape.
746 Raises:
747 NotImplementedError: Must be implemented by platform subclasses.
748 """
749 raise NotImplementedError("Platform subclasses must implement zeros")
751 @staticmethod
752 def parameters_dict(cell):
753 """Get the parameters dictionary of a cell/module.
755 Args:
756 cell: The cell or module to get parameters from.
758 Returns:
759 dict: A dictionary mapping parameter names to parameters.
760 """
761 raise NotImplementedError("Platform subclasses must implement parameters_dict")
763 @staticmethod
764 def get_model_state_dict(model, *, options=None):
765 """Get the state dictionary of a model.
767 Args:
768 model: The model to extract state from.
769 options: Optional configuration for state dict extraction.
771 Returns:
772 dict: The state dictionary containing model parameters and buffers.
773 """
774 raise NotImplementedError(
775 "Platform subclasses must implement get_model_state_dict"
776 )
778 @staticmethod
779 def save_checkpoint(cell, file_path: str, ckpt_format: str = "safetensors") -> None:
780 """Save a cell/module checkpoint to file.
782 Args:
783 cell: The cell or module to save.
784 file_path (str): The path to save the checkpoint to.
785 ckpt_format (str): The file format.
786 """
787 raise NotImplementedError("Platform subclasses must implement save_checkpoint")
789 @staticmethod
790 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict:
791 """Load a checkpoint from file.
793 Args:
794 file_path (str): The path to load the checkpoint from.
795 ckpt_format (str): The file format.
797 Returns:
798 dict: The loaded checkpoint state dictionary.
799 """
800 raise NotImplementedError("Platform subclasses must implement load_checkpoint")
802 def _create_group(self, rank_list):
803 """Create a new process group with the specified ranks.
805 Internal method to be implemented by subclasses.
807 Args:
808 rank_list (list): List of ranks to include in the group.
810 Returns:
811 The newly created process group.
812 """
813 raise NotImplementedError("Platform subclasses must implement _create_group")
815 def new_stream(self):
816 """Create a new compute stream for asynchronous operations.
818 Returns:
819 A new stream object for the current device.
820 """
821 raise NotImplementedError("Platform subclasses must implement new_stream")
823 def get_stream_context(self):
824 """Get a context manager for executing operations on a specific stream.
826 Returns:
827 A context manager that can be used with 'with' statement to set stream.
828 """
829 raise NotImplementedError("Platform subclasses must implement get_stream_context")
831 @staticmethod
832 def get_tensor_transform():
833 """Get the tensor transformation utilities for the current framework.
835 Returns:
836 A module or object containing tensor transformation functions.
837 """
838 raise NotImplementedError("Platform subclasses must implement get_tensor_transform")
840 @staticmethod
841 def construct_strided_slice(x, begin, end, stride):
842 """Construct a strided slice operation on a tensor.
844 Args:
845 x: The input tensor to slice.
846 begin: The starting indices for each dimension.
847 end: The ending indices for each dimension.
848 stride: The stride for each dimension.
850 Returns:
851 The sliced tensor.
852 """
853 raise NotImplementedError("Platform subclasses must implement construct_strided_slice")
855 @staticmethod
856 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
857 """Split inputs into micro-batches for pipeline parallelism.
859 Args:
860 micro_batch_num (int): The number of micro-batches to create.
861 args_batch_dim (list, optional): Batch dimension for each positional arg.
862 kwargs_batch_dim (dict, optional): Batch dimension for each keyword arg.
864 Returns:
865 A decorator that splits function inputs into micro-batches.
866 """
867 raise NotImplementedError("Platform subclasses must implement micro_batch")
869 @staticmethod
870 def get_symmetric_memory_handler():
871 raise NotImplementedError("Platform subclasses must implement get_symmetric_memory_handler")
873 @staticmethod
874 def load_into_param(param, data):
875 raise NotImplementedError("Platform subclasses must implement load_into_param")
877 def create_group(self, rank_list):
878 """Create or retrieve a communication group with the specified ranks.
880 If a group with the same rank list already exists, returns the existing
881 group instead of creating a new one.
883 Args:
884 rank_list (list): List of ranks to include in the group.
886 Returns:
887 The process group for the specified ranks.
888 """
889 group_key = str(tuple(sorted(rank_list)))
890 if group_key in EXISTING_COMM_GROUPS:
891 return EXISTING_COMM_GROUPS[group_key]
893 group = self._create_group(rank_list)
894 EXISTING_COMM_GROUPS[group_key] = group
895 return group
897 @staticmethod
898 def _process_current_handle():
899 """Wait for the current gradient handle and execute post-process callback.
901 Internal method to synchronize pending gradient operations.
902 """
903 if Platform.current_grad_handle is None:
904 return
906 Platform.current_grad_handle.wait()
907 if Platform.post_grad_handle_process is None:
908 return
909 # pylint: disable=E1102
910 Platform.post_grad_handle_process()
912 def set_grad_reduce_handle(self, handle, post_process=None):
913 """Set a new gradient reduction handle after waiting for the current one.
915 Waits for any pending gradient handle on the grad sync stream, then
916 sets the new handle and optional post-process callback.
918 Args:
919 handle: The async work handle for gradient reduction.
920 post_process (callable, optional): Callback to run after handle completes.
921 """
922 if Platform.grad_sync_stream is None:
923 Platform.grad_sync_stream = self.new_stream()
924 stream_context = self.get_stream_context()
925 with stream_context(Platform.grad_sync_stream):
926 Platform._process_current_handle()
927 Platform.current_grad_handle = handle
928 Platform.post_grad_handle_process = post_process
930 def wait_grad_handle(self):
931 """Wait for the current gradient handle to complete.
933 Blocks until the current gradient reduction handle completes and
934 clears the handle state.
935 """
936 if Platform.current_grad_handle is None:
937 return
938 if Platform.grad_sync_stream is None:
939 Platform.grad_sync_stream = self.new_stream()
940 stream_context = self.get_stream_context()
941 with stream_context(Platform.grad_sync_stream):
942 Platform._process_current_handle()
943 sync_event = Platform.grad_sync_stream.record_event()
944 sync_event.wait()
945 Platform.current_grad_handle = None
946 Platform.post_grad_handle_process = None
948 @staticmethod
949 def all_gather_object(object_list, obj, group=None) -> None:
950 """Gather Python objects from all ranks into a list.
952 Each rank contributes its object, and all ranks receive the complete list.
954 Args:
955 object_list (list): List to store gathered objects (output parameter).
956 obj: The Python object from this rank to contribute.
957 group: The process group for communication. Defaults to None (default group).
958 """
959 raise NotImplementedError("Platform subclasses must implement all_gather_object")
961 @staticmethod
962 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any:
963 """Synchronize all processes in the given process group.
965 Each rank blocks until every rank in the group enters this collective (when ``async_op``
966 is False), or returns an async handle that must be completed before proceeding.
968 Args:
969 group: The process group or communication group. ``None`` uses the default group.
970 async_op (bool): If True, returns a backend-specific async work handle. Default: False.
971 device_ids: Optional device id list; semantics depend on the backend.
973 Returns:
974 Async work handle when ``async_op`` is True; otherwise ``None`` (unless the rank
975 is not in the group, in which case the backend may return ``None``).
976 """
977 raise NotImplementedError("Platform subclasses must implement barrier")
979 @staticmethod
980 def init_process_group(
981 backend: Optional[str] = None,
982 *,
983 init_method: Optional[str] = None,
984 timeout: Optional[timedelta] = None,
985 world_size: int = -1,
986 rank: int = -1,
987 store: Any = None,
988 pg_options: Any = None,
989 device_id: Any = None
990 ) -> None:
991 """
992 Initialize the default distributed process group.
994 Args:
995 backend: The backend to use for distributed communication
996 init_method: URL specifying how to initialize the process group
997 timeout: Timeout for operations executed against the process group
998 world_size: Number of processes participating in the job
999 rank: Rank of the current process
1000 store: Key/value store for exchanging connection information
1001 pg_options: Process group options for backend-specific configurations
1002 device_id: Specific device this process will work on
1004 Raises:
1005 NotImplementedError: This method must be implemented by subclasses
1006 """
1007 raise NotImplementedError("Platform subclasses must implement init_process_group")
1009 @staticmethod
1010 def destroy_process_group(group=None) -> None:
1011 """
1012 Destroy a given process group.
1014 Args:
1015 group: The process group to be destroyed. If None, destroys the default group.
1017 Raises:
1018 NotImplementedError: This method must be implemented by subclasses
1019 """
1020 raise NotImplementedError("Platform subclasses must implement destroy_process_group")
1022 @staticmethod
1023 def get_process_group_ranks(group=None) -> list[int]:
1024 """
1025 Get rank list of the given process group.
1027 Args:
1028 group: The process group to get ranks from. If None, uses the default group.
1030 Returns:
1031 List of ranks in the specified process group.
1033 Raises:
1034 NotImplementedError: This method must be implemented by subclasses
1035 """
1036 raise NotImplementedError("Platform subclasses must implement get_process_group_ranks")
1038 @staticmethod
1039 def get_backend(group=None):
1040 """
1041 Get the backend of the given process group.
1042 Args:
1043 group: The process group to get backend from. If None, uses the default group.
1045 Returns:
1046 The backend name of the specified process group.
1048 Raises:
1049 NotImplementedError: This method must be implemented by subclasses
1050 """
1051 raise NotImplementedError("Platform subclasses must implement get_backend")
1053 @staticmethod
1054 def split_group(parent_pg: Any = None,
1055 split_ranks: Optional[list] = None,
1056 timeout: Optional[timedelta] = None,
1057 pg_options: Optional[Any] = None,
1058 group_desc: Optional[str] = None,
1059 ) -> Any:
1060 """Create a split group relative to the parent process group.
1062 Args:
1063 parent_pg: The parent process group to split from.
1064 split_ranks (list, optional): Ranks to include in the split group.
1065 timeout (timedelta, optional): Timeout for operations.
1066 pg_options: Process group options for backend-specific configurations.
1067 group_desc (str, optional): Description of the group.
1069 Returns:
1070 The new split process group.
1071 """
1072 raise NotImplementedError("Platform subclasses must implement split_group")
1074 @staticmethod
1075 def get_group_local_rank(group=None) -> int:
1076 """Get the local rank within the given process group.
1078 Args:
1079 group: The process group to query. If None, uses the default group.
1081 Returns:
1082 int: The local rank within the group.
1083 """
1084 raise NotImplementedError("Platform subclasses must implement get_group_local_rank")
1086 @staticmethod
1087 def no_grad():
1088 """Get a context manager to disable gradient computation.
1090 Returns:
1091 A context manager that disables gradient tracking.
1092 """
1093 raise NotImplementedError("Platform subclasses must implement no_grad")
1095 @staticmethod
1096 def relu(tensor):
1097 """Apply ReLU activation element-wise.
1099 Args:
1100 tensor: Input tensor.
1102 Returns:
1103 Tensor with ReLU applied (max(0, x)).
1104 """
1105 raise NotImplementedError("Platform subclasses must implement relu")
1107 @staticmethod
1108 def cat(tensors, dim=0):
1109 """Concatenate tensors along a dimension."""
1110 raise NotImplementedError("Platform subclasses must implement cat")
1112 @staticmethod
1113 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
1114 """Create an uninitialized tensor with the same shape as input.
1116 Args:
1117 tensor: The input tensor to copy shape from.
1118 dtype: Optional dtype for the new tensor. If None, uses input tensor's dtype.
1119 device: Optional device for the new tensor. If None, uses input tensor's device.
1120 pin_memory (bool): If True, allocate pinned memory for faster CPU-GPU transfer.
1122 Returns:
1123 An uninitialized tensor with the same shape as input.
1124 """
1125 raise NotImplementedError("Platform subclasses must implement empty_like")
1127 def get_current_stream(self):
1128 """Get the current compute stream for the device.
1130 Returns:
1131 The current stream object.
1132 """
1133 raise NotImplementedError("Platform subclasses must implement get_current_stream")
1135 def new_event(self):
1136 """Create a new event for stream synchronization.
1138 Returns:
1139 A new event object.
1140 """
1141 raise NotImplementedError("Platform subclasses must implement new_event")
1143 def tree_map(self, fn, tree):
1144 """Apply a function to all tensors in a nested structure.
1146 Args:
1147 fn (callable): Function to apply to each tensor.
1148 tree: Nested structure (list, tuple, dict) containing tensors.
1150 Returns:
1151 The same nested structure with fn applied to all tensors.
1152 """
1153 raise NotImplementedError("Platform subclasses must implement tree_map")
1155 @staticmethod
1156 def is_linear_module(module) -> bool:
1157 """Check whether *module* is a linear/dense layer for the current framework.
1159 Args:
1160 module: The module instance to check.
1162 Returns:
1163 True if *module* is the framework's linear layer type.
1164 """
1165 raise NotImplementedError("Platform subclasses must implement is_linear_module")
1167 @staticmethod
1168 def is_embedding_module(module) -> bool:
1169 """Check whether *module* is an embedding layer for the current framework.
1171 Args:
1172 module: The module instance to check.
1174 Returns:
1175 True if *module* is the framework's embedding layer type.
1176 """
1177 raise NotImplementedError("Platform subclasses must implement is_embedding_module")
1179 @staticmethod
1180 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False):
1181 """Register a forward pre-hook on a module.
1183 Args:
1184 module: The module to register the hook on.
1185 hook (callable): The hook function to register.
1186 prepend (bool): If True, prepend the hook to existing hooks.
1187 with_kwargs (bool): If True, hook receives both args and kwargs.
1189 Returns:
1190 A handle that can be used to remove the hook.
1191 """
1192 return module.register_forward_pre_hook(hook, prepend=prepend, with_kwargs=with_kwargs)
1194 @staticmethod
1195 def register_full_backward_hook(module, hook, prepend=False):
1196 """Register a full backward hook on a module.
1198 Args:
1199 module: The module to register the hook on.
1200 hook (callable): The hook function to register.
1201 prepend (bool): If True, prepend the hook to existing hooks.
1203 Returns:
1204 A handle that can be used to remove the hook.
1205 """
1206 return module.register_full_backward_hook(hook, prepend)
1208 @staticmethod
1209 def register_full_backward_pre_hook(module, hook, prepend=False):
1210 """Register a full backward pre-hook on a module.
1212 Args:
1213 module: The module to register the hook on.
1214 hook (callable): The hook function to register.
1215 prepend (bool): If True, prepend the hook to existing hooks.
1217 Returns:
1218 A handle that can be used to remove the hook.
1219 """
1220 return module.register_full_backward_pre_hook(hook, prepend)
1222 @property
1223 def checkpoint(self):
1224 """Get the checkpoint function for activation checkpointing.
1226 Returns:
1227 The checkpoint function for the current framework.
1228 """
1229 raise NotImplementedError("Platform subclasses must implement checkpoint")
1231 @staticmethod
1232 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
1233 """Wrap a module with checkpoint functionality.
1235 Args:
1236 module: The module to wrap with checkpointing.
1237 checkpoint_fn: Optional custom checkpoint function.
1238 **checkpoint_fn_kwargs: Additional kwargs for checkpoint function.
1240 Returns:
1241 The wrapped module with checkpointing enabled.
1242 """
1243 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper")
1245 @staticmethod
1246 def swap_wrapper(module, policy_fn=None):
1247 """Wrap a module with activation swap functionality.
1249 Args:
1250 module: The module to wrap with activation swap.
1251 policy_fn: Optional per-tensor swap policy function.
1253 Returns:
1254 The wrapped module with activation swap enabled.
1255 """
1256 raise NotImplementedError("Platform subclasses must implement swap_wrapper")
1258 @staticmethod
1259 def swap_tensor_wrapper(target, tag=None):
1260 """Register target tensors into the current swap group.
1262 Args:
1263 target: A tensor or nested container of tensors to register.
1264 tag: Optional debug tag associated with the wrapped tensors.
1266 Returns:
1267 The original target structure, unchanged semantically.
1268 """
1269 raise NotImplementedError("Platform subclasses must implement swap_tensor_wrapper")
1271 @property
1272 def noop_context_fn(self):
1273 """Get a no-op context function for checkpointing.
1275 Returns:
1276 A context function that performs no operation.
1277 """
1278 raise NotImplementedError("Platform subclasses must implement noop_context_fn")
1280 @staticmethod
1281 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
1282 """Create contexts for selective activation checkpointing.
1284 Args:
1285 policy_fn_or_list: A policy function or list of layer names to checkpoint.
1286 allow_cache_entry_mutation (bool): Whether to allow cache entry mutation.
1288 Returns:
1289 Context functions for selective checkpointing.
1290 """
1291 raise NotImplementedError("Platform subclasses must implement create_selective_checkpoint_contexts")
1293 @staticmethod
1294 def async_save_on_cpu(policy_fn=None):
1295 """Create an async CPU offload context for activation checkpointing.
1297 Args:
1298 policy_fn: Optional policy function to determine which activations to offload.
1300 Returns:
1301 Context manager for async CPU offloading during checkpointing.
1302 """
1303 raise NotImplementedError("Platform subclasses must implement async_save_on_cpu")
1305 @staticmethod
1306 def get_element_size(tensor):
1307 """Get Tensor Element Size"""
1308 raise NotImplementedError("Platform subclasses must implement get_element_size")
1310 @staticmethod
1311 def tensor_to_numpy(tensor) -> np.ndarray:
1312 """Convert a framework tensor to a NumPy array.
1314 Args:
1315 tensor: The tensor to convert.
1317 Returns:
1318 np.ndarray: The tensor data as a NumPy array.
1319 """
1320 raise NotImplementedError("Platform subclasses must implement tensor_to_numpy")
1322 @staticmethod
1323 def profiler_record(name):
1324 """Record a profiler event with the given name.
1326 Args:
1327 name (str): The name of the profiler event.
1329 Returns:
1330 A context manager or decorator for profiling a code region.
1331 """
1332 raise NotImplementedError("Platform subclasses must implement profiler_record")
1334 def cast_fp_tensor(self, dtype, x):
1335 """Cast floating-point tensor to target dtype if applicable.
1337 Args:
1338 dtype: The target dtype to cast to.
1339 x: The input tensor.
1341 Returns:
1342 The tensor cast to target dtype, or unchanged if not floating-point.
1343 """
1344 raise NotImplementedError("Platform subclasses must implement cast_fp_tensor")
1346 def apply_to_tensors(self, fn, container):
1347 """Recursively apply a function to all tensors in a container.
1349 Supports nested structures including lists, tuples, and dicts.
1351 Args:
1352 fn (callable): Function to apply to each tensor.
1353 container: Nested structure containing tensors.
1355 Returns:
1356 The same structure with fn applied to all tensors.
1357 """
1358 raise NotImplementedError("Platform subclasses must implement apply_to_tensors")
1360 @staticmethod
1361 def clip_grad_norm_(
1362 parameters, max_norm: float, norm_type: float = 2.0,
1363 error_if_nonfinite: bool = False, foreach=None,
1364 ):
1365 """Compute and clip gradient norms for distributed models.
1367 Communication is derived from each parameter's DTensor spec.
1368 Subclasses must implement this method.
1370 Args:
1371 parameters: An ``nn.Module``, a single ``Tensor``, or an
1372 iterable of ``Tensor`` s whose gradients to clip.
1373 max_norm: Maximum allowed gradient norm.
1374 norm_type: Type of the norm (default ``2.0``).
1375 error_if_nonfinite: If ``True``, raise when total norm is
1376 non-finite. Default ``False``.
1377 foreach: Unused, accepted for API compatibility.
1379 Returns:
1380 The total (unclipped) gradient norm.
1381 """
1382 raise NotImplementedError(
1383 "Platform subclasses must implement clip_grad_norm_"
1384 )
1386 @staticmethod
1387 def get_created_group(rank_list: Union[list[int], tuple[int]]):
1388 """Get an existing process group by rank list.
1390 Args:
1391 rank_list (Union[list[int], tuple[int]]): Tuple or list of ranks.
1393 Returns:
1394 The process group corresponding to the rank list if it exists, else None.
1395 """
1396 group_key = str(tuple(sorted(rank_list)))
1397 if group_key in EXISTING_COMM_GROUPS:
1398 return EXISTING_COMM_GROUPS[group_key]
1399 return None
1401 @classmethod
1402 def mark_created_groups(cls, process_group: Union[Any, list[Any]]) -> None:
1403 """Register process groups in the global cache for reuse.
1405 Args:
1406 process_group (Union[Any, list[Any]]): A process group or a list of process groups.
1407 """
1408 if not isinstance(process_group, list):
1409 process_group = [process_group]
1410 for group in process_group:
1411 rank_list = cls.get_process_group_ranks(group)
1412 group_key = str(tuple(sorted(rank_list)))
1413 EXISTING_COMM_GROUPS[group_key] = group
1415 @property
1416 def meta_device(self):
1417 """Get the framework-specific meta device for tensor shape inference.
1419 The meta device allows creating tensors without allocating actual storage,
1420 useful for shape inference and model initialization.
1422 Returns:
1423 The meta device object for the current framework.
1424 """
1425 raise NotImplementedError("Platform subclasses must implement meta_device")
1427 def init_on_device(self, device, include_buffers=False):
1428 """Get a context manager for initializing module parameters on a device.
1430 Args:
1431 device: The target device for parameter initialization.
1432 include_buffers (bool): If True, also initialize buffers on the device.
1434 Returns:
1435 A context manager for device-specific initialization.
1436 """
1437 raise NotImplementedError("Platform subclasses must implement init_on_device")
1439 def str_to_dtype(self, dtype_str: str) -> Any:
1440 """
1441 Map a framework-style dtype string (e.g. ``torch.float32``) to the backend dtype object.
1443 Args:
1444 dtype_str (str): Serialized dtype identifier produced by checkpoint metadata.
1446 Returns:
1447 Framework dtype object (e.g. ``torch.dtype`` or MindSpore dtype).
1448 """
1449 raise NotImplementedError("Platform subclasses must implement str_to_dtype")
1451 def list_to_size(self, size_list: list[int]) -> Any:
1452 """
1453 Convert a shape list from checkpoint metadata to the framework's size type (e.g. ``torch.Size``).
1455 Args:
1456 size_list (list[int]): Tensor global shape as a list of ints.
1458 Returns:
1459 Framework-specific size object.
1460 """
1461 raise NotImplementedError("Platform subclasses must implement list_to_size")