Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_npu_sparse_flash_attention.py: 70%

107 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for npu_sparse_flash_attention operator.""" 

16import copy 

17from typing import Callable, Optional, Tuple 

18 

19from hyper_parallel.core.dtensor.dtensor import DTensor 

20from hyper_parallel.core.dtensor.layout import Layout 

21from .parallel_ops import DistributedOp 

22from .parallel_npu_dense_lightning_indexer_softmax_lse import _adjust_bsnd_key, _adjust_tnd_seq_lens 

23 

24_MAX_INT64 = 9223372036854775807 

25 

26# Maps layout_str -> tensor role -> {dim_index: dim_label} for replicated-dim checks. 

27# 'q' = query, 'k' = key, 'v' = value, 'si' = sparse_indices. 

28# N1 (head num of query) is forbidden from sharding due to severe performance impact. 

29_REPLICATED_DIMS = { 

30 'BSND': { 

31 'q': {2: 'N1', 3: 'D'}, 

32 'k': {1: 'S2', 2: 'N2', 3: 'D'}, 

33 'v': {1: 'S2', 2: 'N2', 3: 'D'}, 

34 'si': {2: 'N2', 3: 'sparse_size'}, 

35 }, 

36 'TND': { 

37 'q': {1: 'N1', 2: 'D'}, 

38 'k': {1: 'N2', 2: 'D'}, 

39 'v': {1: 'N2', 2: 'D'}, 

40 'si': {1: 'N2', 2: 'sparse_size'}, 

41 }, 

42} 

43 

44 

45def _normalize_sfa_args( 

46 query, 

47 key, 

48 value, 

49 sparse_indices, 

50 scale_value, 

51 block_table=None, 

52 actual_seq_lengths_query=None, 

53 actual_seq_lengths_kv=None, 

54 query_rope=None, 

55 key_rope=None, 

56 sparse_block_size=1, 

57 layout_query='BSND', 

58 layout_kv='BSND', 

59 sparse_mode=3, 

60 pre_tokens=_MAX_INT64, 

61 next_tokens=_MAX_INT64, 

62 attention_mode=2, 

63 return_softmax_lse=False): 

64 """Normalize positional and keyword arguments into a canonical positional tuple. 

65 

66 Args: 

67 query: Query tensor. 

68 key: Key tensor. 

69 value: Value tensor. 

70 sparse_indices: Sparse index tensor (int32). 

71 scale_value: Scaling factor (float). 

72 block_table: Optional PageAttention block mapping table. 

73 actual_seq_lengths_query: Actual query sequence lengths per batch. 

74 actual_seq_lengths_kv: Actual KV sequence lengths per batch. 

75 query_rope: Optional MLA query rope tensor. 

76 key_rope: Optional MLA key rope tensor. 

77 sparse_block_size: Block size for sparse computation. 

78 layout_query: Query layout string ('BSND' or 'TND'). 

79 layout_kv: KV layout string ('BSND', 'TND', or 'PA_BSND'). 

80 sparse_mode: Sparse attention mode. 

81 pre_tokens: Preceding token window size. 

82 next_tokens: Following token window size. 

83 attention_mode: Attention mode (0 or 2 for MLA-absorb). 

84 return_softmax_lse: Whether to return softmax max/sum. 

85 

86 Returns: 

87 tuple: (positional_args_tuple, empty_kwargs_dict) 

88 """ 

89 return ( 

90 query, key, value, sparse_indices, scale_value, 

91 block_table, actual_seq_lengths_query, actual_seq_lengths_kv, 

92 query_rope, key_rope, sparse_block_size, 

93 layout_query, layout_kv, sparse_mode, 

94 pre_tokens, next_tokens, attention_mode, return_softmax_lse, 

95 ), {} 

96 

97 

98def _to_local(t): 

99 """Extract local tensor from DTensor, or pass through non-DTensor values.""" 

100 if isinstance(t, DTensor): 

101 return t.to_local() 

102 return t 

103 

104 

105class SparseFlashAttentionDistributedOp(DistributedOp): 

106 """Distributed operator for npu_sparse_flash_attention. 

107 

108 Supports BSND and TND input layouts on both MindSpore 

109 and PyTorch / torch_npu backends. 

110 

111 Both frameworks provide built-in forward and backward implementations; 

112 this class handles only the distributed dispatch (layout inference and 

113 optional TND+CP sequence-length adjustment). 

114 

115 Output shapes relative to inputs: 

116 - BSND: query (B, S1, N1, D) → attention_out (B, S1, N1, D), 

117 softmax_max/sum (B, N2, S1, N1/N2) 

118 - TND: query (T1, N1, D) → attention_out (T1, N1, D), 

119 softmax_max/sum (N2, T1, N1/N2) 

