Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / multicore / modules / moe_ffn / forward / tiling_tables.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Forward tiling lookup tables. 

17Each function mirrors the C++ get_tiling_data_* functions. 

18String-based GMM tiling is expanded num_cube_cores times (default 24); 

19struct-based SwiGLU tiling is expanded 49 times (= 2 * 24 + 1 for Ascend 910B). 

20""" 

21# pylint: disable=line-too-long 

22import copy 

23import ctypes 

24 

25from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import ( 

26 SwiGluTilingDataC, 

27) 

28 

29 

30# ── helpers ─────────────────────────────────────────────────────────────────── 

31 

32def _expand_gmm_string(csv: str, repeat: int = 24) -> bytes: 

33 """Parse comma-separated uint32 string and tile it `repeat` times.""" 

34 vals = [int(x) for x in csv.split(',')] 

35 arr = (ctypes.c_uint32 * (len(vals) * repeat))() 

36 for i in range(len(vals) * repeat): 

37 arr[i] = vals[i % len(vals)] 

38 return bytes(arr) 

39 

40 

41def _expand_swiglu_struct(td: SwiGluTilingDataC) -> bytes: 

42 """Tile SwiGluTilingDataC 49 times (= 2 * 24 AI Cube cores + 1).""" 

43 one = bytes(td) 

44 return one * 49 

45 

46 

47def _find_protected_positions(vals, active_map, split_value, is_weight_grad): 

48 """Return index positions that must not be replaced by dim_map substitution. 

49 

50 When split_value collides with a K/N dim in active_map, the M-tile or 

51 K-tile positions (depending on GMM type) must be shielded. 

52 T is located via the 393216 anchor which sits at T+19. 

53 """ 

54 if split_value is None or split_value not in active_map: 

55 return set() 

56 for i in range(350, len(vals)): 

57 if vals[i] == 393216: # anchor always at t_base+19 

58 t_base = i - 19 

59 if is_weight_grad: 

60 return {t_base + 3, t_base + 4, t_base + 7} 

61 return {14, t_base + 1, t_base + 5} 

62 return set() 

63 

64 

65def _patch_gmm_dims(csv: str, dim_map: dict, num_groups: int, 

66 default_groups: int = 8, 

67 split_value: int = None, 

68 num_cube_cores: int = 24, 

69 default_num_cores: int = 24) -> str: 

70 """Replace matrix dimension values in a GMM tiling CSV string. 

71 

72 dim_map: {old_int_value: new_int_value} for K/N dimension replacements. 

73 Entries where old == new are no-ops and can be omitted. 

74 num_groups: replacement for default_groups; only applied at index > 350. 

75 split_value: M tile size (table key). When it equals a K/N value in dim_map, 

76 we must protect M-tile positions from accidental replacement. 

77 Two GMM types detected via vals[10]: 

78 0 → activation GMM: M tile at {14, T+1, T+5} 

79 2 → weight-grad GMM: split (K-tile) at {T+3, T+4, T+7} 

80 T is found via the 393216 anchor fixed at T+19. 

81 num_cube_cores: number of AI Cube cores on the target hardware (default 24 = 910B). 

82 Replaces the hardcoded usedCoreNum/blockDim fields in the CSV: 

83 vals[1] — header blockDim field 

84 i > 350 — tail usedCoreNum and any derived parallelism fields 

