Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / multicore / modules / moe_ffn / backward / backward_graph.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MoE-FFN backward compute graph.""" 

16 

17from hyper_parallel.core.multicore.modules.moe_ffn.common.compute_graph import ( 

18 ComputeGraph, OperatorNode, TensorSpec, SplitSpec, OpType, 

19) 

20from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builders import ( 

21 AllToAllFillConfig, AllToAllType, GmmFillConfig, SwiGLUFillConfig, 

22) 

23 

24# Tiling buffer positions in the C++ runtime descriptor array. 

25_TILING_POS_ACT_GRAD = 20 

26_TILING_POS_GATE_GRAD = 21 

27_TILING_POS_W2_GRAD = 22 

28_TILING_POS_W1_GRAD = 23 

29_TILING_POS_SWIGLU_GRAD = 24 

30 

31 

32def _build_bwd_tensor_specs(tsv, hidden_size, intermediate_size, dtype_size): 

33 """Create all TensorSpec objects for the backward graph.""" 

34 sre = tsv.single_rank_expert_num 

35 target = TensorSpec( 

36 "target", [tsv.per_rank_seq, hidden_size], 

37 param_position=0, dtype_size=dtype_size, tensor_type=1, is_dynamic=True) 

38 target_offset = TensorSpec( 

39 "target_offset", [tsv.all_expert_num], 

40 param_position=1, dtype_size=8) 

41 src = TensorSpec( 

42 "src", [tsv.per_rank_seq, hidden_size], 

43 param_position=2, dtype_size=dtype_size, tensor_type=1) 

44 src_offset = TensorSpec( 

45 "src_offset", [tsv.all_expert_num], 

46 param_position=3, dtype_size=8) 

47 size_d = TensorSpec( 

48 "size_d", [tsv.all_expert_num], 

49 param_position=4, dtype_size=4) 

50 w1_grad_x1 = TensorSpec( 

51 "w1_grad_x1", [tsv.per_rank_seq, intermediate_size], 

52 param_position=5, dtype_size=dtype_size, 

53 tensor_type=1, is_dynamic=True, transpose=True) 

54 w1_grad_y = TensorSpec( 

55 "w1_grad_y", [sre, intermediate_size, hidden_size], 

56 param_position=6, dtype_size=dtype_size, tensor_type=1) 

57 act_grad_weight = TensorSpec( 

58 "act_grad_weight", [sre, intermediate_size, hidden_size], 

59 param_position=7, dtype_size=dtype_size, tensor_type=1, transpose=True) 

60 act_grad_y = TensorSpec( 

61 "act_grad_y", [tsv.per_rank_seq, intermediate_size], 

62 param_position=8, dtype_size=dtype_size, tensor_type=1, is_dynamic=True) 

63 swiglu_dy = TensorSpec( 

64 "swiglu_dy", [tsv.per_rank_seq, intermediate_size * 2], 

65 param_position=9, dtype_size=dtype_size, tensor_type=1, is_dynamic=True) 

66 swiglu_out = TensorSpec( 

67 "swiglu_out", [tsv.per_rank_seq, intermediate_size * 2], 

68 param_position=10, dtype_size=dtype_size, tensor_type=1, is_dynamic=True) 

69 gate_grad_weight = TensorSpec( 

70 "gate_grad_weight", [sre, hidden_size, intermediate_size * 2], 

71 param_position=11, dtype_size=dtype_size, tensor_type=1, transpose=True) 

72 gate_grad_y = TensorSpec( 

73 "gate_grad_y", [tsv.per_rank_seq, hidden_size], 

74 param_position=12, dtype_size=dtype_size, tensor_type=1, is_dynamic=True) 

75 combine_out = TensorSpec( 

76 "combine_out", [tsv.per_rank_seq, hidden_size], 

77 param_position=13, dtype_size=dtype_size, tensor_type=1) 

78 target_offset_c = TensorSpec( 

79 "target_offset_c", [tsv.all_expert_num], 

80 param_position=14, dtype_size=8) 

81 src_offset_c = TensorSpec( 

82 "src_offset_c", [tsv.all_expert_num], 

83 param_position=15, dtype_size=8) 

84 size_c = TensorSpec( 

85 "size_c", [tsv.all_expert_num], 

86 param_position=16, dtype_size=4) 

87 w2_grad_x1 = TensorSpec( 

88 "w2_grad_x1", [tsv.per_rank_seq, hidden_size], 

89 param_position=17, dtype_size=dtype_size, 

90 tensor_type=1, is_dynamic=True, transpose=True) 

91 w2_grad_y = TensorSpec( 

92 "w2_grad_y", [sre, hidden_size, intermediate_size * 2], 

93 param_position=18, dtype_size=dtype_size, tensor_type=1) 

94 glist = TensorSpec( 

95 "glist", [sre], 

96 param_position=19, dtype_size=8) 

97 return (target, target_offset, src, src_offset, size_d, 

98 w1_grad_x1, w1_grad_y, act_grad_weight, act_grad_y, 

99 swiglu_dy, swiglu_out, gate_grad_weight, gate_grad_y, 

100 combine_out, target_offset_c, src_offset_c, size_c, 

101 w2_grad_x1, w2_grad_y, glist) 

102 

103 

104def _build_bwd_ops_first(tsv, specs, *, dispatch_sv, act_grad_sv, w1_grad_sv, 

105 swiglu_sv, num_cube_cores): 

106 """Create dispatch, act_grad, w1_grad, swiglu_grad operator nodes.""" 

107 (target, target_offset, src, src_offset, size_d, 

108 w1_grad_x1, w1_grad_y, act_grad_weight, act_grad_y, 

109 swiglu_dy, swiglu_out, *_) = specs 

110 glist = specs[-1] 

111 dispatch = OperatorNode( 

112 name="dispatch", op_type=OpType.ALLTOALL, 

113 inputs=[target_offset, src, src_offset, size_d], outputs=[target], 

114 param_positions=[1, 2, 3, 4, 0], split_value=dispatch_sv, 

115 split_spec=SplitSpec( 

116 split_inputs=None, split_output_dims=[0], 

117 task_num_fn=lambda tsv: tsv.all_expert_num * (tsv.per_expert_seq_to_other // dispatch_sv), 

118 ), 

119 tiling_position=-1, 

120 fill_config=AllToAllFillConfig(moe_type=AllToAllType.DISPATCH, advance="vector", 

121 event_group=tsv.all_expert_num), 

122 ) 

123 act_grad = OperatorNode( 

124 name="act_grad", op_type=OpType.GMM, 

125 inputs=[target, act_grad_weight, glist], outputs=[act_grad_y], 

126 param_positions=[0, 7, 19, 8], split_value=act_grad_sv, 

127 split_spec=SplitSpec( 

128 split_inputs=[(0, 0)], split_output_dims=[0], 

129 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num, 

130 ), 

131 tiling_position=_TILING_POS_ACT_GRAD, 

132 fill_config=GmmFillConfig(offset_inputs={0}, rank_in_event=True, global_trigger=False, 

133 out_offset=True, advance="cube_only", num_cube_cores=num_cube_cores), 

134 ) 

135 w1_grad = OperatorNode( 

136 name="w1_grad", op_type=OpType.GMM, 

137 inputs=[w1_grad_x1, target, glist], outputs=[w1_grad_y], 

138 param_positions=[5, 0, 19, 6], split_value=w1_grad_sv, 

139 split_spec=SplitSpec( 

140 split_inputs=[(1, 0)], split_output_dims=[0], 

141 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num, 

142 ), 

143 tiling_position=_TILING_POS_W1_GRAD, 

144 fill_config=GmmFillConfig(offset_inputs={0, 1}, rank_in_event=True, global_trigger=True, 

145 out_offset=False, advance="cube_custom", 

146 event_delta=tsv.single_rank_expert_num, num_cube_cores=num_cube_cores), 

147 ) 

148 swiglu_grad = OperatorNode( 

149 name="swiglu_grad", op_type=OpType.SWIGLU_GRAD, 

150 inputs=[act_grad_y, swiglu_dy], outputs=[swiglu_out], 

151 param_positions=[8, 9, 10], split_value=swiglu_sv, 

152 split_spec=SplitSpec( 

153 split_inputs=[(0, 0)], split_output_dims=[0], 

154 task_num_fn=lambda tsv: (tsv.per_expert_seq // swiglu_sv) * tsv.single_rank_expert_num, 

155 ), 

156 tiling_position=_TILING_POS_SWIGLU_GRAD, 

157 fill_config=SwiGLUFillConfig(), 

158 ) 

159 return dispatch, act_grad, w1_grad, swiglu_grad 

160 

161 

162def _build_bwd_ops_second(specs, *, gate_grad_sv, w2_grad_sv, combine_sv, num_cube_cores): 

163 """Create gate_grad, combine, w2_grad operator nodes.""" 

164 (*_, swiglu_out, gate_grad_weight, gate_grad_y, 

165 combine_out, target_offset_c, src_offset_c, size_c, 

166 w2_grad_x1, w2_grad_y, glist) = specs 

167 gate_grad = OperatorNode( 

168 name="gate_grad", op_type=OpType.GMM, 

169 inputs=[swiglu_out, gate_grad_weight, glist], outputs=[gate_grad_y], 

170 param_positions=[10, 11, 19, 12], split_value=gate_grad_sv, 

171 split_spec=SplitSpec( 

172 split_inputs=[(0, 0)], split_output_dims=[0], 

173 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num, 

174 ), 

175 tiling_position=_TILING_POS_GATE_GRAD, 

176 fill_config=GmmFillConfig(offset_inputs={0}, rank_in_event=False, global_trigger=False, 

177 out_offset=True, advance="cube", num_cube_cores=num_cube_cores), 

178 ) 

179 combine = OperatorNode( 

180 name="combine", op_type=OpType.ALLTOALL, 

181 inputs=[target_offset_c, gate_grad_y, src_offset_c, size_c], outputs=[combine_out], 

182 param_positions=[14, 12, 15, 16, 13], split_value=combine_sv, 

183 split_spec=SplitSpec( 

184 split_inputs=[(1, 0)], split_output_dims=[0], 

185 task_num_fn=lambda tsv: tsv.all_expert_num * (tsv.per_expert_seq_to_other // combine_sv), 

186 ), 

187 tiling_position=-1, 

188 fill_config=AllToAllFillConfig(moe_type=AllToAllType.COMBINE, advance="vector_only", event_group=1), 

189 ) 

190 w2_grad = OperatorNode( 

191 name="w2_grad", op_type=OpType.GMM, 

192 inputs=[w2_grad_x1, swiglu_out, glist], outputs=[w2_grad_y], 

193 param_positions=[17, 10, 19, 18], split_value=w2_grad_sv, 

194 split_spec=SplitSpec( 

195 split_inputs=[(1, 0)], split_output_dims=[0], 

196 task_num_fn=lambda tsv, _ncc=num_cube_cores: _ncc * tsv.single_rank_expert_num, 

197 ), 

198 tiling_position=_TILING_POS_W2_GRAD, 

199 fill_config=GmmFillConfig(offset_inputs={0, 1}, rank_in_event=False, global_trigger=True, 

200 out_offset=False, advance="cube_custom", 

201 event_delta=1, num_cube_cores=num_cube_cores), 

202 ) 

203 return gate_grad, combine, w2_grad 

204 

205 

206def build_backward_graph(tsv, *, 

207 dispatch_sv: int = 128, 

208 act_grad_sv: int = 4096, 

209 w1_grad_sv: int = 4096, 

210 swiglu_sv: int = 128, 

211 gate_grad_sv: int = 4096, 

212 w2_grad_sv: int = 4096, 

213 combine_sv: int = 128, 

214 hidden_size: int = 7168, 

215 intermediate_size: int = 2048, 

216 dtype_size: int = 2, 

217 num_cube_cores: int = 24) -> ComputeGraph: 

218 """Build the MoE-FFN backward DAG. 

