Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_mhc_pre_sinkhorn.py: 98%

55 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for npu_mhc_pre_sinkhorn operator.""" 

16from typing import Tuple, Dict, Any 

17 

18from hyper_parallel.core.dtensor.layout import Layout 

19from hyper_parallel.platform import get_platform 

20from hyper_parallel.platform.platform import PlatformType 

21from .parallel_ops import DistributedOp 

22 

23platform = get_platform() 

24 

25_HC_MULT_DEFAULT = 4 

26_NUM_ITERS_DEFAULT = 20 

27_HC_EPS_DEFAULT = 1e-6 

28_NORM_EPS_DEFAULT = 1e-6 

29 

30 

31def _normalize_mhc_pre_sinkhorn_args( 

32 x, 

33 phi, 

34 alpha, 

35 bias, 

36 hc_mult=_HC_MULT_DEFAULT, 

37 num_iters=_NUM_ITERS_DEFAULT, 

38 hc_eps=_HC_EPS_DEFAULT, 

39 norm_eps=_NORM_EPS_DEFAULT, 

40 out_flag=True): 

41 """Normalize positional and keyword arguments into a canonical positional tuple. 

42 

43 Args: 

44 x: Input tensor [B,S,N,C] or [T,N,C]. 

45 phi: mHC parameter matrix [N*N+2*N, N*C]. 

46 alpha: mHC scaling parameters [3]. 

47 bias: mHC bias parameters [N*N+2*N]. 

48 hc_mult: HC dimension size (currently only 4 supported). 

49 num_iters: Sinkhorn iteration count. 

50 hc_eps: H_pre sigmoid eps parameter. 

51 norm_eps: RmsNorm eps parameter. 

52 out_flag: Whether to output intermediate gradients. 

53 

54 Returns: 

55 tuple: (positional_args_tuple, empty_kwargs_dict) 

56 """ 

57 return ( 

58 x, phi, alpha, bias, 

59 hc_mult, num_iters, hc_eps, norm_eps, out_flag, 

60 ), {} 

61 

62 

63# Validation rules table for npu_mhc_pre_sinkhorn 

64# Key: tensor_map length (format identifier) 

65# Value: validation rules for that format 

66_MHC_PRE_SINKHORN_VALIDATION_RULES: Dict[int, Dict[str, Any]] = { 

67 4: { 

68 "op_name": "npu_mhc_pre_sinkhorn", 

69 "forbidden_dims": {2: "N"}, 

70 "phi_forbidden_dims": {0: "dim0", 1: "dim1"}, 

71 "alpha_forbidden_dims": {0: "dim0"}, 

72 "bias_forbidden_dims": {0: "dim0"}, 

73 }, 

74 3: { 

75 "op_name": "npu_mhc_pre_sinkhorn", 

76 "forbidden_dims": {1: "N"}, 

77 "phi_forbidden_dims": {0: "dim0", 1: "dim1"}, 

78 "alpha_forbidden_dims": {0: "dim0"}, 

79 "bias_forbidden_dims": {0: "dim0"}, 

80 }, 

81} 

82 

83 

84def _validate_tensor_map_dims( 

85 tensor_map: tuple, 

86 op_name: str, 

87 forbidden_dims: Dict[int, str], 

88) -> None: 

89 """Check that specified dimensions are not sharded (replicated). 

90 

91 Args: 

92 tensor_map: The tensor_map to check. 

93 op_name: Operator name for error message. 

94 forbidden_dims: Dict mapping dim index to dim name. 

95 

96 Raises: 

97 ValueError: If any forbidden dimension is sharded. 

98 """ 

99 for dim_idx, dim_name in forbidden_dims.items(): 

100 dim_value = tensor_map[dim_idx] 

101 if dim_value != -1: 

102 raise ValueError( 

103 f"For {op_name}, {dim_name} dimension (dim {dim_idx}) of x " 

104 f"should be replicated, but got {dim_value}" 

105 ) 

106 

107 

108class NpuMhcPreSinkhornDistributedOp(DistributedOp): 

109 """DistributedOp for npu_mhc_pre_sinkhorn operator. 

110 

111 Implements layout inference for the MHC pre-processing with Sinkhorn operation. 