85 """ 

86 active_map = {k: v for k, v in dim_map.items() if k != v} 

87 groups_changed = num_groups != default_groups 

88 cores_changed = num_cube_cores != default_num_cores 

89 if not active_map and not groups_changed and not cores_changed: 

90 return csv # fast path: nothing to do 

91 

92 vals = list(map(int, csv.split(','))) 

93 is_weight_grad = vals[10] == 2 

94 

95 # Header fields (set before main loop, not touched in loop) 

96 if groups_changed and vals[0] == default_groups: 

97 vals[0] = num_groups # vals[0] = groupNum 

98 if cores_changed and vals[1] == default_num_cores: 

99 vals[1] = num_cube_cores # vals[1] = blockDim 

100 

101 protected = _find_protected_positions(vals, active_map, split_value, is_weight_grad) 

102 

103 for i, v in enumerate(vals): 

104 if i in protected: 

105 continue 

106 if v in active_map: 

107 vals[i] = active_map[v] 

108 elif cores_changed and i > 350 and v == default_num_cores: 

109 vals[i] = num_cube_cores # covers T+0 (usedCoreNum) and derived fields 

110 return ','.join(map(str, vals)) 

111 

112 

113# ── Forward GMM1: x=[per_rank_seq,7168], weight=[E,7168,4096] ───────────────── 

114# C++: get_tiling_data_gmm_x_7168_g2(split_value) 

115 

116GMM1_TABLE = { 

117 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,4096,7168,7168,4096,256,7168,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 

118} 

119 

120 

121# ── Forward GMM2: x=[per_rank_seq,2048], weight=[E,2048,7168] ───────────────── 

122# C++: get_tiling_data_gmm_x_2048_g2(split_value) 

123 

124GMM2_TABLE = { 

125 4096: "8,24,0,0,0,0,0,1,1,1,0,0,0,0,4096,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4294967295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,4096,7168,2048,2048,4096,256,2048,128,256,64,8,8,1,1,0,0,0,0,393216,131072,0,1,1,1,1,4,4,0,0,2,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 

126} 

127 

128 

129# ── Forward SwiGLU: colLen=2048 ─────────────────────────────────────────────── 

130# C++: get_tiling_data_swiglu(split_value) → SwiGluTilingData struct 

131 

132def _make_swiglu(is32b, idbl, row_len, col_len, base_row_len, base_col_len, 

133 act_left, bias_empty, qs_empty, as_empty, 

134 swi_col_len, per_row_len, mod_row_len, used_core_num): 

135 """Build a SwiGluTilingDataC struct with the given tile parameters.""" 

136 td = SwiGluTilingDataC() 

137 td.is32BAligned = is32b 

138 td.isDoubleBuffer = idbl 

139 td.rowLen = row_len 

140 td.colLen = col_len 

141 td.baseRowLen = base_row_len 

142 td.baseColLen = base_col_len 

143 td.activateLeft = act_left 

144 td.biasIsEmpty = bias_empty 

145 td.quantScaleIsEmpty = qs_empty 

146 td.activateScaleIsEmpty = as_empty 

147 td.swiColLen = swi_col_len 

148 td.perRowLen = per_row_len 

149 td.modRowLen = mod_row_len 

150 td.usedCoreNum = used_core_num 

151 return td 

152 

153 

154SWIGLU_TABLE = { 

155 16: _make_swiglu(1,1,16, 2048, 16, 512, 0,0,0,0, 0,0,0,0), 

156 32: _make_swiglu(1,1,32, 2048, 32, 256, 0,0,0,0, 0,0,0,0), 

157 64: _make_swiglu(1,1,64, 2048, 19, 512, 0,0,0,0, 0,0,0,0), 

158 128: _make_swiglu(1,1,128, 2048, 19, 512, 0,0,0,0, 0,0,0,0), 

159 256: _make_swiglu(1,1,256, 2048, 19, 512, 0,0,0,0, 0,0,0,0), 

160 512: _make_swiglu(1,1,512, 2048, 19, 512, 0,0,0,0, 0,0,0,0), 

161 1024: _make_swiglu(1,1,1024,2048, 19, 512, 0,0,0,0, 0,0,0,0), 

162} 

163 

164 

165# ── Public API ──────────────────────────────────────────────────────────────── 

166 

167def get_up_proj_tiling_bytes(split_value: int, *, 

168 hidden_size: int = 7168, 

169 intermediate_size: int = 2048, 

170 num_groups: int = 8, 

171 num_cube_cores: int = 24) -> bytes: 

172 """Forward GMM1 tiling [x,hidden]×[E,hidden,intermediate*2], expanded num_cube_cores times.""" 

173 if split_value not in GMM1_TABLE: 

174 raise KeyError(f"GMM1 tiling: split_value={split_value} not found") 

175 csv = _patch_gmm_dims(GMM1_TABLE[split_value], 

176 {7168: hidden_size, 4096: intermediate_size * 2}, 

177 num_groups, split_value=split_value, 

178 num_cube_cores=num_cube_cores) 

179 return _expand_gmm_string(csv, repeat=num_cube_cores) 

180 

181 

182def get_down_proj_tiling_bytes(split_value: int, *, 

183 hidden_size: int = 7168, 

184 intermediate_size: int = 2048, 

185 num_groups: int = 8, 

186 num_cube_cores: int = 24) -> bytes: 

187 """Forward GMM2 tiling [x,intermediate]×[E,intermediate,hidden], expanded num_cube_cores times.""" 

188 if split_value not in GMM2_TABLE: 

189 raise KeyError(f"GMM2 tiling: split_value={split_value} not found") 

190 csv = _patch_gmm_dims(GMM2_TABLE[split_value], 

191 {2048: intermediate_size, 7168: hidden_size}, 

192 num_groups, split_value=split_value, 

193 num_cube_cores=num_cube_cores) 

194 return _expand_gmm_string(csv, repeat=num_cube_cores) 

195 

196 

197def get_swiglu_tiling_bytes(split_value: int, *, 

198 intermediate_size: int = 2048) -> bytes: 

199 """Forward SwiGLU tiling (colLen=intermediate_size), expanded 49 times.""" 

200 if split_value not in SWIGLU_TABLE: 

201 raise KeyError(f"SwiGLU tiling: split_value={split_value} not found") 

202 td = copy.copy(SWIGLU_TABLE[split_value]) 

203 td.colLen = intermediate_size 

204 return _expand_swiglu_struct(td)