219 

220 Execution order: 

221 dispatch -> {act_grad, w1_grad} -> swiglu_grad -> gate_grad -> {combine, w2_grad} 

222 act_grad and w1_grad run in parallel (both depend on dispatch). 

223 combine and w2_grad run in parallel (both depend on gate_grad). 

224 

225 Param positions (C++ memory slots): 

226 dispatch: [target_offset=1, src=2, src_offset=3, size=4 | target=0] 

227 act_grad: [x=0, weight=7, glist=19 | y=8] 

228 w1_grad: [x1=5, x2=0, glist=19 | y=6] 

229 swiglu_grad: [x=8, dy=9 | out=10] 

230 gate_grad: [x=10, weight=11, glist=19 | y=12] 

231 combine: [target_offset=14, src=12, src_offset=15, size=16 | target=13] 

232 w2_grad: [x1=17, x2=10, glist=19 | y=18] 

233 

234 Args: 

235 tsv: TaskSplitValue carrying TP/EP/seq partition metadata. 

236 dispatch_sv: tile size for dispatch AllToAll. 

237 act_grad_sv: tile size for activation gradient GMM. 

238 w1_grad_sv: tile size for w1 weight gradient GMM. 

239 swiglu_sv: tile size for SwiGLU-grad. 

240 gate_grad_sv: tile size for gate gradient GMM. 

