Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / platform.py: 44%

565 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore platform api""" 

16from datetime import timedelta 

17from typing import Any, Optional, Union 

18import dataclasses 

19from collections import OrderedDict 

20 

21import contextlib 

22import numpy as np 

23import mindspore as ms 

24import mindspore.common.dtype as mstype 

25from mindspore.mint.distributed import TCPStore 

26 

27from mindspore.nn import Cell 

28from mindspore import mint 

29from mindspore.common.api import _no_grad 

30from mindspore.common._grad_function import _Function 

31from mindspore.common.dtype import type_size_in_bytes 

32from mindspore.common.parameter import Parameter 

33from mindspore.common.tensor import Tensor 

34from mindspore.common.initializer import initializer 

35from mindspore.common.recompute import null_context_fn 

36from mindspore.communication import GlobalComm 

37from mindspore.communication import get_group_size 

38from mindspore.communication import create_group as new_group 

39from mindspore.communication import get_rank as get_rank_id 

40from mindspore.ops import communication as ops_comm 

41from mindspore.ops.function import comm_func 

42from mindspore._c_expression import TensorTransform 

43import mindspore.mint.distributed as dist 

44 

45from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS 

46from hyper_parallel.platform.mindspore.dtensor import DTensorBase 

47from hyper_parallel.platform.mindspore.pipeline_parallel.stage import PipelineStageBase 

48from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters 

49from hyper_parallel.platform.mindspore.init_weights import ( 

50 init_on_device as _init_on_device, 

51 _install_cell_to_empty_patch, 

52) 

53 

54comm_func.set_comm_ops_inplace(False) 

55_tensor_transform = TensorTransform.get_instance() 

56 

57 

58# pylint: disable=C0103 

59 

60 

61def _a2a_reconstruct_ms(out_perm: Tensor, concat_dim: int) -> Tensor: 

62 """Reconstruct A2A result from raw out_perm buffer.""" 

63 new_ndim = out_perm.dim() 

64 chunk_in_perm = concat_dim + 1 

65 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim)) 

66 x_recon = out_perm.permute(recon_perm).contiguous() 

67 shape = list(x_recon.shape) 

68 merged = shape[concat_dim] * shape[concat_dim + 1] 

69 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:]) 

70 

71 

72def _normalize_all_to_all_single_result(result, output: Tensor) -> tuple[Tensor, object]: 

73 """Normalize MindSpore all_to_all_single return values to ``(output, handle)``.""" 

74 if isinstance(result, tuple): 

75 if len(result) != 2: 

76 raise ValueError( 

77 "mindspore all_to_all_single returned an unexpected tuple " 

78 f"with length {len(result)}" 

79 ) 

80 return result 

81 return output, result 

82 

83 

84def _mindspore_all_to_all_single(input_tensor: Tensor, output_shape, group, async_op=False) -> tuple[Tensor, object]: 

85 """Launch MindSpore all_to_all_single and normalize return values.""" 

86 output = mint.empty(tuple(output_shape), dtype=input_tensor.dtype) 

87 result = ops_comm.all_to_all_single(output, input_tensor, group=group, async_op=async_op) 

88 normalized_output, handle = _normalize_all_to_all_single_result(result, output) 

89 if not async_op: 

90 return normalized_output, None 

91 return normalized_output, handle 

92 

93 

94class _MSAsyncA2AFunction(_Function): 

95 """Differentiable wrapper for pre-launched async all-to-all.""" 

96 

97 @staticmethod 

98 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box): # pylint: disable=arguments-differ 

99 """Wait for pre-launched async A2A and return reconstructed output.""" 

100 ctx.group = group 

101 ctx.world_size = world_size 

102 ctx.concat_dim = concat_dim 

103 ctx.split_dim = split_dim 

104 ctx.handle_box = handle_box 

105 ctx.x_shape = tuple(x.shape) 

106 work.wait() 

107 return _a2a_reconstruct_ms(out_perm, concat_dim) 

108 

109 @staticmethod 

110 def backward(ctx, grad_output): 

111 """Launch async head->seq A2A for backward overlap, or return zero grad.""" 

112 if ctx.handle_box is not None: 

113 g = grad_output.contiguous() 

114 shape = list(g.shape) 

115 seq_dim = ctx.concat_dim 

116 s_full = shape[seq_dim] 

117 ndim = len(shape) + 1 

118 x_perm = g.reshape( 

119 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:] 

120 ).permute( 

121 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim)) 

122 ).contiguous() 

123 out_perm, work = _mindspore_all_to_all_single( 

124 x_perm, 

125 list(x_perm.shape), 

126 ctx.group, 

127 async_op=True, 

128 ) 

129 ctx.handle_box.append((work, out_perm)) 

130 return mint.zeros(ctx.x_shape, dtype=grad_output.dtype), None, None, None, None, None, None, None 

131 

132 

133class MindSporePlatform(Platform): 

134 """MindSpore platform api""" 

135 Tensor = Tensor 

136 tensor = Tensor 

137 Parameter = Parameter 

138 Module = Cell 

139 DTensorBase = DTensorBase 

140 PipelineStageBase = PipelineStageBase 

141 platform_type = PlatformType.MINDSPORE 

142 tensor_dtype = mstype 

143 dtype = ms.Type 

144 Function = _Function 

145 

146 _custom_ops_cls = None 

147 

148 @property 

149 def custom_ops(self): 

150 """Return the MindSpore platform custom ops instance. 

