Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / multicore / __init__.py: 0%
11 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""PyTorch multicore handler for hyper-parallel."""
18class TorchMulticoreHandler:
19 """PyTorch platform handler for MoE-FFN multicore operators."""
21 def __init__(self):
22 # Eagerly import platform/torch/__init__.py so that its module-level code
23 # runs now (sets ASCEND_CUSTOM_OPP_PATH, preloads ctypes libs).
24 # Must happen before torch_npu is imported anywhere else in the process.
25 import hyper_parallel.core.multicore.platform.torch # noqa: F401 # pylint: disable=C0415,W0611
27 @staticmethod
28 def moe_ffn_fwd(
29 dispatch_target, dispatch_target_off,
30 dispatch_src, dispatch_src_off, dispatch_size,
31 up_proj_weight, up_proj_glist,
32 up_proj_y, swiglu_out,
33 down_proj_weight, down_proj_glist, down_proj_y,
34 combine_target, combine_target_off, combine_src_off, combine_size,
35 gmm_workspace, up_proj_tiling, swiglu_tiling, down_proj_tiling,
36 runtime_config, all_event_counters,
37 rank_id: int, ep: int, expert_num: int,
38 hidden_size: int, seq_size: int,
39 ):
40 """MoE-FFN forward operator (PyTorch backend)."""
41 # pylint: disable=C0415
42 from hyper_parallel.core.multicore.platform.torch import moe_ffn_fwd
43 return moe_ffn_fwd(
44 dispatch_target, dispatch_target_off,
45 dispatch_src, dispatch_src_off, dispatch_size,
46 up_proj_weight, up_proj_glist,
47 up_proj_y, swiglu_out,
48 down_proj_weight, down_proj_glist, down_proj_y,
49 combine_target, combine_target_off, combine_src_off, combine_size,
50 gmm_workspace, up_proj_tiling, swiglu_tiling, down_proj_tiling,
51 runtime_config, all_event_counters,
52 rank_id, ep, expert_num, hidden_size, seq_size,
53 )
55 @staticmethod
56 def moe_ffn_bwd(
57 dispatch_target, dispatch_target_off,
58 dy, dispatch_src_off, dispatch_size,
59 hidden, hidden_dw,
60 w2, act_grad_y, gate, grad_gate, w1, gate_dx, grad_x,
61 combine_target_off, combine_src_off, combine_size,
62 permute_out, gate_dw, group_list,
63 act_grad_tiling, gate_grad_tiling, w2_grad_tiling, w1_grad_tiling,
64 swiglu_grad_tiling, gmm_workspace, swiglu_grad_workspace,
65 runtime_config, all_event_counters,
66 rank_id: int, ep: int, expert_num: int,
67 hidden_size: int, seq_size: int,
68 ):
69 """MoE-FFN backward operator (PyTorch backend)."""
70 # pylint: disable=C0415
71 from hyper_parallel.core.multicore.platform.torch import moe_ffn_bwd
72 return moe_ffn_bwd(
73 dispatch_target, dispatch_target_off,
74 dy, dispatch_src_off, dispatch_size,
75 hidden, hidden_dw,
76 w2, act_grad_y, gate, grad_gate, w1, gate_dx, grad_x,
77 combine_target_off, combine_src_off, combine_size,
78 permute_out, gate_dw, group_list,
79 act_grad_tiling, gate_grad_tiling, w2_grad_tiling, w1_grad_tiling,
80 swiglu_grad_tiling, gmm_workspace, swiglu_grad_workspace,
81 runtime_config, all_event_counters,
82 rank_id, ep, expert_num, hidden_size, seq_size,
83 )