1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MoE-FFN forward compute graph."""
16
17from hyper_parallel.core.multicore.modules.moe_ffn.common.compute_graph import (
18 ComputeGraph, OperatorNode, TensorSpec, SplitSpec, OpType,
19)
20from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builders import (
21 AllToAllFillConfig, AllToAllType, GmmFillConfig, SwiGLUFillConfig,
22)
23
24# Tiling buffer positions in the C++ runtime descriptor array.
25_TILING_POS_UP_PROJ = 17
26_TILING_POS_SWIGLU = 18
27_TILING_POS_DN_PROJ = 19
28
29
30def _build_fwd_tensor_specs(tsv, hidden_size, intermediate_size, dtype_size):
31 """Create all TensorSpec objects for the forward graph."""
32 sre = tsv.single_rank_expert_num
33 target = TensorSpec(
34 "target", [tsv.per_rank_seq, hidden_size],
35 param_position=0, dtype_size=dtype_size, is_dynamic=True)
36 target_offset = TensorSpec(
37 "target_offset", [tsv.all_expert_num],
38 param_position=1, dtype_size=8)
39 src = TensorSpec(
40 "src", [tsv.per_rank_seq, hidden_size],
41 param_position=2, dtype_size=dtype_size)
42 src_offset = TensorSpec(
43 "src_offset", [tsv.all_expert_num],
44 param_position=3, dtype_size=8)
45 size_d = TensorSpec(
46 "size_d", [tsv.all_expert_num],
47 param_position=4, dtype_size=4)
48 up_proj_weight = TensorSpec(
49 "up_proj_weight", [sre, hidden_size, intermediate_size * 2],
50 param_position=5, dtype_size=dtype_size, tensor_type=1)
51 up_proj_glist = TensorSpec(
52 "up_proj_glist", [sre],
53 param_position=6, dtype_size=8)
54 up_proj_y = TensorSpec(
55 "up_proj_y", [tsv.per_rank_seq, intermediate_size * 2],
56 param_position=7, dtype_size=dtype_size, is_dynamic=True)
57 swiglu_out = TensorSpec(
58 "swiglu_out", [tsv.per_rank_seq, intermediate_size],
59 param_position=8, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
60 down_proj_weight = TensorSpec(
61 "down_proj_weight", [sre, intermediate_size, hidden_size],
62 param_position=9, dtype_size=dtype_size, tensor_type=1)
63 down_proj_glist = TensorSpec(
64 "down_proj_glist", [sre],
65 param_position=10, dtype_size=8)
66 down_proj_y = TensorSpec(
67 "down_proj_y", [tsv.per_rank_seq, hidden_size],
68 param_position=11, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
69 combine_out = TensorSpec(
70 "combine_out", [tsv.per_rank_seq, hidden_size],
71 param_position=12, dtype_size=dtype_size)
72 target_offset_c = TensorSpec(
73 "target_offset_c", [tsv.all_expert_num],
74 param_position=13, dtype_size=8)
75 src_offset_c = TensorSpec(
76 "src_offset_c", [tsv.all_expert_num],
77 param_position=14, dtype_size=8)
78 size_c = TensorSpec(
79 "size_c", [tsv.all_expert_num],
80 param_position=15, dtype_size=4)
81 return (target, target_offset, src, src_offset, size_d,
82 up_proj_weight, up_proj_glist, up_proj_y, swiglu_out,
83 down_proj_weight, down_proj_glist, down_proj_y,
84 combine_out, target_offset_c, src_offset_c, size_c)
85
86
87def _build_fwd_ops(tsv, specs, *, dispatch_sv, up_proj_sv, swiglu_sv,
88 down_proj_sv, combine_sv, num_cube_cores):
89 """Create all OperatorNode objects for the forward graph."""
90 (target, target_offset, src, src_offset, size_d,
91 up_proj_weight, up_proj_glist, up_proj_y, swiglu_out,
92 down_proj_weight, down_proj_glist, down_proj_y,
93 combine_out, target_offset_c, src_offset_c, size_c) = specs
94 dispatch = OperatorNode(
95 name="dispatch", op_type=OpType.ALLTOALL,
96 inputs=[target_offset, src, src_offset, size_d],
97 outputs=[target],
98 param_positions=[1, 2, 3, 4, 0],
99 split_value=dispatch_sv,
100 split_spec=SplitSpec(
101 split_inputs=None, split_output_dims=[0],
102 task_num_fn=lambda tsv: tsv.all_expert_num * (tsv.per_expert_seq_to_other // dispatch_sv),
103 ),
104 tiling_position=-1,
105 fill_config=AllToAllFillConfig(moe_type=AllToAllType.DISPATCH, advance="vector",
106 event_group=tsv.all_expert_num),
107 )
108 up_proj = OperatorNode(
109 name="up_proj", op_type=OpType.GMM,
110 inputs=[target, up_proj_weight, up_proj_glist], outputs=[up_proj_y],
111 param_positions=[0, 5, 6, 7], split_value=up_proj_sv,
112 split_spec=SplitSpec(
113 split_inputs=[(0, 0)], split_output_dims=[0],
114 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num,
115 ),
116 tiling_position=_TILING_POS_UP_PROJ,
117 fill_config=GmmFillConfig(offset_inputs={0}, rank_in_event=True, global_trigger=False,
118 out_offset=True, advance="cube", num_cube_cores=num_cube_cores),
119 )
120 swiglu = OperatorNode(
121 name="swiglu", op_type=OpType.SWIGLU,
122 inputs=[up_proj_y], outputs=[swiglu_out],
123 param_positions=[7, 8], split_value=swiglu_sv,
124 split_spec=SplitSpec(
125 split_inputs=[(0, 0)], split_output_dims=[0],
126 task_num_fn=lambda tsv: (tsv.per_expert_seq // swiglu_sv) * tsv.single_rank_expert_num,
127 ),
128 tiling_position=_TILING_POS_SWIGLU,
129 fill_config=SwiGLUFillConfig(),
130 )
131 down_proj = OperatorNode(
132 name="down_proj", op_type=OpType.GMM,
133 inputs=[swiglu_out, down_proj_weight, down_proj_glist], outputs=[down_proj_y],
134 param_positions=[8, 9, 10, 11], split_value=down_proj_sv,
135 split_spec=SplitSpec(
136 split_inputs=[(0, 0)], split_output_dims=[0],
137 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num,
138 ),
139 tiling_position=_TILING_POS_DN_PROJ,
140 fill_config=GmmFillConfig(offset_inputs={0}, rank_in_event=False, global_trigger=False,
141 out_offset=True, advance="cube", num_cube_cores=num_cube_cores),
142 )
143 combine = OperatorNode(
144 name="combine", op_type=OpType.ALLTOALL,
145 inputs=[target_offset_c, down_proj_y, src_offset_c, size_c], outputs=[combine_out],
146 param_positions=[13, 11, 14, 15, 12], split_value=combine_sv,
147 split_spec=SplitSpec(
148 split_inputs=[(1, 0)], split_output_dims=[0],
149 task_num_fn=lambda tsv: tsv.all_expert_num * (tsv.per_expert_seq_to_other // combine_sv),
150 ),
151 tiling_position=-1,
152 fill_config=AllToAllFillConfig(moe_type=AllToAllType.COMBINE, advance="vector", event_group=1),
153 )
154 return dispatch, up_proj, swiglu, down_proj, combine
155
156
157def build_forward_graph(tsv, *,
158 dispatch_sv: int = 128,
159 up_proj_sv: int = 4096,
160 swiglu_sv: int = 128,
161 down_proj_sv: int = 4096,
162 combine_sv: int = 128,
163 hidden_size: int = 7168,
164 intermediate_size: int = 2048,
165 dtype_size: int = 2,
166 num_cube_cores: int = 24) -> ComputeGraph:
167 """Build the MoE-FFN forward DAG: dispatch -> up_proj -> swiglu -> down_proj -> combine.
168
169 Operator execution order and param_positions (C++ memory slots):
170 dispatch: [target_offset=1, src=2, src_offset=3, size=4 | target=0]
171 up_proj: [x=0, weight=5, glist=6 | y=7]
172 swiglu: [x=7 | out=8]
173 down_proj: [x=8, weight=9, glist=10 | y=11]
174 combine: [target_offset=13, src=11, src_offset=14, size=15 | target=12]
175
176 Args:
177 tsv: TaskSplitValue carrying TP/EP/seq partition metadata.
178 dispatch_sv: tile size for dispatch AllToAll (vector cores).
179 up_proj_sv: tile size for up-projection GMM (cube cores).
180 swiglu_sv: tile size for SwiGLU (vector cores).
181 down_proj_sv: tile size for down-projection GMM (cube cores).
182 combine_sv: tile size for combine AllToAll (vector cores).
183 hidden_size: model hidden dimension.
184 intermediate_size: FFN intermediate dimension after SwiGLU halving (half of up-proj output).
185 dtype_size: bytes per activation element (2=bf16, 4=fp32).
186 num_cube_cores: number of AIC cube cores on the target device.
187
188 Returns:
189 A fully-connected ComputeGraph ready for propagate_splits().
190 """
191 specs = _build_fwd_tensor_specs(tsv, hidden_size, intermediate_size, dtype_size)
192 dispatch, up_proj, swiglu, down_proj, combine = _build_fwd_ops(
193 tsv, specs,
194 dispatch_sv=dispatch_sv, up_proj_sv=up_proj_sv, swiglu_sv=swiglu_sv,
195 down_proj_sv=down_proj_sv, combine_sv=combine_sv, num_cube_cores=num_cube_cores,
196 )
197 graph = ComputeGraph()
198 (graph.add_op(dispatch).add_op(up_proj).add_op(swiglu).add_op(down_proj).add_op(combine)
199 .add_edge(dispatch, up_proj)
200 .add_edge(up_proj, swiglu)
201 .add_edge(swiglu, down_proj)
202 .add_edge(down_proj, combine))
203 return graph