120 

121 Sharding constraints: 

122 - N1 (query head dim) must be replicated — TP on this dim is forbidden 

123 due to severe performance impact. 

124 - Key/value S2 (or T2), N2, and D dims must be replicated. 

125 - sparse_indices N2 and sparse_size dims must be replicated. 

126 - PA_BSND layout is not supported in distributed mode. 

127 

128 Context parallelism: 

129 - BSND+CP: k, v, and key_rope are sliced to the causal window 

130 ``[:, :S1_local*(split_id+1), :, :]`` before calling the kernel, matching the 

131 MindFormers adjust_bsnd_input logic. sparse_indices from lightning_indexer are 

132 generated with the same truncation, so they remain valid for the sliced k. 

133 - TND+CP: adjusts actual_seq_lengths_query/kv per rank using 

134 _adjust_tnd_seq_lens (same logic as dsa_attention.py). 

135 """ 

136 

137 @staticmethod 

138 def _infer_softmax_layout(q_layout: Layout, layout_str: str) -> Layout: 

139 """Build the output layout for softmax_max and softmax_sum. 

140 

141 BSND: query (B, S1, N1, D) → softmax (B, N2, S1, N1/N2) 

142 tensor_map: (q_tm[0], -1, q_tm[1], -1) 

143 TND: query (T1, N1, D) → softmax (N2, T1, N1/N2) 

144 tensor_map: (-1, q_tm[0], -1) 

145 

146 N2 and N1/N2 are always replicated because N2=1 and N1 is forbidden 

147 from sharding. 

148 

149 Args: 

150 q_layout: Layout of the query input. 

151 layout_str: 'BSND' or 'TND'. 

152 

153 Returns: 

154 Layout for softmax_max / softmax_sum. 

155 """ 

156 q_tm = q_layout.tensor_map 

157 out_layout = Layout.from_device_mesh(q_layout.mesh) 

158 if layout_str == 'BSND': 

159 out_tm = (q_tm[0], -1, q_tm[1], -1) 

160 else: 

161 out_tm = (-1, q_tm[0], -1) 

162 out_layout.set_tensor_map(out_tm) 

163 out_layout.tensor_map_to_placement() 

164 return out_layout 

165 

166 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

167 """Extract local tensors and build the layout cache. 

168 

169 Args: 

170 args: Positional arguments (may contain DTensors). 

171 kwargs: Keyword arguments. 

172 

173 Returns: 

174 tuple: (local_args, local_kwargs, cache_values) where 

175 local_args = (query_local, key_local, value_local, 

176 sparse_indices_local, scale_value), 

177 local_kwargs contains all remaining arguments, 

178 cache_values = [q_layout, k_layout, v_layout, si_layout, layout_query_str]. 

179 """ 

180 norm_args, _ = _normalize_sfa_args(*args, **kwargs) 

181 query = norm_args[0] 

182 key = norm_args[1] 

183 value = norm_args[2] 

184 sparse_indices = norm_args[3] 

185 scale_value = norm_args[4] 

186 layout_query_str = norm_args[11] 

187 

188 local_args = ( 

189 _to_local(query), 

190 _to_local(key), 

191 _to_local(value), 

192 _to_local(sparse_indices), 

193 scale_value, 

194 ) 

195 local_kwargs = { 

196 'block_table': _to_local(norm_args[5]), 

197 'actual_seq_lengths_query': _to_local(norm_args[6]), 

198 'actual_seq_lengths_kv': _to_local(norm_args[7]), 

199 'query_rope': _to_local(norm_args[8]), 

200 'key_rope': _to_local(norm_args[9]), 

201 'sparse_block_size': norm_args[10], 

202 'layout_query': norm_args[11], 

203 'layout_kv': norm_args[12], 

204 'sparse_mode': norm_args[13], 

205 'pre_tokens': norm_args[14], 

206 'next_tokens': norm_args[15], 

207 'attention_mode': norm_args[16], 

208 'return_softmax_lse': norm_args[17], 

209 } 

210 

211 cache_values = [ 

212 query.layout, 

213 key.layout, 

214 value.layout, 

215 sparse_indices.layout, 

216 layout_query_str, 

217 ] 

218 return local_args, local_kwargs, cache_values 

219 

220 @staticmethod 

221 def _validate_input_layouts( 

222 q_layout: Layout, 

223 k_layout: Layout, 

224 v_layout: Layout, 

225 si_layout: Layout, 

226 layout_str: str, 

227 ) -> None: 

228 """Validate sharding constraints for all input tensors. 

229 

230 BSND rules (shapes: (B,S1,N1,D) / (B,S2,N2,D) / (B,S2,N2,D) / (B,S1,N2,sparse_size)): 

