1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for npu_sparse_flash_attention operator."""
16import copy
17from typing import Callable, Optional, Tuple
18
19from hyper_parallel.core.dtensor.dtensor import DTensor
20from hyper_parallel.core.dtensor.layout import Layout
21from .parallel_ops import DistributedOp
22from .parallel_npu_dense_lightning_indexer_softmax_lse import _adjust_bsnd_key, _adjust_tnd_seq_lens
23
24_MAX_INT64 = 9223372036854775807
25
26# Maps layout_str -> tensor role -> {dim_index: dim_label} for replicated-dim checks.
27# 'q' = query, 'k' = key, 'v' = value, 'si' = sparse_indices.
28# N1 (head num of query) is forbidden from sharding due to severe performance impact.
29_REPLICATED_DIMS = {
30 'BSND': {
31 'q': {2: 'N1', 3: 'D'},
32 'k': {1: 'S2', 2: 'N2', 3: 'D'},
33 'v': {1: 'S2', 2: 'N2', 3: 'D'},
34 'si': {2: 'N2', 3: 'sparse_size'},
35 },
36 'TND': {
37 'q': {1: 'N1', 2: 'D'},
38 'k': {1: 'N2', 2: 'D'},
39 'v': {1: 'N2', 2: 'D'},
40 'si': {1: 'N2', 2: 'sparse_size'},
41 },
42}
43
44
45def _normalize_sfa_args(
46 query,
47 key,
48 value,
49 sparse_indices,
50 scale_value,
51 block_table=None,
52 actual_seq_lengths_query=None,
53 actual_seq_lengths_kv=None,
54 query_rope=None,
55 key_rope=None,
56 sparse_block_size=1,
57 layout_query='BSND',
58 layout_kv='BSND',
59 sparse_mode=3,
60 pre_tokens=_MAX_INT64,
61 next_tokens=_MAX_INT64,
62 attention_mode=2,
63 return_softmax_lse=False):
64 """Normalize positional and keyword arguments into a canonical positional tuple.
65
66 Args:
67 query: Query tensor.
68 key: Key tensor.
69 value: Value tensor.
70 sparse_indices: Sparse index tensor (int32).
71 scale_value: Scaling factor (float).
72 block_table: Optional PageAttention block mapping table.
73 actual_seq_lengths_query: Actual query sequence lengths per batch.
74 actual_seq_lengths_kv: Actual KV sequence lengths per batch.
75 query_rope: Optional MLA query rope tensor.
76 key_rope: Optional MLA key rope tensor.
77 sparse_block_size: Block size for sparse computation.
78 layout_query: Query layout string ('BSND' or 'TND').
79 layout_kv: KV layout string ('BSND', 'TND', or 'PA_BSND').
80 sparse_mode: Sparse attention mode.
81 pre_tokens: Preceding token window size.
82 next_tokens: Following token window size.
83 attention_mode: Attention mode (0 or 2 for MLA-absorb).
84 return_softmax_lse: Whether to return softmax max/sum.
85
86 Returns:
87 tuple: (positional_args_tuple, empty_kwargs_dict)
88 """
89 return (
90 query, key, value, sparse_indices, scale_value,
91 block_table, actual_seq_lengths_query, actual_seq_lengths_kv,
92 query_rope, key_rope, sparse_block_size,
93 layout_query, layout_kv, sparse_mode,
94 pre_tokens, next_tokens, attention_mode, return_softmax_lse,
95 ), {}
96
97
98def _to_local(t):
99 """Extract local tensor from DTensor, or pass through non-DTensor values."""
100 if isinstance(t, DTensor):
101 return t.to_local()
102 return t
103
104
105class SparseFlashAttentionDistributedOp(DistributedOp):
106 """Distributed operator for npu_sparse_flash_attention.
107
108 Supports BSND and TND input layouts on both MindSpore
109 and PyTorch / torch_npu backends.
110
111 Both frameworks provide built-in forward and backward implementations;
112 this class handles only the distributed dispatch (layout inference and
113 optional TND+CP sequence-length adjustment).
114
115 Output shapes relative to inputs:
116 - BSND: query (B, S1, N1, D) → attention_out (B, S1, N1, D),
117 softmax_max/sum (B, N2, S1, N1/N2)
118 - TND: query (T1, N1, D) → attention_out (T1, N1, D),
119 softmax_max/sum (N2, T1, N1/N2)
120
121 Sharding constraints:
122 - N1 (query head dim) must be replicated — TP on this dim is forbidden
123 due to severe performance impact.
124 - Key/value S2 (or T2), N2, and D dims must be replicated.
125 - sparse_indices N2 and sparse_size dims must be replicated.
126 - PA_BSND layout is not supported in distributed mode.
127
128 Context parallelism:
129 - BSND+CP: k, v, and key_rope are sliced to the causal window
130 ``[:, :S1_local*(split_id+1), :, :]`` before calling the kernel, matching the
131 MindFormers adjust_bsnd_input logic. sparse_indices from lightning_indexer are
132 generated with the same truncation, so they remain valid for the sliced k.
133 - TND+CP: adjusts actual_seq_lengths_query/kv per rank using
134 _adjust_tnd_seq_lens (same logic as dsa_attention.py).
135 """
136
137 @staticmethod
138 def _infer_softmax_layout(q_layout: Layout, layout_str: str) -> Layout:
139 """Build the output layout for softmax_max and softmax_sum.
140
141 BSND: query (B, S1, N1, D) → softmax (B, N2, S1, N1/N2)
142 tensor_map: (q_tm[0], -1, q_tm[1], -1)
143 TND: query (T1, N1, D) → softmax (N2, T1, N1/N2)
144 tensor_map: (-1, q_tm[0], -1)
145
146 N2 and N1/N2 are always replicated because N2=1 and N1 is forbidden
147 from sharding.
148
149 Args:
150 q_layout: Layout of the query input.
151 layout_str: 'BSND' or 'TND'.
152
153 Returns:
154 Layout for softmax_max / softmax_sum.
155 """
156 q_tm = q_layout.tensor_map
157 out_layout = Layout.from_device_mesh(q_layout.mesh)
158 if layout_str == 'BSND':
159 out_tm = (q_tm[0], -1, q_tm[1], -1)
160 else:
161 out_tm = (-1, q_tm[0], -1)
162 out_layout.set_tensor_map(out_tm)
163 out_layout.tensor_map_to_placement()
164 return out_layout
165
166 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
167 """Extract local tensors and build the layout cache.
168
169 Args:
170 args: Positional arguments (may contain DTensors).
171 kwargs: Keyword arguments.
172
173 Returns:
174 tuple: (local_args, local_kwargs, cache_values) where
175 local_args = (query_local, key_local, value_local,
176 sparse_indices_local, scale_value),
177 local_kwargs contains all remaining arguments,
178 cache_values = [q_layout, k_layout, v_layout, si_layout, layout_query_str].
179 """
180 norm_args, _ = _normalize_sfa_args(*args, **kwargs)
181 query = norm_args[0]
182 key = norm_args[1]
183 value = norm_args[2]
184 sparse_indices = norm_args[3]
185 scale_value = norm_args[4]
186 layout_query_str = norm_args[11]
187
188 local_args = (
189 _to_local(query),
190 _to_local(key),
191 _to_local(value),
192 _to_local(sparse_indices),
193 scale_value,
194 )
195 local_kwargs = {
196 'block_table': _to_local(norm_args[5]),
197 'actual_seq_lengths_query': _to_local(norm_args[6]),
198 'actual_seq_lengths_kv': _to_local(norm_args[7]),
199 'query_rope': _to_local(norm_args[8]),
200 'key_rope': _to_local(norm_args[9]),
201 'sparse_block_size': norm_args[10],
202 'layout_query': norm_args[11],
203 'layout_kv': norm_args[12],
204 'sparse_mode': norm_args[13],
205 'pre_tokens': norm_args[14],
206 'next_tokens': norm_args[15],
207 'attention_mode': norm_args[16],
208 'return_softmax_lse': norm_args[17],
209 }
210
211 cache_values = [
212 query.layout,
213 key.layout,
214 value.layout,
215 sparse_indices.layout,
216 layout_query_str,
217 ]
218 return local_args, local_kwargs, cache_values
219
220 @staticmethod
221 def _validate_input_layouts(
222 q_layout: Layout,
223 k_layout: Layout,
224 v_layout: Layout,
225 si_layout: Layout,
226 layout_str: str,
227 ) -> None:
228 """Validate sharding constraints for all input tensors.
229
230 BSND rules (shapes: (B,S1,N1,D) / (B,S2,N2,D) / (B,S2,N2,D) / (B,S1,N2,sparse_size)):
231 - N1 (dim 2) and D (dim 3) of query must be replicated.
232 - S2 (dim 1), N2 (dim 2), D (dim 3) of key and value must be replicated.
233 - N2 (dim 2) and sparse_size (dim 3) of sparse_indices must be replicated.
234 - B sharding of key, value, and sparse_indices must match query.
235 - S1 sharding of sparse_indices must match query.
236
237 TND rules (shapes: (T1,N1,D) / (T2,N2,D) / (T2,N2,D) / (T1,N2,sparse_size)):
238 - N1 (dim 1) and D (dim 2) of query must be replicated.
239 - N2 (dim 1) and D (dim 2) of key and value must be replicated.
240 - N2 (dim 1) and sparse_size (dim 2) of sparse_indices must be replicated.
241 - T2 sharding of key and value must match.
242 - T1 sharding of sparse_indices must match query.
243
244 PA_BSND is not supported in distributed mode.
245
246 Args:
247 q_layout: Layout of query.
248 k_layout: Layout of key.
249 v_layout: Layout of value.
250 si_layout: Layout of sparse_indices.
251 layout_str: 'BSND' or 'TND'.
252
253 Raises:
254 ValueError: If layout_str is 'PA_BSND', if any required dimension is
255 sharded, or if batch/sequence consistency constraints are violated.
256 """
257 if layout_str == 'PA_BSND':
258 raise ValueError(
259 "For npu_sparse_flash_attention, PA_BSND layout is not supported "
260 "in distributed mode."
261 )
262
263 op = "npu_sparse_flash_attention"
264 q_tm = q_layout.tensor_map
265 k_tm = k_layout.tensor_map
266 v_tm = v_layout.tensor_map
267 si_tm = si_layout.tensor_map
268 tms = {
269 'q': (q_tm, 'query'),
270 'k': (k_tm, 'key'),
271 'v': (v_tm, 'value'),
272 'si': (si_tm, 'sparse_indices'),
273 }
274 for role, dims in _REPLICATED_DIMS.get(layout_str, {}).items():
275 tm, tensor_name = tms[role]
276 for dim, label in dims.items():
277 if tm[dim] != -1:
278 raise ValueError(
279 f"For {op}, {label} (dim {dim}) of {tensor_name} should be replicated, "
280 f"but got tensor_map={tm}"
281 )
282
283 if layout_str == 'BSND':
284 if q_tm[0] != k_tm[0]:
285 raise ValueError(
286 f"For {op}, B (dim 0) sharding of key should match query, "
287 f"but got query={q_tm[0]}, key={k_tm[0]}"
288 )
289 if q_tm[0] != v_tm[0]:
290 raise ValueError(
291 f"For {op}, B (dim 0) sharding of value should match query, "
292 f"but got query={q_tm[0]}, value={v_tm[0]}"
293 )
294 if q_tm[0] != si_tm[0]:
295 raise ValueError(
296 f"For {op}, B (dim 0) sharding of sparse_indices should match query, "
297 f"but got query={q_tm[0]}, sparse_indices={si_tm[0]}"
298 )
299 if q_tm[1] != si_tm[1]:
300 raise ValueError(
301 f"For {op}, S1 (dim 1) sharding of sparse_indices should match query, "
302 f"but got query={q_tm[1]}, sparse_indices={si_tm[1]}"
303 )
304 else: # TND
305 if k_tm[0] != v_tm[0]:
306 raise ValueError(
307 f"For {op}, T2 (dim 0) sharding of value should match key, "
308 f"but got key={k_tm[0]}, value={v_tm[0]}"
309 )
310 if q_tm[0] != si_tm[0]:
311 raise ValueError(
312 f"For {op}, T1 (dim 0) sharding of sparse_indices should match query, "
313 f"but got query={q_tm[0]}, sparse_indices={si_tm[0]}"
314 )
315
316 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
317 """Infer output layouts for all three outputs.
318
319 Rules:
320 1. PA_BSND layout is rejected.
321 2. Partial inputs are not allowed on any of the four primary tensors.
322 3. Sharding constraints are validated (see _validate_input_layouts).
323 4. attention_out inherits query layout (deep copy).
324 5. softmax_max and softmax_sum share the same layout derived from
325 query layout with N2 and N1/N2 dims always replicated.
326 6. All three output layouts are independent deep copies.
327
328 Args:
329 cache_values: [q_layout, k_layout, v_layout, si_layout, layout_str]
330
331 Returns:
332 tuple: ((attn_layout, softmax_max_layout, softmax_sum_layout), None)
333
334 Raises:
335 ValueError: If PA_BSND layout, any input has Partial status, or
336 sharding constraints are violated.
337 """
338 q_layout = cache_values[0]
339 k_layout = cache_values[1]
340 v_layout = cache_values[2]
341 si_layout = cache_values[3]
342 layout_str = cache_values[4]
343
344 self._check_partial_inputs([q_layout, k_layout, v_layout, si_layout])
345 self._validate_input_layouts(q_layout, k_layout, v_layout, si_layout, layout_str)
346
347 attn_layout = copy.deepcopy(q_layout)
348 softmax_layout = self._infer_softmax_layout(q_layout, layout_str)
349 return (attn_layout, softmax_layout, copy.deepcopy(softmax_layout)), None
350
351 def get_expand_impl( # pylint: disable=W0237
352 self,
353 func: Optional[Callable],
354 infer_result: tuple,
355 cache_values: list,
356 extra_args: Optional[tuple] = None,
357 ) -> Optional[Callable]:
358 """Return a custom callable if context-parallel adjustment is needed.
359
360 BSND (S1 not sharded): returns None — k/v are Replicated; sparse_indices
361 reference the full k directly.
362 BSND+CP (S1 sharded): wraps func to slice k, v, and key_rope to the
363 causal window ``k[:, :S1_local*(split_id+1), :, :]`` before calling
364 the kernel. Mirrors MindFormers adjust_bsnd_input logic, ensuring
365 that sparse_indices produced by lightning_indexer (which applies the
366 same truncation) remain valid.
367 TND+CP: wraps func to adjust actual_seq_lengths_query/kv per rank,
368 using the same algorithm as dsa_attention._sparse_flash_attention_forward.
369 TND (no CP): wraps func to clamp seq_lens to local T1 slice.
370
371 Args:
372 func: The underlying op callable.
373 infer_result: Output from infer_layout.
374 cache_values: [q_layout, k_layout, v_layout, si_layout, layout_str].
375 extra_args: Unused; kept for interface compatibility.
376
377 Returns:
378 Callable wrapper or None.
379 """
380 q_layout = cache_values[0]
381 k_layout = cache_values[1]
382 layout_str = cache_values[4]
383
384 if layout_str == 'BSND':
385 if q_layout.tensor_map[1] == -1:
386 # S1 not sharded: pure DP or fully replicated.
387 # k/v are Replicate on the CP dimension, so sparse_indices reference
388 # the full k directly; no truncation needed.
389 return None
390 split_id = q_layout.get_split_id(1)
391
392 def _bsnd_cp_impl(*args, **kwargs):
393 local_q, local_k, local_v = args[0], args[1], args[2]
394 s1_local = local_q.shape[1]
395 sliced_k = _adjust_bsnd_key(local_k, s1_local, split_id)
396 sliced_v = _adjust_bsnd_key(local_v, s1_local, split_id)
397 key_rope = kwargs.get('key_rope')
398 new_kwargs = (
399 {**kwargs, 'key_rope': _adjust_bsnd_key(key_rope, s1_local, split_id)}
400 if key_rope is not None else kwargs
401 )
402 return func(local_q, sliced_k, sliced_v, *args[3:], **new_kwargs)
403
404 return _bsnd_cp_impl
405
406 # TND: CP applies when q's T1 is sharded more finely than k's T2.
407 q_split = q_layout.get_dim_split_num(0)
408 k_split = k_layout.get_dim_split_num(0)
409 split_id = q_layout.get_split_id(0) if q_split > k_split else 0
410 cp_size = q_split // k_split if k_split > 0 else 1
411 cp_rank = split_id % cp_size if cp_size > 1 else 0
412
413 def _tnd_cp_impl(*args, **kwargs):
414 local_q, local_k = args[0], args[1]
415 qlen_tensor = kwargs.get('actual_seq_lengths_query')
416 klen_tensor = kwargs.get('actual_seq_lengths_kv')
417 if qlen_tensor is None or klen_tensor is None:
418 return func(*args, **kwargs)
419 adj_q, adj_k = _adjust_tnd_seq_lens(
420 local_q, local_k, qlen_tensor, klen_tensor,
421 cp_rank=cp_rank,
422 )
423 return func(*args, **{
424 **kwargs,
425 'actual_seq_lengths_query': adj_q,
426 'actual_seq_lengths_kv': adj_k,
427 })
428
429 return _tnd_cp_impl