Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_all_gather_matmul.py: 81%

57 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for AllGatherMatmul (MC2 fusion) operator.""" 

16import copy 

17from typing import Tuple 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23def _normalize_agm_args( 

24 x1, 

25 x2, 

26 group, 

27 world_size, 

28 bias=None, 

29 gather_index=0, 

30 gather_output=True, 

31 comm_turn=0, 

32 trans_input=False, 

33 trans_x2=False, 

34) -> Tuple[tuple, dict]: 

35 """Normalize positional and keyword arguments into a canonical positional tuple. 

36 

37 Args: 

38 x1: Input tensor. Physical shape (m, k); trans_input=False only (MS constraint). 

39 x2: Weight tensor with shape (k, n) if trans_x2=False, or (n, k) if trans_x2=True. 

40 group: Communication group name string. 

41 world_size: Number of ranks in the communication group. 

42 bias: Must be None (MS constraint; bias is not supported). 

43 gather_index: Index for gather operation (only 0 is supported). 

44 gather_output: Whether to return the gathered intermediate tensor. 

45 comm_turn: Communication turn (only 0 is supported). 

46 trans_input: Must be False (MS constraint). 

47 trans_x2: If True, x2 physical shape is (n, k); MindSpore transposes x2 before CANN call. 

48 

49 Returns: 

50 tuple: (positional_args_tuple, empty_kwargs_dict) 

51 """ 

52 return (x1, x2, group, world_size, bias, gather_index, gather_output, comm_turn, trans_input, trans_x2), {} 

53 

54 

55class AllGatherMatmulDistributedOp(DistributedOp): 

56 """Distributed operator for mindspore.ops.all_gather_matmul (MC2 fusion). 

57 

58 The CANN AllGatherMatmul kernel handles communication (AllGather) internally. 

59 HyperParallel's role is to: 

60 1. Extract local tensors from DTensor inputs via to_local(). 

61 2. Infer output DTensor layouts so downstream operators can correctly 

62 understand the distribution state of the output. 

63 

64 Shape transformation (logical, after any transpose): 

65 x1 (m_local, k_local) —[CANN AllGather on m]→ (m_global, k_local) —[matmul]→ output (m_global, n_local) 

66 gather_out (m_global, k_local) — valid only when gather_output=True 

67 

68 Sharding constraints: 

69 - x1 physical (m, k): dim 0 is m (AllGather consumes m); dim 1 (k) may be Replicate or Shard. 

70 - x1's m-dim tensor_map must not be a tuple (joint sharding across multiple mesh dims unsupported). 

71 - x1 k-dim and x2 k-dim must share the same placement (both Replicate, or both sharded on the 

72 same mesh axis). trans_x2=False: x2 k is dim 0; trans_x2=True: x2 k is dim 1. 

73 - When k is sharded, output carries Partial(sum) status on the k-dim mesh axis; the caller is 

74 responsible for applying AllReduce to obtain the correct full result. 

75 - gather_out k-dim follows x1's k-dim placement. 

76 - Partial inputs are not allowed. 

77 - gather_index=0, trans_input=False, bias=None only (current MS constraints). 

78 """ 

79 

80 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

81 """Extract local tensors and build the layout cache. 

82 

83 Args: 

84 args: Positional arguments (DTensors for x1 and x2). 

85 kwargs: Keyword arguments. 

86 

87 Returns: 

88 tuple: (local_args, local_kwargs, cache_values) where 

89 local_args contains extracted local tensors, 

90 local_kwargs contains keyword arguments, 

91 cache_values = [x1_layout, x2_layout, trans_x2, gather_output]. 

92 """ 

93 norm_args, _ = _normalize_agm_args(*args, **kwargs) 

94 x1 = norm_args[0] 

95 x2 = norm_args[1] 

96 gather_output = norm_args[6] 

97 trans_x2 = norm_args[9] 

98 

99 # MindSpore all_gather_matmul: only (input, x2, group, world_size) are 

100 # positional; bias, gather_index, gather_output, comm_turn, trans_input, 

101 # trans_x2 are keyword-only (after the '*' separator). 

102 local_args = ( 

103 x1.to_local(), 

104 x2.to_local(), 

105 norm_args[2], # group 

106 norm_args[3], # world_size 

107 ) 

108 local_kwargs = { 

109 'bias': norm_args[4], 

110 'gather_index': norm_args[5], 

111 'gather_output': gather_output, 

112 'comm_turn': norm_args[7], 

113 'trans_input': norm_args[8], 

114 'trans_x2': trans_x2, 

115 } 

116 

117 cache_values = [ 

118 x1.layout, 

119 x2.layout, 

120 trans_x2, 

121 gather_output, 

122 ] 

123 return local_args, local_kwargs, cache_values 

124 

125 @staticmethod 

126 def _set_partial_from_k(output_layout: Layout, k_placement, op: str = 'sum') -> None: 

127 """Set Partial on output_layout for the mesh axes corresponding to k_placement. 

128 

129 Args: 

130 output_layout: Layout to mark as Partial. 

131 k_placement: Tensor_map value for the k dimension (integer or tuple of integers). 