112 Outputs 8 tensors: hin, h_post, h_res, h_pre, hc_before_norm, inv_rms, sum_out, norm_out. 

113 """ 

114 

115 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

116 norm_args, _ = _normalize_mhc_pre_sinkhorn_args(*args, **kwargs) 

117 dtensor_x = norm_args[0] 

118 dtensor_phi = norm_args[1] 

119 dtensor_alpha = norm_args[2] 

120 dtensor_bias = norm_args[3] 

121 

122 if platform.platform_type == PlatformType.MINDSPORE: 

123 local_args = ( 

124 dtensor_x.to_local(), 

125 dtensor_phi.to_local(), 

126 dtensor_alpha.to_local(), 

127 dtensor_bias.to_local(), 

128 norm_args[4], 

129 norm_args[5], 

130 norm_args[6], 

131 norm_args[7], 

132 norm_args[8], 

133 ) 

134 local_kwargs = {} 

135 else: 

136 local_args = ( 

137 dtensor_x.to_local(), 

138 dtensor_phi.to_local(), 

139 dtensor_alpha.to_local(), 

140 dtensor_bias.to_local(), 

141 ) 

142 local_kwargs = { 

143 'hc_mult': norm_args[4], 

144 'num_iters': norm_args[5], 

145 'hc_eps': norm_args[6], 

146 'norm_eps': norm_args[7], 

147 'out_flag': norm_args[8], 

148 } 

149 

150 cache_values = [ 

151 dtensor_x.layout, 

152 dtensor_phi.layout, 

153 dtensor_alpha.layout, 

154 dtensor_bias.layout, 

155 ] 

156 return local_args, local_kwargs, cache_values 

157 

158 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

159 x_layout, phi_layout, alpha_layout, bias_layout = cache_values 

160 

161 self._check_partial_inputs([x_layout, phi_layout, alpha_layout, bias_layout]) 

162 

163 self._validate_input_layouts_mhc_pre_sinkhorn( 

164 x_layout, phi_layout, alpha_layout, bias_layout 

165 ) 

166 

167 out_layouts = self._infer_output_layouts(x_layout) 

168 return out_layouts, None 

169 

170 @staticmethod 

171 def _validate_input_layouts_mhc_pre_sinkhorn( 

172 x_layout: Layout, 

173 phi_layout: Layout, 

174 alpha_layout: Layout, 

175 bias_layout: Layout, 

176 ) -> None: 

177 """Validate input layouts for npu_mhc_pre_sinkhorn operator.""" 

178 x_tm = x_layout.tensor_map 

179 x_tm_len = len(x_tm) 

180 

181 # Get validation rules from table based on tensor_map length 

182 rules = _MHC_PRE_SINKHORN_VALIDATION_RULES.get(x_tm_len) 

183 if rules is None: 

184 raise ValueError( 

185 f"For npu_mhc_pre_sinkhorn, tensor_map length should be 4 or 3, but got {x_tm_len}" 

186 ) 

187 

188 # Validate forbidden dimensions (N must be replicated) 

189 _validate_tensor_map_dims(x_tm, rules["op_name"], rules["forbidden_dims"]) 

190 

191 # Validate phi, alpha, bias must be fully replicated 

192 _validate_tensor_map_dims(phi_layout.tensor_map, rules["op_name"], rules["phi_forbidden_dims"]) 

193 _validate_tensor_map_dims(alpha_layout.tensor_map, rules["op_name"], rules["alpha_forbidden_dims"]) 

194 _validate_tensor_map_dims(bias_layout.tensor_map, rules["op_name"], rules["bias_forbidden_dims"]) 

195 

196 @staticmethod 

197 def _infer_output_layouts( 

198 x_layout: Layout, 

199 ) -> Tuple[Layout, Layout, Layout, Layout, Layout, Layout, Layout, Layout]: 

200 out_layout = Layout.from_device_mesh(x_layout.mesh) 

201 out_layout.set_tensor_map(x_layout.tensor_map) 

202 out_layout.tensor_map_to_placement() 

203 

204 return ( 

205 out_layout, out_layout, out_layout, out_layout, 

206 out_layout, out_layout, out_layout, out_layout, 

207 )