Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / custom_ops / experimental / __init__.py: 0%

16 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Experimental custom operators for HyperParallel. 

16 

17.. warning:: 

18 This is an experimental API that subject to change or deletion. 

19 

20These operators delegate to the platform-specific ``custom_ops`` interface. 

21On MindSpore they wrap Ascend NPU custom C++ kernels through the 

22``DFunction`` distributed dispatch framework. On PyTorch they raise 

23``NotImplementedError``. 

24 

25Usage:: 

26 

27 from hyper_parallel.custom_ops.experimental import npu_dense_lightning_indexer_softmax_lse 

28 softmax_max, softmax_sum = npu_dense_lightning_indexer_softmax_lse( 

29 query_index, key_index, weights, layout='BSND') 

30 

31When inputs are ``DTensor`` objects, the call is automatically routed 

32through the registered ``DistributedOp`` for layout inference and 

33re-distribution. 

34""" 

35from typing import Optional, Tuple 

36 

37from mindspore import Tensor 

38 

39from hyper_parallel.platform import get_platform 

40 

41__all__ = [ 

42 "npu_dense_lightning_indexer_grad_kl_loss", 

43 "npu_dense_lightning_indexer_softmax_lse", 

44 "npu_mhc_post", 

45 "npu_mhc_pre_sinkhorn", 

46 "npu_sparse_lightning_indexer_grad_kl_loss", 

47] 

48 

49_platform = get_platform() 

50 

51_MAX_INT64 = 9223372036854775807 

52 

53 

54def npu_dense_lightning_indexer_softmax_lse( 

55 query_index, 

56 key_index, 

57 weights, 

58 *, 

59 actual_seq_qlen: Optional[Tensor] = None, 

60 actual_seq_klen: Optional[Tensor] = None, 

61 layout: str = 'BSND', 

62 sparse_mode: int = 3, 

63 pre_tokens: int = _MAX_INT64, 

64 next_tokens: int = _MAX_INT64, 

65) -> Tuple: 

66 """Compute softmax max/sum indices for Lightning Indexer attention. 

67 

68 .. warning:: 

69 This is an experimental API that subject to change or deletion. 

70 

71 Pre-computes the Softmax max and sum values to reduce memory usage. 

72 

73 The call is routed through the platform ``custom_ops`` layer, which 

74 delegates to a ``DFunction`` wrapping the Ascend custom C++ kernel. 

75 DTensor inputs are transparently handled by distributed dispatch. 

76 

77 Args: 

78 query_index: Lightning Indexer query input (Q̃). dtype bfloat16/float16. 

79 key_index: Lightning Indexer key input (K̃). Same dtype as query_index. 

80 weights: Weight coefficient (W). dtype bfloat16/float16/float32. 

81 actual_seq_qlen: Cumulative query sequence lengths (int32 Tensor). 

82 actual_seq_klen: Cumulative key sequence lengths (int32 Tensor). 

83 layout: Data layout format — 'BSND' (default) or 'TND'. 

84 sparse_mode: Sparse computation mode; only mode 3 is supported. 

85 pre_tokens: Preceding token window size for sparse attention (int64). 

86 next_tokens: Following token window size for sparse attention (int64). 

87 

88 Returns: 

89 tuple[Tensor, Tensor]: ``(softmax_max_index, softmax_sum_index)``. 

90 """ 

91 return _platform.custom_ops.npu_dense_lightning_indexer_softmax_lse( 

92 query_index, key_index, weights, 

93 actual_seq_qlen, actual_seq_klen, 

94 layout, sparse_mode, pre_tokens, next_tokens, 

95 ) 

96 

97 

98def npu_dense_lightning_indexer_grad_kl_loss( 

99 query, 

100 key, 

101 query_index, 

102 key_index, 

103 weights, 

104 softmax_max, 

105 softmax_sum, 

106 softmax_max_index, 

107 softmax_sum_index, 

108 scale_value, 

109 *, 

110 query_rope=None, 

111 key_rope=None, 

112 actual_seq_qlen: Optional[Tensor] = None, 

113 actual_seq_klen: Optional[Tensor] = None, 

114 layout: str = 'BSND', 

115 sparse_mode: int = 3, 

116 pre_tokens: int = _MAX_INT64, 

117 next_tokens: int = _MAX_INT64, 

118) -> Tuple: 

119 """Compute backward gradients and KL-divergence loss for dense Lightning Indexer. 

