1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Common utilities shared by forward and backward task builders.
17Functions here are byte-for-byte identical between fwd and bwd C++ code.
18"""
19
20from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import (
21 TensorDescC, RuntimeConfigC,
22 MAX_TENSOR_DIMS,
23)
24
25
26# ── TensorDescC quick constructor ─────────────────────────────────────────────
27
28def make_tensor_desc(
29 tensor_type: int = 0,
30 data_type: int = 2,
31 input_position: int = 0,
32 base_ptr_offset: int = 0,
33 transpose_flag: int = 0,
34 dynamic_shape: int = 0,
35 dynamic_dim: int = 0,
36 shape: list = None,
37) -> TensorDescC:
38 """Build a TensorDescC with the given fields."""
39 t = TensorDescC()
40 t.tensor_type = tensor_type
41 t.data_type = data_type
42 t.input_position = input_position
43 t.base_ptr_offset = base_ptr_offset
44 t.transpose_flag = transpose_flag
45 t.dynamic_shape = dynamic_shape
46 t.dynamic_dim = dynamic_dim
47 if shape:
48 for k, v in enumerate(shape[:MAX_TENSOR_DIMS]):
49 t.dim[k] = v
50 return t
51
52
53# ── tsv counter helpers ───────────────────────────────────────────────────────
54
55def advance_tsv_vector(tsv, task_num: int, event_group_size: int) -> None:
56 """Advance counters after a vector op with event grouping."""
57 tsv.pre_task_num += task_num
58 tsv.pre_vector_task_num += task_num
59 tsv.pre_pre_event_num = tsv.pre_event_num
60 tsv.pre_event_num += event_group_size
61
62
63def advance_tsv_cube(tsv, task_num: int, event_group_size: int) -> None:
64 """Advance counters after a cube op with event grouping."""
65 tsv.pre_task_num += task_num
66 tsv.pre_cube_task_num += task_num
67 tsv.pre_pre_event_num = tsv.pre_event_num
68 tsv.pre_event_num += event_group_size
69
70
71def advance_tsv_cube_only(tsv, task_num: int) -> None:
72 """Advance only task/cube counters (no event update). Used by bwd GMM1."""
73 tsv.pre_task_num += task_num
74 tsv.pre_cube_task_num += task_num
75
76
77def advance_tsv_vector_only(tsv, task_num: int) -> None:
78 """Advance only task/vector counters (no event update). Used by bwd A2."""
79 tsv.pre_task_num += task_num
80 tsv.pre_vector_task_num += task_num
81
82
83# ── revise_task_queue (identical in fwd and bwd) ──────────────────────────────
84
85def revise_task_queue(cfg: RuntimeConfigC, tsv,
86 dispatch_task_num: int, swiglu_task_num: int) -> None:
87 """
88 Reorder vector_task_indices for dispatch and combine segments based on rank_id.
89 Directly translates C++ revise_task_queue.
90
91 dispatch_task_num : task_num of the dispatch (A1) operator
92 swiglu_task_num : task_num of the swiglu / swiglu_grad operator
93 """
94 # snapshot current queue
95 temp = list(cfg.vector_task_indices)
96
97 single_rank_expert_num = tsv.single_rank_expert_num
98 single_expert_task_num = dispatch_task_num // tsv.all_expert_num
99 ep = tsv.ep
100 rank_id = tsv.rank_id
101 single_rank_task_num = dispatch_task_num // tsv.ep
102
103 ep_rank = [(i + rank_id) % ep for i in range(ep)]
104
105 # ── dispatch segment ─────────────────────────────────────────────────────
106 start = 0
107 index = 0
108 for j in range(single_rank_expert_num):
109 j_v = j * single_expert_task_num
110 for k in range(single_expert_task_num):
111 k_v = j_v + k
112 for i in ep_rank:
113 i_v = k_v + i * single_rank_task_num
114 cfg.vector_task_indices[start + index] = temp[start + i_v]
115 index += 1
116
117 # ── combine segment ──────────────────────────────────────────────────────
118 start = dispatch_task_num + swiglu_task_num
119 index = 0
120 for j in range(single_rank_expert_num):
121 j_v = j * single_expert_task_num
122 for k in range(single_expert_task_num):
123 k_v = j_v + k
124 for i in ep_rank:
125 i_v = k_v + i * single_rank_task_num
126 cfg.vector_task_indices[start + index] = temp[start + i_v]
127 index += 1