Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_lightning_indexer.py: 87%
93 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for lightning_indexer operator."""
16import copy
17from typing import Callable, Optional, Tuple
19from hyper_parallel.core.dtensor.layout import Layout
20from hyper_parallel.platform import get_platform
21from .parallel_ops import DistributedOp
22from .parallel_npu_dense_lightning_indexer_softmax_lse import (
23 _adjust_bsnd_key,
24 _adjust_tnd_seq_lens,
25)
27platform = get_platform()
29_MAX_INT64 = 9223372036854775807
31# Maps layout_str -> tensor role -> {dim_index: dim_label} for replicated-dim checks.
32# 'q' = query, 'k' = key, 'w' = weights.
33_REPLICATED_DIMS = {
34 'BSND': {
35 'q': {2: 'N1', 3: 'D'},
36 'k': {1: 'S2', 2: 'N2', 3: 'D'},
37 'w': {2: 'N1'},
38 },
39 'TND': {
40 'q': {1: 'N1', 2: 'D'},
41 'k': {1: 'N2', 2: 'D'},
42 'w': {1: 'N1'},
43 },
44}
47def _normalize_lightning_indexer_args(
48 query,
49 key,
50 weights,
51 actual_seq_lengths_query=None,
52 actual_seq_lengths_key=None,
53 block_table=None,
54 layout_query='BSND',
55 layout_key='BSND',
56 sparse_count=2048,
57 sparse_mode=3,
58 pre_tokens=_MAX_INT64,
59 next_tokens=_MAX_INT64,
60 return_value=False):
61 """Normalize positional and keyword arguments into a canonical positional tuple.
63 Args:
64 query: Query tensor.
65 key: Key tensor.
66 weights: Weight tensor.
67 actual_seq_lengths_query: Cumulative query sequence lengths (TND only).
68 actual_seq_lengths_key: Cumulative key sequence lengths (TND only).
69 block_table: Block table for PageAttention (optional).
70 layout_query: Input layout string for query, 'BSND' or 'TND'.
71 layout_key: Input layout string for key, 'BSND', 'TND', or 'PA_BSND'.
72 sparse_count: Number of top-k blocks to retain.
73 sparse_mode: Sparse attention mode (0=defaultMask, 3=rightDownCausal).
74 pre_tokens: Sparse pre-tokens count.
75 next_tokens: Sparse next-tokens count.
76 return_value: Whether to output sparse_values.
78 Returns:
79 tuple: (positional_args_tuple, keyword_args_dict)
80 """
81 local_args = (query, key, weights)
82 local_kwargs = {
83 'actual_seq_lengths_query': actual_seq_lengths_query,
84 'actual_seq_lengths_key': actual_seq_lengths_key,
85 'block_table': block_table,
86 'layout_query': layout_query,
87 'layout_key': layout_key,
88 'sparse_count': sparse_count,
89 'sparse_mode': sparse_mode,
90 'pre_tokens': pre_tokens,
91 'next_tokens': next_tokens,
92 'return_value': return_value,
93 }
94 return local_args, local_kwargs
97class LightningIndexerDistributedOp(DistributedOp):
98 """Distributed operator for MindSpore built-in lightning_indexer.
100 LightningIndexer computes the top-k most relevant key positions for each query token
101 in sparse attention. It is a MindSpore built-in op (accessed via
102 ``ops.lightning_indexer``), not a custom op, so only the distributed sharding
103 logic is implemented here.
105 Supports BSND and TND input layouts on both MindSpore and PyTorch platforms.
107 Output shapes:
108 - BSND: query (B, S1, N1, D) → outputs (B, S1, N2, sparse_count)
109 - TND: query (T1, N1, D) → outputs (T1, N2, sparse_count)
111 Context parallelism (CP) is handled in ``get_expand_impl``:
112 - BSND+CP: key S2 is sliced to the causal window for each rank.
113 - TND+CP: actual_seq_qlen / actual_seq_klen are adjusted per rank.
115 """
117 @staticmethod
118 def _infer_output_layout(q_layout: Layout, layout_str: str) -> Layout:
119 """Build the output layout for both sparse outputs from the query layout.
121 BSND: input (B, S1, N1, D) → output (B, S1, N2, sparse_count)
122 tensor_map: (q_tm[0], q_tm[1], -1, -1)
123 TND: input (T1, N1, D) → output (T1, N2, sparse_count)
124 tensor_map: (q_tm[0], -1, -1)
126 N2 is always replicated (key's head dimension constraint).
127 sparse_count is always replicated (int scalar attribute).
129 Args:
130 q_layout: Layout of the query input.
131 layout_str: 'BSND' or 'TND'.
133 Returns:
134 Layout for the output tensors.
135 """
136 q_tm = q_layout.tensor_map
137 out_layout = Layout.from_device_mesh(q_layout.mesh)
138 if layout_str == 'BSND':
139 out_tm = (q_tm[0], q_tm[1], -1, -1)
140 else:
141 out_tm = (q_tm[0], -1, -1)
142 out_layout.set_tensor_map(out_tm)
143 out_layout.tensor_map_to_placement()
144 return out_layout
146 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
147 """Extract local tensors and build the layout cache.
149 Args:
150 args: Positional arguments (may contain DTensors).
151 kwargs: Keyword arguments.
153 Returns:
154 tuple: (local_args, local_kwargs, cache_values) where cache_values is
155 [q_layout, k_layout, w_layout, layout_str].
156 """
157 norm_args, local_kwargs = _normalize_lightning_indexer_args(*args, **kwargs)
159 query_index, key_index, weights = norm_args[0], norm_args[1], norm_args[2]
160 layout_str = local_kwargs['layout_query'] # layout_query
162 qlen_kw = local_kwargs.get('actual_seq_lengths_query')
163 klen_kw = local_kwargs.get('actual_seq_lengths_key')
164 if qlen_kw is not None:
165 local_kwargs['actual_seq_lengths_query'] = qlen_kw.to_local()
166 if klen_kw is not None:
167 local_kwargs['actual_seq_lengths_key'] = klen_kw.to_local()
169 local_args = (query_index.to_local(), key_index.to_local(), weights.to_local())
171 cache_values = [query_index.layout, key_index.layout, weights.layout, layout_str]
172 return local_args, local_kwargs, cache_values
174 @staticmethod
175 def _validate_input_layouts(
176 q_layout: Layout,
177 k_layout: Layout,
178 w_layout: Layout,
179 layout_str: str,
180 ) -> None:
181 """Validate sharding constraints for all input tensors.
183 BSND rules (query/key/weights shapes: (B,S1,N1,D) / (B,S2,N2,D) / (B,S1,N1)):
184 - N1 (dim 2) and D (dim 3) of query must be replicated.
185 - S2 (dim 1), N2 (dim 2), D (dim 3) of key must be replicated.
186 - B sharding of query and key must be identical.
187 - B and S1 sharding of weights must match query; N1 must be replicated.
189 TND rules (query/key/weights shapes: (T1,N1,D) / (T2,N2,D) / (T1,N1)):
190 - N1 (dim 1) and D (dim 2) of query must be replicated.
191 - N2 (dim 1) and D (dim 2) of key must be replicated.
192 - T1 sharding of weights must match query; N1 must be replicated.
194 Args:
195 q_layout: Layout of query.
196 k_layout: Layout of key.
197 w_layout: Layout of weights.
198 layout_str: 'BSND' or 'TND'.
200 Raises:
201 ValueError: If any constraint is violated.
202 """
203 op = "lightning_indexer"
204 q_tm = q_layout.tensor_map
205 k_tm = k_layout.tensor_map
206 w_tm = w_layout.tensor_map
207 tms = {'q': (q_tm, 'query'), 'k': (k_tm, 'key'), 'w': (w_tm, 'weights')}
208 for role, dims in _REPLICATED_DIMS.get(layout_str, {}).items():
209 tm, tensor_name = tms[role]
210 for dim, label in dims.items():
211 if tm[dim] != -1:
212 raise ValueError(
213 f"For {op}, {label} (dim {dim}) of {tensor_name} should be replicated, "
214 f"but got tensor_map={tm}"
215 )
216 if layout_str == 'BSND':
217 if q_tm[0] != k_tm[0]:
218 raise ValueError(
219 f"For {op}, B (dim 0) sharding of query and key should match, "
220 f"but got query={q_tm[0]}, key={k_tm[0]}"
221 )
222 if w_tm[0] != q_tm[0]:
223 raise ValueError(
224 f"For {op}, B (dim 0) sharding of weights should match query, "
225 f"but got weights={w_tm[0]}, query={q_tm[0]}"
226 )
227 if w_tm[1] != q_tm[1]:
228 raise ValueError(
229 f"For {op}, S1 (dim 1) sharding of weights should match query, "
230 f"but got weights={w_tm[1]}, query={q_tm[1]}"
231 )
232 else: # TND
233 if w_tm[0] != q_tm[0]:
234 raise ValueError(
235 f"For {op}, T1 (dim 0) sharding of weights should match query, "
236 f"but got weights={w_tm[0]}, query={q_tm[0]}"
237 )
239 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
240 """Infer output layouts for sparse_indices and sparse_values outputs.
242 Rules:
243 1. No Partial inputs are allowed on any of the three input tensors.
244 2. Input sharding constraints are validated per layout_str (see
245 ``_validate_input_layouts`` for the full rule set).
246 3. Output tensor shape depends on layout_str:
247 - BSND: query (B, S1, N1, D) → outputs (B, S1, N2, sparse_count).
248 B and S1 sharding are inherited from query;
249 N2 and sparse_count are always replicated.
250 - TND: query (T1, N1, D) → outputs (T1, N2, sparse_count).
251 T1 sharding is inherited from query;
252 N2 and sparse_count are always replicated.
253 4. Both sparse_indices and sparse_values outputs share the same layout
254 (independent deep copies so callers can mutate them safely).
256 Args:
257 cache_values: [q_layout, k_layout, w_layout, layout_str]
259 Returns:
260 tuple: ((indices_layout, values_layout), None)
262 Raises:
263 ValueError: If any input has Partial status, or sharding constraints
264 are violated.
265 """
266 q_layout = cache_values[0]
267 k_layout = cache_values[1]
268 w_layout = cache_values[2]
269 layout_str = cache_values[3]
271 self._check_partial_inputs([q_layout, k_layout, w_layout])
272 self._validate_input_layouts(q_layout, k_layout, w_layout, layout_str)
274 out_layout = self._infer_output_layout(q_layout, layout_str)
275 return (out_layout, copy.deepcopy(out_layout)), None
277 def get_expand_impl( # pylint: disable=W0237
278 self,
279 func: Optional[Callable],
280 infer_result: tuple,
281 cache_values: list,
282 extra_args: Optional[tuple] = None,
283 ) -> Optional[Callable]:
284 """Return a custom callable if context-parallel adjustments are needed.
286 BSND+CP: wraps ``func`` to slice key's S2 to the causal window.
287 TND+CP: wraps ``func`` to adjust actual_seq_qlen/klen per rank.
288 No CP: returns None (dispatcher calls ``func`` directly).
290 Args:
291 func: The underlying op callable.
292 infer_result: Output from ``infer_layout``.
293 cache_values: [q_layout, k_layout, w_layout, layout_str].
294 extra_args: Unused; kept for interface compatibility.
296 Returns:
297 Callable wrapper or None.
298 """
299 q_layout = cache_values[0]
300 k_layout = cache_values[1]
301 layout_str = cache_values[3]
303 if layout_str == 'BSND':
304 # S1 is dim 1 of query; if not sharded, no CP adjustment needed.
305 if q_layout.tensor_map[1] == -1:
306 return None
307 split_id = q_layout.get_split_id(1)
309 def _bsnd_cp_impl(*args, **kwargs):
310 local_q, local_k = args[0], args[1]
311 sliced_k = _adjust_bsnd_key(local_k, local_q.shape[1], split_id)
312 return func(local_q, sliced_k, *args[2:], **kwargs)
314 return _bsnd_cp_impl
316 # TND: DP always requires seq_len adjustment; CP additionally
317 # requires token-level offset adjustment.
318 dp_size = k_layout.get_dim_split_num(0) # DP splits on k's T2
319 split_id = q_layout.get_split_id(0)
320 cp_size = (q_layout.get_dim_split_num(0) // dp_size
321 if dp_size > 0 else 1)
322 cp_rank = split_id % cp_size if cp_size > 1 else 0
324 def _tnd_impl(*args, **kwargs):
325 local_q, local_k = args[0], args[1]
327 qlen_tensor = kwargs.get('actual_seq_lengths_query')
328 klen_tensor = kwargs.get('actual_seq_lengths_key')
330 if qlen_tensor is None or klen_tensor is None:
331 return func(*args, **kwargs)
333 adj_q, adj_k = _adjust_tnd_seq_lens(
334 local_q, local_k, qlen_tensor, klen_tensor,
335 cp_rank=cp_rank,
336 )
338 return func(*args, **{**kwargs, 'actual_seq_lengths_query': adj_q,
339 'actual_seq_lengths_key': adj_k})
341 return _tnd_impl