151 

152 .. warning:: 

153 This is an experimental API that subject to change or deletion. 

154 

155 Returns: 

156 MindSporeCustomOps: Custom ops class that delegates to DFunction 

157 implementations wrapping Ascend NPU custom C++ kernels. 

158 """ 

159 if self._custom_ops_cls is None: 

160 from hyper_parallel.platform.mindspore.custom_ops.custom_ops import ( # pylint: disable=import-outside-toplevel 

161 MindSporeCustomOps, 

162 ) 

163 self._custom_ops_cls = MindSporeCustomOps 

164 return self._custom_ops_cls 

165 

166 def __init__(self): 

167 # Ensure MindSpore ``nn.Cell.to_empty`` is patched as soon as the 

168 # MindSpore platform instance is created. 

169 _install_cell_to_empty_patch() 

170 

171 @staticmethod 

172 def is_linear_module(module) -> bool: 

173 """Check whether *module* is a MindSpore ``Dense`` (linear) or ``mint.nn.Linear`` layer.""" 

174 return isinstance(module, (ms.nn.Dense, mint.nn.Linear)) 

175 

176 @staticmethod 

177 def is_embedding_module(module) -> bool: 

178 """Check whether *module* is a MindSpore ``Embedding`` or ``mint.nn.Embedding`` layer.""" 

179 return isinstance(module, (ms.nn.Embedding, mint.nn.Embedding)) 

180 

181 def device_count(self, device_handle): 

182 """ 

183 Get the number of available devices. 

184 

185 Args: 

186 device_handle: The device handle (e.g., ms.device_context). 

187 

188 Returns: 

189 int: The number of available devices. 

190 """ 

191 device_type = self.device_type() 

192 if device_type == "cpu": 

193 return device_handle.device_context.cpu.device_count() 

194 if device_type == "gpu": 

195 return device_handle.device_context.gpu.device_count() 

196 return device_handle.device_context.ascend.device_count() 

197 

198 @staticmethod 

199 def get_rng_state(device=None, device_handle=None): 

200 """ 

201 Get the random number generator state. 

202 

203 Args: 

204 device (Optional): The device to get RNG state from (not used in MindSpore). 

205 device_handle (Optional): The device handle (not used in MindSpore). 

206 

207 Returns: 

208 Tensor: The RNG state as a tensor. 

209 """ 

210 _ = device, device_handle 

211 return ms.get_rng_state() 

212 

213 @staticmethod 

214 def set_rng_state(state, device=None, device_handle=None): 

215 """ 

216 Set the random number generator state. 

217 

218 Args: 

219 state (Tensor): The RNG state to set. 

220 device (Optional): The device to set RNG state for (not used in MindSpore). 

221 device_handle (Optional): The device handle (not used in MindSpore). 

222 """ 

223 _ = device, device_handle 

224 return ms.set_rng_state(state) 

225 

226 def device_type(self): 

227 """ 

228 Get the current device type. 

229 

230 Returns: 

231 str: The device type string ("npu" for Ascend, "gpu" for GPU, "cpu" for CPU). 

232 """ 

233 device_type = ms.get_context("device_target") 

234 if device_type == "Ascend": 

235 return "npu" 

236 return device_type.lower() 

237 

238 def device(self, device_idx=None): 

239 """ 

240 Get the device type string. 

241 

242 Args: 

243 device_idx (Optional[int]): The device index (not used in MindSpore). 

244 

245 Returns: 

246 str: The device type string. 

247 """ 

248 _ = device_idx 

249 device_type = self.device_type() 

250 return device_type 

251 

252 @staticmethod 

253 def get_device_handle(): 

254 """ 

255 Get the MindSpore module as the device handle. 

256 

257 Returns: 

258 module: The mindspore module. 

259 """ 

260 return ms 

261 

262 @staticmethod 

263 def manual_seed(seed): 

264 """ 

265 Set the random seed for reproducibility. 

266 

267 Args: 

268 seed (int): The random seed value. 

269 

270 Returns: 

271 None 

272 """ 

273 return ms.manual_seed(seed) 

274 

275 @staticmethod 

276 def ones(size, dtype=None): 

277 """ 

278 Create a tensor filled with ones. 

