Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_lightning_indexer.py: 87%

93 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for lightning_indexer operator.""" 

16import copy 

17from typing import Callable, Optional, Tuple 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from hyper_parallel.platform import get_platform 

21from .parallel_ops import DistributedOp 

22from .parallel_npu_dense_lightning_indexer_softmax_lse import ( 

23 _adjust_bsnd_key, 

24 _adjust_tnd_seq_lens, 

25) 

26 

27platform = get_platform() 

28 

29_MAX_INT64 = 9223372036854775807 

30 

31# Maps layout_str -> tensor role -> {dim_index: dim_label} for replicated-dim checks. 

32# 'q' = query, 'k' = key, 'w' = weights. 

33_REPLICATED_DIMS = { 

34 'BSND': { 

35 'q': {2: 'N1', 3: 'D'}, 

36 'k': {1: 'S2', 2: 'N2', 3: 'D'}, 

37 'w': {2: 'N1'}, 

38 }, 

39 'TND': { 

40 'q': {1: 'N1', 2: 'D'}, 

41 'k': {1: 'N2', 2: 'D'}, 

42 'w': {1: 'N1'}, 

43 }, 

44} 

45 

46 

47def _normalize_lightning_indexer_args( 

48 query, 

49 key, 

50 weights, 

51 actual_seq_lengths_query=None, 

52 actual_seq_lengths_key=None, 

53 block_table=None, 

54 layout_query='BSND', 

55 layout_key='BSND', 

56 sparse_count=2048, 

57 sparse_mode=3, 

58 pre_tokens=_MAX_INT64, 

59 next_tokens=_MAX_INT64, 

60 return_value=False): 

61 """Normalize positional and keyword arguments into a canonical positional tuple. 

62 

63 Args: 

64 query: Query tensor. 

65 key: Key tensor. 

66 weights: Weight tensor. 

67 actual_seq_lengths_query: Cumulative query sequence lengths (TND only). 

68 actual_seq_lengths_key: Cumulative key sequence lengths (TND only). 

69 block_table: Block table for PageAttention (optional). 

70 layout_query: Input layout string for query, 'BSND' or 'TND'. 

71 layout_key: Input layout string for key, 'BSND', 'TND', or 'PA_BSND'. 

72 sparse_count: Number of top-k blocks to retain. 

73 sparse_mode: Sparse attention mode (0=defaultMask, 3=rightDownCausal). 

74 pre_tokens: Sparse pre-tokens count. 

75 next_tokens: Sparse next-tokens count. 

76 return_value: Whether to output sparse_values. 

77 

78 Returns: 

79 tuple: (positional_args_tuple, keyword_args_dict) 

80 """ 

81 local_args = (query, key, weights) 

82 local_kwargs = { 

83 'actual_seq_lengths_query': actual_seq_lengths_query, 

84 'actual_seq_lengths_key': actual_seq_lengths_key, 

85 'block_table': block_table, 

86 'layout_query': layout_query, 

87 'layout_key': layout_key, 

88 'sparse_count': sparse_count, 

89 'sparse_mode': sparse_mode, 

90 'pre_tokens': pre_tokens, 

91 'next_tokens': next_tokens, 

92 'return_value': return_value, 

93 } 

94 return local_args, local_kwargs 

95 

96 

97class LightningIndexerDistributedOp(DistributedOp): 

98 """Distributed operator for MindSpore built-in lightning_indexer. 

99 

100 LightningIndexer computes the top-k most relevant key positions for each query token 

101 in sparse attention. It is a MindSpore built-in op (accessed via 

102 ``ops.lightning_indexer``), not a custom op, so only the distributed sharding 

103 logic is implemented here. 

104 

105 Supports BSND and TND input layouts on both MindSpore and PyTorch platforms. 

106 

107 Output shapes: 

108 - BSND: query (B, S1, N1, D) → outputs (B, S1, N2, sparse_count) 

109 - TND: query (T1, N1, D) → outputs (T1, N2, sparse_count) 

110 

111 Context parallelism (CP) is handled in ``get_expand_impl``: 

112 - BSND+CP: key S2 is sliced to the causal window for each rank. 

113 - TND+CP: actual_seq_qlen / actual_seq_klen are adjusted per rank. 

114 

115 """ 

116 

117 @staticmethod 

118 def _infer_output_layout(q_layout: Layout, layout_str: str) -> Layout: 

119 """Build the output layout for both sparse outputs from the query layout. 

120 

121 BSND: input (B, S1, N1, D) → output (B, S1, N2, sparse_count) 

122 tensor_map: (q_tm[0], q_tm[1], -1, -1) 

123 TND: input (T1, N1, D) → output (T1, N2, sparse_count) 

124 tensor_map: (q_tm[0], -1, -1) 

125 

