1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Backward tiling lookup tables.
17Each dict maps split_value -> CSV tiling string (for GMM ops) or
18struct fields (for SwiGLU-grad).
19
20C++ source functions:
21 GMM1 (pos 20) → get_tiling_data_first_gmm(split_value)
22 GMM2 (pos 21) → get_tiling_data_second_gmm(split_value)
23 GMM3 (pos 22) → get_tiling_data_third_gmm(split_value)
24 GMM4 (pos 23) → get_tiling_data_fourth_gmm(split_value)
25 SwiGLU-grad (pos 24) → get_tiling_data_swiglu_grad(split_value)
26"""
27# pylint: disable=line-too-long
28import copy
29import ctypes
30
31from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import (
32 SwiGluTilingDataC,
33)
34
35
36# ── helpers ───────────────────────────────────────────────────────────────────
37
38def _expand_gmm_string(csv: str, repeat: int = 24) -> bytes:
39 """Parse comma-separated uint32 string and tile it `repeat` times."""
40 vals = [int(x) for x in csv.split(',')]
41 arr = (ctypes.c_uint32 * (len(vals) * repeat))()
42 for i in range(len(vals) * repeat):
43 arr[i] = vals[i % len(vals)]
44 return bytes(arr)
45
46
47def _expand_swiglu_struct(td: SwiGluTilingDataC) -> bytes:
48 """Tile SwiGluTilingDataC 49 times (= 2 * 24 AI Cube cores + 1)."""
49 return bytes(td) * 49
50
51
52def _find_protected_positions(vals, active_map, split_value, is_weight_grad):
53 """Return index positions that must not be replaced by dim_map substitution.
54
55 When split_value collides with a K/N dim in active_map, the M-tile or
56 K-tile positions (depending on GMM type) must be shielded.
57 T is located via the 393216 anchor which sits at T+19.
58 """
59 if split_value is None or split_value not in active_map:
60 return set()
61 for i in range(350, len(vals)):
62 if vals[i] == 393216: # anchor always at t_base+19
63 t_base = i - 19
64 if is_weight_grad:
65 return {t_base + 3, t_base + 4, t_base + 7}
66 return {14, t_base + 1, t_base + 5}
67 return set()
68
69
70def _patch_gmm_dims(csv: str, dim_map: dict, num_groups: int,
71 default_groups: int = 8,
72 split_value: int = None,
73 num_cube_cores: int = 24,
74 default_num_cores: int = 24) -> str:
75 """Replace matrix dimension values in a GMM tiling CSV string.
76
77 dim_map: {old_int_value: new_int_value} for K/N dimension replacements.
78 Entries where old == new are no-ops and can be omitted.
79 num_groups: replacement for default_groups; only applied at index > 350.
80 split_value: M tile size (table key). When it equals a K/N value in dim_map,
81 we must protect M-tile positions from accidental replacement.
82 Two GMM types detected via vals[10]:
83 0 → activation GMM: M tile at {14, T+1, T+5}
84 2 → weight-grad GMM: split (K-tile) at {T+3, T+4, T+7}
85 T is found via the 393216 anchor fixed at T+19.
86 num_cube_cores: number of AI Cube cores on the target hardware (default 24 = 910B).
87 Replaces the hardcoded usedCoreNum/blockDim fields in the CSV:
88 vals[1] — header blockDim field
89 i > 350 — tail usedCoreNum and any derived parallelism fields
90 """
91 active_map = {k: v for k, v in dim_map.items() if k != v}
92 groups_changed = num_groups != default_groups
93 cores_changed = num_cube_cores != default_num_cores
94 if not active_map and not groups_changed and not cores_changed:
95 return csv # fast path: nothing to do
96
97 vals = list(map(int, csv.split(',')))
98 is_weight_grad = vals[10] == 2
99
100 # Header fields (set before main loop, not touched in loop)
101 if groups_changed and vals[0] == default_groups:
102 vals[0] = num_groups # vals[0] = groupNum
103 if cores_changed and vals[1] == default_num_cores:
104 vals[1] = num_cube_cores # vals[1] = blockDim
105
106 protected = _find_protected_positions(vals, active_map, split_value, is_weight_grad)
107
108 for i, v in enumerate(vals):
109 if i in protected:
110 continue
111 if v in active_map:
112 vals[i] = active_map[v]
113 elif cores_changed and i > 350 and v == default_num_cores:
114 vals[i] = num_cube_cores # covers T+0 (usedCoreNum) and derived fields
115 return ','.join(map(str, vals))
116
117
118def _make_swiglu(row_len, col_len, base_row_len, base_col_len):
119 """Build a SwiGluTilingDataC for SwiGLU-grad with the given tile dimensions."""
120 td = SwiGluTilingDataC()
121 td.is32BAligned = 1
122 td.isDoubleBuffer = 1
123 td.rowLen = row_len
124 td.colLen = col_len
125 td.baseRowLen = base_row_len
126 td.baseColLen = base_col_len
127 td.activateLeft = 0
128 td.biasIsEmpty = 0
129 td.quantScaleIsEmpty = 0
130 td.activateScaleIsEmpty = 0
131 td.swiColLen = 0
132 td.perRowLen = 0
133 td.modRowLen = 0
134 td.usedCoreNum = 0
135 return td
136
137
138# ── Backward GMM1: x=[per_rank_seq,7168], weight=[E,2048,7168] ───────────────
139# C++: get_tiling_data_first_gmm(split_value)
140# x input [*, 7168] → y output [*, 2048]
141
142FIRST_GMM_TABLE = {
143 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,2048,7168,7168,4096,256,7168,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0",
144}
145
146
147# ── Backward GMM2: x=[per_rank_seq,4096], weight=[E,7168,4096] ───────────────
148# C++: get_tiling_data_second_gmm(split_value)
149# x input [*, 4096] → y output [*, 7168]
150
151SECOND_GMM_TABLE = {
152 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,7168,4096,4096,4096,256,4096,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0",
153}
154
155
156# ── Backward GMM3: x1^T=[sre,7168], x2=[sre,4096], weight grad ───────────────
157# C++: get_tiling_data_third_gmm(split_value)
158
159THIRD_GMM_TABLE = {
160 4096: "8,24,0,0,0,0,0,1,1,1,2,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,7168,4096,4096,4096,7168,256,4096,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0",
161}
162
163
164# ── Backward GMM4: x1^T=[sre,2048], x2=[sre,7168], weight grad ───────────────
165# C++: get_tiling_data_fourth_gmm(split_value)
166
167FOURTH_GMM_TABLE = {
168 4096: "8,24,0,0,0,0,0,1,1,1,2,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,2048,7168,4096,4096,2048,256,4096,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0",
169}
170
171
172# ── Backward SwiGLU-grad tiling ───────────────────────────────────────────────
173# C++: get_tiling_data_swiglu_grad(split_value) → baseColLen=256
174
175SWIGLU_GRAD_TABLE = {
176 32: _make_swiglu(row_len=32, col_len=2048, base_row_len=19, base_col_len=256),
177 64: _make_swiglu(row_len=64, col_len=2048, base_row_len=19, base_col_len=256),
178 128: _make_swiglu(row_len=128, col_len=2048, base_row_len=19, base_col_len=256),
179}
180
181
182# ── Public API ────────────────────────────────────────────────────────────────
183
184def get_act_grad_tiling_bytes(split_value: int, *,
185 hidden_size: int = 7168,
186 intermediate_size: int = 2048,
187 num_groups: int = 8,
188 num_cube_cores: int = 24) -> bytes:
189 """Backward GMM1 [x,hidden]×[E,intermediate,hidden] → y=[x,intermediate]."""
190 csv = FIRST_GMM_TABLE.get(split_value)
191 if csv is None:
192 raise KeyError(f"No first_gmm tiling for split_value={split_value}")
193 csv = _patch_gmm_dims(csv,
194 {7168: hidden_size, 2048: intermediate_size},
195 num_groups, split_value=split_value,
196 num_cube_cores=num_cube_cores)
197 return _expand_gmm_string(csv, repeat=num_cube_cores)
198
199
200def get_gate_grad_tiling_bytes(split_value: int, *,
201 hidden_size: int = 7168,
202 intermediate_size: int = 2048,
203 num_groups: int = 8,
204 num_cube_cores: int = 24) -> bytes:
205 """Backward GMM2 [x,intermediate*2]×[E,hidden,intermediate*2] → y=[x,hidden]."""
206 csv = SECOND_GMM_TABLE.get(split_value)
207 if csv is None:
208 raise KeyError(f"No second_gmm tiling for split_value={split_value}")
209 csv = _patch_gmm_dims(csv,
210 {4096: intermediate_size * 2, 7168: hidden_size},
211 num_groups, split_value=split_value,
212 num_cube_cores=num_cube_cores)
213 return _expand_gmm_string(csv, repeat=num_cube_cores)
214
215
216def get_w2_grad_tiling_bytes(split_value: int, *,
217 hidden_size: int = 7168,
218 intermediate_size: int = 2048,
219 num_groups: int = 8,
220 num_cube_cores: int = 24) -> bytes:
221 """Backward GMM3 weight-grad [hidden,sre]×[sre,intermediate*2] → grad=[hidden,intermediate*2]."""
222 csv = THIRD_GMM_TABLE.get(split_value)
223 if csv is None:
224 raise KeyError(f"No third_gmm tiling for split_value={split_value}")
225 csv = _patch_gmm_dims(csv,
226 {7168: hidden_size, 4096: intermediate_size * 2},
227 num_groups, split_value=split_value,
228 num_cube_cores=num_cube_cores)
229 return _expand_gmm_string(csv, repeat=num_cube_cores)
230
231
232def get_w1_grad_tiling_bytes(split_value: int, *,
233 hidden_size: int = 7168,
234 intermediate_size: int = 2048,
235 num_groups: int = 8,
236 num_cube_cores: int = 24) -> bytes:
237 """Backward GMM4 weight-grad [intermediate,sre]×[sre,hidden] → grad=[intermediate,hidden]."""
238 csv = FOURTH_GMM_TABLE.get(split_value)
239 if csv is None:
240 raise KeyError(f"No fourth_gmm tiling for split_value={split_value}")
241 csv = _patch_gmm_dims(csv,
242 {2048: intermediate_size, 7168: hidden_size,
243 4096: intermediate_size * 2},
244 num_groups, split_value=split_value,
245 num_cube_cores=num_cube_cores)
246 return _expand_gmm_string(csv, repeat=num_cube_cores)
247
248
249def get_swiglu_grad_tiling_bytes(split_value: int, *,
250 intermediate_size: int = 2048) -> bytes:
251 """Backward SwiGLU-grad tiling (colLen=intermediate_size)."""
252 td = SWIGLU_GRAD_TABLE.get(split_value)
253 if td is None:
254 raise KeyError(f"No swiglu_grad tiling for split_value={split_value}")
255 td = copy.copy(td)
256 td.colLen = intermediate_size
257 return _expand_swiglu_struct(td)