279 

280 Args: 

281 size (tuple): The shape of the output tensor. 

282 dtype (Optional[ms.Type]): The desired data type. 

283 

284 Returns: 

285 Tensor: A tensor filled with ones. 

286 """ 

287 return mint.ones(size, dtype=dtype) 

288 

289 @staticmethod 

290 def zeros(size, dtype=None, device=None): 

291 """ 

292 Create a tensor filled with zeros. 

293 

294 Args: 

295 size (tuple): The shape of the output tensor. 

296 dtype (Optional[ms.Type]): The desired data type. 

297 device (Optional[ms.device]): The device to create the tensor on. 

298 

299 Returns: 

300 Tensor: A tensor filled with zeros. 

301 """ 

302 tensor = mint.zeros(size, dtype=dtype) 

303 if device in ("GPU", "Ascend"): 

304 return tensor.to(device) 

305 return tensor 

306 

307 @staticmethod 

308 def full(size, fill_value, dtype=None): 

309 """ 

310 Create a tensor filled with a scalar value. 

311 

312 Args: 

313 size (tuple): The shape of the output tensor. 

314 fill_value (scalar): The value to fill the tensor with. 

315 dtype (Optional[ms.Type]): The desired data type. 

316 

317 Returns: 

318 Tensor: A tensor filled with the specified value. 

319 """ 

320 return mint.full(size, fill_value, dtype=dtype) 

321 

322 @staticmethod 

323 def empty(size, dtype=None): 

324 """ 

325 Create an uninitialized tensor. 

326 

327 Args: 

328 size (tuple): The shape of the output tensor. 

329 dtype (Optional[ms.Type]): The desired data type. 

330 

331 Returns: 

332 Tensor: An uninitialized tensor. 

333 """ 

334 return mint.empty(size, dtype=dtype) 

335 

336 @staticmethod 

337 def get_rank(): 

338 """ 

339 Get the rank of the current process in the distributed group. 

340 

341 Returns: 

342 int: The rank of the current process. 

343 """ 

344 return get_rank_id() 

345 

346 @staticmethod 

347 def get_global_rank(group, group_rank): 

348 """ 

349 Get the global rank from a group rank. 

350 

351 Args: 

352 group (str): The process group name. 

353 group_rank (int): The rank within the group. 

354 

355 Returns: 

356 int: The global rank. 

357 """ 

358 return dist.get_global_rank(group, group_rank) 

359 

360 @staticmethod 

361 def get_world_size(): 

362 """ 

363 Get the total number of processes in the distributed group. 

364 

365 Returns: 

366 int: The world size. 

367 """ 

368 return get_group_size() 

369 

370 @staticmethod 

371 def get_op_name(func): 

372 """ 

373 Extract the operation name from a function. 

374 

375 Args: 

376 func: The function to extract the name from. 

377 

378 Returns: 

379 str: The operation name. 

380 """ 

381 return func.name 

382 

383 @staticmethod 

384 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

385 output, _ = comm_func.all_gather_into_tensor(None, data, group=group) 

386 if concat_dim == 0: 

387 return output 

388 output_tensors = ms.ops.Split(output_num=concat_size)(output) 

389 return ms.mint.concat(output_tensors, concat_dim) 

390 

391 @staticmethod 

392 def chunk(data, split_dim, split_size, index): 

393 return ms.ops.Split(axis=split_dim, output_num=split_size)(data)[index] 

394 

395 @staticmethod 

396 def differentiable_all_to_all(input_data, output_shape, group): 

397 output_tensor, _ = comm_func.all_to_all_single( 

398 output_shape, 

399 input_data, 

400 group=group, 

401 async_op=False 

402 ) 

403 return output_tensor 

404 

405 @staticmethod 

406 def tensor_type_cast(input_data, cast_type): 

407 """Cast tensor to specified data type.""" 

408 type_mapping = { 

409 'float32': ms.float32, 

410 'float16': ms.float16, 

411 'int64': ms.int64, 

412 'int32': ms.int32 

413 } 

414 if cast_type not in type_mapping: 

415 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}") 

416 return input_data.to(type_mapping[cast_type]) 

417 

418 @staticmethod 

419 def differentiable_all_reduce(data, op, group): 

420 output, _ = comm_func.all_reduce(data, op, group) 

421 return output 

422 

423 @staticmethod 

424 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

425 if axis > 0: 

426 data = ms.mint.concat(ms.ops.Split(axis=axis, output_num=dev_num)(data), dim=0) 

427 output_tensor, _ = comm_func.reduce_scatter_tensor(None, data, 'sum', group) 

428 if op == 'avg': 

429 output_tensor = output_tensor / dev_num 

430 return output_tensor 

431 

432 @staticmethod 

433 def init_parameters(module, stage_index): 

434 return _init_parameters(module, stage_index) 

435 

436 # pylint: disable=W0212 

437 @staticmethod 

438 def update_param_data(param, data): 

439 """update param data""" 

440 if isinstance(param, DTensorBase): 

441 param.set_data(data) 

442 else: 

443 param._update_data(data) 

444 

445 @staticmethod 

446 def load_into_param(param, data): 

447 copy_tensor = MindSporePlatform.empty_like(data) 

448 copy_tensor.copy_(data) 

449 if isinstance(param, DTensorBase): 

450 param.set_data(copy_tensor) 

451 else: 

452 param._update(copy_tensor) 

453 

454 @staticmethod 

455 def get_cell_construct(cell): 

456 return cell.construct 

457 

458 @staticmethod 

459 def get_cells_and_names(cell): 

460 return cell.cells_and_names() 

461 

462 @staticmethod 

463 def get_modules(module): 

464 return module.cells() 

465 

466 @staticmethod 

467 def search_parameter_by_name(cell, param_name: str): 

468 """ 

