Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_matmul_reduce_scatter.py: 82%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for MatmulReduceScatter (MC2 fusion) operator.""" 

16import copy 

17from typing import Tuple 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23def _normalize_mrs_args( 

24 x1, 

25 x2, 

26 group, 

27 world_size, 

28 reduce_op='sum', 

29 bias=None, 

30 comm_turn=0, 

31 trans_input=False, 

32 trans_x2=False, 

33) -> Tuple[tuple, dict]: 

34 """Normalize positional and keyword arguments into a canonical positional tuple. 

35 

36 Args: 

37 x1: Input tensor. Physical shape (m, k); trans_input=False only (MS constraint). 

38 x2: Weight tensor with shape (k, n) if trans_x2=False, or (n, k) if trans_x2=True. 

39 group: Communication group name string. 

40 world_size: Number of ranks in the communication group. 

41 reduce_op: Reduce operation string (only 'sum' is supported). 

42 bias: Must be None (MS constraint; bias is not supported). 

43 comm_turn: Communication turn (only 0 is supported). 

44 trans_input: Must be False (MS constraint). 

45 trans_x2: If True, x2 physical shape is (n, k); MindSpore transposes x2 before CANN call. 

46 

47 Returns: 

48 tuple: (positional_args_tuple, empty_kwargs_dict) 

49 """ 

50 return (x1, x2, group, world_size, reduce_op, bias, comm_turn, trans_input, trans_x2), {} 

51 

52 

53class MatmulReduceScatterDistributedOp(DistributedOp): 

54 """Distributed operator for mindspore.ops.matmul_reduce_scatter (MC2 fusion). 

55 

56 The CANN MatmulReduceScatter kernel handles communication (ReduceScatter) 

57 internally. HyperParallel's role is to: 

58 1. Extract local tensors from DTensor inputs via to_local(). 

59 2. Infer output DTensor layouts so downstream operators can correctly 

60 understand the distribution state of the output. 

61 

62 No Partial state is needed because CANN internally completes the ReduceScatter; 

63 each rank receives the correctly reduced and scattered result slice. 

64 

65 Shape transformation (logical, after any transpose): 

66 x1 (m, k_local), x2 (k_local, n) —[matmul]→ partial (m, n) 

67 —[CANN ReduceScatter on m]→ output (m_local = m / world_size, n) 

68 

69 Sharding constraints: 

70 - x1 physical (m, k): dim 1 (k) must be Shard (TP/comm axis); 

71 dim 0 (m) may be Replicate or Shard (DP axis). 

72 - x2's k-dim must be Shard and match x1's k-dim placement exactly. 

73 - trans_x2=False: x2 k-dim is dim 0; trans_x2=True: x2 k-dim is dim 1. 

74 - trans_input=False, bias=None only (current MS constraints). 

75 - Partial inputs are not allowed. 

76 """ 

77 

78 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

79 """Extract local tensors and build the layout cache. 

80 

81 Args: 

82 args: Positional arguments (DTensors for x1 and x2). 

83 kwargs: Keyword arguments. 

84 

85 Returns: 

86 tuple: (local_args, local_kwargs, cache_values) where 

87 local_args contains extracted local tensors, 

88 local_kwargs contains keyword arguments, 

89 cache_values = [x1_layout, x2_layout, trans_x2]. 

90 """ 

91 norm_args, _ = _normalize_mrs_args(*args, **kwargs) 

92 x1 = norm_args[0] 

93 x2 = norm_args[1] 

94 trans_x2 = norm_args[8] 

95 

96 # MindSpore matmul_reduce_scatter: only (input, x2, group, world_size) are 

97 # positional; reduce_op, bias, comm_turn, trans_input, trans_x2 are 

98 # keyword-only (after the '*' separator). 

99 local_args = ( 

100 x1.to_local(), 

101 x2.to_local(), 

102 norm_args[2], # group 

103 norm_args[3], # world_size 

104 ) 

105 local_kwargs = { 

106 'reduce_op': norm_args[4], 

107 'bias': norm_args[5], 

108 'comm_turn': norm_args[6], 

109 'trans_input': norm_args[7], 

110 'trans_x2': trans_x2, 

111 } 

112 

113 cache_values = [ 

114 x1.layout, 

115 x2.layout, 

116 trans_x2, 

117 ] 

118 return local_args, local_kwargs, cache_values 

119 

120 @staticmethod 

121 def _validate_input_layouts( 

122 x1_layout: Layout, 

123 x2_layout: Layout, 

124 trans_x2: bool, 

125 ) -> None: 

126 """Validate sharding constraints for MatmulReduceScatter inputs. 