231 - N1 (dim 2) and D (dim 3) of query must be replicated. 

232 - S2 (dim 1), N2 (dim 2), D (dim 3) of key and value must be replicated. 

233 - N2 (dim 2) and sparse_size (dim 3) of sparse_indices must be replicated. 

234 - B sharding of key, value, and sparse_indices must match query. 

235 - S1 sharding of sparse_indices must match query. 

236 

237 TND rules (shapes: (T1,N1,D) / (T2,N2,D) / (T2,N2,D) / (T1,N2,sparse_size)): 

238 - N1 (dim 1) and D (dim 2) of query must be replicated. 

239 - N2 (dim 1) and D (dim 2) of key and value must be replicated. 

240 - N2 (dim 1) and sparse_size (dim 2) of sparse_indices must be replicated. 

241 - T2 sharding of key and value must match. 

242 - T1 sharding of sparse_indices must match query. 

243 

244 PA_BSND is not supported in distributed mode. 

245 

246 Args: 

247 q_layout: Layout of query. 

248 k_layout: Layout of key. 

249 v_layout: Layout of value. 

250 si_layout: Layout of sparse_indices. 

251 layout_str: 'BSND' or 'TND'. 

252 

253 Raises: 

254 ValueError: If layout_str is 'PA_BSND', if any required dimension is 

255 sharded, or if batch/sequence consistency constraints are violated. 

256 """ 

257 if layout_str == 'PA_BSND': 

258 raise ValueError( 

259 "For npu_sparse_flash_attention, PA_BSND layout is not supported " 

260 "in distributed mode." 

261 ) 

262 

263 op = "npu_sparse_flash_attention" 

264 q_tm = q_layout.tensor_map 

265 k_tm = k_layout.tensor_map 

266 v_tm = v_layout.tensor_map 

267 si_tm = si_layout.tensor_map 

268 tms = { 

269 'q': (q_tm, 'query'), 

270 'k': (k_tm, 'key'), 

271 'v': (v_tm, 'value'), 

272 'si': (si_tm, 'sparse_indices'), 

273 } 

274 for role, dims in _REPLICATED_DIMS.get(layout_str, {}).items(): 

275 tm, tensor_name = tms[role] 

276 for dim, label in dims.items(): 

277 if tm[dim] != -1: 

278 raise ValueError( 

279 f"For {op}, {label} (dim {dim}) of {tensor_name} should be replicated, " 

280 f"but got tensor_map={tm}" 

281 ) 

282 

283 if layout_str == 'BSND': 

284 if q_tm[0] != k_tm[0]: 

285 raise ValueError( 

286 f"For {op}, B (dim 0) sharding of key should match query, " 

287 f"but got query={q_tm[0]}, key={k_tm[0]}" 

288 ) 

289 if q_tm[0] != v_tm[0]: 

290 raise ValueError( 

291 f"For {op}, B (dim 0) sharding of value should match query, " 

292 f"but got query={q_tm[0]}, value={v_tm[0]}" 

293 ) 

294 if q_tm[0] != si_tm[0]: 

295 raise ValueError( 

296 f"For {op}, B (dim 0) sharding of sparse_indices should match query, " 

297 f"but got query={q_tm[0]}, sparse_indices={si_tm[0]}" 

298 ) 

299 if q_tm[1] != si_tm[1]: 

300 raise ValueError( 

301 f"For {op}, S1 (dim 1) sharding of sparse_indices should match query, " 

302 f"but got query={q_tm[1]}, sparse_indices={si_tm[1]}" 

303 ) 

304 else: # TND 

305 if k_tm[0] != v_tm[0]: 

306 raise ValueError( 

307 f"For {op}, T2 (dim 0) sharding of value should match key, " 

308 f"but got key={k_tm[0]}, value={v_tm[0]}" 

309 ) 

310 if q_tm[0] != si_tm[0]: 

311 raise ValueError( 

312 f"For {op}, T1 (dim 0) sharding of sparse_indices should match query, " 

313 f"but got query={q_tm[0]}, sparse_indices={si_tm[0]}" 

314 ) 

315 

316 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

317 """Infer output layouts for all three outputs. 

318 

319 Rules: 

320 1. PA_BSND layout is rejected. 

321 2. Partial inputs are not allowed on any of the four primary tensors. 

322 3. Sharding constraints are validated (see _validate_input_layouts). 

323 4. attention_out inherits query layout (deep copy). 

324 5. softmax_max and softmax_sum share the same layout derived from 

325 query layout with N2 and N1/N2 dims always replicated. 

326 6. All three output layouts are independent deep copies. 

327 

328 Args: 

329 cache_values: [q_layout, k_layout, v_layout, si_layout, layout_str] 

330 

331 Returns: 