241 w2_grad_sv: tile size for w2 weight gradient GMM. 

242 combine_sv: tile size for combine AllToAll. 

243 hidden_size: model hidden dimension. 

244 intermediate_size: FFN intermediate dimension after SwiGLU halving (half of up-proj output). 

245 dtype_size: bytes per activation element (2=bf16, 4=fp32). 

246 num_cube_cores: number of AIC cube cores on the target device. 

247 

248 Returns: 

249 A fully-connected ComputeGraph ready for propagate_splits(). 

250 """ 

251 specs = _build_bwd_tensor_specs(tsv, hidden_size, intermediate_size, dtype_size) 

252 dispatch, act_grad, w1_grad, swiglu_grad = _build_bwd_ops_first( 

253 tsv, specs, 

254 dispatch_sv=dispatch_sv, act_grad_sv=act_grad_sv, 

255 w1_grad_sv=w1_grad_sv, swiglu_sv=swiglu_sv, num_cube_cores=num_cube_cores, 

256 ) 

257 gate_grad, combine, w2_grad = _build_bwd_ops_second( 

258 specs, 

259 gate_grad_sv=gate_grad_sv, w2_grad_sv=w2_grad_sv, 

260 combine_sv=combine_sv, num_cube_cores=num_cube_cores, 

261 ) 

262 graph = ComputeGraph() 

263 (graph.add_op(dispatch).add_op(act_grad).add_op(w1_grad) 

264 .add_op(swiglu_grad).add_op(gate_grad).add_op(combine).add_op(w2_grad) 

265 .add_edge(dispatch, act_grad) 

266 .add_edge(dispatch, w1_grad) 

267 .add_edge(act_grad, swiglu_grad) 

268 .add_edge(swiglu_grad, gate_grad) 

269 .add_edge(gate_grad, combine) 

270 .add_edge(gate_grad, w2_grad)) 

271 return graph