Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / platform.py: 44%
565 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore platform api"""
16from datetime import timedelta
17from typing import Any, Optional, Union
18import dataclasses
19from collections import OrderedDict
21import contextlib
22import numpy as np
23import mindspore as ms
24import mindspore.common.dtype as mstype
25from mindspore.mint.distributed import TCPStore
27from mindspore.nn import Cell
28from mindspore import mint
29from mindspore.common.api import _no_grad
30from mindspore.common._grad_function import _Function
31from mindspore.common.dtype import type_size_in_bytes
32from mindspore.common.parameter import Parameter
33from mindspore.common.tensor import Tensor
34from mindspore.common.initializer import initializer
35from mindspore.common.recompute import null_context_fn
36from mindspore.communication import GlobalComm
37from mindspore.communication import get_group_size
38from mindspore.communication import create_group as new_group
39from mindspore.communication import get_rank as get_rank_id
40from mindspore.ops import communication as ops_comm
41from mindspore.ops.function import comm_func
42from mindspore._c_expression import TensorTransform
43import mindspore.mint.distributed as dist
45from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS
46from hyper_parallel.platform.mindspore.dtensor import DTensorBase
47from hyper_parallel.platform.mindspore.pipeline_parallel.stage import PipelineStageBase
48from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters
49from hyper_parallel.platform.mindspore.init_weights import (
50 init_on_device as _init_on_device,
51 _install_cell_to_empty_patch,
52)
54comm_func.set_comm_ops_inplace(False)
55_tensor_transform = TensorTransform.get_instance()
58# pylint: disable=C0103
61def _a2a_reconstruct_ms(out_perm: Tensor, concat_dim: int) -> Tensor:
62 """Reconstruct A2A result from raw out_perm buffer."""
63 new_ndim = out_perm.dim()
64 chunk_in_perm = concat_dim + 1
65 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim))
66 x_recon = out_perm.permute(recon_perm).contiguous()
67 shape = list(x_recon.shape)
68 merged = shape[concat_dim] * shape[concat_dim + 1]
69 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:])
72def _normalize_all_to_all_single_result(result, output: Tensor) -> tuple[Tensor, object]:
73 """Normalize MindSpore all_to_all_single return values to ``(output, handle)``."""
74 if isinstance(result, tuple):
75 if len(result) != 2:
76 raise ValueError(
77 "mindspore all_to_all_single returned an unexpected tuple "
78 f"with length {len(result)}"
79 )
80 return result
81 return output, result
84def _mindspore_all_to_all_single(input_tensor: Tensor, output_shape, group, async_op=False) -> tuple[Tensor, object]:
85 """Launch MindSpore all_to_all_single and normalize return values."""
86 output = mint.empty(tuple(output_shape), dtype=input_tensor.dtype)
87 result = ops_comm.all_to_all_single(output, input_tensor, group=group, async_op=async_op)
88 normalized_output, handle = _normalize_all_to_all_single_result(result, output)
89 if not async_op:
90 return normalized_output, None
91 return normalized_output, handle
94class _MSAsyncA2AFunction(_Function):
95 """Differentiable wrapper for pre-launched async all-to-all."""
97 @staticmethod
98 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box): # pylint: disable=arguments-differ
99 """Wait for pre-launched async A2A and return reconstructed output."""
100 ctx.group = group
101 ctx.world_size = world_size
102 ctx.concat_dim = concat_dim
103 ctx.split_dim = split_dim
104 ctx.handle_box = handle_box
105 ctx.x_shape = tuple(x.shape)
106 work.wait()
107 return _a2a_reconstruct_ms(out_perm, concat_dim)
109 @staticmethod
110 def backward(ctx, grad_output):
111 """Launch async head->seq A2A for backward overlap, or return zero grad."""
112 if ctx.handle_box is not None:
113 g = grad_output.contiguous()
114 shape = list(g.shape)
115 seq_dim = ctx.concat_dim
116 s_full = shape[seq_dim]
117 ndim = len(shape) + 1
118 x_perm = g.reshape(
119 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:]
120 ).permute(
121 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim))
122 ).contiguous()
123 out_perm, work = _mindspore_all_to_all_single(
124 x_perm,
125 list(x_perm.shape),
126 ctx.group,
127 async_op=True,
128 )
129 ctx.handle_box.append((work, out_perm))
130 return mint.zeros(ctx.x_shape, dtype=grad_output.dtype), None, None, None, None, None, None, None
133class MindSporePlatform(Platform):
134 """MindSpore platform api"""
135 Tensor = Tensor
136 tensor = Tensor
137 Parameter = Parameter
138 Module = Cell
139 DTensorBase = DTensorBase
140 PipelineStageBase = PipelineStageBase
141 platform_type = PlatformType.MINDSPORE
142 tensor_dtype = mstype
143 dtype = ms.Type
144 Function = _Function
146 _custom_ops_cls = None
148 @property
149 def custom_ops(self):
150 """Return the MindSpore platform custom ops instance.
152 .. warning::
153 This is an experimental API that subject to change or deletion.
155 Returns:
156 MindSporeCustomOps: Custom ops class that delegates to DFunction
157 implementations wrapping Ascend NPU custom C++ kernels.
158 """
159 if self._custom_ops_cls is None:
160 from hyper_parallel.platform.mindspore.custom_ops.custom_ops import ( # pylint: disable=import-outside-toplevel
161 MindSporeCustomOps,
162 )
163 self._custom_ops_cls = MindSporeCustomOps
164 return self._custom_ops_cls
166 def __init__(self):
167 # Ensure MindSpore ``nn.Cell.to_empty`` is patched as soon as the
168 # MindSpore platform instance is created.
169 _install_cell_to_empty_patch()
171 @staticmethod
172 def is_linear_module(module) -> bool:
173 """Check whether *module* is a MindSpore ``Dense`` (linear) or ``mint.nn.Linear`` layer."""
174 return isinstance(module, (ms.nn.Dense, mint.nn.Linear))
176 @staticmethod
177 def is_embedding_module(module) -> bool:
178 """Check whether *module* is a MindSpore ``Embedding`` or ``mint.nn.Embedding`` layer."""
179 return isinstance(module, (ms.nn.Embedding, mint.nn.Embedding))
181 def device_count(self, device_handle):
182 """
183 Get the number of available devices.
185 Args:
186 device_handle: The device handle (e.g., ms.device_context).
188 Returns:
189 int: The number of available devices.
190 """
191 device_type = self.device_type()
192 if device_type == "cpu":
193 return device_handle.device_context.cpu.device_count()
194 if device_type == "gpu":
195 return device_handle.device_context.gpu.device_count()
196 return device_handle.device_context.ascend.device_count()
198 @staticmethod
199 def get_rng_state(device=None, device_handle=None):
200 """
201 Get the random number generator state.
203 Args:
204 device (Optional): The device to get RNG state from (not used in MindSpore).
205 device_handle (Optional): The device handle (not used in MindSpore).
207 Returns:
208 Tensor: The RNG state as a tensor.
209 """
210 _ = device, device_handle
211 return ms.get_rng_state()
213 @staticmethod
214 def set_rng_state(state, device=None, device_handle=None):
215 """
216 Set the random number generator state.
218 Args:
219 state (Tensor): The RNG state to set.
220 device (Optional): The device to set RNG state for (not used in MindSpore).
221 device_handle (Optional): The device handle (not used in MindSpore).
222 """
223 _ = device, device_handle
224 return ms.set_rng_state(state)
226 def device_type(self):
227 """
228 Get the current device type.
230 Returns:
231 str: The device type string ("npu" for Ascend, "gpu" for GPU, "cpu" for CPU).
232 """
233 device_type = ms.get_context("device_target")
234 if device_type == "Ascend":
235 return "npu"
236 return device_type.lower()
238 def device(self, device_idx=None):
239 """
240 Get the device type string.
242 Args:
243 device_idx (Optional[int]): The device index (not used in MindSpore).
245 Returns:
246 str: The device type string.
247 """
248 _ = device_idx
249 device_type = self.device_type()
250 return device_type
252 @staticmethod
253 def get_device_handle():
254 """
255 Get the MindSpore module as the device handle.
257 Returns:
258 module: The mindspore module.
259 """
260 return ms
262 @staticmethod
263 def manual_seed(seed):
264 """
265 Set the random seed for reproducibility.
267 Args:
268 seed (int): The random seed value.
270 Returns:
271 None
272 """
273 return ms.manual_seed(seed)
275 @staticmethod
276 def ones(size, dtype=None):
277 """
278 Create a tensor filled with ones.
280 Args:
281 size (tuple): The shape of the output tensor.
282 dtype (Optional[ms.Type]): The desired data type.
284 Returns:
285 Tensor: A tensor filled with ones.
286 """
287 return mint.ones(size, dtype=dtype)
289 @staticmethod
290 def zeros(size, dtype=None, device=None):
291 """
292 Create a tensor filled with zeros.
294 Args:
295 size (tuple): The shape of the output tensor.
296 dtype (Optional[ms.Type]): The desired data type.
297 device (Optional[ms.device]): The device to create the tensor on.
299 Returns:
300 Tensor: A tensor filled with zeros.
301 """
302 tensor = mint.zeros(size, dtype=dtype)
303 if device in ("GPU", "Ascend"):
304 return tensor.to(device)
305 return tensor
307 @staticmethod
308 def full(size, fill_value, dtype=None):
309 """
310 Create a tensor filled with a scalar value.
312 Args:
313 size (tuple): The shape of the output tensor.
314 fill_value (scalar): The value to fill the tensor with.
315 dtype (Optional[ms.Type]): The desired data type.
317 Returns:
318 Tensor: A tensor filled with the specified value.
319 """
320 return mint.full(size, fill_value, dtype=dtype)
322 @staticmethod
323 def empty(size, dtype=None):
324 """
325 Create an uninitialized tensor.
327 Args:
328 size (tuple): The shape of the output tensor.
329 dtype (Optional[ms.Type]): The desired data type.
331 Returns:
332 Tensor: An uninitialized tensor.
333 """
334 return mint.empty(size, dtype=dtype)
336 @staticmethod
337 def get_rank():
338 """
339 Get the rank of the current process in the distributed group.
341 Returns:
342 int: The rank of the current process.
343 """
344 return get_rank_id()
346 @staticmethod
347 def get_global_rank(group, group_rank):
348 """
349 Get the global rank from a group rank.
351 Args:
352 group (str): The process group name.
353 group_rank (int): The rank within the group.
355 Returns:
356 int: The global rank.
357 """
358 return dist.get_global_rank(group, group_rank)
360 @staticmethod
361 def get_world_size():
362 """
363 Get the total number of processes in the distributed group.
365 Returns:
366 int: The world size.
367 """
368 return get_group_size()
370 @staticmethod
371 def get_op_name(func):
372 """
373 Extract the operation name from a function.
375 Args:
376 func: The function to extract the name from.
378 Returns:
379 str: The operation name.
380 """
381 return func.name
383 @staticmethod
384 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
385 output, _ = comm_func.all_gather_into_tensor(None, data, group=group)
386 if concat_dim == 0:
387 return output
388 output_tensors = ms.ops.Split(output_num=concat_size)(output)
389 return ms.mint.concat(output_tensors, concat_dim)
391 @staticmethod
392 def chunk(data, split_dim, split_size, index):
393 return ms.ops.Split(axis=split_dim, output_num=split_size)(data)[index]
395 @staticmethod
396 def differentiable_all_to_all(input_data, output_shape, group):
397 output_tensor, _ = comm_func.all_to_all_single(
398 output_shape,
399 input_data,
400 group=group,
401 async_op=False
402 )
403 return output_tensor
405 @staticmethod
406 def tensor_type_cast(input_data, cast_type):
407 """Cast tensor to specified data type."""
408 type_mapping = {
409 'float32': ms.float32,
410 'float16': ms.float16,
411 'int64': ms.int64,
412 'int32': ms.int32
413 }
414 if cast_type not in type_mapping:
415 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}")
416 return input_data.to(type_mapping[cast_type])
418 @staticmethod
419 def differentiable_all_reduce(data, op, group):
420 output, _ = comm_func.all_reduce(data, op, group)
421 return output
423 @staticmethod
424 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
425 if axis > 0:
426 data = ms.mint.concat(ms.ops.Split(axis=axis, output_num=dev_num)(data), dim=0)
427 output_tensor, _ = comm_func.reduce_scatter_tensor(None, data, 'sum', group)
428 if op == 'avg':
429 output_tensor = output_tensor / dev_num
430 return output_tensor
432 @staticmethod
433 def init_parameters(module, stage_index):
434 return _init_parameters(module, stage_index)
436 # pylint: disable=W0212
437 @staticmethod
438 def update_param_data(param, data):
439 """update param data"""
440 if isinstance(param, DTensorBase):
441 param.set_data(data)
442 else:
443 param._update_data(data)
445 @staticmethod
446 def load_into_param(param, data):
447 copy_tensor = MindSporePlatform.empty_like(data)
448 copy_tensor.copy_(data)
449 if isinstance(param, DTensorBase):
450 param.set_data(copy_tensor)
451 else:
452 param._update(copy_tensor)
454 @staticmethod
455 def get_cell_construct(cell):
456 return cell.construct
458 @staticmethod
459 def get_cells_and_names(cell):
460 return cell.cells_and_names()
462 @staticmethod
463 def get_modules(module):
464 return module.cells()
466 @staticmethod
467 def search_parameter_by_name(cell, param_name: str):
468 """
469 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter.
470 Return value: (parent Module instance, parameter's name in parent Module, parameter object).
471 Returns None if not found.
472 """
473 # Remove the "self." prefix from param_name (to maintain compatibility with original logic)
474 param_name = param_name.replace("self.", "")
475 # Case 1: The parameter is a direct parameter of the current Module (not in any sub-Module)
476 if param_name in cell._params:
477 return (cell, param_name, cell._params[param_name])
479 # Case 2: The parameter is in a sub-Module (supports multi-level nesting, e.g., "net_b.dense1.weight")
480 if "." in param_name:
481 # Split into: sub-Module path + parameter name (e.g., "net_b.dense1" + "weight")
482 cell_path, param_key = param_name.rsplit(".", 1)
483 try:
484 # Locate the sub-Module where the parameter resides (supports multi-level paths)
485 target_cell = cell.get_sub_cell(cell_path)
486 # Check if the sub-Module directly contains this parameter
487 if param_key in target_cell._params:
488 return target_cell, param_key, target_cell._params[param_key]
489 except AttributeError:
490 # Sub-Module path does not exist or the parameter is not in that sub-Module
491 pass
493 # Traverse all sub-Modules (recursively) to search for the parameter
494 for _, child_cell in cell._cells.items():
495 if isinstance(child_cell, Cell):
496 # Recursively search within the sub-Module
497 result = MindSporePlatform.search_parameter_by_name(child_cell, param_name)
498 if result is not None:
499 return result
501 return None
503 @staticmethod
504 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
505 """
506 Modify the original parameter in a Module or sub-Module using the search result
507 Args:
508 cell: The cell which parameter is to update
509 result: A tuple contains parent Module, parameter key and old parameter.
510 new_param: New Parameter object (used to replace the original parameter)
511 """
512 parent_cell, param_key, _ = result
513 # Key operation: directly modify the _params dictionary of the parent Module (original storage location)
514 parent_cell._params[param_key] = new_param
516 if param_key in parent_cell.__dict__:
517 parent_cell.__dict__[param_key] = new_param
518 parent_cell._params_list[param_key] = new_param
519 return True
521 @staticmethod
522 def set_layout_into_parameter(param, layout):
523 """Set layout in to parameter"""
524 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel
525 from hyper_parallel.core.dtensor.layout import _infer_slice_shape_by_layout, \
526 _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel
527 if isinstance(param, DTensor):
528 raise ValueError(f"Parameter {param.name} has been configured layout, cannot be set repeatedly.")
529 param_info = param.param_info
530 requires_grad = param.requires_grad
531 name = param.name
532 slice_shape = _infer_slice_shape_by_layout(param.shape, layout)
534 if not param.has_init:
535 # has been init, get slice data
536 param_dtensor = DTensor.from_local(
537 _get_slice_tensor_by_layout(param, layout).value(), layout.mesh, layout.alias_placements
538 )
539 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad)
540 param.param_info = param_info
541 else:
542 # has not been init, need to modify init shape
543 param.init_mode.shape = slice_shape
544 param_dtensor = DTensor.from_local(param.init_mode, layout.mesh, layout.alias_placements)
545 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad)
546 param.param_info = param_info
547 return param
549 @staticmethod
550 def get_param_local_shape(param):
551 """get param local shape"""
552 if isinstance(param, DTensorBase):
553 return param.local_shape
554 return param.shape
556 @staticmethod
557 def get_param_local_data(param):
558 """get param local shape"""
559 if isinstance(param, DTensorBase):
560 return param.to_local()
561 return param
563 @staticmethod
564 def get_param_type_size(param):
565 return type_size_in_bytes(param.dtype)
567 @staticmethod
568 def is_tensor(obj: Any) -> bool:
569 """Return True if ``obj`` is a ``mindspore.Tensor``."""
570 return isinstance(obj, Tensor)
572 @staticmethod
573 def get_tensor_storage_size(tensor: Any) -> int:
574 """Return serialized byte size (numel * itemsize) for a MindSpore tensor."""
575 if not MindSporePlatform.is_tensor(tensor):
576 raise TypeError(
577 f"MindSporePlatform.get_tensor_storage_size expects mindspore.Tensor, got {type(tensor)!r}"
578 )
579 return int(tensor.numel()) * int(tensor.itemsize)
581 @staticmethod
582 def new_zero_parameter(param_shape, param_type, requires_grad, device):
583 param = Parameter(initializer("zeros", param_shape, param_type), requires_grad=requires_grad)
584 if device in ("GPU", "Ascend"):
585 return param.to(device)
586 return param
588 @staticmethod
589 def new_tensor(tensor_shape, tensor_type, device):
590 tensor = Tensor(shape=tensor_shape, dtype=tensor_type)
591 if device in ("GPU", "Ascend"):
592 return tensor.to(device)
593 return tensor
595 @staticmethod
596 def full_like(tensor, fill_value, dtype=None):
597 return mint.full_like(tensor, fill_value, dtype=dtype)
599 @staticmethod
600 def isend(tensor, dst=None, group=None, tag=0):
601 return dist.isend(tensor, dst, group, tag)
603 @staticmethod
604 def irecv(tensor, src=None, group=None, tag=0):
605 return dist.irecv(tensor, src, group, tag)
607 @staticmethod
608 def p2p_exchange(tensor, peer_rank: int, group=None): # pylint: disable=unused-argument
609 raise NotImplementedError(
610 "p2p_exchange is not yet supported on the MindSpore platform."
611 )
613 @staticmethod
614 def send_object_list(obj_list, dst=None, group=None):
615 # pylint: disable=C0415
616 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import send_object_list
617 send_object_list(obj_list, dst, group)
619 @staticmethod
620 def recv_object_list(obj_list, src=None, group=None):
621 # pylint: disable=C0415
622 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import recv_object_list
623 recv_object_list(obj_list, src, group)
625 @staticmethod
626 def set_tensor_requires_grad(input_tensor):
627 """
628 set requires grad flag for input tensor
629 """
630 input_tensor.requires_grad_()
632 def _create_group(self, rank_list):
633 world_group = self._maybe_reuse_world_group(rank_list)
634 if world_group is not None:
635 return world_group
637 group_name = str(tuple(sorted(rank_list)))
638 new_group(rank_ids=rank_list, group=group_name)
639 EXISTING_COMM_GROUPS[group_name] = group_name
640 return group_name
642 @staticmethod
643 def all_gather_into_tensor(data, group_info, async_op=False):
644 return comm_func.all_gather_into_tensor(None, data, group=group_info.group_name, async_op=async_op)
646 @staticmethod
647 def all_reduce(data, group_info, async_op=False):
648 if isinstance(group_info, str):
649 handle = dist.all_reduce(data, group=group_info, async_op=async_op)
650 else:
651 handle = dist.all_reduce(data, group=group_info.group_name, async_op=async_op)
652 return data, handle
654 @staticmethod
655 def broadcast(data, src, group=None, async_op=False):
656 handle = dist.broadcast(data, src, group, async_op)
657 if async_op:
658 handle.wait()
659 return data
661 @staticmethod
662 def reduce_scatter_tensor(data, group_info, async_op=False):
663 return comm_func.reduce_scatter_tensor(None, data, group=group_info.group_name, async_op=async_op)
665 @staticmethod
666 def all_to_all_single(input_tensor, output_shape, group, async_op=False):
667 return _mindspore_all_to_all_single(input_tensor, output_shape, group, async_op=async_op)
669 @staticmethod
670 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=unused-argument
671 handle_box=None):
672 return _MSAsyncA2AFunction.apply(
673 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box
674 )
676 @staticmethod
677 def parameters_dict(cell: Cell):
678 return cell.parameters_and_names()
680 @staticmethod
681 def get_tensor_transform():
682 return _tensor_transform
684 @staticmethod
685 def construct_strided_slice(x, begin, end, stride):
686 return ms.ops.strided_slice(x, begin, end, stride)
688 @staticmethod
689 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
690 # pylint: disable=C0415
691 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import _MicroBatch
692 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim)
694 @staticmethod
695 def get_model_state_dict(model, *, options=None):
696 raise NotImplementedError(
697 "get_model_state_dict is not yet supported on MindSpore"
698 )
700 @staticmethod
701 def save_checkpoint(cell: Union[Cell, dict], file_path: str, ckpt_format: str = "safetensors") -> None:
702 if isinstance(cell, dict):
703 save_dict = {}
704 for k, v in cell.items():
705 if isinstance(v, Parameter):
706 save_dict[k] = v
707 elif isinstance(v, Tensor):
708 save_dict[k] = Parameter(v, name=k)
709 else:
710 save_dict[k] = v
711 else:
712 save_dict = cell._params
713 ms.save_checkpoint(save_obj=save_dict, ckpt_file_name=file_path, format=ckpt_format)
715 @staticmethod
716 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict:
717 return ms.load_checkpoint(ckpt_file_name=file_path, format=ckpt_format)
719 @staticmethod
720 def get_symmetric_memory_handler():
721 # pylint: disable=C0415
722 from hyper_parallel.platform.mindspore.symmetric_memory import MSSymmetricMemoryHandler
723 symmetric_memory = MSSymmetricMemoryHandler()
724 return symmetric_memory
726 @staticmethod
727 def get_multicore_handler():
728 # pylint: disable=C0415
729 from hyper_parallel.platform.mindspore.multicore import MSMulticoreHandler
730 return MSMulticoreHandler()
732 def new_stream(self):
733 return ms.runtime.Stream()
735 def get_stream_context(self):
736 return ms.runtime.StreamCtx
738 @staticmethod
739 def all_gather_object(object_list, obj, group=None) -> None:
740 """
741 Gathers objects from the given group into object list.
743 Args:
744 object_list (list[Any]): Define the output list, which size equal to the size of group.
745 obj (Any): The object on current rank and in given process group.
746 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means
747 global group.
749 Returns:
750 None. Objs are gathered into ``object_list``.
751 """
752 dist.all_gather_object(object_list, obj, group)
754 @staticmethod
755 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any:
756 """
757 Synchronize all processes in the given communication group.
759 Args:
760 group (str, optional): The communication group to work on. Default is ``None``,
761 meaning the default world group.
762 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``.
763 device_ids (list[int], optional): Reserved parameter on Ascend. Default: ``None``.
765 Returns:
766 CommHandle if ``async_op`` is True; otherwise ``None``.
767 """
768 return dist.barrier(group, async_op, device_ids)
770 @staticmethod
771 def init_process_group(
772 backend: str = None,
773 *,
774 init_method: Optional[str] = None,
775 timeout: Optional[timedelta] = None,
776 world_size: int = -1,
777 rank: int = -1,
778 store: TCPStore = None,
779 pg_options=None,
780 device_id=None
781 ) -> None:
782 """
783 Initialize global process group.
785 Args:
786 backend (str): The backend used to init process group. Default is ``"hccl"`` and now only support hccl.
787 init_method (str, optional): URL specifying how to initialize the process group. Default is ``None``.
788 timeout (timedelta, optional): Timeout for API executed. Default is ``None``.
789 world_size (int): Number of processes. Default is ``-1``.
790 rank (int, optional): Rank of the current process. Default is ``-1``.
791 store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
792 communication addresses and connection information. Default is ``None``. Currently, only the
793 ``TCPStore`` type is supported.
794 pg_options (ProcessGroupOptions, optional): Reserved parameter. Current not take effect.
795 device_id (int, optional): Reserved parameter. Current not take effect.
796 """
797 if backend is None:
798 backend = "hccl"
799 try:
800 if dist.is_initialized():
801 return
802 except AttributeError:
803 pass
804 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size,
805 rank=rank, store=store, pg_options=pg_options, device_id=device_id)
807 @staticmethod
808 def destroy_process_group(group: Optional[str] = None) -> None:
809 """
810 Destroy given process group.
812 Args:
813 group (str, optional): Specify the group to destroy. Default: ``None`` means ``hccl_world_group``. If group
814 is None or "hccl_world_group", destroy global process group and all process groups relative to global
815 process group.
816 """
817 if group in EXISTING_COMM_GROUPS.values():
818 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group]
819 for k in keys_to_destroy:
820 del EXISTING_COMM_GROUPS[k]
821 dist.destroy_process_group(group)
823 @staticmethod
824 def get_process_group_ranks(group: Optional[str] = None) -> list[int]:
825 """
826 Get all ranks in given process group.
828 Args:
829 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``.
831 Returns:
832 List[int]: List of ranks in given process group.
833 """
834 return dist.get_process_group_ranks(group)
836 @staticmethod
837 def get_backend(group: Optional[str] = None) -> str:
838 """
839 Get the backend of given process group.
841 Args:
842 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``.
844 Returns:
845 str: The backend of the group.
846 """
847 return dist.get_backend(group)
849 @staticmethod
850 def split_group(parent_pg: Optional[str] = None,
851 split_ranks: Optional[list] = None,
852 timeout: Optional[timedelta] = None,
853 pg_options: Optional[str] = None,
854 group_desc: Optional[str] = None,
855 ) -> str:
856 """
857 Create split group for a specific group rank in split_ranks, which group contains current rank id.
859 Args:
860 parent_pg (str, Optional): A process group which the goal group split from.
861 split_ranks (Optional[list]): A list like ``list[list[int]]``.
862 timeout (Optional[timedelta]): Timeout for API executed. Default is ``None``.
863 pg_options (Optional[str]): Reserved parameter. Current not take effect.
864 group_desc (Optional[str]): Description of process group.
866 Returns:
867 str: The split group name.
868 """
869 if split_ranks is None or len(split_ranks) == 0:
870 raise ValueError("split_ranks cannot be None or empty")
872 rank_id = MindSporePlatform.get_rank()
873 for split_rank in split_ranks:
874 if rank_id in split_rank:
875 world_group = MindSporePlatform._maybe_reuse_world_group(split_rank)
876 if world_group is not None:
877 return world_group
878 split_group = MindSporePlatform.get_created_group(split_rank)
879 if split_group:
880 return split_group
881 group_name = str(tuple(sorted(split_rank)))
882 new_group(rank_ids=split_rank, group=group_name)
883 EXISTING_COMM_GROUPS[group_name] = group_name
884 return group_name
885 raise ValueError(f"Split group invalid rank, the Split_ranks {split_ranks} does not contain current rank"
886 f" {rank_id}")
888 @staticmethod
889 def get_group_local_rank(group=None) -> int:
890 """get group local rank id."""
891 return dist.get_group_rank(group, MindSporePlatform.get_rank())
893 @staticmethod
894 def no_grad():
895 return _no_grad()
897 @staticmethod
898 def relu(tensor):
899 return mint.nn.functional.relu(tensor)
901 @staticmethod
902 def cat(tensors, dim=0):
903 return mint.cat(tensors, dim=dim)
905 @staticmethod
906 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
907 return mint.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory)
909 def get_current_stream(self):
910 return ms.runtime.current_stream()
912 def new_event(self):
913 return ms.runtime.Event()
915 def tree_map(self, fn, tree):
916 """
917 Apply fn to each leaf in a nested structure (list / tuple / dict),
918 preserving the original structure.
919 """
920 if isinstance(tree, dict):
921 return type(tree)(
922 (k, self.tree_map(fn, v)) for k, v in tree.items()
923 )
925 if isinstance(tree, tuple):
926 return tuple(self.tree_map(fn, v) for v in tree)
928 if isinstance(tree, list):
929 return [self.tree_map(fn, v) for v in tree]
931 # leaf
932 return fn(tree)
934 @staticmethod
935 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False):
936 return module.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
938 @staticmethod
939 def register_full_backward_hook(module, hook, prepend=False):
940 return module.register_backward_hook(hook)
942 @staticmethod
943 def register_full_backward_pre_hook(module, hook, prepend=False):
944 return module.register_backward_pre_hook(hook)
946 @property
947 def checkpoint(self):
948 return ms.recompute
950 @staticmethod
951 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
952 # pylint: disable=C0415
953 from hyper_parallel.platform.mindspore.activation_checkpoint.checkpoint_wrapper import checkpoint_wrapper
954 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)
956 @staticmethod
957 def swap_wrapper(module, policy_fn=None):
958 # pylint: disable=C0415
959 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_wrapper
960 return swap_wrapper(module, policy_fn=policy_fn)
962 @staticmethod
963 def swap_tensor_wrapper(target, tag=None):
964 # pylint: disable=C0415
965 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_tensor_wrapper
966 return swap_tensor_wrapper(target, tag=tag)
968 @property
969 def noop_context_fn(self):
970 return null_context_fn
972 @staticmethod
973 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
974 # pylint: disable=C0415
975 from hyper_parallel.platform.mindspore.activation_checkpoint.sac import create_selective_checkpoint_contexts
976 return create_selective_checkpoint_contexts(policy_fn_or_list,
977 allow_cache_entry_mutation=allow_cache_entry_mutation)
979 @staticmethod
980 def async_save_on_cpu(policy_fn=None):
981 # pylint: disable=C0415
982 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import AsyncSaveOnCpu
983 return AsyncSaveOnCpu(policy_fn=policy_fn)
985 @staticmethod
986 def get_element_size(tensor):
987 """Get Tensor Element Size"""
988 return tensor.itemsize
990 @staticmethod
991 def tensor_to_numpy(tensor) -> np.ndarray:
992 """Convert MindSpore tensor to numpy array."""
993 return tensor.asnumpy()
995 @staticmethod
997 def clip_grad_norm_(
998 parameters, max_norm, norm_type=2.0,
999 error_if_nonfinite=False, foreach=None,
1000 ):
1001 raise NotImplementedError(
1002 "clip_grad_norm_ is not yet supported on MindSpore"
1003 )
1005 @property
1006 def meta_device(self):
1007 return "meta"
1009 def init_on_device(self, device, include_buffers=False):
1010 return _init_on_device(device, include_buffers=include_buffers)
1012 def cast_fp_tensor(self, dtype, x):
1013 """
1014 Cast floating-point tensor to target dtype if applicable.
1015 """
1016 if (
1017 not isinstance(x, ms.Tensor)
1018 or not ms.ops.is_floating_point(x)
1019 or x.dtype == dtype
1020 ):
1021 return x
1022 return x.to(dtype)
1024 def apply_to_tensors(self, fn, container):
1025 """Recursively apply to all tensor in different kinds of container types."""
1027 def apply(x):
1028 if isinstance(x, ms.Tensor):
1029 return fn(x)
1030 if hasattr(x, "__dataclass_fields__"):
1031 dc = dataclasses.replace(x)
1032 changes = {
1033 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
1034 }
1035 return dataclasses.replace(dc, **changes)
1036 if isinstance(x, OrderedDict):
1037 od = x.__class__()
1038 for key, value in x.items():
1039 od[key] = apply(value)
1040 return od
1041 if isinstance(x, dict):
1042 return {key: apply(value) for key, value in x.items()}
1043 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"):
1044 res = (apply(el) for el in x)
1045 return type(x)(*res)
1046 if isinstance(x, (list, tuple, set)):
1047 return type(x)(apply(el) for el in x)
1048 return x
1050 return apply(container)
1052 @staticmethod
1053 def profiler_record(name):
1054 """Profiler context manager for recording operations using mindspore.profiler."""
1055 return contextlib.nullcontext()
1057 def str_to_dtype(self, dtype_str: str) -> Any:
1058 """Resolve checkpoint dtype strings (``mindspore.*`` or short ``str(Tensor.dtype)`` e.g. ``Float32``)."""
1059 if "." in dtype_str:
1060 prefix, name = dtype_str.split(".", 1)
1061 if prefix == "mindspore":
1062 return getattr(ms, name)
1063 dtype = getattr(ms, dtype_str.lower(), None)
1064 if dtype is not None:
1065 return dtype
1066 raise ValueError(
1067 f"Expected dtype string like 'mindspore.float32' or 'Float32', got {dtype_str!r}."
1068 )
1070 def list_to_size(self, size_list: list[int]) -> tuple[int, ...]:
1071 return tuple(size_list)
1073 @staticmethod
1074 def _maybe_reuse_world_group(rank_list):
1075 """Reuse the default world group for full-world rank lists."""
1076 normalized = tuple(sorted(rank_list))
1077 world_ranks = tuple(range(MindSporePlatform.get_world_size()))
1078 if normalized != world_ranks:
1079 return None
1081 EXISTING_COMM_GROUPS[str(normalized)] = GlobalComm.WORLD_COMM_GROUP
1082 return GlobalComm.WORLD_COMM_GROUP