Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / multicore / __init__.py: 0%
11 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore multicore handler for hyper-parallel."""
18class MSMulticoreHandler:
19 """MindSpore platform handler for MoE-FFN multicore operators."""
21 def __init__(self):
22 # Eagerly import platform/mindspore/__init__.py so that its module-level
23 # code runs now (sets ASCEND_CUSTOM_OPP_PATH, preloads ctypes libs, adds
24 # build/lib to sys.path). This MUST happen before any `import mindspore`
25 # elsewhere in the process; deferring to the first moe_ffn_fwd/bwd call
26 # is too late when symmetric_memory or other modules import mindspore first.
27 # Note: platform/mindspore/__init__.py itself does NOT import mindspore at
28 # module level, so this import is safe to call early.
29 import hyper_parallel.core.multicore.platform.mindspore # noqa: F401 # pylint: disable=C0415,W0611
31 @staticmethod
32 def moe_ffn_fwd(
33 dispatch_target, dispatch_target_off,
34 dispatch_src, dispatch_src_off, dispatch_size,
35 up_proj_weight, up_proj_glist,
36 up_proj_y, swiglu_out,
37 down_proj_weight, down_proj_glist, down_proj_y,
38 combine_target, combine_target_off, combine_src_off, combine_size,
39 gmm_workspace, up_proj_tiling, swiglu_tiling, down_proj_tiling,
40 runtime_config, all_event_counters,
41 rank_id: int, ep: int, expert_num: int,
42 hidden_size: int, seq_size: int,
43 ):
44 """MoE-FFN forward operator (MindSpore backend)."""
45 # pylint: disable=C0415
46 from hyper_parallel.core.multicore.platform.mindspore import moe_ffn_fwd
47 return moe_ffn_fwd(
48 dispatch_target, dispatch_target_off,
49 dispatch_src, dispatch_src_off, dispatch_size,
50 up_proj_weight, up_proj_glist,
51 up_proj_y, swiglu_out,
52 down_proj_weight, down_proj_glist, down_proj_y,
53 combine_target, combine_target_off, combine_src_off, combine_size,
54 gmm_workspace, up_proj_tiling, swiglu_tiling, down_proj_tiling,
55 runtime_config, all_event_counters,
56 rank_id, ep, expert_num, hidden_size, seq_size,
57 )
59 @staticmethod
60 def moe_ffn_bwd(
61 dispatch_target, dispatch_target_off,
62 dy, dispatch_src_off, dispatch_size,
63 hidden, hidden_dw,
64 w2, act_grad_y, gate, grad_gate, w1, gate_dx, grad_x,
65 combine_target_off, combine_src_off, combine_size,
66 permute_out, gate_dw, group_list,
67 act_grad_tiling, gate_grad_tiling, w2_grad_tiling, w1_grad_tiling,
68 swiglu_grad_tiling, gmm_workspace, swiglu_grad_workspace,
69 runtime_config, all_event_counters,
70 rank_id: int, ep: int, expert_num: int,
71 hidden_size: int, seq_size: int,
72 ):
73 """MoE-FFN backward operator (MindSpore backend)."""
74 # pylint: disable=C0415
75 from hyper_parallel.core.multicore.platform.mindspore import moe_ffn_bwd
76 return moe_ffn_bwd(
77 dispatch_target, dispatch_target_off,
78 dy, dispatch_src_off, dispatch_size,
79 hidden, hidden_dw,
80 w2, act_grad_y, gate, grad_gate, w1, gate_dx, grad_x,
81 combine_target_off, combine_src_off, combine_size,
82 permute_out, gate_dw, group_list,
83 act_grad_tiling, gate_grad_tiling, w2_grad_tiling, w1_grad_tiling,
84 swiglu_grad_tiling, gmm_workspace, swiglu_grad_workspace,
85 runtime_config, all_event_counters,
86 rank_id, ep, expert_num, hidden_size, seq_size,
87 )