120 

121 .. warning:: 

122 This is an experimental API that subject to change or deletion. 

123 

124 The call is routed through the platform ``custom_ops`` layer. 

125 

126 Returns: 

127 tuple[Tensor, Tensor, Tensor, Tensor]: 

128 ``(d_query_index, d_key_index, d_weights, loss)``. 

129 """ 

130 return _platform.custom_ops.npu_dense_lightning_indexer_grad_kl_loss( 

131 query, key, query_index, key_index, weights, 

132 softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, 

133 scale_value, 

134 query_rope, key_rope, 

135 actual_seq_qlen, actual_seq_klen, 

136 layout, sparse_mode, 

137 pre_tokens, next_tokens, 

138 ) 

139 

140 

141def npu_sparse_lightning_indexer_grad_kl_loss( 

142 query, 

143 key, 

144 query_index, 

145 key_index, 

146 weights, 

147 sparse_indices, 

148 softmax_max, 

149 softmax_sum, 

150 scale_value, 

151 *, 

152 query_rope=None, 

153 key_rope=None, 

154 actual_seq_qlen: Optional[Tensor] = None, 

155 actual_seq_klen: Optional[Tensor] = None, 

156 layout: str = 'BSND', 

157 sparse_mode: int = 3, 

158 pre_tokens: int = _MAX_INT64, 

159 next_tokens: int = _MAX_INT64, 

160) -> Tuple: 

161 """Compute backward gradients and KL-divergence loss for sparse Lightning Indexer. 

162 

163 .. warning:: 

164 This is an experimental API that subject to change or deletion. 

165 

166 Returns: 

167 tuple[Tensor, Tensor, Tensor, Tensor]: 

168 ``(d_query_index, d_key_index, d_weights, loss)``. 

169 """ 

170 return _platform.custom_ops.npu_sparse_lightning_indexer_grad_kl_loss( 

171 query, key, query_index, key_index, weights, 

172 sparse_indices, softmax_max, softmax_sum, scale_value, 

173 query_rope, key_rope, 

174 actual_seq_qlen, actual_seq_klen, 

175 layout, sparse_mode, 

176 pre_tokens, next_tokens, 

177 ) 

178 

179 

180def npu_mhc_post(x, h_res, h_out, h_post) -> Tuple: 

181 """MHC post-processing with residual connection. 

182 

183 .. warning:: 

184 This is an experimental API that subject to change or deletion. 

185 

186 Returns: Output tensor with same shape and dtype as x. 

187 """ 

188 return _platform.custom_ops.npu_mhc_post(x, h_res, h_out, h_post) 

189 

190 

191def npu_mhc_pre_sinkhorn( 

192 x, 

193 phi, 

194 alpha, 

195 bias, 

196 *, 

197 hc_mult: int = 4, 

198 num_iters: int = 20, 

199 hc_eps: float = 1e-6, 

200 norm_eps: float = 1e-6, 

201 out_flag: bool = True, 

202) -> Tuple: 

203 """MHC pre-processing with Sinkhorn normalization. 

204 

205 .. warning:: 

206 This is an experimental API that subject to change or deletion. 

207 

208 Returns: 8 output tensors. 

209 """ 

210 return _platform.custom_ops.npu_mhc_pre_sinkhorn( 

211 x, phi, alpha, bias, 

212 hc_mult, num_iters, 

213 hc_eps, norm_eps, out_flag, 

214 )