1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MoE-FFN backward compute graph."""
16
17from hyper_parallel.core.multicore.modules.moe_ffn.common.compute_graph import (
18 ComputeGraph, OperatorNode, TensorSpec, SplitSpec, OpType,
19)
20from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builders import (
21 AllToAllFillConfig, AllToAllType, GmmFillConfig, SwiGLUFillConfig,
22)
23
24# Tiling buffer positions in the C++ runtime descriptor array.
25_TILING_POS_ACT_GRAD = 20
26_TILING_POS_GATE_GRAD = 21
27_TILING_POS_W2_GRAD = 22
28_TILING_POS_W1_GRAD = 23
29_TILING_POS_SWIGLU_GRAD = 24
30
31
32def _build_bwd_tensor_specs(tsv, hidden_size, intermediate_size, dtype_size):
33 """Create all TensorSpec objects for the backward graph."""
34 sre = tsv.single_rank_expert_num
35 target = TensorSpec(
36 "target", [tsv.per_rank_seq, hidden_size],
37 param_position=0, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
38 target_offset = TensorSpec(
39 "target_offset", [tsv.all_expert_num],
40 param_position=1, dtype_size=8)
41 src = TensorSpec(
42 "src", [tsv.per_rank_seq, hidden_size],
43 param_position=2, dtype_size=dtype_size, tensor_type=1)
44 src_offset = TensorSpec(
45 "src_offset", [tsv.all_expert_num],
46 param_position=3, dtype_size=8)
47 size_d = TensorSpec(
48 "size_d", [tsv.all_expert_num],
49 param_position=4, dtype_size=4)
50 w1_grad_x1 = TensorSpec(
51 "w1_grad_x1", [tsv.per_rank_seq, intermediate_size],
52 param_position=5, dtype_size=dtype_size,
53 tensor_type=1, is_dynamic=True, transpose=True)
54 w1_grad_y = TensorSpec(
55 "w1_grad_y", [sre, intermediate_size, hidden_size],
56 param_position=6, dtype_size=dtype_size, tensor_type=1)
57 act_grad_weight = TensorSpec(
58 "act_grad_weight", [sre, intermediate_size, hidden_size],
59 param_position=7, dtype_size=dtype_size, tensor_type=1, transpose=True)
60 act_grad_y = TensorSpec(
61 "act_grad_y", [tsv.per_rank_seq, intermediate_size],
62 param_position=8, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
63 swiglu_dy = TensorSpec(
64 "swiglu_dy", [tsv.per_rank_seq, intermediate_size * 2],
65 param_position=9, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
66 swiglu_out = TensorSpec(
67 "swiglu_out", [tsv.per_rank_seq, intermediate_size * 2],
68 param_position=10, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
69 gate_grad_weight = TensorSpec(
70 "gate_grad_weight", [sre, hidden_size, intermediate_size * 2],
71 param_position=11, dtype_size=dtype_size, tensor_type=1, transpose=True)
72 gate_grad_y = TensorSpec(
73 "gate_grad_y", [tsv.per_rank_seq, hidden_size],
74 param_position=12, dtype_size=dtype_size, tensor_type=1, is_dynamic=True)
75 combine_out = TensorSpec(
76 "combine_out", [tsv.per_rank_seq, hidden_size],
77 param_position=13, dtype_size=dtype_size, tensor_type=1)
78 target_offset_c = TensorSpec(
79 "target_offset_c", [tsv.all_expert_num],
80 param_position=14, dtype_size=8)
81 src_offset_c = TensorSpec(
82 "src_offset_c", [tsv.all_expert_num],
83 param_position=15, dtype_size=8)
84 size_c = TensorSpec(
85 "size_c", [tsv.all_expert_num],
86 param_position=16, dtype_size=4)
87 w2_grad_x1 = TensorSpec(
88 "w2_grad_x1", [tsv.per_rank_seq, hidden_size],
89 param_position=17, dtype_size=dtype_size,
90 tensor_type=1, is_dynamic=True, transpose=True)
91 w2_grad_y = TensorSpec(
92 "w2_grad_y", [sre, hidden_size, intermediate_size * 2],
93 param_position=18, dtype_size=dtype_size, tensor_type=1)
94 glist = TensorSpec(
95 "glist", [sre],
96 param_position=19, dtype_size=8)
97 return (target, target_offset, src, src_offset, size_d,
98 w1_grad_x1, w1_grad_y, act_grad_weight, act_grad_y,
99 swiglu_dy, swiglu_out, gate_grad_weight, gate_grad_y,
100 combine_out, target_offset_c, src_offset_c, size_c,
101 w2_grad_x1, w2_grad_y, glist)
102
103
104def _build_bwd_ops_first(tsv, specs, *, dispatch_sv, act_grad_sv, w1_grad_sv,
105 swiglu_sv, num_cube_cores):
106 """Create dispatch, act_grad, w1_grad, swiglu_grad operator nodes."""
107 (target, target_offset, src, src_offset, size_d,
108 w1_grad_x1, w1_grad_y, act_grad_weight, act_grad_y,
109 swiglu_dy, swiglu_out, *_) = specs
110 glist = specs[-1]
111 dispatch = OperatorNode(
112 name="dispatch", op_type=OpType.ALLTOALL,
113 inputs=[target_offset, src, src_offset, size_d], outputs=[target],
114 param_positions=[1, 2, 3, 4, 0], split_value=dispatch_sv,
115 split_spec=SplitSpec(
116 split_inputs=None, split_output_dims=[0],
117 task_num_fn=lambda tsv: tsv.all_expert_num * (tsv.per_expert_seq_to_other // dispatch_sv),
118 ),
119 tiling_position=-1,
120 fill_config=AllToAllFillConfig(moe_type=AllToAllType.DISPATCH, advance="vector",
121 event_group=tsv.all_expert_num),
122 )
123 act_grad = OperatorNode(
124 name="act_grad", op_type=OpType.GMM,
125 inputs=[target, act_grad_weight, glist], outputs=[act_grad_y],
126 param_positions=[0, 7, 19, 8], split_value=act_grad_sv,
127 split_spec=SplitSpec(
128 split_inputs=[(0, 0)], split_output_dims=[0],
129 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num,
130 ),
131 tiling_position=_TILING_POS_ACT_GRAD,
132 fill_config=GmmFillConfig(offset_inputs={0}, rank_in_event=True, global_trigger=False,
133 out_offset=True, advance="cube_only", num_cube_cores=num_cube_cores),
134 )
135 w1_grad = OperatorNode(
136 name="w1_grad", op_type=OpType.GMM,
137 inputs=[w1_grad_x1, target, glist], outputs=[w1_grad_y],
138 param_positions=[5, 0, 19, 6], split_value=w1_grad_sv,
139 split_spec=SplitSpec(
140 split_inputs=[(1, 0)], split_output_dims=[0],
141 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num,
142 ),
143 tiling_position=_TILING_POS_W1_GRAD,
144 fill_config=GmmFillConfig(offset_inputs={0, 1}, rank_in_event=True, global_trigger=True,
145 out_offset=False, advance="cube_custom",
146 event_delta=tsv.single_rank_expert_num, num_cube_cores=num_cube_cores),
147 )
148 swiglu_grad = OperatorNode(
149 name="swiglu_grad", op_type=OpType.SWIGLU_GRAD,
150 inputs=[act_grad_y, swiglu_dy], outputs=[swiglu_out],
151 param_positions=[8, 9, 10], split_value=swiglu_sv,
152 split_spec=SplitSpec(
153 split_inputs=[(0, 0)], split_output_dims=[0],
154 task_num_fn=lambda tsv: (tsv.per_expert_seq // swiglu_sv) * tsv.single_rank_expert_num,
155 ),
156 tiling_position=_TILING_POS_SWIGLU_GRAD,
157 fill_config=SwiGLUFillConfig(),
158 )
159 return dispatch, act_grad, w1_grad, swiglu_grad
160
161
162def _build_bwd_ops_second(specs, *, gate_grad_sv, w2_grad_sv, combine_sv, num_cube_cores):
163 """Create gate_grad, combine, w2_grad operator nodes."""
164 (*_, swiglu_out, gate_grad_weight, gate_grad_y,
165 combine_out, target_offset_c, src_offset_c, size_c,
166 w2_grad_x1, w2_grad_y, glist) = specs
167 gate_grad = OperatorNode(
168 name="gate_grad", op_type=OpType.GMM,
169 inputs=[swiglu_out, gate_grad_weight, glist], outputs=[gate_grad_y],
170 param_positions=[10, 11, 19, 12], split_value=gate_grad_sv,
171 split_spec=SplitSpec(
172 split_inputs=[(0, 0)], split_output_dims=[0],
173 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num,
174 ),
175 tiling_position=_TILING_POS_GATE_GRAD,
176 fill_config=GmmFillConfig(offset_inputs={0}, rank_in_event=False, global_trigger=False,
177 out_offset=True, advance="cube", num_cube_cores=num_cube_cores),
178 )
179 combine = OperatorNode(
180 name="combine", op_type=OpType.ALLTOALL,
181 inputs=[target_offset_c, gate_grad_y, src_offset_c, size_c], outputs=[combine_out],
182 param_positions=[14, 12, 15, 16, 13], split_value=combine_sv,
183 split_spec=SplitSpec(
184 split_inputs=[(1, 0)], split_output_dims=[0],
185 task_num_fn=lambda tsv: tsv.all_expert_num * (tsv.per_expert_seq_to_other // combine_sv),
186 ),
187 tiling_position=-1,
188 fill_config=AllToAllFillConfig(moe_type=AllToAllType.COMBINE, advance="vector_only", event_group=1),
189 )
190 w2_grad = OperatorNode(
191 name="w2_grad", op_type=OpType.GMM,
192 inputs=[w2_grad_x1, swiglu_out, glist], outputs=[w2_grad_y],
193 param_positions=[17, 10, 19, 18], split_value=w2_grad_sv,
194 split_spec=SplitSpec(
195 split_inputs=[(1, 0)], split_output_dims=[0],
196 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num,
197 ),
198 tiling_position=_TILING_POS_W2_GRAD,
199 fill_config=GmmFillConfig(offset_inputs={0, 1}, rank_in_event=False, global_trigger=True,
200 out_offset=False, advance="cube_custom",
201 event_delta=1, num_cube_cores=num_cube_cores),
202 )
203 return gate_grad, combine, w2_grad
204
205
206def build_backward_graph(tsv, *,
207 dispatch_sv: int = 128,
208 act_grad_sv: int = 4096,
209 w1_grad_sv: int = 4096,
210 swiglu_sv: int = 128,
211 gate_grad_sv: int = 4096,
212 w2_grad_sv: int = 4096,
213 combine_sv: int = 128,
214 hidden_size: int = 7168,
215 intermediate_size: int = 2048,
216 dtype_size: int = 2,
217 num_cube_cores: int = 24) -> ComputeGraph:
218 """Build the MoE-FFN backward DAG.
219
220 Execution order:
221 dispatch -> {act_grad, w1_grad} -> swiglu_grad -> gate_grad -> {combine, w2_grad}
222 act_grad and w1_grad run in parallel (both depend on dispatch).
223 combine and w2_grad run in parallel (both depend on gate_grad).
224
225 Param positions (C++ memory slots):
226 dispatch: [target_offset=1, src=2, src_offset=3, size=4 | target=0]
227 act_grad: [x=0, weight=7, glist=19 | y=8]
228 w1_grad: [x1=5, x2=0, glist=19 | y=6]
229 swiglu_grad: [x=8, dy=9 | out=10]
230 gate_grad: [x=10, weight=11, glist=19 | y=12]
231 combine: [target_offset=14, src=12, src_offset=15, size=16 | target=13]
232 w2_grad: [x1=17, x2=10, glist=19 | y=18]
233
234 Args:
235 tsv: TaskSplitValue carrying TP/EP/seq partition metadata.
236 dispatch_sv: tile size for dispatch AllToAll.
237 act_grad_sv: tile size for activation gradient GMM.
238 w1_grad_sv: tile size for w1 weight gradient GMM.
239 swiglu_sv: tile size for SwiGLU-grad.
240 gate_grad_sv: tile size for gate gradient GMM.
241 w2_grad_sv: tile size for w2 weight gradient GMM.
242 combine_sv: tile size for combine AllToAll.
243 hidden_size: model hidden dimension.
244 intermediate_size: FFN intermediate dimension after SwiGLU halving (half of up-proj output).
245 dtype_size: bytes per activation element (2=bf16, 4=fp32).
246 num_cube_cores: number of AIC cube cores on the target device.
247
248 Returns:
249 A fully-connected ComputeGraph ready for propagate_splits().
250 """
251 specs = _build_bwd_tensor_specs(tsv, hidden_size, intermediate_size, dtype_size)
252 dispatch, act_grad, w1_grad, swiglu_grad = _build_bwd_ops_first(
253 tsv, specs,
254 dispatch_sv=dispatch_sv, act_grad_sv=act_grad_sv,
255 w1_grad_sv=w1_grad_sv, swiglu_sv=swiglu_sv, num_cube_cores=num_cube_cores,
256 )
257 gate_grad, combine, w2_grad = _build_bwd_ops_second(
258 specs,
259 gate_grad_sv=gate_grad_sv, w2_grad_sv=w2_grad_sv,
260 combine_sv=combine_sv, num_cube_cores=num_cube_cores,
261 )
262 graph = ComputeGraph()
263 (graph.add_op(dispatch).add_op(act_grad).add_op(w1_grad)
264 .add_op(swiglu_grad).add_op(gate_grad).add_op(combine).add_op(w2_grad)
265 .add_edge(dispatch, act_grad)
266 .add_edge(dispatch, w1_grad)
267 .add_edge(act_grad, swiglu_grad)
268 .add_edge(swiglu_grad, gate_grad)
269 .add_edge(gate_grad, combine)
270 .add_edge(gate_grad, w2_grad))
271 return graph