Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / custom_ops / custom_op_impl.py: 57%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore custom kernel implementations and DFunction wrappers.""" 

16import os 

17import sys 

18 

19import mindspore as ms # pylint: disable=C0415 

20 

21from hyper_parallel.core.shard.dfunction import DFunction 

22 

23 

24_CC_DIR = os.path.dirname(os.path.abspath(__file__)) 

25_MS_EXTENSION_NAME = "hyper_parallel_custom_ops_ms" 

26_BUILD_LIB = os.path.join(_CC_DIR, "build", "lib") 

27 

28if _BUILD_LIB not in sys.path: 

29 sys.path.insert(0, _BUILD_LIB) 

30 

31try: 

32 _custom_ops = __import__(_MS_EXTENSION_NAME) 

33except ImportError: 

34 # Source-tree development: .so not pre-built; JIT-compile from local .cc files. 

35 _custom_ops = ms.ops.CustomOpBuilder( 

36 _MS_EXTENSION_NAME, 

37 [ 

38 os.path.join(_CC_DIR, "module.cc"), 

39 os.path.join(_CC_DIR, "dense_lightning_indexer_grad_kl_loss.cc"), 

40 os.path.join(_CC_DIR, "dense_lightning_indexer_softmax_lse.cc"), 

41 os.path.join(_CC_DIR, "sparse_lightning_indexer_grad_kl_loss.cc"), 

42 os.path.join(_CC_DIR, "mhc_post.cc"), 

43 os.path.join(_CC_DIR, "mhc_post_backward.cc"), 

44 os.path.join(_CC_DIR, "mhc_pre_sinkhorn.cc"), 

45 os.path.join(_CC_DIR, "mhc_pre_sinkhorn_backward.cc"), 

46 ], 

47 backend="Ascend", 

48 ).load() 

49 

50 

51def _ensure_contiguous(*tensors): 

52 """Ensure all tensors are contiguous (no-op if already contiguous).""" 

53 return tuple(t.contiguous() if not t.is_contiguous() else t for t in tensors) 

54 

55 

56def _to_list_int64(val): 

57 """Convert Tensor(int32) to List[int64] for aclnn kernel consumption.""" 

58 if isinstance(val, ms.Tensor): 

59 return val.asnumpy().astype("int64").tolist() 

60 return val 

61 

62 

63class NpuDenseLightningIndexerSoftmaxLseDFunction(DFunction): # pylint: disable=W0221 

64 """DFunction wrapper for npu_dense_lightning_indexer_softmax_lse on MindSpore. 

65 

66 Routes plain-tensor calls directly to the MindSpore custom kernel, and 

67 DTensor calls through the distributed dispatch framework using the 

68 registered DistributedOp with the same op_name. 

69 

70 All 9 forward arguments after ``ctx`` are positional to stay compatible 

71 with both MindSpore autograd function conventions. 

72 

73 No backward is defined because the operator does not require gradients. 

74 """ 

75 

76 _op_name = "npu_dense_lightning_indexer_softmax_lse" 

77 

78 @staticmethod 

79 def forward(ctx, query_index, key_index, weights, 

80 actual_seq_qlen, actual_seq_klen, 

81 layout, sparse_mode, pre_tokens, next_tokens): 

82 """Forward pass: delegates to the MindSpore Ascend custom kernel. 

83 

84 Args: 

85 ctx: Autograd context. 

86 query_index: Lightning Indexer query input (Q̃). 

87 key_index: Lightning Indexer key input (K̃). 

88 weights: Lightning Indexer weight coefficient (W). 

89 actual_seq_qlen: Cumulative query sequence lengths; None for BSND. 

90 actual_seq_klen: Cumulative key sequence lengths; None for BSND. 

91 layout: Data layout format, 'BSND' or 'TND'. 

92 sparse_mode: Sparse computation mode (only mode 3 supported). 

93 pre_tokens: Number of preceding tokens for sparse attention. 

94 next_tokens: Number of following tokens for sparse attention. 

95 

96 Returns: 

97 tuple[Tensor, Tensor]: (softmax_max_index, softmax_sum_index), both float32. 