132 op: Reduction operation, default 'sum'. 

133 """ 

134 alias = output_layout.alias_name 

135 n = len(alias) 

136 if isinstance(k_placement, tuple): 

137 for v in k_placement: 

138 output_layout.set_partial_by_dev_axis(alias[n - 1 - v], op) 

139 else: 

140 output_layout.set_partial_by_dev_axis(alias[n - 1 - k_placement], op) 

141 

142 @staticmethod 

143 def _validate_input_layouts( 

144 x1_layout: Layout, 

145 x2_layout: Layout, 

146 trans_x2: bool, 

147 ) -> None: 

148 """Validate sharding constraints for AllGatherMatmul inputs. 

149 

150 Args: 

151 x1_layout: Layout of x1. Physical (m, k); trans_input=False only. 

152 x2_layout: Layout of x2 (k, n) if trans_x2=False, or (n, k) if trans_x2=True. 

153 trans_x2: Whether x2 is transposed. 

154 

155 Raises: 

156 ValueError: If x1's m-dim tensor_map is a tuple, the k-dim placements of x1 and x2 

157 do not match, or any input has Partial status. 

158 """ 

159 op = "all_gather_matmul" 

160 x1_tm = x1_layout.tensor_map 

161 x2_tm = x2_layout.tensor_map 

162 

163 # trans_input=False only: x1 physical (m, k) — k is dim 1, m is dim 0. 

164 x1_m_dim = 0 

165 

166 if isinstance(x1_tm[x1_m_dim], tuple): 

167 raise ValueError( 

168 f"For {op}, x1 m-dim (dim {x1_m_dim}) " 

169 f"with tensor_map={x1_tm[x1_m_dim]} is jointly sharded across multiple " 

170 f"mesh dims, which is not supported in this version." 

171 ) 

172 

173 # k-dim placement must match between x1 and x2. 

174 x2_k_dim = 1 if trans_x2 else 0 

175 if x1_tm[1] != x2_tm[x2_k_dim]: 

176 raise ValueError( 

177 f"For {op}, x1 k-dim (dim 1) placement {x1_tm[1]} must match " 

178 f"x2 k-dim (dim {x2_k_dim}) placement {x2_tm[x2_k_dim]} " 

179 f"(trans_x2={trans_x2})." 

180 ) 

181 

182 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

183 """Infer output layouts for (output, gather_out). 

184 

185 AllGather on m dim: output dim 0 is always -1 (Replicate), because 

186 AllGather unconditionally makes the m dimension global. 

187 

188 n dim: follows x2's n placement. 

189 - trans_x2=False: n is x2 dim 1 → output_tm[1] = x2_tm[1] 

190 - trans_x2=True: n is x2 dim 0 → output_tm[1] = x2_tm[0] 

191 

192 k dim (contraction): when k is sharded, output carries Partial(sum) on the 

193 k-dim mesh axis; the caller must apply AllReduce to get the correct result. 

194 

195 gather_out layout: m is Replicate (-1); k follows x1's k-dim placement. 

196 

197 Args: 

198 cache_values: [x1_layout, x2_layout, trans_x2, gather_output] 

199 

200 Returns: 

201 tuple: ((output_layout, gather_out_layout), None) 

202 

203 Raises: 

204 ValueError: If any input has Partial status or sharding constraints are violated. 

205 """ 

206 x1_layout = cache_values[0] 

207 x2_layout = cache_values[1] 

208 trans_x2 = cache_values[2] 

209 gather_output = cache_values[3] 

210 

211 self._check_partial_inputs([x1_layout, x2_layout]) 

212 self._validate_input_layouts(x1_layout, x2_layout, trans_x2) 

213 

214 x1_tm = x1_layout.tensor_map 

215 x2_tm = x2_layout.tensor_map 

216 k_placement = x1_tm[1] 

217 n_placement = x2_tm[0] if trans_x2 else x2_tm[1] 

218 

219 # output: m is Replicate (-1) because AllGather consumed the m sharding; 

220 # n inherits from x2's n dim. 

221 output_layout = Layout.from_device_mesh(x1_layout.mesh) 

222 output_layout.set_tensor_map((-1, n_placement)) 

223 output_layout.tensor_map_to_placement() 

224 

225 # When k is sharded, output is a partial sum; mark Partial so the framework 

226 # can insert AllReduce downstream. 

227 if k_placement != -1: 

228 self._set_partial_from_k(output_layout, k_placement) 

229 

230 # gather_out: gather_output=True → m Replicate (-1), k follows x1's k placement. 

231 # gather_output=False → CANN returns a 1-D empty tensor; force all-Replicate so 

232 # the layout is compatible with any tensor rank returned by the kernel. 

233 gather_out_layout = Layout.from_device_mesh(x1_layout.mesh) 

234 gather_k = k_placement if gather_output else -1 

235 gather_out_layout.set_tensor_map((-1, gather_k)) 

236 gather_out_layout.tensor_map_to_placement() 

237 

238 return (copy.deepcopy(output_layout), copy.deepcopy(gather_out_layout)), None