332 tuple: ((attn_layout, softmax_max_layout, softmax_sum_layout), None) 

333 

334 Raises: 

335 ValueError: If PA_BSND layout, any input has Partial status, or 

336 sharding constraints are violated. 

337 """ 

338 q_layout = cache_values[0] 

339 k_layout = cache_values[1] 

340 v_layout = cache_values[2] 

341 si_layout = cache_values[3] 

342 layout_str = cache_values[4] 

343 

344 self._check_partial_inputs([q_layout, k_layout, v_layout, si_layout]) 

345 self._validate_input_layouts(q_layout, k_layout, v_layout, si_layout, layout_str) 

346 

347 attn_layout = copy.deepcopy(q_layout) 

348 softmax_layout = self._infer_softmax_layout(q_layout, layout_str) 

349 return (attn_layout, softmax_layout, copy.deepcopy(softmax_layout)), None 

350 

351 def get_expand_impl( # pylint: disable=W0237 

352 self, 

353 func: Optional[Callable], 

354 infer_result: tuple, 

355 cache_values: list, 

356 extra_args: Optional[tuple] = None, 

357 ) -> Optional[Callable]: 

358 """Return a custom callable if context-parallel adjustment is needed. 

359 

360 BSND (S1 not sharded): returns None — k/v are Replicated; sparse_indices 

361 reference the full k directly. 

362 BSND+CP (S1 sharded): wraps func to slice k, v, and key_rope to the 

363 causal window ``k[:, :S1_local*(split_id+1), :, :]`` before calling 

364 the kernel. Mirrors MindFormers adjust_bsnd_input logic, ensuring 

365 that sparse_indices produced by lightning_indexer (which applies the 

366 same truncation) remain valid. 

367 TND+CP: wraps func to adjust actual_seq_lengths_query/kv per rank, 

368 using the same algorithm as dsa_attention._sparse_flash_attention_forward. 

369 TND (no CP): wraps func to clamp seq_lens to local T1 slice. 

370 

371 Args: 

372 func: The underlying op callable. 

373 infer_result: Output from infer_layout. 

374 cache_values: [q_layout, k_layout, v_layout, si_layout, layout_str]. 

375 extra_args: Unused; kept for interface compatibility. 

376 

377 Returns: 

378 Callable wrapper or None. 

379 """ 

380 q_layout = cache_values[0] 

381 k_layout = cache_values[1] 

382 layout_str = cache_values[4] 

383 

384 if layout_str == 'BSND': 

385 if q_layout.tensor_map[1] == -1: 

386 # S1 not sharded: pure DP or fully replicated. 

387 # k/v are Replicate on the CP dimension, so sparse_indices reference 

388 # the full k directly; no truncation needed. 

389 return None 

390 split_id = q_layout.get_split_id(1) 

391 

392 def _bsnd_cp_impl(*args, **kwargs): 

393 local_q, local_k, local_v = args[0], args[1], args[2] 

394 s1_local = local_q.shape[1] 

395 sliced_k = _adjust_bsnd_key(local_k, s1_local, split_id) 

396 sliced_v = _adjust_bsnd_key(local_v, s1_local, split_id) 

397 key_rope = kwargs.get('key_rope') 

398 new_kwargs = ( 

399 {**kwargs, 'key_rope': _adjust_bsnd_key(key_rope, s1_local, split_id)} 

400 if key_rope is not None else kwargs 

401 ) 

402 return func(local_q, sliced_k, sliced_v, *args[3:], **new_kwargs) 

403 

404 return _bsnd_cp_impl 

405 

406 # TND: CP applies when q's T1 is sharded more finely than k's T2. 

407 q_split = q_layout.get_dim_split_num(0) 

408 k_split = k_layout.get_dim_split_num(0) 

409 split_id = q_layout.get_split_id(0) if q_split > k_split else 0 

410 cp_size = q_split // k_split if k_split > 0 else 1 

411 cp_rank = split_id % cp_size if cp_size > 1 else 0 

412 

413 def _tnd_cp_impl(*args, **kwargs): 

414 local_q, local_k = args[0], args[1] 

415 qlen_tensor = kwargs.get('actual_seq_lengths_query') 

416 klen_tensor = kwargs.get('actual_seq_lengths_kv') 

417 if qlen_tensor is None or klen_tensor is None: 

418 return func(*args, **kwargs) 

419 adj_q, adj_k = _adjust_tnd_seq_lens( 

420 local_q, local_k, qlen_tensor, klen_tensor, 

421 cp_rank=cp_rank, 

422 ) 

423 return func(*args, **{ 

424 **kwargs, 

425 'actual_seq_lengths_query': adj_q, 

426 'actual_seq_lengths_kv': adj_k, 

427 }) 

428 

429 return _tnd_cp_impl