Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / multicore / modules / moe_ffn / backward / tiling_tables.py: 0%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Backward tiling lookup tables. 

17Each dict maps split_value -> CSV tiling string (for GMM ops) or 

18struct fields (for SwiGLU-grad). 

19 

20C++ source functions: 

21 GMM1 (pos 20) → get_tiling_data_first_gmm(split_value) 

22 GMM2 (pos 21) → get_tiling_data_second_gmm(split_value) 

23 GMM3 (pos 22) → get_tiling_data_third_gmm(split_value) 

24 GMM4 (pos 23) → get_tiling_data_fourth_gmm(split_value) 

25 SwiGLU-grad (pos 24) → get_tiling_data_swiglu_grad(split_value) 

26""" 

27# pylint: disable=line-too-long 

28import copy 

29import ctypes 

30 

31from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import ( 

32 SwiGluTilingDataC, 

33) 

34 

35 

36# ── helpers ─────────────────────────────────────────────────────────────────── 

37 

38def _expand_gmm_string(csv: str, repeat: int = 24) -> bytes: 

39 """Parse comma-separated uint32 string and tile it `repeat` times.""" 

40 vals = [int(x) for x in csv.split(',')] 

41 arr = (ctypes.c_uint32 * (len(vals) * repeat))() 

42 for i in range(len(vals) * repeat): 

43 arr[i] = vals[i % len(vals)] 

44 return bytes(arr) 

45 

46 

47def _expand_swiglu_struct(td: SwiGluTilingDataC) -> bytes: 

48 """Tile SwiGluTilingDataC 49 times (= 2 * 24 AI Cube cores + 1).""" 

49 return bytes(td) * 49 

50 

51 

52def _find_protected_positions(vals, active_map, split_value, is_weight_grad): 

53 """Return index positions that must not be replaced by dim_map substitution. 

54 

55 When split_value collides with a K/N dim in active_map, the M-tile or 

56 K-tile positions (depending on GMM type) must be shielded. 

57 T is located via the 393216 anchor which sits at T+19. 

58 """ 

59 if split_value is None or split_value not in active_map: 

60 return set() 

61 for i in range(350, len(vals)): 

62 if vals[i] == 393216: # anchor always at t_base+19 

63 t_base = i - 19 

64 if is_weight_grad: 

65 return {t_base + 3, t_base + 4, t_base + 7} 

66 return {14, t_base + 1, t_base + 5} 

67 return set() 

68 

69 

70def _patch_gmm_dims(csv: str, dim_map: dict, num_groups: int, 

71 default_groups: int = 8, 

72 split_value: int = None, 

73 num_cube_cores: int = 24, 

74 default_num_cores: int = 24) -> str: 

75 """Replace matrix dimension values in a GMM tiling CSV string. 

76 

77 dim_map: {old_int_value: new_int_value} for K/N dimension replacements. 

78 Entries where old == new are no-ops and can be omitted. 

79 num_groups: replacement for default_groups; only applied at index > 350. 

80 split_value: M tile size (table key). When it equals a K/N value in dim_map, 

81 we must protect M-tile positions from accidental replacement. 

82 Two GMM types detected via vals[10]: 

83 0 → activation GMM: M tile at {14, T+1, T+5} 

84 2 → weight-grad GMM: split (K-tile) at {T+3, T+4, T+7} 

85 T is found via the 393216 anchor fixed at T+19. 

86 num_cube_cores: number of AI Cube cores on the target hardware (default 24 = 910B). 

87 Replaces the hardcoded usedCoreNum/blockDim fields in the CSV: 

88 vals[1] — header blockDim field 

89 i > 350 — tail usedCoreNum and any derived parallelism fields 