98 """ 

99 return _custom_ops.npu_dense_lightning_indexer_softmax_lse( 

100 query_index, key_index, weights, 

101 _to_list_int64(actual_seq_qlen), _to_list_int64(actual_seq_klen), 

102 layout, sparse_mode, pre_tokens, next_tokens, 

103 ) 

104 

105 @staticmethod 

106 def backward(ctx, *grad_outputs): 

107 """No-op backward — this operator does not require gradients.""" 

108 return (None,) * 9 

109 

110 

111class NpuDenseLightningIndexerGradKlLossDFunction(DFunction): # pylint: disable=W0221 

112 """DFunction wrapper for npu_dense_lightning_indexer_grad_kl_loss on MindSpore. 

113 

114 Routes plain-tensor calls directly to the MindSpore custom kernel, and 

115 DTensor calls through the distributed dispatch framework using the 

116 registered DistributedOp with the same op_name. 

117 

118 All 18 forward arguments after ``ctx`` are positional to stay compatible 

119 with both MindSpore autograd function conventions. 

120 """ 

121 

122 _op_name = "npu_dense_lightning_indexer_grad_kl_loss" 

123 

124 @staticmethod 

125 def forward(ctx, query, key, query_index, key_index, weights, 

126 softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, 

127 scale_value, query_rope, key_rope, 

128 actual_seq_qlen, actual_seq_klen, 

129 layout, sparse_mode, pre_tokens, next_tokens): 

130 """Forward pass: delegates to the MindSpore Ascend custom kernel. 

131 

132 Args: 

133 ctx: Autograd context. 

134 query: Main attention query (Q). dtype bfloat16/float16. 

135 key: Main attention key (K). dtype bfloat16/float16. 

136 query_index: Lightning Indexer query input (Q̃). dtype bfloat16/float16. 

137 key_index: Lightning Indexer key input (K̃). dtype bfloat16/float16. 

138 weights: Lightning Indexer weight coefficient (W). 

139 softmax_max: Attention softmax max values. dtype float32. 

140 softmax_sum: Attention softmax sum values. dtype float32. 

141 softmax_max_index: Index attention softmax max (from softmax_lse). dtype float32. 

142 softmax_sum_index: Index attention softmax sum (from softmax_lse). dtype float32. 

143 scale_value: Scaling factor. dtype float32. 

144 query_rope: Optional MLA query rope tensor. 

145 key_rope: Optional MLA key rope tensor. 

146 actual_seq_qlen: Cumulative query sequence lengths; None for BSND. 

147 actual_seq_klen: Cumulative key sequence lengths; None for BSND. 

148 layout: Data layout format, 'BSND' or 'TND'. 

149 sparse_mode: Sparse computation mode (only mode 3 supported). 

150 pre_tokens: Number of preceding tokens for sparse attention. 

151 next_tokens: Number of following tokens for sparse attention. 

152 

153 Returns: 

154 tuple[Tensor, Tensor, Tensor, Tensor]: 

155 (d_query_index, d_key_index, d_weights, loss). 

156 """ 

157 result = _custom_ops.npu_dense_lightning_indexer_grad_kl_loss( 

158 query, key, query_index, key_index, weights, 

159 softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, 

160 scale_value, query_rope, key_rope, 

161 _to_list_int64(actual_seq_qlen), _to_list_int64(actual_seq_klen), 

162 layout, sparse_mode, pre_tokens, next_tokens, 

163 ) 

164 ctx.save_for_backward(result[0], result[1], result[2]) 

165 return result 

166 

167 @staticmethod 

168 def backward(ctx, *grad_outputs): 

169 d_query_index, d_key_index, d_weights = _ensure_contiguous(*ctx.saved_tensors) 

170 return (None, None, d_query_index, d_key_index, d_weights, 

171 None, None, None, None, None, None, None, None, None, None, None, None, None) 

172 

173 

174class NpuSparseLightningIndexerGradKlLossDFunction(DFunction): # pylint: disable=W0221 

175 """DFunction wrapper for npu_sparse_lightning_indexer_grad_kl_loss on MindSpore. 

176 

177 Routes plain-tensor calls directly to the MindSpore custom kernel, and 

178 DTensor calls through the distributed dispatch framework using the 

179 registered DistributedOp with the same op_name. 

180 

181 All 17 forward arguments after ``ctx`` are positional to stay compatible 

182 with both MindSpore autograd function conventions. 

