1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Forward tiling lookup tables.
17Each function mirrors the C++ get_tiling_data_* functions.
18String-based GMM tiling is expanded num_cube_cores times (default 24);
19struct-based SwiGLU tiling is expanded 49 times (= 2 * 24 + 1 for Ascend 910B).
20"""
21# pylint: disable=line-too-long
22import copy
23import ctypes
24
25from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import (
26 SwiGluTilingDataC,
27)
28
29
30# ── helpers ───────────────────────────────────────────────────────────────────
31
32def _expand_gmm_string(csv: str, repeat: int = 24) -> bytes:
33 """Parse comma-separated uint32 string and tile it `repeat` times."""
34 vals = [int(x) for x in csv.split(',')]
35 arr = (ctypes.c_uint32 * (len(vals) * repeat))()
36 for i in range(len(vals) * repeat):
37 arr[i] = vals[i % len(vals)]
38 return bytes(arr)
39
40
41def _expand_swiglu_struct(td: SwiGluTilingDataC) -> bytes:
42 """Tile SwiGluTilingDataC 49 times (= 2 * 24 AI Cube cores + 1)."""
43 one = bytes(td)
44 return one * 49
45
46
47def _find_protected_positions(vals, active_map, split_value, is_weight_grad):
48 """Return index positions that must not be replaced by dim_map substitution.
49
50 When split_value collides with a K/N dim in active_map, the M-tile or
51 K-tile positions (depending on GMM type) must be shielded.
52 T is located via the 393216 anchor which sits at T+19.
53 """
54 if split_value is None or split_value not in active_map:
55 return set()
56 for i in range(350, len(vals)):
57 if vals[i] == 393216: # anchor always at t_base+19
58 t_base = i - 19
59 if is_weight_grad:
60 return {t_base + 3, t_base + 4, t_base + 7}
61 return {14, t_base + 1, t_base + 5}
62 return set()
63
64
65def _patch_gmm_dims(csv: str, dim_map: dict, num_groups: int,
66 default_groups: int = 8,
67 split_value: int = None,
68 num_cube_cores: int = 24,
69 default_num_cores: int = 24) -> str:
70 """Replace matrix dimension values in a GMM tiling CSV string.
71
72 dim_map: {old_int_value: new_int_value} for K/N dimension replacements.
73 Entries where old == new are no-ops and can be omitted.
74 num_groups: replacement for default_groups; only applied at index > 350.
75 split_value: M tile size (table key). When it equals a K/N value in dim_map,
76 we must protect M-tile positions from accidental replacement.
77 Two GMM types detected via vals[10]:
78 0 → activation GMM: M tile at {14, T+1, T+5}
79 2 → weight-grad GMM: split (K-tile) at {T+3, T+4, T+7}
80 T is found via the 393216 anchor fixed at T+19.
81 num_cube_cores: number of AI Cube cores on the target hardware (default 24 = 910B).
82 Replaces the hardcoded usedCoreNum/blockDim fields in the CSV:
83 vals[1] — header blockDim field
84 i > 350 — tail usedCoreNum and any derived parallelism fields
85 """
86 active_map = {k: v for k, v in dim_map.items() if k != v}
87 groups_changed = num_groups != default_groups
88 cores_changed = num_cube_cores != default_num_cores
89 if not active_map and not groups_changed and not cores_changed:
90 return csv # fast path: nothing to do
91
92 vals = list(map(int, csv.split(',')))
93 is_weight_grad = vals[10] == 2
94
95 # Header fields (set before main loop, not touched in loop)
96 if groups_changed and vals[0] == default_groups:
97 vals[0] = num_groups # vals[0] = groupNum
98 if cores_changed and vals[1] == default_num_cores:
99 vals[1] = num_cube_cores # vals[1] = blockDim
100
101 protected = _find_protected_positions(vals, active_map, split_value, is_weight_grad)
102
103 for i, v in enumerate(vals):
104 if i in protected:
105 continue
106 if v in active_map:
107 vals[i] = active_map[v]
108 elif cores_changed and i > 350 and v == default_num_cores:
109 vals[i] = num_cube_cores # covers T+0 (usedCoreNum) and derived fields
110 return ','.join(map(str, vals))
111
112
113# ── Forward GMM1: x=[per_rank_seq,7168], weight=[E,7168,4096] ─────────────────
114# C++: get_tiling_data_gmm_x_7168_g2(split_value)
115
116GMM1_TABLE = {
117 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,4096,7168,7168,4096,256,7168,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0",
118}
119
120
121# ── Forward GMM2: x=[per_rank_seq,2048], weight=[E,2048,7168] ─────────────────
122# C++: get_tiling_data_gmm_x_2048_g2(split_value)
123
124GMM2_TABLE = {
125 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,7168,2048,2048,4096,256,2048,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0",
126}
127
128
129# ── Forward SwiGLU: colLen=2048 ───────────────────────────────────────────────
130# C++: get_tiling_data_swiglu(split_value) → SwiGluTilingData struct
131
132def _make_swiglu(is32b, idbl, row_len, col_len, base_row_len, base_col_len,
133 act_left, bias_empty, qs_empty, as_empty,
134 swi_col_len, per_row_len, mod_row_len, used_core_num):
135 """Build a SwiGluTilingDataC struct with the given tile parameters."""
136 td = SwiGluTilingDataC()
137 td.is32BAligned = is32b
138 td.isDoubleBuffer = idbl
139 td.rowLen = row_len
140 td.colLen = col_len
141 td.baseRowLen = base_row_len
142 td.baseColLen = base_col_len
143 td.activateLeft = act_left
144 td.biasIsEmpty = bias_empty
145 td.quantScaleIsEmpty = qs_empty
146 td.activateScaleIsEmpty = as_empty
147 td.swiColLen = swi_col_len
148 td.perRowLen = per_row_len
149 td.modRowLen = mod_row_len
150 td.usedCoreNum = used_core_num
151 return td
152
153
154SWIGLU_TABLE = {
155 16: _make_swiglu(1,1,16, 2048, 16, 512, 0,0,0,0, 0,0,0,0),
156 32: _make_swiglu(1,1,32, 2048, 32, 256, 0,0,0,0, 0,0,0,0),
157 64: _make_swiglu(1,1,64, 2048, 19, 512, 0,0,0,0, 0,0,0,0),
158 128: _make_swiglu(1,1,128, 2048, 19, 512, 0,0,0,0, 0,0,0,0),
159 256: _make_swiglu(1,1,256, 2048, 19, 512, 0,0,0,0, 0,0,0,0),
160 512: _make_swiglu(1,1,512, 2048, 19, 512, 0,0,0,0, 0,0,0,0),
161 1024: _make_swiglu(1,1,1024,2048, 19, 512, 0,0,0,0, 0,0,0,0),
162}
163
164
165# ── Public API ────────────────────────────────────────────────────────────────
166
167def get_up_proj_tiling_bytes(split_value: int, *,
168 hidden_size: int = 7168,
169 intermediate_size: int = 2048,
170 num_groups: int = 8,
171 num_cube_cores: int = 24) -> bytes:
172 """Forward GMM1 tiling [x,hidden]×[E,hidden,intermediate*2], expanded num_cube_cores times."""
173 if split_value not in GMM1_TABLE:
174 raise KeyError(f"GMM1 tiling: split_value={split_value} not found")
175 csv = _patch_gmm_dims(GMM1_TABLE[split_value],
176 {7168: hidden_size, 4096: intermediate_size * 2},
177 num_groups, split_value=split_value,
178 num_cube_cores=num_cube_cores)
179 return _expand_gmm_string(csv, repeat=num_cube_cores)
180
181
182def get_down_proj_tiling_bytes(split_value: int, *,
183 hidden_size: int = 7168,
184 intermediate_size: int = 2048,
185 num_groups: int = 8,
186 num_cube_cores: int = 24) -> bytes:
187 """Forward GMM2 tiling [x,intermediate]×[E,intermediate,hidden], expanded num_cube_cores times."""
188 if split_value not in GMM2_TABLE:
189 raise KeyError(f"GMM2 tiling: split_value={split_value} not found")
190 csv = _patch_gmm_dims(GMM2_TABLE[split_value],
191 {2048: intermediate_size, 7168: hidden_size},
192 num_groups, split_value=split_value,
193 num_cube_cores=num_cube_cores)
194 return _expand_gmm_string(csv, repeat=num_cube_cores)
195
196
197def get_swiglu_tiling_bytes(split_value: int, *,
198 intermediate_size: int = 2048) -> bytes:
199 """Forward SwiGLU tiling (colLen=intermediate_size), expanded 49 times."""
200 if split_value not in SWIGLU_TABLE:
201 raise KeyError(f"SwiGLU tiling: split_value={split_value} not found")
202 td = copy.copy(SWIGLU_TABLE[split_value])
203 td.colLen = intermediate_size
204 return _expand_swiglu_struct(td)