90 """ 

91 active_map = {k: v for k, v in dim_map.items() if k != v} 

92 groups_changed = num_groups != default_groups 

93 cores_changed = num_cube_cores != default_num_cores 

94 if not active_map and not groups_changed and not cores_changed: 

95 return csv # fast path: nothing to do 

96 

97 vals = list(map(int, csv.split(','))) 

98 is_weight_grad = vals[10] == 2 

99 

100 # Header fields (set before main loop, not touched in loop) 

101 if groups_changed and vals[0] == default_groups: 

102 vals[0] = num_groups # vals[0] = groupNum 

103 if cores_changed and vals[1] == default_num_cores: 

104 vals[1] = num_cube_cores # vals[1] = blockDim 

105 

106 protected = _find_protected_positions(vals, active_map, split_value, is_weight_grad) 

107 

108 for i, v in enumerate(vals): 

109 if i in protected: 

110 continue 

111 if v in active_map: 

112 vals[i] = active_map[v] 

113 elif cores_changed and i > 350 and v == default_num_cores: 

114 vals[i] = num_cube_cores # covers T+0 (usedCoreNum) and derived fields 

115 return ','.join(map(str, vals)) 

116 

117 

118def _make_swiglu(row_len, col_len, base_row_len, base_col_len): 

119 """Build a SwiGluTilingDataC for SwiGLU-grad with the given tile dimensions.""" 

120 td = SwiGluTilingDataC() 

121 td.is32BAligned = 1 

122 td.isDoubleBuffer = 1 

123 td.rowLen = row_len 

124 td.colLen = col_len 

125 td.baseRowLen = base_row_len 

126 td.baseColLen = base_col_len 

127 td.activateLeft = 0 

128 td.biasIsEmpty = 0 

129 td.quantScaleIsEmpty = 0 

130 td.activateScaleIsEmpty = 0 

131 td.swiColLen = 0 

132 td.perRowLen = 0 

133 td.modRowLen = 0 

134 td.usedCoreNum = 0 

135 return td 

136 

137 

138# ── Backward GMM1: x=[per_rank_seq,7168], weight=[E,2048,7168] ─────────────── 

139# C++: get_tiling_data_first_gmm(split_value) 

140# x input [*, 7168] → y output [*, 2048] 

141 

142FIRST_GMM_TABLE = { 

143 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,2048,7168,7168,4096,256,7168,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 

144} 

145 

146 

147# ── Backward GMM2: x=[per_rank_seq,4096], weight=[E,7168,4096] ─────────────── 

148# C++: get_tiling_data_second_gmm(split_value) 

149# x input [*, 4096] → y output [*, 7168] 

150 

151SECOND_GMM_TABLE = { 

152 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,7168,4096,4096,4096,256,4096,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 

153} 

154 

155 

156# ── Backward GMM3: x1^T=[sre,7168], x2=[sre,4096], weight grad ─────────────── 

157# C++: get_tiling_data_third_gmm(split_value) 

158 

159THIRD_GMM_TABLE = { 

160 4096: "8,24,0,0,0,0,0,1,1,1,2,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,7168,4096,4096,4096,7168,256,4096,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 

161} 

162 

163 

164# ── Backward GMM4: x1^T=[sre,2048], x2=[sre,7168], weight grad ─────────────── 

165# C++: get_tiling_data_fourth_gmm(split_value) 

166 

167FOURTH_GMM_TABLE = { 

168 4096: "8,24,0,0,0,0,0,1,1,1,2,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,2048,7168,4096,4096,2048,256,4096,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 

169} 

170 

171 

172# ── Backward SwiGLU-grad tiling ─────────────────────────────────────────────── 

173# C++: get_tiling_data_swiglu_grad(split_value) → baseColLen=256 

174 

175SWIGLU_GRAD_TABLE = { 

176 32: _make_swiglu(row_len=32, col_len=2048, base_row_len=19, base_col_len=256), 

177 64: _make_swiglu(row_len=64, col_len=2048, base_row_len=19, base_col_len=256), 

178 128: _make_swiglu(row_len=128, col_len=2048, base_row_len=19, base_col_len=256), 

179} 

180 

181 

182# ── Public API ──────────────────────────────────────────────────────────────── 

183 

184def get_act_grad_tiling_bytes(split_value: int, *, 

185 hidden_size: int = 7168, 

186 intermediate_size: int = 2048, 

187 num_groups: int = 8, 

188 num_cube_cores: int = 24) -> bytes: 

189 """Backward GMM1 [x,hidden]×[E,intermediate,hidden] → y=[x,intermediate].""" 

190 csv = FIRST_GMM_TABLE.get(split_value) 

191 if csv is None: 

192 raise KeyError(f"No first_gmm tiling for split_value={split_value}") 

193 csv = _patch_gmm_dims(csv, 

194 {7168: hidden_size, 2048: intermediate_size}, 

195 num_groups, split_value=split_value, 

196 num_cube_cores=num_cube_cores) 

197 return _expand_gmm_string(csv, repeat=num_cube_cores) 

198 

199 

200def get_gate_grad_tiling_bytes(split_value: int, *, 

201 hidden_size: int = 7168, 

202 intermediate_size: int = 2048, 

203 num_groups: int = 8, 

204 num_cube_cores: int = 24) -> bytes: 

205 """Backward GMM2 [x,intermediate*2]×[E,hidden,intermediate*2] → y=[x,hidden].""" 

206 csv = SECOND_GMM_TABLE.get(split_value) 

207 if csv is None: 

208 raise KeyError(f"No second_gmm tiling for split_value={split_value}") 

209 csv = _patch_gmm_dims(csv, 

210 {4096: intermediate_size * 2, 7168: hidden_size}, 

211 num_groups, split_value=split_value, 

212 num_cube_cores=num_cube_cores) 

213 return _expand_gmm_string(csv, repeat=num_cube_cores) 

214 

215 

216def get_w2_grad_tiling_bytes(split_value: int, *, 

217 hidden_size: int = 7168, 

218 intermediate_size: int = 2048, 

219 num_groups: int = 8, 

220 num_cube_cores: int = 24) -> bytes: 

221 """Backward GMM3 weight-grad [hidden,sre]×[sre,intermediate*2] → grad=[hidden,intermediate*2].""" 

222 csv = THIRD_GMM_TABLE.get(split_value) 

223 if csv is None: 

224 raise KeyError(f"No third_gmm tiling for split_value={split_value}") 

225 csv = _patch_gmm_dims(csv, 

226 {7168: hidden_size, 4096: intermediate_size * 2}, 

227 num_groups, split_value=split_value, 

228 num_cube_cores=num_cube_cores) 

229 return _expand_gmm_string(csv, repeat=num_cube_cores) 

230 

231 

232def get_w1_grad_tiling_bytes(split_value: int, *, 

233 hidden_size: int = 7168, 

234 intermediate_size: int = 2048, 

235 num_groups: int = 8, 

236 num_cube_cores: int = 24) -> bytes: 

237 """Backward GMM4 weight-grad [intermediate,sre]×[sre,hidden] → grad=[intermediate,hidden].""" 

238 csv = FOURTH_GMM_TABLE.get(split_value) 

239 if csv is None: 

240 raise KeyError(f"No fourth_gmm tiling for split_value={split_value}") 

241 csv = _patch_gmm_dims(csv, 

242 {2048: intermediate_size, 7168: hidden_size, 

243 4096: intermediate_size * 2}, 

244 num_groups, split_value=split_value, 

245 num_cube_cores=num_cube_cores) 

246 return _expand_gmm_string(csv, repeat=num_cube_cores) 

247 

248 

249def get_swiglu_grad_tiling_bytes(split_value: int, *, 

250 intermediate_size: int = 2048) -> bytes: 

251 """Backward SwiGLU-grad tiling (colLen=intermediate_size).""" 

252 td = SWIGLU_GRAD_TABLE.get(split_value) 

253 if td is None: 

254 raise KeyError(f"No swiglu_grad tiling for split_value={split_value}") 

255 td = copy.copy(td) 

256 td.colLen = intermediate_size 

257 return _expand_swiglu_struct(td)