127 

128 Args: 

129 x1_layout: Layout of x1. Physical (m, k); trans_input=False only. 

130 x2_layout: Layout of x2 (k, n) if trans_x2=False, or (n, k) if trans_x2=True. 

131 trans_x2: Whether x2 is transposed. 

132 

133 Raises: 

134 ValueError: If x1's k dimension is Replicate, x2's k dimension layout does not 

135 match x1's k dimension layout, or any input has Partial status. 

136 """ 

137 op = "matmul_reduce_scatter" 

138 x1_tm = x1_layout.tensor_map 

139 x2_tm = x2_layout.tensor_map 

140 

141 # trans_input=False only: x1 physical (m, k) — k is dim 1. 

142 x1_k_dim = 1 

143 

144 if x1_tm[x1_k_dim] == -1: 

145 raise ValueError( 

146 f"For {op}, x1 k-dim (dim {x1_k_dim}) must be " 

147 f"Shard (not Replicate), because ReduceScatter requires k to be sharded. " 

148 f"Got tensor_map={x1_tm}" 

149 ) 

150 

151 x2_k_dim = 1 if trans_x2 else 0 

152 if x2_tm[x2_k_dim] != x1_tm[x1_k_dim]: 

153 raise ValueError( 

154 f"For {op}, x2 dim {x2_k_dim} (k) layout must match x1 k-dim (dim {x1_k_dim}) " 

155 f"layout (trans_x2={trans_x2}), " 

156 f"but got x1_k={x1_tm[x1_k_dim]}, x2_k={x2_tm[x2_k_dim]}" 

157 ) 

158 

159 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

160 """Infer output layout for MatmulReduceScatter. 

161 

162 ReduceScatter converts the k-dim sharding (TP comm axis) into m-dim sharding. 

163 trans_input=False only: x1 physical (m, k). 

164 

165 comm_mesh_dim = x1_tm[1] (k → TP / ReduceScatter scatter axis) 

166 dp_mesh_dim = x1_tm[0] (m → DP, or -1 if Replicate) 

167 

168 If dp_mesh_dim == -1 (pure TP): 

169 output_tm[0] = comm_mesh_dim 

170 Else (TP + DP joint sharding): 

171 output_tm[0] = (comm_mesh_dim, dp_mesh_dim) 

172 

173 output_tm[1] = x2's n dimension placement 

174 - trans_x2=False: x2 dim 1 (n) → x2_tm[1] 

175 - trans_x2=True: x2 dim 0 (n) → x2_tm[0] 

176 

177 Args: 

178 cache_values: [x1_layout, x2_layout, trans_x2] 

179 

180 Returns: 

181 tuple: (output_layout, None) 

182 

183 Raises: 

184 ValueError: If any input has Partial status or sharding constraints are violated. 

185 """ 

186 x1_layout = cache_values[0] 

187 x2_layout = cache_values[1] 

188 trans_x2 = cache_values[2] 

189 

190 self._check_partial_inputs([x1_layout, x2_layout]) 

191 self._validate_input_layouts(x1_layout, x2_layout, trans_x2) 

192 

193 x1_tm = x1_layout.tensor_map 

194 x2_tm = x2_layout.tensor_map 

195 

196 # trans_input=False: k is dim 1 (comm axis), m is dim 0 (DP axis or Replicate). 

197 comm_mesh_dim = x1_tm[1] 

198 dp_mesh_dim = x1_tm[0] 

199 

200 if dp_mesh_dim == -1: 

201 output_m = comm_mesh_dim 

202 else: 

203 output_m = (comm_mesh_dim, dp_mesh_dim) 

204 

205 n_placement = x2_tm[0] if trans_x2 else x2_tm[1] 

206 

207 output_layout = Layout.from_device_mesh(x1_layout.mesh) 

208 output_layout.set_tensor_map((output_m, n_placement)) 

209 output_layout.tensor_map_to_placement() 

210 

211 return copy.deepcopy(output_layout), None