1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Polymorphic fill configs for Multicore MoE-FFN task scheduling.
17
18Each operator type has a FillConfig subclass that encapsulates both the
19config data and the fill logic (fill method). OperatorNode.fill_config
20holds an instance; gen_runtime_data calls op.fill_config.fill(cfg, op, tsv).
21
22Public fill config classes:
23 AllToAllFillConfig — dispatch and combine (fwd + bwd), unified
24 GmmFillConfig — all GMM variants G1/G2/G3/G4 (fwd + bwd)
25 SwiGLUFillConfig — SwiGLU fwd and SwiGLU-grad bwd (no config fields)
26
27Utility functions (called directly from gen_runtime_data):
28 add_terminate — terminate task; caller passes trigger_count int
29 add_dynamic_data — dynamic data record (fwd pos=6, bwd pos=19)
30 revise_gmm_task_queue_bwd — backward GMM1/GMM4 interleave in cube_task_indices
31"""
32
33
34from abc import ABC, abstractmethod
35from dataclasses import dataclass, field
36from enum import Enum
37from typing import Set
38
39from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import (
40 TaskDescC, TensorDescC, RuntimeConfigC,
41 TaskAiCoreType, TaskType, DynamicType,
42 MAX_TENSOR_DIMS,
43)
44from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builder_utils import (
45 advance_tsv_vector, advance_tsv_cube,
46 advance_tsv_cube_only, advance_tsv_vector_only,
47)
48from hyper_parallel.core.multicore.modules.moe_ffn.common.compute_graph import OpType
49
50
51# ── MoE AllToAll type enum ────────────────────────────────────────────────────
52
53class AllToAllType(Enum):
54 """
55 MoE AllToAll semantic type, determines event wiring.
56
57 DISPATCH — scatter tokens from model-parallel ranks to expert-holding ranks.
58 Event wiring: per-expert trigger.
59 dependent_event = pre_pre_event_num + 0
60 trigger_event = pre_event_num + (i // per_g_e_num) + 1
61 trigger_count = task_num * ep // all_expert_num
62
63 COMBINE — gather expert results back to the originating model-parallel rank.
64 Event wiring: global trigger (wait for all experts before gathering).
65 dependent_event = pre_pre_event_num + (i // per_g_e_num) % sre + 1
66 trigger_event = all_event_num
67 trigger_count = task_num
68
69 OTHER
70 Reserved for AllToAll patterns outside the MoE dispatch/combine semantic.
71 """
72 DISPATCH = 1
73 COMBINE = 2
74 OTHER = 3
75
76
77# ── Abstract base ──────────────────────────────────────────────────────────────
78
79class FillConfig(ABC):
80 """Abstract base for all fill configs. Subclasses hold config data and
81 implement fill() with the actual task-building logic."""
82
83 @abstractmethod
84 def fill(self, cfg: RuntimeConfigC, op, tsv) -> None:
85 """Fill tasks into cfg for the given op using runtime state tsv."""
86
87
88# ── AllToAll (dispatch + combine, forward + backward) ─────────────────────────
89
90@dataclass
91class AllToAllFillConfig(FillConfig):
92 """
93 AllToAll behaviour config covering dispatch and combine for fwd and bwd.
94
95 moe_type : AllToAllType
96 Determines event wiring; see AllToAllType for details.
97
98 advance : "vector" | "vector_only"
99 "vector" — advance_tsv_vector(tsv, task_num, event_group_size=event_group)
100 Advances pre_event_num/pre_pre_event_num/pre_task_num/
101 pre_vector_task_num.
102 Used by: dispatch (fwd/bwd), combine forward.
103 "vector_only" — advance_tsv_vector_only(tsv, task_num)
104 Only advances pre_task_num/pre_vector_task_num; does not
105 advance event counters.
106 Used by: combine backward (event advance is deferred to GMM3).
107
108 event_group : int
109 Only effective when advance="vector"; passed to advance_tsv_vector.
110 dispatch (fwd/bwd): all_expert_num
111 combine forward: 1
112 """
113 moe_type: AllToAllType = AllToAllType.DISPATCH
114 advance: str = "vector" # "vector" | "vector_only"
115 event_group: int = 1 # only used when advance="vector"
116
117 def fill(self, cfg: RuntimeConfigC, op, tsv) -> None:
118 task_num = op.task_num
119 per_g_e_num = task_num // tsv.all_expert_num
120 param = op.param_positions
121
122 for i in range(task_num):
123 task = TaskDescC()
124 task.task_type = TaskType.TASK_SHMEM_PUT_MEM_SIGNAL
125 task.task_aicore_type = TaskAiCoreType.TASK_AICORE_CUBE
126 task.num_inputs = len(op.inputs)
127 task.num_outputs = len(op.outputs)
128
129 for j, spec in enumerate(op.inputs):
130 td = TensorDescC()
131 td.data_type = spec.dtype_size
132 td.input_position = param[j]
133 td.base_ptr_offset = 0
134
135 if j == 1: # src data: tensor_type and is_dynamic come from TensorSpec
136 td.tensor_type = spec.tensor_type
137 td.dynamic_shape = int(spec.is_dynamic)
138 else: # metadata (target_offset / src_offset / size): fixed type=0
139 td.tensor_type = 0
140 td.base_ptr_offset = i // per_g_e_num # expert-group index
141
142 task.inputs[j] = td
143
144 out_spec = op.outputs[0]
145 out = TensorDescC()
146 out.tensor_type = out_spec.tensor_type
147 out.data_type = out_spec.dtype_size
148 out.input_position = param[task.num_inputs]
149 out.base_ptr_offset = 0
150 out.dynamic_shape = int(out_spec.is_dynamic)
151 task.outputs[0] = out
152
153 if self.moe_type == AllToAllType.DISPATCH:
154 res = i // per_g_e_num
155 task.dependent_event = tsv.pre_pre_event_num + 0
156 task.trigger_event = tsv.pre_event_num + res + 1
157 cfg.all_event_num_triggers[task.trigger_event] = (
158 task_num * tsv.ep // tsv.all_expert_num
159 )
160 else: # COMBINE (or OTHER)
161 current_dep = (i // per_g_e_num) % tsv.single_rank_expert_num
162 task.dependent_event = tsv.pre_pre_event_num + current_dep + 1
163 task.trigger_event = tsv.all_event_num
164 cfg.all_event_num_triggers[task.trigger_event] = task_num
165
166 task.task_index = i
167 task.task_split_num = task_num
168 task.task_split_value = op.split_value
169 task.tiling_data_position = 0xFFFFFFFF
170
171 cfg.all_tasks[tsv.pre_task_num + i] = task
172 cfg.vector_task_indices[tsv.pre_vector_task_num + i] = tsv.pre_task_num + i
173
174 cfg.task_index_num[1] += task_num
175 if self.advance == "vector":
176 advance_tsv_vector(tsv, task_num, event_group_size=self.event_group)
177 else: # "vector_only"
178 advance_tsv_vector_only(tsv, task_num)
179
180
181# ── GMM (up_proj/down_proj/act_grad/gate_grad/w1_grad/w2_grad, fwd + bwd) ─────
182
183@dataclass
184class GmmFillConfig(FillConfig):
185 """
186 GMM behaviour config covering all GMM variants (fwd up_proj/down_proj,
187 bwd act_grad/w1_grad/gate_grad/w2_grad).
188 Tensor-level attributes (tensor_type, dtype_size, is_dynamic, transpose,
189 shape) are read from TensorSpec.
190
191 offset_inputs : Set[int]
192 Set of input indices that receive base_ptr_offset and dynamic_shape.
193 fwd/bwd activation GMMs (up_proj, down_proj, act_grad, gate_grad): {0}
194 bwd weight-grad GMMs (w1_grad, w2_grad): {0, 1}
195
196 rank_in_event : bool
197 True → dependent_event adds single_rank_expert_num * rank_id.
198 Used by: GMM1 (fwd/bwd), GMM4/w1_grad (bwd).
199 False → no rank offset.
200 Used by: GMM2 (fwd/bwd), GMM3/w2_grad (bwd).
201
202 global_trigger : bool
203 False → trigger_event = pre_event_num + data_index + 1 (per-expert trigger).
204 Used by: GMM1/GMM2 (activation-gradient path).
205 True → trigger_event = all_event_num (global trigger).
206 Used by: GMM4/GMM3 (weight-gradient path).
207
208 out_offset : bool
209 True → out.base_ptr_offset = data_index * 4096 * shape[1].
210 False → out.base_ptr_offset = 0 (weight-grad output writes to dedicated buffer).
211
212 advance : "cube" | "cube_only" | "cube_custom"
213 "cube" — advance_tsv_cube(tsv, task_num, event_group_size=task_num//CUBE)
214 Standard cube advance (also advances events).
215 Used by: GMM1 fwd, GMM2 fwd/bwd.
216 "cube_only" — advance_tsv_cube_only(tsv, task_num)
217 Advances task/cube counters only; event advance is deferred
218 to the subsequent GMM4.
219 Used by: GMM1 bwd (GMM1 and GMM4 run in parallel sharing events).
220 "cube_custom" — Manual advance: pre_task_num/pre_cube_task_num += task_num,
221 pre_pre_event_num = pre_event_num,
222 pre_event_num += event_delta.
223 Used by: GMM4 bwd (event_delta = sre), GMM3 bwd (event_delta = 1).
224
225 event_delta : int
226 Only effective when advance="cube_custom"; increment for pre_event_num.
227 GMM4 bwd: single_rank_expert_num (computed and passed at graph declaration time).
228 GMM3 bwd: 1.
229 """
230 offset_inputs: Set[int] = field(default_factory=lambda: {0})
231 rank_in_event: bool = False
232 global_trigger: bool = False
233 out_offset: bool = True
234 advance: str = "cube" # "cube" | "cube_only" | "cube_custom"
235 event_delta: int = 0 # only used when advance="cube_custom"
236 num_cube_cores: int = 24 # number of AI Cube cores (910B=24)
237
238 def fill(self, cfg: RuntimeConfigC, op, tsv) -> None:
239 task_num = op.task_num
240 param = op.param_positions
241 glist_j = len(op.inputs) - 1 # last input is always group_list
242
243 for i in range(task_num):
244 data_index = i // self.num_cube_cores
245 task = TaskDescC()
246 task.task_type = TaskType.TASK_GROUPED_MATMUL
247 task.task_aicore_type = TaskAiCoreType.TASK_AICORE_CUBE
248 task.num_inputs = len(op.inputs)
249 task.num_outputs = len(op.outputs)
250
251 for j, spec in enumerate(op.inputs):
252 td = TensorDescC()
253 td.input_position = param[j]
254 td.dynamic_shape = 0
255 td.base_ptr_offset = 0
256
257 if j == glist_j: # group_list: fixed type=0, dtype int64
258 td.tensor_type = 0
259 td.data_type = spec.dtype_size
260 else: # x or weight: tensor_type comes from TensorSpec
261 td.tensor_type = spec.tensor_type
262 td.data_type = spec.dtype_size
263
264 if j in self.offset_inputs:
265 td.base_ptr_offset = data_index * 4096 * spec.shape[1]
266 td.dynamic_shape = int(spec.is_dynamic)
267
268 td.transpose_flag = int(spec.transpose)
269 for k in range(min(len(spec.shape), MAX_TENSOR_DIMS)):
270 td.dim[k] = spec.shape[k]
271 task.inputs[j] = td
272
273 out_spec = op.outputs[0]
274 out = TensorDescC()
275 out.tensor_type = out_spec.tensor_type
276 out.data_type = out_spec.dtype_size
277 out.input_position = param[task.num_inputs]
278 out.dynamic_shape = int(out_spec.is_dynamic)
279 out.base_ptr_offset = (data_index * 4096 * out_spec.shape[1]
280 if self.out_offset else 0)
281 if self.out_offset: # C++ fills dims only for activation outputs, not weight grads
282 for k in range(min(len(out_spec.shape), MAX_TENSOR_DIMS)):
283 out.dim[k] = out_spec.shape[k]
284 task.outputs[0] = out
285
286 dep_extra = (tsv.single_rank_expert_num * tsv.rank_id
287 if self.rank_in_event else 0)
288 task.dependent_event = tsv.pre_pre_event_num + dep_extra + data_index + 1
289 if self.global_trigger:
290 task.trigger_event = tsv.all_event_num
291 else:
292 task.trigger_event = tsv.pre_event_num + data_index + 1
293 cfg.all_event_num_triggers[task.trigger_event] = self.num_cube_cores
294
295 task.task_index = i
296 task.task_split_num = task_num
297 task.task_split_value = op.split_value
298 task.tiling_data_position = op.tiling_position
299
300 cfg.all_tasks[tsv.pre_task_num + i] = task
301 cfg.cube_task_indices[tsv.pre_cube_task_num + i] = tsv.pre_task_num + i
302
303 cfg.task_index_num[0] += task_num
304 if self.advance == "cube":
305 advance_tsv_cube(tsv, task_num, event_group_size=task_num // self.num_cube_cores)
306 elif self.advance == "cube_only":
307 advance_tsv_cube_only(tsv, task_num)
308 else: # "cube_custom"
309 tsv.pre_task_num += task_num
310 tsv.pre_cube_task_num += task_num
311 tsv.pre_pre_event_num = tsv.pre_event_num
312 tsv.pre_event_num += self.event_delta
313
314
315# ── SwiGLU / SwiGLU-grad ──────────────────────────────────────────────────────
316
317@dataclass
318class SwiGLUFillConfig(FillConfig):
319 """
320 SwiGLU fill config — forward (TASK_SWI_GLU) and backward gradient
321 (TASK_SWI_GLU_GRAD).
322
323 No config fields; task_type is derived from op.op_type, split_value and
324 task_num are read from op. All input/output tensor_type values are always
325 1 (vector operator convention).
326 """
327
328 def fill(self, cfg: RuntimeConfigC, op, tsv) -> None:
329 task_type = (TaskType.TASK_SWI_GLU if op.op_type == OpType.SWIGLU
330 else TaskType.TASK_SWI_GLU_GRAD)
331
332 num_triggers = tsv.per_expert_seq // op.split_value
333 task_num = op.task_num
334 param = op.param_positions
335
336 for i in range(task_num):
337 task = TaskDescC()
338 task.task_type = task_type
339 task.task_aicore_type = TaskAiCoreType.TASK_AICORE_VECTOR
340 task.num_inputs = len(op.inputs)
341 task.num_outputs = len(op.outputs)
342
343 for j, spec in enumerate(op.inputs):
344 td = TensorDescC()
345 td.tensor_type = 1 # all SwiGLU inputs are tensor lists
346 td.data_type = spec.dtype_size
347 td.input_position = param[j]
348 td.base_ptr_offset = i * spec.shape[1] * op.split_value
349 td.dynamic_shape = int(spec.is_dynamic)
350 for k in range(min(len(spec.shape), MAX_TENSOR_DIMS)):
351 td.dim[k] = spec.shape[k]
352 task.inputs[j] = td
353
354 for j, spec in enumerate(op.outputs):
355 td = TensorDescC()
356 td.tensor_type = 1 # all SwiGLU outputs are tensor lists
357 td.data_type = spec.dtype_size
358 td.input_position = param[task.num_inputs + j]
359 td.base_ptr_offset = i * spec.shape[1] * op.split_value
360 td.dynamic_shape = int(spec.is_dynamic)
361 for k in range(min(len(spec.shape), MAX_TENSOR_DIMS)):
362 td.dim[k] = spec.shape[k]
363 task.outputs[j] = td
364
365 ev_idx = i // num_triggers
366 task.dependent_event = tsv.pre_pre_event_num + ev_idx + 1
367 task.trigger_event = tsv.pre_event_num + ev_idx + 1
368 cfg.all_event_num_triggers[task.trigger_event] = num_triggers
369
370 task.task_index = i
371 task.task_split_num = task_num
372 task.task_split_value = op.split_value
373 task.tiling_data_position = op.tiling_position
374
375 cfg.all_tasks[tsv.pre_task_num + i] = task
376 cfg.vector_task_indices[tsv.pre_vector_task_num + i] = tsv.pre_task_num + i
377
378 cfg.task_index_num[1] += task_num
379 advance_tsv_vector(tsv, task_num, event_group_size=tsv.single_rank_expert_num)
380
381
382# ── Utility functions ──────────────────────────────────────────────────────────
383
384def add_terminate(cfg: RuntimeConfigC, tsv, trigger_count: int) -> None:
385 """
386 Append a terminate task to cfg.
387
388 trigger_count : value written to cfg.all_event_num_triggers[all_event_num].
389 forward: combine_op.task_num // tsv.ep * tsv.ep
390 backward: w1_grad_op.task_num + w2_grad_op.task_num + combine_op.task_num // tsv.ep * tsv.ep
391 """
392 cfg.all_event_num_triggers[tsv.all_event_num] = trigger_count
393 task = TaskDescC()
394 task.task_type = TaskType.TASK_TERMINATE
395 task.dependent_event = tsv.pre_pre_event_num + 1
396 task.trigger_event = tsv.pre_event_num + 1
397 cfg.all_event_num_triggers[task.trigger_event] = 1
398
399 cfg.all_tasks[tsv.pre_task_num] = task
400 cfg.vector_task_indices[tsv.pre_vector_task_num] = tsv.pre_task_num
401 cfg.task_index_num[1] += 1
402
403
404def add_dynamic_data(cfg: RuntimeConfigC, tsv, dynamic_input_position: int) -> None:
405 """
406 Write the dynamic data record into cfg.
407 dynamic_input_position: forward = 6, backward = 19.
408 """
409 cfg.dynamic_data.dynamic_type = DynamicType.DYNAMIC_DSV3_MOE_FFN
410 cfg.dynamic_data.dynamic_input_position = dynamic_input_position
411 cfg.dynamic_data.dynamic_group_size = tsv.single_rank_expert_num
412 cfg.dynamic_data.dynamic_max_seq_len = -1
413
414
415def revise_gmm_task_queue_bwd(cfg: RuntimeConfigC, tsv,
416 act_grad_task_num: int,
417 num_cube_cores: int = 24) -> None:
418 """
419 Backward-only: interleave w1_grad and act_grad experts in cube_task_indices.
420 Result pattern: [w1_grad exp0, act_grad exp0, w1_grad exp1, act_grad exp1, ...]
421 act_grad start offset = 0; w1_grad start offset = act_grad_task_num.
422 """
423 temp = list(cfg.cube_task_indices)
424 expert_single = tsv.single_rank_expert_num
425 changes_num = 2 # two streams: w1_grad (index=1) and act_grad (index=0)
426
427 for i in range(expert_single * changes_num):
428 index = 1 - (i % changes_num) # alternates: 1, 0, 1, 0, ...
429 m = i // changes_num # expert block: 0, 0, 1, 1, ...
430 for j in range(num_cube_cores):
431 dst = i * num_cube_cores + j
432 src = index * act_grad_task_num + m * num_cube_cores + j
433 cfg.cube_task_indices[dst] = temp[src]