183 """ 

184 

185 _op_name = "npu_sparse_lightning_indexer_grad_kl_loss" 

186 

187 @staticmethod 

188 def forward(ctx, query, key, query_index, key_index, weights, 

189 sparse_indices, softmax_max, softmax_sum, scale_value, 

190 query_rope, key_rope, 

191 actual_seq_qlen, actual_seq_klen, 

192 layout, sparse_mode, pre_tokens, next_tokens): 

193 """Forward pass: delegates to the MindSpore Ascend custom kernel. 

194 

195 Args: 

196 ctx: Autograd context. 

197 query: Main attention query (q_t). dtype bfloat16/float16. 

198 key: Main attention key (K_t). dtype bfloat16/float16. 

199 query_index: Lightning Indexer query input (q̃_t). dtype bfloat16/float16. 

200 key_index: Lightning Indexer key input (K̃_t). dtype bfloat16/float16. 

201 weights: Lightning Indexer weight coefficient (W_t). 

202 sparse_indices: Sorted token indices for key/key_index. dtype bfloat16/float16. 

203 softmax_max: Attention softmax max values. 

204 softmax_sum: Attention softmax sum values. 

205 scale_value: Scaling factor. dtype float. 

206 query_rope: Optional MLA query rope tensor. 

207 key_rope: Optional MLA key rope tensor. 

208 actual_seq_qlen: Cumulative query sequence lengths; None for BSND. 

209 actual_seq_klen: Cumulative key sequence lengths; None for BSND. 

210 layout: Data layout format, 'BSND' or 'TND'. 

211 sparse_mode: Sparse computation mode (only mode 3 supported). 

212 pre_tokens: Number of preceding tokens for sparse attention. 

213 next_tokens: Number of following tokens for sparse attention. 

214 

215 Returns: 

216 tuple[Tensor, Tensor, Tensor, Tensor]: 

217 (d_query_index, d_key_index, d_weights, loss). 

218 """ 

219 result = _custom_ops.npu_sparse_lightning_indexer_grad_kl_loss( 

220 query, key, query_index, key_index, weights, 

221 sparse_indices, softmax_max, softmax_sum, scale_value, 

222 query_rope, key_rope, 

223 _to_list_int64(actual_seq_qlen), _to_list_int64(actual_seq_klen), 

224 layout, sparse_mode, pre_tokens, next_tokens, 

225 ) 

226 ctx.save_for_backward(result[0], result[1], result[2]) 

227 return result 

228 

229 @staticmethod 

230 def backward(ctx, *grad_outputs): 

231 d_query_index, d_key_index, d_weights = _ensure_contiguous(*ctx.saved_tensors) 

232 return (None, None, d_query_index, d_key_index, d_weights, 

233 None, None, None, None, None, None, None, None, None, None, None, None) 

234 

235 

236class NpuMhcPostDFunction(DFunction): # pylint: disable=W0221 

237 """DFunction wrapper for npu_mhc_post on MindSpore. 

238 

239 Routes plain-tensor calls directly to the MindSpore custom kernel, and 

240 DTensor calls through the distributed dispatch framework using the 

241 registered DistributedOp with the same op_name. 

242 

243 All 4 forward arguments after ``ctx`` are positional to stay compatible 

244 with both MindSpore autograd function conventions. 

245 """ 

246 

247 _op_name = "npu_mhc_post" 

248 

249 @staticmethod 

250 def forward(ctx, x, h_res, h_out, h_post): 

251 """Forward pass: delegates to the MindSpore Ascend custom kernel. 

252 

253 Args: 

254 ctx: Autograd context. 

255 x: Input tensor of shape [B,S,N,D] or [T,N,D]. dtype bfloat16/float16. 

256 h_res: mHC h_res transformation matrix. dtype float32. 

257 h_out: Attention/MLP layer output. dtype bfloat16/float16. 

258 h_post: mHC h_post transformation matrix. dtype float32. 

259 

260 Returns: 

261 Tensor: Output tensor with same shape and dtype as x. 

262 """ 

263 ctx.save_for_backward(x, h_res, h_out, h_post) 

264 return _custom_ops.npu_mhc_post(x, h_res, h_out, h_post) 

265 

266 @staticmethod 

267 def backward(ctx, *grad_outputs): 

268 """Backward pass: calls npu_mhc_post_backward kernel. 

269 

270 Args: 

271 ctx: Autograd context. 

272 grad_outputs: Upstream gradients; grad_outputs[0] is grad_y. 

