Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_mhc_post.py: 98%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for npu_mhc_post operator.""" 

16from typing import Tuple, Dict, Any 

17 

18from hyper_parallel.core.dtensor.layout import Layout 

19from .parallel_ops import DistributedOp 

20 

21 

22def _normalize_mhc_post_args( 

23 x, 

24 h_res, 

25 h_out, 

26 h_post): 

27 """Normalize positional and keyword arguments into a canonical positional tuple. 

28 

29 Args: 

30 x: Input tensor [B,S,N,D] or [T,N,D]. 

31 h_res: mHC h_res transformation matrix [B,S,N,N] or [T,N,N]. 

32 h_out: Attention/MLP output [B,S,D] or [T,D]. 

33 h_post: mHC h_post transformation matrix [B,S,N] or [T,N]. 

34 

35 Returns: 

36 tuple: (positional_args_tuple, empty_kwargs_dict) 

37 """ 

38 return (x, h_res, h_out, h_post), {} 

39 

40 

41# Validation rules table for npu_mhc_post 

42# Key: tensor_map length (format identifier) 

43# Value: validation rules for that format 

44_MHC_POST_VALIDATION_RULES: Dict[int, Dict[str, Any]] = { 

45 4: { 

46 "op_name": "npu_mhc_post", 

47 "forbidden_dims": {2: "N", 3: "D"}, 

48 "dim_requirements": { 

49 "h_res": 4, 

50 "h_out": 3, 

51 "h_post": 3, 

52 }, 

53 }, 

54 3: { 

55 "op_name": "npu_mhc_post", 

56 "forbidden_dims": {1: "N", 2: "D"}, 

57 "dim_requirements": { 

58 "h_res": 3, 

59 "h_out": 2, 

60 "h_post": 2, 

61 }, 

62 }, 

63} 

64 

65 

66def _validate_tensor_map_dims( 

67 tensor_map: tuple, 

68 op_name: str, 

69 forbidden_dims: Dict[int, str], 

70) -> None: 

71 """Check that specified dimensions are not sharded (replicated). 

72 

73 Args: 

74 tensor_map: The tensor_map to check. 

75 op_name: Operator name for error message. 

76 forbidden_dims: Dict mapping dim index to dim name. 

77 

78 Raises: 

79 ValueError: If any forbidden dimension is sharded. 

80 """ 

81 for dim_idx, dim_name in forbidden_dims.items(): 

82 dim_value = tensor_map[dim_idx] 

83 if dim_value != -1: 

84 raise ValueError( 

85 f"For {op_name}, {dim_name} dimension (dim {dim_idx}) of x " 

86 f"should be replicated, but got {dim_value}" 

87 ) 

88 

89 

90def _validate_tensor_dimensions( 

91 actual_dims: Dict[str, int], 

92 required_dims: Dict[str, int], 

93 op_name: str, 

94) -> None: 

95 """Validate that each input tensor has the expected number of dimensions. 

96 

97 Args: 

98 actual_dims: Dict mapping input name to actual dimension count. 

99 required_dims: Dict mapping input name to expected dimension count. 

100 op_name: Operator name for error message. 

101 

102 Raises: 

103 ValueError: If any input has wrong dimension count. 

104 """ 

105 for input_name, required_dim in required_dims.items(): 

106 actual_dim = actual_dims.get(input_name) 

107 if actual_dim != required_dim: 

108 raise ValueError( 

109 f"For {op_name}, {input_name} tensor should have {required_dim} dimensions, " 

110 f"but got {actual_dim}" 

111 ) 

112 

113 

114class NpuMhcPostDistributedOp(DistributedOp): 

115 """DistributedOp for npu_mhc_post operator. 

116 

117 Implements layout inference for the MHC post-processing operation: 

118 x_{l+1} = (H_l^res)^T × x_l + h_l^out ⊗ H_t^post 

119 """ 

120 

121 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

122 norm_args, _ = _normalize_mhc_post_args(*args, **kwargs) 

123 dtensor_x, dtensor_h_res, dtensor_h_out, dtensor_h_post = ( 

124 norm_args[0], norm_args[1], norm_args[2], norm_args[3] 

125 ) 

126 

127 local_args = ( 

128 dtensor_x.to_local(), 

129 dtensor_h_res.to_local(), 

130 dtensor_h_out.to_local(), 

131 dtensor_h_post.to_local(), 

132 ) 

133 local_kwargs = {} 

134 

135 cache_values = [ 

136 dtensor_x.layout, 

137 dtensor_h_res.layout, 

138 dtensor_h_out.layout, 

139 dtensor_h_post.layout, 

140 ] 

141 return local_args, local_kwargs, cache_values 

142 

143 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

144 x_layout, h_res_layout, h_out_layout, h_post_layout = cache_values 

145 

146 self._check_partial_inputs([x_layout, h_res_layout, h_out_layout, h_post_layout]) 

147 

148 self._validate_input_layouts_mhc_post( 

149 x_layout, h_res_layout, h_out_layout, h_post_layout 

150 ) 

151 

152 out_layout = self._infer_output_layout(x_layout) 

153 return (out_layout,), None 

154 

155 @staticmethod 

156 def _validate_input_layouts_mhc_post( 

157 x_layout: Layout, 

158 h_res_layout: Layout, 

159 h_out_layout: Layout, 

160 h_post_layout: Layout, 

161 ) -> None: 

162 """Validate input layouts for npu_mhc_post operator.""" 

163 x_tm = x_layout.tensor_map 

164 x_tm_len = len(x_tm) 

165 

166 # Get validation rules from table based on tensor_map length 

167 rules = _MHC_POST_VALIDATION_RULES.get(x_tm_len) 

168 if rules is None: 

169 raise ValueError( 

170 f"For npu_mhc_post, tensor_map length should be 4 or 3, but got {x_tm_len}" 

171 ) 

172 

173 # Validate forbidden dimensions (N, D must be replicated) 

174 _validate_tensor_map_dims(x_tm, rules["op_name"], rules["forbidden_dims"]) 

175 

176 # Validate dimension counts for each input 

177 actual_dims = { 

178 "h_res": len(h_res_layout.tensor_map), 

179 "h_out": len(h_out_layout.tensor_map), 

180 "h_post": len(h_post_layout.tensor_map), 

181 } 

182 _validate_tensor_dimensions(actual_dims, rules["dim_requirements"], rules["op_name"]) 

183 

184 @staticmethod 

185 def _infer_output_layout(x_layout: Layout) -> Layout: 

186 out_layout = Layout.from_device_mesh(x_layout.mesh) 

187 out_layout.set_tensor_map(x_layout.tensor_map) 

188 out_layout.tensor_map_to_placement() 

189 return out_layout