Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / custom_ops / experimental / __init__.py: 0%
16 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Experimental custom operators for HyperParallel.
17.. warning::
18 This is an experimental API that subject to change or deletion.
20These operators delegate to the platform-specific ``custom_ops`` interface.
21On MindSpore they wrap Ascend NPU custom C++ kernels through the
22``DFunction`` distributed dispatch framework. On PyTorch they raise
23``NotImplementedError``.
25Usage::
27 from hyper_parallel.custom_ops.experimental import npu_dense_lightning_indexer_softmax_lse
28 softmax_max, softmax_sum = npu_dense_lightning_indexer_softmax_lse(
29 query_index, key_index, weights, layout='BSND')
31When inputs are ``DTensor`` objects, the call is automatically routed
32through the registered ``DistributedOp`` for layout inference and
33re-distribution.
34"""
35from typing import Optional, Tuple
37from mindspore import Tensor
39from hyper_parallel.platform import get_platform
41__all__ = [
42 "npu_dense_lightning_indexer_grad_kl_loss",
43 "npu_dense_lightning_indexer_softmax_lse",
44 "npu_mhc_post",
45 "npu_mhc_pre_sinkhorn",
46 "npu_sparse_lightning_indexer_grad_kl_loss",
47]
49_platform = get_platform()
51_MAX_INT64 = 9223372036854775807
54def npu_dense_lightning_indexer_softmax_lse(
55 query_index,
56 key_index,
57 weights,
58 *,
59 actual_seq_qlen: Optional[Tensor] = None,
60 actual_seq_klen: Optional[Tensor] = None,
61 layout: str = 'BSND',
62 sparse_mode: int = 3,
63 pre_tokens: int = _MAX_INT64,
64 next_tokens: int = _MAX_INT64,
65) -> Tuple:
66 """Compute softmax max/sum indices for Lightning Indexer attention.
68 .. warning::
69 This is an experimental API that subject to change or deletion.
71 Pre-computes the Softmax max and sum values to reduce memory usage.
73 The call is routed through the platform ``custom_ops`` layer, which
74 delegates to a ``DFunction`` wrapping the Ascend custom C++ kernel.
75 DTensor inputs are transparently handled by distributed dispatch.
77 Args:
78 query_index: Lightning Indexer query input (Q̃). dtype bfloat16/float16.
79 key_index: Lightning Indexer key input (K̃). Same dtype as query_index.
80 weights: Weight coefficient (W). dtype bfloat16/float16/float32.
81 actual_seq_qlen: Cumulative query sequence lengths (int32 Tensor).
82 actual_seq_klen: Cumulative key sequence lengths (int32 Tensor).
83 layout: Data layout format — 'BSND' (default) or 'TND'.
84 sparse_mode: Sparse computation mode; only mode 3 is supported.
85 pre_tokens: Preceding token window size for sparse attention (int64).
86 next_tokens: Following token window size for sparse attention (int64).
88 Returns:
89 tuple[Tensor, Tensor]: ``(softmax_max_index, softmax_sum_index)``.
90 """
91 return _platform.custom_ops.npu_dense_lightning_indexer_softmax_lse(
92 query_index, key_index, weights,
93 actual_seq_qlen, actual_seq_klen,
94 layout, sparse_mode, pre_tokens, next_tokens,
95 )
98def npu_dense_lightning_indexer_grad_kl_loss(
99 query,
100 key,
101 query_index,
102 key_index,
103 weights,
104 softmax_max,
105 softmax_sum,
106 softmax_max_index,
107 softmax_sum_index,
108 scale_value,
109 *,
110 query_rope=None,
111 key_rope=None,
112 actual_seq_qlen: Optional[Tensor] = None,
113 actual_seq_klen: Optional[Tensor] = None,
114 layout: str = 'BSND',
115 sparse_mode: int = 3,
116 pre_tokens: int = _MAX_INT64,
117 next_tokens: int = _MAX_INT64,
118) -> Tuple:
119 """Compute backward gradients and KL-divergence loss for dense Lightning Indexer.
121 .. warning::
122 This is an experimental API that subject to change or deletion.
124 The call is routed through the platform ``custom_ops`` layer.
126 Returns:
127 tuple[Tensor, Tensor, Tensor, Tensor]:
128 ``(d_query_index, d_key_index, d_weights, loss)``.
129 """
130 return _platform.custom_ops.npu_dense_lightning_indexer_grad_kl_loss(
131 query, key, query_index, key_index, weights,
132 softmax_max, softmax_sum, softmax_max_index, softmax_sum_index,
133 scale_value,
134 query_rope, key_rope,
135 actual_seq_qlen, actual_seq_klen,
136 layout, sparse_mode,
137 pre_tokens, next_tokens,
138 )
141def npu_sparse_lightning_indexer_grad_kl_loss(
142 query,
143 key,
144 query_index,
145 key_index,
146 weights,
147 sparse_indices,
148 softmax_max,
149 softmax_sum,
150 scale_value,
151 *,
152 query_rope=None,
153 key_rope=None,
154 actual_seq_qlen: Optional[Tensor] = None,
155 actual_seq_klen: Optional[Tensor] = None,
156 layout: str = 'BSND',
157 sparse_mode: int = 3,
158 pre_tokens: int = _MAX_INT64,
159 next_tokens: int = _MAX_INT64,
160) -> Tuple:
161 """Compute backward gradients and KL-divergence loss for sparse Lightning Indexer.
163 .. warning::
164 This is an experimental API that subject to change or deletion.
166 Returns:
167 tuple[Tensor, Tensor, Tensor, Tensor]:
168 ``(d_query_index, d_key_index, d_weights, loss)``.
169 """
170 return _platform.custom_ops.npu_sparse_lightning_indexer_grad_kl_loss(
171 query, key, query_index, key_index, weights,
172 sparse_indices, softmax_max, softmax_sum, scale_value,
173 query_rope, key_rope,
174 actual_seq_qlen, actual_seq_klen,
175 layout, sparse_mode,
176 pre_tokens, next_tokens,
177 )
180def npu_mhc_post(x, h_res, h_out, h_post) -> Tuple:
181 """MHC post-processing with residual connection.
183 .. warning::
184 This is an experimental API that subject to change or deletion.
186 Returns: Output tensor with same shape and dtype as x.
187 """
188 return _platform.custom_ops.npu_mhc_post(x, h_res, h_out, h_post)
191def npu_mhc_pre_sinkhorn(
192 x,
193 phi,
194 alpha,
195 bias,
196 *,
197 hc_mult: int = 4,
198 num_iters: int = 20,
199 hc_eps: float = 1e-6,
200 norm_eps: float = 1e-6,
201 out_flag: bool = True,
202) -> Tuple:
203 """MHC pre-processing with Sinkhorn normalization.
205 .. warning::
206 This is an experimental API that subject to change or deletion.
208 Returns: 8 output tensors.
209 """
210 return _platform.custom_ops.npu_mhc_pre_sinkhorn(
211 x, phi, alpha, bias,
212 hc_mult, num_iters,
213 hc_eps, norm_eps, out_flag,
214 )