Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / custom_ops / custom_op_impl.py: 57%
76 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore custom kernel implementations and DFunction wrappers."""
16import os
17import sys
19import mindspore as ms # pylint: disable=C0415
21from hyper_parallel.core.shard.dfunction import DFunction
24_CC_DIR = os.path.dirname(os.path.abspath(__file__))
25_MS_EXTENSION_NAME = "hyper_parallel_custom_ops_ms"
26_BUILD_LIB = os.path.join(_CC_DIR, "build", "lib")
28if _BUILD_LIB not in sys.path:
29 sys.path.insert(0, _BUILD_LIB)
31try:
32 _custom_ops = __import__(_MS_EXTENSION_NAME)
33except ImportError:
34 # Source-tree development: .so not pre-built; JIT-compile from local .cc files.
35 _custom_ops = ms.ops.CustomOpBuilder(
36 _MS_EXTENSION_NAME,
37 [
38 os.path.join(_CC_DIR, "module.cc"),
39 os.path.join(_CC_DIR, "dense_lightning_indexer_grad_kl_loss.cc"),
40 os.path.join(_CC_DIR, "dense_lightning_indexer_softmax_lse.cc"),
41 os.path.join(_CC_DIR, "sparse_lightning_indexer_grad_kl_loss.cc"),
42 os.path.join(_CC_DIR, "mhc_post.cc"),
43 os.path.join(_CC_DIR, "mhc_post_backward.cc"),
44 os.path.join(_CC_DIR, "mhc_pre_sinkhorn.cc"),
45 os.path.join(_CC_DIR, "mhc_pre_sinkhorn_backward.cc"),
46 ],
47 backend="Ascend",
48 ).load()
51def _ensure_contiguous(*tensors):
52 """Ensure all tensors are contiguous (no-op if already contiguous)."""
53 return tuple(t.contiguous() if not t.is_contiguous() else t for t in tensors)
56def _to_list_int64(val):
57 """Convert Tensor(int32) to List[int64] for aclnn kernel consumption."""
58 if isinstance(val, ms.Tensor):
59 return val.asnumpy().astype("int64").tolist()
60 return val
63class NpuDenseLightningIndexerSoftmaxLseDFunction(DFunction): # pylint: disable=W0221
64 """DFunction wrapper for npu_dense_lightning_indexer_softmax_lse on MindSpore.
66 Routes plain-tensor calls directly to the MindSpore custom kernel, and
67 DTensor calls through the distributed dispatch framework using the
68 registered DistributedOp with the same op_name.
70 All 9 forward arguments after ``ctx`` are positional to stay compatible
71 with both MindSpore autograd function conventions.
73 No backward is defined because the operator does not require gradients.
74 """
76 _op_name = "npu_dense_lightning_indexer_softmax_lse"
78 @staticmethod
79 def forward(ctx, query_index, key_index, weights,
80 actual_seq_qlen, actual_seq_klen,
81 layout, sparse_mode, pre_tokens, next_tokens):
82 """Forward pass: delegates to the MindSpore Ascend custom kernel.
84 Args:
85 ctx: Autograd context.
86 query_index: Lightning Indexer query input (Q̃).
87 key_index: Lightning Indexer key input (K̃).
88 weights: Lightning Indexer weight coefficient (W).
89 actual_seq_qlen: Cumulative query sequence lengths; None for BSND.
90 actual_seq_klen: Cumulative key sequence lengths; None for BSND.
91 layout: Data layout format, 'BSND' or 'TND'.
92 sparse_mode: Sparse computation mode (only mode 3 supported).
93 pre_tokens: Number of preceding tokens for sparse attention.
94 next_tokens: Number of following tokens for sparse attention.
96 Returns:
97 tuple[Tensor, Tensor]: (softmax_max_index, softmax_sum_index), both float32.
98 """
99 return _custom_ops.npu_dense_lightning_indexer_softmax_lse(
100 query_index, key_index, weights,
101 _to_list_int64(actual_seq_qlen), _to_list_int64(actual_seq_klen),
102 layout, sparse_mode, pre_tokens, next_tokens,
103 )
105 @staticmethod
106 def backward(ctx, *grad_outputs):
107 """No-op backward — this operator does not require gradients."""
108 return (None,) * 9
111class NpuDenseLightningIndexerGradKlLossDFunction(DFunction): # pylint: disable=W0221
112 """DFunction wrapper for npu_dense_lightning_indexer_grad_kl_loss on MindSpore.
114 Routes plain-tensor calls directly to the MindSpore custom kernel, and
115 DTensor calls through the distributed dispatch framework using the
116 registered DistributedOp with the same op_name.
118 All 18 forward arguments after ``ctx`` are positional to stay compatible
119 with both MindSpore autograd function conventions.
120 """
122 _op_name = "npu_dense_lightning_indexer_grad_kl_loss"
124 @staticmethod
125 def forward(ctx, query, key, query_index, key_index, weights,
126 softmax_max, softmax_sum, softmax_max_index, softmax_sum_index,
127 scale_value, query_rope, key_rope,
128 actual_seq_qlen, actual_seq_klen,
129 layout, sparse_mode, pre_tokens, next_tokens):
130 """Forward pass: delegates to the MindSpore Ascend custom kernel.
132 Args:
133 ctx: Autograd context.
134 query: Main attention query (Q). dtype bfloat16/float16.
135 key: Main attention key (K). dtype bfloat16/float16.
136 query_index: Lightning Indexer query input (Q̃). dtype bfloat16/float16.
137 key_index: Lightning Indexer key input (K̃). dtype bfloat16/float16.
138 weights: Lightning Indexer weight coefficient (W).
139 softmax_max: Attention softmax max values. dtype float32.
140 softmax_sum: Attention softmax sum values. dtype float32.
141 softmax_max_index: Index attention softmax max (from softmax_lse). dtype float32.
142 softmax_sum_index: Index attention softmax sum (from softmax_lse). dtype float32.
143 scale_value: Scaling factor. dtype float32.
144 query_rope: Optional MLA query rope tensor.
145 key_rope: Optional MLA key rope tensor.
146 actual_seq_qlen: Cumulative query sequence lengths; None for BSND.
147 actual_seq_klen: Cumulative key sequence lengths; None for BSND.
148 layout: Data layout format, 'BSND' or 'TND'.
149 sparse_mode: Sparse computation mode (only mode 3 supported).
150 pre_tokens: Number of preceding tokens for sparse attention.
151 next_tokens: Number of following tokens for sparse attention.
153 Returns:
154 tuple[Tensor, Tensor, Tensor, Tensor]:
155 (d_query_index, d_key_index, d_weights, loss).
156 """
157 result = _custom_ops.npu_dense_lightning_indexer_grad_kl_loss(
158 query, key, query_index, key_index, weights,
159 softmax_max, softmax_sum, softmax_max_index, softmax_sum_index,
160 scale_value, query_rope, key_rope,
161 _to_list_int64(actual_seq_qlen), _to_list_int64(actual_seq_klen),
162 layout, sparse_mode, pre_tokens, next_tokens,
163 )
164 ctx.save_for_backward(result[0], result[1], result[2])
165 return result
167 @staticmethod
168 def backward(ctx, *grad_outputs):
169 d_query_index, d_key_index, d_weights = _ensure_contiguous(*ctx.saved_tensors)
170 return (None, None, d_query_index, d_key_index, d_weights,
171 None, None, None, None, None, None, None, None, None, None, None, None, None)
174class NpuSparseLightningIndexerGradKlLossDFunction(DFunction): # pylint: disable=W0221
175 """DFunction wrapper for npu_sparse_lightning_indexer_grad_kl_loss on MindSpore.
177 Routes plain-tensor calls directly to the MindSpore custom kernel, and
178 DTensor calls through the distributed dispatch framework using the
179 registered DistributedOp with the same op_name.
181 All 17 forward arguments after ``ctx`` are positional to stay compatible
182 with both MindSpore autograd function conventions.
183 """
185 _op_name = "npu_sparse_lightning_indexer_grad_kl_loss"
187 @staticmethod
188 def forward(ctx, query, key, query_index, key_index, weights,
189 sparse_indices, softmax_max, softmax_sum, scale_value,
190 query_rope, key_rope,
191 actual_seq_qlen, actual_seq_klen,
192 layout, sparse_mode, pre_tokens, next_tokens):
193 """Forward pass: delegates to the MindSpore Ascend custom kernel.
195 Args:
196 ctx: Autograd context.
197 query: Main attention query (q_t). dtype bfloat16/float16.
198 key: Main attention key (K_t). dtype bfloat16/float16.
199 query_index: Lightning Indexer query input (q̃_t). dtype bfloat16/float16.
200 key_index: Lightning Indexer key input (K̃_t). dtype bfloat16/float16.
201 weights: Lightning Indexer weight coefficient (W_t).
202 sparse_indices: Sorted token indices for key/key_index. dtype bfloat16/float16.
203 softmax_max: Attention softmax max values.
204 softmax_sum: Attention softmax sum values.
205 scale_value: Scaling factor. dtype float.
206 query_rope: Optional MLA query rope tensor.
207 key_rope: Optional MLA key rope tensor.
208 actual_seq_qlen: Cumulative query sequence lengths; None for BSND.
209 actual_seq_klen: Cumulative key sequence lengths; None for BSND.
210 layout: Data layout format, 'BSND' or 'TND'.
211 sparse_mode: Sparse computation mode (only mode 3 supported).
212 pre_tokens: Number of preceding tokens for sparse attention.
213 next_tokens: Number of following tokens for sparse attention.
215 Returns:
216 tuple[Tensor, Tensor, Tensor, Tensor]:
217 (d_query_index, d_key_index, d_weights, loss).
218 """
219 result = _custom_ops.npu_sparse_lightning_indexer_grad_kl_loss(
220 query, key, query_index, key_index, weights,
221 sparse_indices, softmax_max, softmax_sum, scale_value,
222 query_rope, key_rope,
223 _to_list_int64(actual_seq_qlen), _to_list_int64(actual_seq_klen),
224 layout, sparse_mode, pre_tokens, next_tokens,
225 )
226 ctx.save_for_backward(result[0], result[1], result[2])
227 return result
229 @staticmethod
230 def backward(ctx, *grad_outputs):
231 d_query_index, d_key_index, d_weights = _ensure_contiguous(*ctx.saved_tensors)
232 return (None, None, d_query_index, d_key_index, d_weights,
233 None, None, None, None, None, None, None, None, None, None, None, None)
236class NpuMhcPostDFunction(DFunction): # pylint: disable=W0221
237 """DFunction wrapper for npu_mhc_post on MindSpore.
239 Routes plain-tensor calls directly to the MindSpore custom kernel, and
240 DTensor calls through the distributed dispatch framework using the
241 registered DistributedOp with the same op_name.
243 All 4 forward arguments after ``ctx`` are positional to stay compatible
244 with both MindSpore autograd function conventions.
245 """
247 _op_name = "npu_mhc_post"
249 @staticmethod
250 def forward(ctx, x, h_res, h_out, h_post):
251 """Forward pass: delegates to the MindSpore Ascend custom kernel.
253 Args:
254 ctx: Autograd context.
255 x: Input tensor of shape [B,S,N,D] or [T,N,D]. dtype bfloat16/float16.
256 h_res: mHC h_res transformation matrix. dtype float32.
257 h_out: Attention/MLP layer output. dtype bfloat16/float16.
258 h_post: mHC h_post transformation matrix. dtype float32.
260 Returns:
261 Tensor: Output tensor with same shape and dtype as x.
262 """
263 ctx.save_for_backward(x, h_res, h_out, h_post)
264 return _custom_ops.npu_mhc_post(x, h_res, h_out, h_post)
266 @staticmethod
267 def backward(ctx, *grad_outputs):
268 """Backward pass: calls npu_mhc_post_backward kernel.
270 Args:
271 ctx: Autograd context.
272 grad_outputs: Upstream gradients; grad_outputs[0] is grad_y.
274 Returns:
275 tuple: (grad_x, grad_h_res, grad_h_out, grad_h_post).
276 """
277 x, h_res, h_out, h_post = ctx.saved_tensors
278 grad_y, x, h_res, h_out, h_post = _ensure_contiguous(
279 grad_outputs[0], x, h_res, h_out, h_post)
280 grads = _custom_ops.npu_mhc_post_backward(
281 grad_y, x, h_res, h_out, h_post)
282 return grads[0], grads[1], grads[2], grads[3]
285class NpuMhcPreSinkhornDFunction(DFunction): # pylint: disable=W0221
286 """DFunction wrapper for npu_mhc_pre_sinkhorn on MindSpore.
288 Routes plain-tensor calls directly to the MindSpore custom kernel, and
289 DTensor calls through the distributed dispatch framework using the
290 registered DistributedOp with the same op_name.
292 All 9 forward arguments after ``ctx`` are positional to stay compatible
293 with both MindSpore autograd function conventions.
294 """
296 _op_name = "npu_mhc_pre_sinkhorn"
298 @staticmethod
299 def forward(ctx, x, phi, alpha, bias, hc_mult, num_iters, hc_eps, norm_eps, out_flag):
300 """Forward pass: delegates to the MindSpore Ascend custom kernel.
302 Args:
303 ctx: Autograd context.
304 x: Input tensor. dtype bfloat16/float16.
305 phi: mHC parameter matrix. dtype float32.
306 alpha: mHC scaling parameters. dtype float32.
307 bias: mHC bias parameters. dtype float32.
308 hc_mult: HC dimension size (currently only 4 supported).
309 num_iters: Sinkhorn iteration count.
310 hc_eps: H_pre sigmoid eps parameter.
311 norm_eps: RmsNorm eps parameter.
312 out_flag: Whether to output intermediate gradients.
314 Returns:
315 tuple[Tensor, ...]: 8 output tensors
316 (h_in, h_post, h_res, h_pre, hc_before_norm, inv_rms, sum_out, norm_out).
317 """
318 result = _custom_ops.npu_mhc_pre_sinkhorn(
319 x, phi, alpha, bias, hc_mult, num_iters, hc_eps, norm_eps, out_flag
320 )
321 _, _, _, h_pre, hc_before_norm, inv_rms, sum_out, norm_out = result
322 ctx.save_for_backward(x, phi, alpha, bias,
323 h_pre, hc_before_norm, inv_rms, sum_out, norm_out)
324 ctx.hc_eps = hc_eps
325 return result
327 @staticmethod
328 def backward(ctx, *grad_outputs):
329 """Backward pass: calls npu_mhc_pre_sinkhorn_backward kernel.
331 Args:
332 ctx: Autograd context.
333 grad_outputs: Upstream gradients for the 8 forward outputs.
334 grad_outputs[0]=grad_h_in, [1]=grad_h_post, [2]=grad_h_res;
335 [3..7] correspond to saved intermediates and are None.
337 Returns:
338 tuple: (grad_x, grad_phi, grad_alpha, grad_bias, None×5) —
339 gradients for the 9 forward inputs.
340 """
341 x, phi, alpha, bias, h_pre, hc_before_norm, inv_rms, sum_out, norm_out = ctx.saved_tensors
342 (grad_h_in, grad_h_post, grad_h_res,
343 x, phi, alpha, bias,
344 h_pre, hc_before_norm, inv_rms, sum_out, norm_out) = _ensure_contiguous(
345 grad_outputs[0], grad_outputs[1], grad_outputs[2],
346 x, phi, alpha, bias,
347 h_pre, hc_before_norm, inv_rms, sum_out, norm_out)
348 grads = _custom_ops.npu_mhc_pre_sinkhorn_backward(
349 grad_h_in, grad_h_post, grad_h_res,
350 x, phi, alpha, bias,
351 h_pre, hc_before_norm, inv_rms, sum_out, norm_out,
352 ctx.hc_eps)
353 return grads[0], grads[1], grads[2], grads[3], None, None, None, None, None