273 

274 Returns: 

275 tuple: (grad_x, grad_h_res, grad_h_out, grad_h_post). 

276 """ 

277 x, h_res, h_out, h_post = ctx.saved_tensors 

278 grad_y, x, h_res, h_out, h_post = _ensure_contiguous( 

279 grad_outputs[0], x, h_res, h_out, h_post) 

280 grads = _custom_ops.npu_mhc_post_backward( 

281 grad_y, x, h_res, h_out, h_post) 

282 return grads[0], grads[1], grads[2], grads[3] 

283 

284 

285class NpuMhcPreSinkhornDFunction(DFunction): # pylint: disable=W0221 

286 """DFunction wrapper for npu_mhc_pre_sinkhorn on MindSpore. 

287 

288 Routes plain-tensor calls directly to the MindSpore custom kernel, and 

289 DTensor calls through the distributed dispatch framework using the 

290 registered DistributedOp with the same op_name. 

291 

292 All 9 forward arguments after ``ctx`` are positional to stay compatible 

293 with both MindSpore autograd function conventions. 

294 """ 

295 

296 _op_name = "npu_mhc_pre_sinkhorn" 

297 

298 @staticmethod 

299 def forward(ctx, x, phi, alpha, bias, hc_mult, num_iters, hc_eps, norm_eps, out_flag): 

300 """Forward pass: delegates to the MindSpore Ascend custom kernel. 

301 

302 Args: 

303 ctx: Autograd context. 

304 x: Input tensor. dtype bfloat16/float16. 

305 phi: mHC parameter matrix. dtype float32. 

306 alpha: mHC scaling parameters. dtype float32. 

307 bias: mHC bias parameters. dtype float32. 

308 hc_mult: HC dimension size (currently only 4 supported). 

309 num_iters: Sinkhorn iteration count. 

310 hc_eps: H_pre sigmoid eps parameter. 

311 norm_eps: RmsNorm eps parameter. 

312 out_flag: Whether to output intermediate gradients. 

313 

314 Returns: 

315 tuple[Tensor, ...]: 8 output tensors 

316 (h_in, h_post, h_res, h_pre, hc_before_norm, inv_rms, sum_out, norm_out). 

317 """ 

318 result = _custom_ops.npu_mhc_pre_sinkhorn( 

319 x, phi, alpha, bias, hc_mult, num_iters, hc_eps, norm_eps, out_flag 

320 ) 

321 _, _, _, h_pre, hc_before_norm, inv_rms, sum_out, norm_out = result 

322 ctx.save_for_backward(x, phi, alpha, bias, 

323 h_pre, hc_before_norm, inv_rms, sum_out, norm_out) 

324 ctx.hc_eps = hc_eps 

325 return result 

326 

327 @staticmethod 

328 def backward(ctx, *grad_outputs): 

329 """Backward pass: calls npu_mhc_pre_sinkhorn_backward kernel. 

330 

331 Args: 

332 ctx: Autograd context. 

333 grad_outputs: Upstream gradients for the 8 forward outputs. 

334 grad_outputs[0]=grad_h_in, [1]=grad_h_post, [2]=grad_h_res; 

335 [3..7] correspond to saved intermediates and are None. 

336 

337 Returns: 

338 tuple: (grad_x, grad_phi, grad_alpha, grad_bias, None×5) — 

339 gradients for the 9 forward inputs. 

340 """ 

341 x, phi, alpha, bias, h_pre, hc_before_norm, inv_rms, sum_out, norm_out = ctx.saved_tensors 

342 (grad_h_in, grad_h_post, grad_h_res, 

343 x, phi, alpha, bias, 

344 h_pre, hc_before_norm, inv_rms, sum_out, norm_out) = _ensure_contiguous( 

345 grad_outputs[0], grad_outputs[1], grad_outputs[2], 

346 x, phi, alpha, bias, 

347 h_pre, hc_before_norm, inv_rms, sum_out, norm_out) 

348 grads = _custom_ops.npu_mhc_pre_sinkhorn_backward( 

349 grad_h_in, grad_h_post, grad_h_res, 

350 x, phi, alpha, bias, 

351 h_pre, hc_before_norm, inv_rms, sum_out, norm_out, 

352 ctx.hc_eps) 

353 return grads[0], grads[1], grads[2], grads[3], None, None, None, None, None