126 N2 is always replicated (key's head dimension constraint). 

127 sparse_count is always replicated (int scalar attribute). 

128 

129 Args: 

130 q_layout: Layout of the query input. 

131 layout_str: 'BSND' or 'TND'. 

132 

133 Returns: 

134 Layout for the output tensors. 

135 """ 

136 q_tm = q_layout.tensor_map 

137 out_layout = Layout.from_device_mesh(q_layout.mesh) 

138 if layout_str == 'BSND': 

139 out_tm = (q_tm[0], q_tm[1], -1, -1) 

140 else: 

141 out_tm = (q_tm[0], -1, -1) 

142 out_layout.set_tensor_map(out_tm) 

143 out_layout.tensor_map_to_placement() 

144 return out_layout 

145 

146 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

147 """Extract local tensors and build the layout cache. 

148 

149 Args: 

150 args: Positional arguments (may contain DTensors). 

151 kwargs: Keyword arguments. 

152 

153 Returns: 

154 tuple: (local_args, local_kwargs, cache_values) where cache_values is 

155 [q_layout, k_layout, w_layout, layout_str]. 

156 """ 

157 norm_args, local_kwargs = _normalize_lightning_indexer_args(*args, **kwargs) 

158 

159 query_index, key_index, weights = norm_args[0], norm_args[1], norm_args[2] 

160 layout_str = local_kwargs['layout_query'] # layout_query 

161 

162 qlen_kw = local_kwargs.get('actual_seq_lengths_query') 

163 klen_kw = local_kwargs.get('actual_seq_lengths_key') 

164 if qlen_kw is not None: 

165 local_kwargs['actual_seq_lengths_query'] = qlen_kw.to_local() 

166 if klen_kw is not None: 

167 local_kwargs['actual_seq_lengths_key'] = klen_kw.to_local() 

168 

169 local_args = (query_index.to_local(), key_index.to_local(), weights.to_local()) 

170 

171 cache_values = [query_index.layout, key_index.layout, weights.layout, layout_str] 

172 return local_args, local_kwargs, cache_values 

173 

174 @staticmethod 

175 def _validate_input_layouts( 

176 q_layout: Layout, 

177 k_layout: Layout, 

178 w_layout: Layout, 

179 layout_str: str, 

180 ) -> None: 

181 """Validate sharding constraints for all input tensors. 

182 

183 BSND rules (query/key/weights shapes: (B,S1,N1,D) / (B,S2,N2,D) / (B,S1,N1)): 

184 - N1 (dim 2) and D (dim 3) of query must be replicated. 

185 - S2 (dim 1), N2 (dim 2), D (dim 3) of key must be replicated. 

186 - B sharding of query and key must be identical. 

187 - B and S1 sharding of weights must match query; N1 must be replicated. 

188 

189 TND rules (query/key/weights shapes: (T1,N1,D) / (T2,N2,D) / (T1,N1)): 

190 - N1 (dim 1) and D (dim 2) of query must be replicated. 

191 - N2 (dim 1) and D (dim 2) of key must be replicated. 

192 - T1 sharding of weights must match query; N1 must be replicated. 

193 

194 Args: 

195 q_layout: Layout of query. 

196 k_layout: Layout of key. 

197 w_layout: Layout of weights. 

198 layout_str: 'BSND' or 'TND'. 

199 

200 Raises: 

201 ValueError: If any constraint is violated. 

202 """ 

203 op = "lightning_indexer" 

204 q_tm = q_layout.tensor_map 

205 k_tm = k_layout.tensor_map 

206 w_tm = w_layout.tensor_map 

207 tms = {'q': (q_tm, 'query'), 'k': (k_tm, 'key'), 'w': (w_tm, 'weights')} 

208 for role, dims in _REPLICATED_DIMS.get(layout_str, {}).items(): 

209 tm, tensor_name = tms[role] 

210 for dim, label in dims.items(): 

211 if tm[dim] != -1: 

212 raise ValueError( 

213 f"For {op}, {label} (dim {dim}) of {tensor_name} should be replicated, " 

214 f"but got tensor_map={tm}" 

215 ) 

216 if layout_str == 'BSND': 

217 if q_tm[0] != k_tm[0]: 

218 raise ValueError( 

219 f"For {op}, B (dim 0) sharding of query and key should match, " 

220 f"but got query={q_tm[0]}, key={k_tm[0]}" 

221 ) 

222 if w_tm[0] != q_tm[0]: 

223 raise ValueError( 

224 f"For {op}, B (dim 0) sharding of weights should match query, " 

225 f"but got weights={w_tm[0]}, query={q_tm[0]}" 

226 ) 

227 if w_tm[1] != q_tm[1]: 

228 raise ValueError( 

229 f"For {op}, S1 (dim 1) sharding of weights should match query, " 

230 f"but got weights={w_tm[1]}, query={q_tm[1]}" 

231 ) 

232 else: # TND 

233 if w_tm[0] != q_tm[0]: 

234 raise ValueError( 

235 f"For {op}, T1 (dim 0) sharding of weights should match query, " 

236 f"but got weights={w_tm[0]}, query={q_tm[0]}" 

237 ) 

238 

239 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

240 """Infer output layouts for sparse_indices and sparse_values outputs. 

241 

242 Rules: 

243 1. No Partial inputs are allowed on any of the three input tensors. 

244 2. Input sharding constraints are validated per layout_str (see 

245 ``_validate_input_layouts`` for the full rule set). 

246 3. Output tensor shape depends on layout_str: 

247 - BSND: query (B, S1, N1, D) → outputs (B, S1, N2, sparse_count). 

248 B and S1 sharding are inherited from query; 

249 N2 and sparse_count are always replicated. 

250 - TND: query (T1, N1, D) → outputs (T1, N2, sparse_count). 

251 T1 sharding is inherited from query; 

252 N2 and sparse_count are always replicated. 

253 4. Both sparse_indices and sparse_values outputs share the same layout 

254 (independent deep copies so callers can mutate them safely). 

255 

256 Args: 

257 cache_values: [q_layout, k_layout, w_layout, layout_str] 

258 

259 Returns: 

260 tuple: ((indices_layout, values_layout), None) 

261 

262 Raises: 

263 ValueError: If any input has Partial status, or sharding constraints 

264 are violated. 

265 """ 

266 q_layout = cache_values[0] 

267 k_layout = cache_values[1] 

268 w_layout = cache_values[2] 

269 layout_str = cache_values[3] 

270 

271 self._check_partial_inputs([q_layout, k_layout, w_layout]) 

272 self._validate_input_layouts(q_layout, k_layout, w_layout, layout_str) 

273 

274 out_layout = self._infer_output_layout(q_layout, layout_str) 

275 return (out_layout, copy.deepcopy(out_layout)), None 

276 

277 def get_expand_impl( # pylint: disable=W0237 

278 self, 

279 func: Optional[Callable], 

280 infer_result: tuple, 

281 cache_values: list, 

282 extra_args: Optional[tuple] = None, 

283 ) -> Optional[Callable]: 

284 """Return a custom callable if context-parallel adjustments are needed. 

285 

286 BSND+CP: wraps ``func`` to slice key's S2 to the causal window. 

287 TND+CP: wraps ``func`` to adjust actual_seq_qlen/klen per rank. 

288 No CP: returns None (dispatcher calls ``func`` directly). 

289 

290 Args: 

291 func: The underlying op callable. 

292 infer_result: Output from ``infer_layout``. 

293 cache_values: [q_layout, k_layout, w_layout, layout_str]. 

294 extra_args: Unused; kept for interface compatibility. 

295 

296 Returns: 

297 Callable wrapper or None. 

298 """ 

299 q_layout = cache_values[0] 

300 k_layout = cache_values[1] 

301 layout_str = cache_values[3] 

302 

303 if layout_str == 'BSND': 

304 # S1 is dim 1 of query; if not sharded, no CP adjustment needed. 

305 if q_layout.tensor_map[1] == -1: 

306 return None 

307 split_id = q_layout.get_split_id(1) 

308 

309 def _bsnd_cp_impl(*args, **kwargs): 

310 local_q, local_k = args[0], args[1] 

311 sliced_k = _adjust_bsnd_key(local_k, local_q.shape[1], split_id) 

312 return func(local_q, sliced_k, *args[2:], **kwargs) 

313 

314 return _bsnd_cp_impl 

315 

316 # TND: DP always requires seq_len adjustment; CP additionally 

317 # requires token-level offset adjustment. 

318 dp_size = k_layout.get_dim_split_num(0) # DP splits on k's T2 

319 split_id = q_layout.get_split_id(0) 

320 cp_size = (q_layout.get_dim_split_num(0) // dp_size 

321 if dp_size > 0 else 1) 

322 cp_rank = split_id % cp_size if cp_size > 1 else 0 

323 

324 def _tnd_impl(*args, **kwargs): 

325 local_q, local_k = args[0], args[1] 

326 

327 qlen_tensor = kwargs.get('actual_seq_lengths_query') 

328 klen_tensor = kwargs.get('actual_seq_lengths_key') 

329 

330 if qlen_tensor is None or klen_tensor is None: 

331 return func(*args, **kwargs) 

332 

333 adj_q, adj_k = _adjust_tnd_seq_lens( 

334 local_q, local_k, qlen_tensor, klen_tensor, 

335 cp_rank=cp_rank, 

336 ) 

337 

338 return func(*args, **{**kwargs, 'actual_seq_lengths_query': adj_q, 

339 'actual_seq_lengths_key': adj_k}) 

340 

341 return _tnd_impl