469 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter. 

470 Return value: (parent Module instance, parameter's name in parent Module, parameter object). 

471 Returns None if not found. 

472 """ 

473 # Remove the "self." prefix from param_name (to maintain compatibility with original logic) 

474 param_name = param_name.replace("self.", "") 

475 # Case 1: The parameter is a direct parameter of the current Module (not in any sub-Module) 

476 if param_name in cell._params: 

477 return (cell, param_name, cell._params[param_name]) 

478 

479 # Case 2: The parameter is in a sub-Module (supports multi-level nesting, e.g., "net_b.dense1.weight") 

480 if "." in param_name: 

481 # Split into: sub-Module path + parameter name (e.g., "net_b.dense1" + "weight") 

482 cell_path, param_key = param_name.rsplit(".", 1) 

483 try: 

484 # Locate the sub-Module where the parameter resides (supports multi-level paths) 

485 target_cell = cell.get_sub_cell(cell_path) 

486 # Check if the sub-Module directly contains this parameter 

487 if param_key in target_cell._params: 

488 return target_cell, param_key, target_cell._params[param_key] 

489 except AttributeError: 

490 # Sub-Module path does not exist or the parameter is not in that sub-Module 

491 pass 

492 

493 # Traverse all sub-Modules (recursively) to search for the parameter 

494 for _, child_cell in cell._cells.items(): 

495 if isinstance(child_cell, Cell): 

496 # Recursively search within the sub-Module 

497 result = MindSporePlatform.search_parameter_by_name(child_cell, param_name) 

498 if result is not None: 

499 return result 

500 

501 return None 

502 

503 @staticmethod 

504 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

505 """ 

506 Modify the original parameter in a Module or sub-Module using the search result 

507 Args: 

508 cell: The cell which parameter is to update 

509 result: A tuple contains parent Module, parameter key and old parameter. 

510 new_param: New Parameter object (used to replace the original parameter) 

511 """ 

512 parent_cell, param_key, _ = result 

513 # Key operation: directly modify the _params dictionary of the parent Module (original storage location) 

514 parent_cell._params[param_key] = new_param 

515 

516 if param_key in parent_cell.__dict__: 

517 parent_cell.__dict__[param_key] = new_param 

518 parent_cell._params_list[param_key] = new_param 

519 return True 

520 

521 @staticmethod 

522 def set_layout_into_parameter(param, layout): 

523 """Set layout in to parameter""" 

524 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel 

525 from hyper_parallel.core.dtensor.layout import _infer_slice_shape_by_layout, \ 

526 _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel 

527 if isinstance(param, DTensor): 

528 raise ValueError(f"Parameter {param.name} has been configured layout, cannot be set repeatedly.") 

529 param_info = param.param_info 

530 requires_grad = param.requires_grad 

531 name = param.name 

532 slice_shape = _infer_slice_shape_by_layout(param.shape, layout) 

533 

534 if not param.has_init: 

535 # has been init, get slice data 

536 param_dtensor = DTensor.from_local( 

537 _get_slice_tensor_by_layout(param, layout).value(), layout.mesh, layout.alias_placements 

538 ) 

539 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad) 

540 param.param_info = param_info 

541 else: 

542 # has not been init, need to modify init shape 

543 param.init_mode.shape = slice_shape 

544 param_dtensor = DTensor.from_local(param.init_mode, layout.mesh, layout.alias_placements) 

545 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad) 

546 param.param_info = param_info 

547 return param 

548 

549 @staticmethod 

550 def get_param_local_shape(param): 

551 """get param local shape""" 

552 if isinstance(param, DTensorBase): 

553 return param.local_shape 

554 return param.shape 

555 

556 @staticmethod 

557 def get_param_local_data(param): 

558 """get param local shape""" 

559 if isinstance(param, DTensorBase): 

560 return param.to_local() 

561 return param 

562 

563 @staticmethod 

564 def get_param_type_size(param): 

565 return type_size_in_bytes(param.dtype) 

566 

567 @staticmethod 

568 def is_tensor(obj: Any) -> bool: 

569 """Return True if ``obj`` is a ``mindspore.Tensor``.""" 

570 return isinstance(obj, Tensor) 

571 

572 @staticmethod 

573 def get_tensor_storage_size(tensor: Any) -> int: 

574 """Return serialized byte size (numel * itemsize) for a MindSpore tensor.""" 

575 if not MindSporePlatform.is_tensor(tensor): 

576 raise TypeError( 

577 f"MindSporePlatform.get_tensor_storage_size expects mindspore.Tensor, got {type(tensor)!r}" 

578 ) 

579 return int(tensor.numel()) * int(tensor.itemsize) 

580 

581 @staticmethod 

582 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

583 param = Parameter(initializer("zeros", param_shape, param_type), requires_grad=requires_grad) 

584 if device in ("GPU", "Ascend"): 

585 return param.to(device) 

586 return param 

587 

588 @staticmethod 

589 def new_tensor(tensor_shape, tensor_type, device): 

590 tensor = Tensor(shape=tensor_shape, dtype=tensor_type) 

591 if device in ("GPU", "Ascend"): 

592 return tensor.to(device) 

593 return tensor 

594 

595 @staticmethod 

596 def full_like(tensor, fill_value, dtype=None): 

597 return mint.full_like(tensor, fill_value, dtype=dtype) 

598 

599 @staticmethod 

600 def isend(tensor, dst=None, group=None, tag=0): 

601 return dist.isend(tensor, dst, group, tag) 

602 

603 @staticmethod 

604 def irecv(tensor, src=None, group=None, tag=0): 

605 return dist.irecv(tensor, src, group, tag) 

606 

607 @staticmethod 

608 def p2p_exchange(tensor, peer_rank: int, group=None): # pylint: disable=unused-argument 

609 raise NotImplementedError( 

610 "p2p_exchange is not yet supported on the MindSpore platform." 

611 ) 

612 

613 @staticmethod 

614 def send_object_list(obj_list, dst=None, group=None): 

615 # pylint: disable=C0415 

616 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import send_object_list 

617 send_object_list(obj_list, dst, group) 

618 

619 @staticmethod 

620 def recv_object_list(obj_list, src=None, group=None): 

621 # pylint: disable=C0415 

622 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import recv_object_list 

623 recv_object_list(obj_list, src, group) 

624 

625 @staticmethod 

626 def set_tensor_requires_grad(input_tensor): 

627 """ 

628 set requires grad flag for input tensor 

629 """ 

630 input_tensor.requires_grad_() 

631 

632 def _create_group(self, rank_list): 

633 world_group = self._maybe_reuse_world_group(rank_list) 

634 if world_group is not None: 

635 return world_group 

636 

637 group_name = str(tuple(sorted(rank_list))) 

638 new_group(rank_ids=rank_list, group=group_name) 

639 EXISTING_COMM_GROUPS[group_name] = group_name 

640 return group_name 

641 

642 @staticmethod 

643 def all_gather_into_tensor(data, group_info, async_op=False): 

644 return comm_func.all_gather_into_tensor(None, data, group=group_info.group_name, async_op=async_op) 

645 

646 @staticmethod 

647 def all_reduce(data, group_info, async_op=False): 

648 if isinstance(group_info, str): 

649 handle = dist.all_reduce(data, group=group_info, async_op=async_op) 

650 else: 

651 handle = dist.all_reduce(data, group=group_info.group_name, async_op=async_op) 

652 return data, handle 

653 

654 @staticmethod 

655 def broadcast(data, src, group=None, async_op=False): 

656 handle = dist.broadcast(data, src, group, async_op) 

657 if async_op: 

658 handle.wait() 

659 return data 

660 

661 @staticmethod 

662 def reduce_scatter_tensor(data, group_info, async_op=False): 

663 return comm_func.reduce_scatter_tensor(None, data, group=group_info.group_name, async_op=async_op) 

664 

665 @staticmethod 

666 def all_to_all_single(input_tensor, output_shape, group, async_op=False): 

667 return _mindspore_all_to_all_single(input_tensor, output_shape, group, async_op=async_op) 

668 

669 @staticmethod 

670 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=unused-argument 

671 handle_box=None): 

672 return _MSAsyncA2AFunction.apply( 

673 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box 

674 ) 

675 

676 @staticmethod 

677 def parameters_dict(cell: Cell): 

678 return cell.parameters_and_names() 

679 

680 @staticmethod 

681 def get_tensor_transform(): 

682 return _tensor_transform 

683 

684 @staticmethod 

685 def construct_strided_slice(x, begin, end, stride): 

686 return ms.ops.strided_slice(x, begin, end, stride) 

687 

688 @staticmethod 

689 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

690 # pylint: disable=C0415 

691 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import _MicroBatch 

692 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim) 

693 

694 @staticmethod 

695 def get_model_state_dict(model, *, options=None): 

696 raise NotImplementedError( 

697 "get_model_state_dict is not yet supported on MindSpore" 

698 ) 

699 

700 @staticmethod 

701 def save_checkpoint(cell: Union[Cell, dict], file_path: str, ckpt_format: str = "safetensors") -> None: 

702 if isinstance(cell, dict): 

703 save_dict = {} 

704 for k, v in cell.items(): 

705 if isinstance(v, Parameter): 

706 save_dict[k] = v 

707 elif isinstance(v, Tensor): 

708 save_dict[k] = Parameter(v, name=k) 

709 else: 

710 save_dict[k] = v 

711 else: 

712 save_dict = cell._params 

713 ms.save_checkpoint(save_obj=save_dict, ckpt_file_name=file_path, format=ckpt_format) 

714 

715 @staticmethod 

716 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict: 

717 return ms.load_checkpoint(ckpt_file_name=file_path, format=ckpt_format) 

718 

719 @staticmethod 

720 def get_symmetric_memory_handler(): 

721 # pylint: disable=C0415 

722 from hyper_parallel.platform.mindspore.symmetric_memory import MSSymmetricMemoryHandler 

723 symmetric_memory = MSSymmetricMemoryHandler() 

724 return symmetric_memory 

725 

726 @staticmethod 

727 def get_multicore_handler(): 

728 # pylint: disable=C0415 

729 from hyper_parallel.platform.mindspore.multicore import MSMulticoreHandler 

730 return MSMulticoreHandler() 

731 

732 def new_stream(self): 

733 return ms.runtime.Stream() 

734 

735 def get_stream_context(self): 

736 return ms.runtime.StreamCtx 

737 

738 @staticmethod 

739 def all_gather_object(object_list, obj, group=None) -> None: 

740 """ 

741 Gathers objects from the given group into object list. 

742 

743 Args: 

744 object_list (list[Any]): Define the output list, which size equal to the size of group. 

745 obj (Any): The object on current rank and in given process group. 

746 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means 

747 global group. 

748 

749 Returns: 

750 None. Objs are gathered into ``object_list``. 

751 """ 

752 dist.all_gather_object(object_list, obj, group) 

753 

754 @staticmethod 

755 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any: 

756 """ 

757 Synchronize all processes in the given communication group. 

758 

759 Args: 

760 group (str, optional): The communication group to work on. Default is ``None``, 

761 meaning the default world group. 

762 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``. 

763 device_ids (list[int], optional): Reserved parameter on Ascend. Default: ``None``. 

764 

765 Returns: 

766 CommHandle if ``async_op`` is True; otherwise ``None``. 

767 """ 

768 return dist.barrier(group, async_op, device_ids) 

769 

770 @staticmethod 

771 def init_process_group( 

772 backend: str = None, 

773 *, 

774 init_method: Optional[str] = None, 

775 timeout: Optional[timedelta] = None, 

776 world_size: int = -1, 

777 rank: int = -1, 

778 store: TCPStore = None, 

779 pg_options=None, 

780 device_id=None 

781 ) -> None: 

782 """ 

783 Initialize global process group. 

784 

785 Args: 

786 backend (str): The backend used to init process group. Default is ``"hccl"`` and now only support hccl. 

787 init_method (str, optional): URL specifying how to initialize the process group. Default is ``None``. 

788 timeout (timedelta, optional): Timeout for API executed. Default is ``None``. 

789 world_size (int): Number of processes. Default is ``-1``. 

790 rank (int, optional): Rank of the current process. Default is ``-1``. 

791 store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process 

792 communication addresses and connection information. Default is ``None``. Currently, only the 

793 ``TCPStore`` type is supported. 

794 pg_options (ProcessGroupOptions, optional): Reserved parameter. Current not take effect. 

795 device_id (int, optional): Reserved parameter. Current not take effect. 

796 """ 

797 if backend is None: 

798 backend = "hccl" 

799 try: 

800 if dist.is_initialized(): 

801 return 

802 except AttributeError: 

803 pass 

804 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

805 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

806 

807 @staticmethod 

808 def destroy_process_group(group: Optional[str] = None) -> None: 

809 """ 

810 Destroy given process group. 

811 

812 Args: 

813 group (str, optional): Specify the group to destroy. Default: ``None`` means ``hccl_world_group``. If group 

814 is None or "hccl_world_group", destroy global process group and all process groups relative to global 

815 process group. 

816 """ 

817 if group in EXISTING_COMM_GROUPS.values(): 

818 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group] 

819 for k in keys_to_destroy: 

820 del EXISTING_COMM_GROUPS[k] 

821 dist.destroy_process_group(group) 

822 

823 @staticmethod 

824 def get_process_group_ranks(group: Optional[str] = None) -> list[int]: 

825 """ 

826 Get all ranks in given process group. 

827 

828 Args: 

829 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``. 

830 

831 Returns: 

832 List[int]: List of ranks in given process group. 

833 """ 

834 return dist.get_process_group_ranks(group) 

835 

836 @staticmethod 

837 def get_backend(group: Optional[str] = None) -> str: 

838 """ 

839 Get the backend of given process group. 

840 

841 Args: 

842 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``. 

843 

844 Returns: 

845 str: The backend of the group. 

846 """ 

847 return dist.get_backend(group) 

848 

849 @staticmethod 

850 def split_group(parent_pg: Optional[str] = None, 

851 split_ranks: Optional[list] = None, 

852 timeout: Optional[timedelta] = None, 

853 pg_options: Optional[str] = None, 

854 group_desc: Optional[str] = None, 

855 ) -> str: 

856 """ 

857 Create split group for a specific group rank in split_ranks, which group contains current rank id. 

858 

859 Args: 

860 parent_pg (str, Optional): A process group which the goal group split from. 

861 split_ranks (Optional[list]): A list like ``list[list[int]]``. 

862 timeout (Optional[timedelta]): Timeout for API executed. Default is ``None``. 

863 pg_options (Optional[str]): Reserved parameter. Current not take effect. 

864 group_desc (Optional[str]): Description of process group. 

865 

866 Returns: 

867 str: The split group name. 

868 """ 

869 if split_ranks is None or len(split_ranks) == 0: 

870 raise ValueError("split_ranks cannot be None or empty") 

871 

872 rank_id = MindSporePlatform.get_rank() 

873 for split_rank in split_ranks: 

874 if rank_id in split_rank: 

875 world_group = MindSporePlatform._maybe_reuse_world_group(split_rank) 

876 if world_group is not None: 

877 return world_group 

878 split_group = MindSporePlatform.get_created_group(split_rank) 

879 if split_group: 

880 return split_group 

881 group_name = str(tuple(sorted(split_rank))) 

882 new_group(rank_ids=split_rank, group=group_name) 

883 EXISTING_COMM_GROUPS[group_name] = group_name 

884 return group_name 

885 raise ValueError(f"Split group invalid rank, the Split_ranks {split_ranks} does not contain current rank" 

886 f" {rank_id}") 

887 

888 @staticmethod 

889 def get_group_local_rank(group=None) -> int: 

890 """get group local rank id.""" 

891 return dist.get_group_rank(group, MindSporePlatform.get_rank()) 

892 

893 @staticmethod 

894 def no_grad(): 

895 return _no_grad() 

896 

897 @staticmethod 

898 def relu(tensor): 

899 return mint.nn.functional.relu(tensor) 

900 

901 @staticmethod 

902 def cat(tensors, dim=0): 

903 return mint.cat(tensors, dim=dim) 

904 

905 @staticmethod 

906 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

907 return mint.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory) 

908 

909 def get_current_stream(self): 

910 return ms.runtime.current_stream() 

911 

912 def new_event(self): 

913 return ms.runtime.Event() 

914 

915 def tree_map(self, fn, tree): 

916 """ 

917 Apply fn to each leaf in a nested structure (list / tuple / dict), 

918 preserving the original structure. 

919 """ 

920 if isinstance(tree, dict): 

921 return type(tree)( 

922 (k, self.tree_map(fn, v)) for k, v in tree.items() 

923 ) 

924 

925 if isinstance(tree, tuple): 

926 return tuple(self.tree_map(fn, v) for v in tree) 

927 

928 if isinstance(tree, list): 

929 return [self.tree_map(fn, v) for v in tree] 

930 

931 # leaf 

932 return fn(tree) 

933 

934 @staticmethod 

935 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False): 

936 return module.register_forward_pre_hook(hook, with_kwargs=with_kwargs) 

937 

938 @staticmethod 

939 def register_full_backward_hook(module, hook, prepend=False): 

940 return module.register_backward_hook(hook) 

941 

942 @staticmethod 

943 def register_full_backward_pre_hook(module, hook, prepend=False): 

944 return module.register_backward_pre_hook(hook) 

945 

946 @property 

947 def checkpoint(self): 

948 return ms.recompute 

949 

950 @staticmethod 

951 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

952 # pylint: disable=C0415 

953 from hyper_parallel.platform.mindspore.activation_checkpoint.checkpoint_wrapper import checkpoint_wrapper 

954 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs) 

955 

956 @staticmethod 

957 def swap_wrapper(module, policy_fn=None): 

958 # pylint: disable=C0415 

959 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_wrapper 

960 return swap_wrapper(module, policy_fn=policy_fn) 

961 

962 @staticmethod 

963 def swap_tensor_wrapper(target, tag=None): 

964 # pylint: disable=C0415 

965 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_tensor_wrapper 

966 return swap_tensor_wrapper(target, tag=tag) 

967 

968 @property 

969 def noop_context_fn(self): 

970 return null_context_fn 

971 

972 @staticmethod 

973 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

974 # pylint: disable=C0415 

975 from hyper_parallel.platform.mindspore.activation_checkpoint.sac import create_selective_checkpoint_contexts 

976 return create_selective_checkpoint_contexts(policy_fn_or_list, 

977 allow_cache_entry_mutation=allow_cache_entry_mutation) 

978 

979 @staticmethod 

980 def async_save_on_cpu(policy_fn=None): 

981 # pylint: disable=C0415 

982 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import AsyncSaveOnCpu 

983 return AsyncSaveOnCpu(policy_fn=policy_fn) 

984 

985 @staticmethod 

986 def get_element_size(tensor): 

987 """Get Tensor Element Size""" 

988 return tensor.itemsize 

989 

990 @staticmethod 

991 def tensor_to_numpy(tensor) -> np.ndarray: 

992 """Convert MindSpore tensor to numpy array.""" 

993 return tensor.asnumpy() 

994 

995 @staticmethod 

996 

997 def clip_grad_norm_( 

998 parameters, max_norm, norm_type=2.0, 

999 error_if_nonfinite=False, foreach=None, 

1000 ): 

1001 raise NotImplementedError( 

1002 "clip_grad_norm_ is not yet supported on MindSpore" 

1003 ) 

1004 

1005 @property 

1006 def meta_device(self): 

1007 return "meta" 

1008 

1009 def init_on_device(self, device, include_buffers=False): 

1010 return _init_on_device(device, include_buffers=include_buffers) 

1011 

1012 def cast_fp_tensor(self, dtype, x): 

1013 """ 

1014 Cast floating-point tensor to target dtype if applicable. 

1015 """ 

1016 if ( 

1017 not isinstance(x, ms.Tensor) 

1018 or not ms.ops.is_floating_point(x) 

1019 or x.dtype == dtype 

1020 ): 

1021 return x 

1022 return x.to(dtype) 

1023 

1024 def apply_to_tensors(self, fn, container): 

1025 """Recursively apply to all tensor in different kinds of container types.""" 

1026 

1027 def apply(x): 

1028 if isinstance(x, ms.Tensor): 

1029 return fn(x) 

1030 if hasattr(x, "__dataclass_fields__"): 

1031 dc = dataclasses.replace(x) 

1032 changes = { 

1033 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 

1034 } 

1035 return dataclasses.replace(dc, **changes) 

1036 if isinstance(x, OrderedDict): 

1037 od = x.__class__() 

1038 for key, value in x.items(): 

1039 od[key] = apply(value) 

1040 return od 

1041 if isinstance(x, dict): 

1042 return {key: apply(value) for key, value in x.items()} 

1043 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"): 

1044 res = (apply(el) for el in x) 

1045 return type(x)(*res) 

1046 if isinstance(x, (list, tuple, set)): 

1047 return type(x)(apply(el) for el in x) 

1048 return x 

1049 

1050 return apply(container) 

1051 

1052 @staticmethod 

1053 def profiler_record(name): 

1054 """Profiler context manager for recording operations using mindspore.profiler.""" 

1055 return contextlib.nullcontext() 

1056 

1057 def str_to_dtype(self, dtype_str: str) -> Any: 

1058 """Resolve checkpoint dtype strings (``mindspore.*`` or short ``str(Tensor.dtype)`` e.g. ``Float32``).""" 

1059 if "." in dtype_str: 

1060 prefix, name = dtype_str.split(".", 1) 

1061 if prefix == "mindspore": 

1062 return getattr(ms, name) 

1063 dtype = getattr(ms, dtype_str.lower(), None) 

1064 if dtype is not None: 

1065 return dtype 

1066 raise ValueError( 

1067 f"Expected dtype string like 'mindspore.float32' or 'Float32', got {dtype_str!r}." 

1068 ) 

1069 

1070 def list_to_size(self, size_list: list[int]) -> tuple[int, ...]: 

1071 return tuple(size_list) 

1072 

1073 @staticmethod 

1074 def _maybe_reuse_world_group(rank_list): 

1075 """Reuse the default world group for full-world rank lists.""" 

1076 normalized = tuple(sorted(rank_list)) 

1077 world_ranks = tuple(range(MindSporePlatform.get_world_size())) 

1078 if normalized != world_ranks: 

1079 return None 

1080 

1081 EXISTING_COMM_GROUPS[str(normalized)] = GlobalComm.WORLD_COMM_GROUP 

1082 return GlobalComm.WORLD_COMM_GROUP