Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_reshape.py: 82%

166 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Reshape operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from hyper_parallel.platform import get_platform 

21from .parallel_ops import DistributedOp 

22platform = get_platform() 

23Tensor = platform.Tensor 

24 

25 

26def _filter_none_split_tensor_map(tensor_map, mesh_shape): 

27 """ 

28 Filter out the elements in tensor_map where the size of the corresponding dimension in device_matrix is 1. 

29 

30 Args: 

31 tensor_map (list): A list of tensor mappings, which may contain integers or tuples. 

32 device_matrix (list): A device matrix representing the device distribution across each dimension. 

33 

34 Returns: 

35 list: The filtered list of tensor mappings, where invalid mappings are replaced with -1 or valid mappings are 

36 retained. 

37 """ 

38 filtered_tensor_map = [] 

39 for item in tensor_map: 

40 if isinstance(item, tuple): 

41 filtered = [] 

42 for i in item: 

43 if mesh_shape[-1 - i] != 1: 

44 filtered.append(i) 

45 if len(filtered) == 0: 

46 filtered_tensor_map.append(-1) 

47 elif len(filtered) == 1: 

48 filtered_tensor_map.append(filtered[0]) 

49 else: 

50 filtered_tensor_map.append(tuple(filtered)) 

51 else: 

52 filtered_tensor_map.append(item if mesh_shape[-1 - item] != 1 else -1) 

53 return filtered_tensor_map 

54 

55 

56class ReshapeDistributedOp(DistributedOp): 

57 """Distributed implementation for Reshape operator.""" 

58 

59 def __init__(self, op_name): 

60 super().__init__(op_name) 

61 self._allow_partial_inputs = True 

62 

63 def _get_dynamic_shape_info(self, shape): 

64 total_size = 1 

65 dynamic_axis = -1 

66 for axis, s in enumerate(shape): 

67 total_size *= s 

68 if s < 0: 

69 dynamic_axis = axis 

70 return total_size < 0, dynamic_axis, total_size 

71 

72 def _handle_dynamic_shape(self, input_shape, output_shape): 

73 """ 

74 Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown, 

75 calculate the relative multiple. 

76 [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8] 

77 """ 

78 input_shape = list(input_shape) 

79 output_shape = list(output_shape) 

80 is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape) 

81 is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape) 

82 dynamic_can_shard = False 

83 if not is_input_dynamic and not is_output_dynamic: 

84 if input_total_size != output_total_size: 

85 raise ValueError(f"The total elements number of input shape {input_shape} and output shape " 

86 f"{output_shape} are different.") 

87 return input_shape, output_shape, dynamic_can_shard 

88 

89 if not is_input_dynamic: 

90 accurate_output_shape = output_shape 

91 accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size 

92 return input_shape, accurate_output_shape, dynamic_can_shard 

93 

94 if not is_output_dynamic: 

95 accurate_input_shape = input_shape 

96 accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size 

97 return accurate_input_shape, output_shape, dynamic_can_shard 

98 

99 if output_total_size >= input_total_size: 

100 output_shape[output_dynamic_axis] = -(input_total_size // output_total_size) 

101 dynamic_can_shard = True 

102 else: 

103 input_shape[input_dynamic_axis] = -(output_total_size // input_total_size) 

104 return input_shape, output_shape, dynamic_can_shard 

105 

106 def _merge_unshared_axis(self, global_shape, tensor_map): 

107 """ 

108 Merge those axes that are not sharded to the high dimension which is shared. 

109 shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48] 

110 

111 Returns: 

112 tuple: (merged_shape, merge_tensor_map). 

113 merge_tensor_map may contain -1 for merged unsharded axis groups. 

114 """ 

115 merged_size = 1 

116 merged_shape = [] 

117 merged_tensor_map = [] 

118 for axis in range(len(global_shape) - 1, -1, -1): 

119 merged_size *= global_shape[axis] 

120 if tensor_map[axis] != -1: 

121 merged_shape.insert(0, merged_size) 

122 merged_tensor_map.insert(0, tensor_map[axis]) 

123 merged_size = 1 

124 if tensor_map[0] == -1: 

125 merged_shape.insert(0, merged_size) 

126 merged_tensor_map.insert(0, -1) 

127 return merged_shape, merged_tensor_map 

128 

129 

130 def _cal_output_layout_and_dst_shape(self, output_tensor_map, dst_shape, x_dict): 

131 """ 

132 calculate output layout tensor map and local dst shape. 

133 """ 

134 x_mesh_shape = x_dict["mesh_shape"] 

135 output_map = [] 

136 local_dst_shape = [] 

137 for idx, map_id in enumerate(output_tensor_map): 

138 if isinstance(map_id, tuple): 

139 shard_size = 1 

140 map_idx = [] 

141 for shard_id in map_id: 

142 map_idx.append(x_dict["alias_name"][-1 - shard_id]) 

143 shard_size *= x_mesh_shape[-1 - shard_id] 

144 output_map.append(tuple(map_idx)) 

145 local_dst_shape.append(dst_shape[idx] // shard_size if dst_shape[idx] > 0 else -1) 

146 continue 

147 if map_id < 0: 

148 output_map.append("None") 

149 local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1) 

150 else: 

151 output_map.append(x_dict["alias_name"][-1 - map_id]) 

152 local_dst_shape.append(dst_shape[idx] // x_mesh_shape[-1 - map_id] if dst_shape[idx] > 0 else -1) 

153 return output_map, local_dst_shape 

154 

155 def _parse_shape_args(self, extra_args): 

156 """Parse shape arguments from extra_args. 

157 

158 Args: 

159 extra_args: Extra arguments containing shape info 

160 

161 Returns: 

162 tuple: (dst_shape, input_shape) 

163 """ 

164 if self.op_name in ["reshape", "view"]: 

165 return self._parse_torch_shape_args(extra_args) 

166 return self._parse_mindspore_shape_args(extra_args) 

167 

168 def _parse_torch_shape_args(self, extra_args): 

169 """Parse PyTorch style shape arguments.""" 

170 if len(extra_args) < 2: 

171 raise ValueError(f"{self.op_name} requires output shape and input shape.") 

172 

173 input_shape = extra_args[-1] 

174 shape_args = extra_args[:-1] 

175 

176 if len(shape_args) == 1: 

177 first_arg = shape_args[0] 

178 if isinstance(first_arg, (list, tuple)): 

179 dst_shape = first_arg 

180 elif isinstance(first_arg, Tensor): 

181 dst_shape = first_arg.tolist() 

182 else: 

183 dst_shape = shape_args 

184 else: 

185 dst_shape = shape_args 

186 

187 return dst_shape, input_shape 

188 

189 def _parse_mindspore_shape_args(self, extra_args): 

190 """Parse MindSpore style shape arguments.""" 

191 if len(extra_args) != 2: 

192 raise ValueError("Reshape requires output shape and input shape.") 

193 

194 return extra_args[0], extra_args[1] 

195 

196 def _normalize_shape(self, dst_shape): 

197 """Normalize dst_shape to list format.""" 

198 if isinstance(dst_shape, Tensor): 

199 dst_shape = dst_shape.tolist() 

200 if not isinstance(dst_shape, (list, tuple)): 

201 raise ValueError("Shape should be a tensor or a tuple or a list.") 

202 return dst_shape 

203 

204 def _compute_output_tensor_map(self, merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard, 

205 input_shape, x_map): 

206 """Compute output tensor_map from merged information. 

207 

208 Args: 

209 merged_shape: Merged shape from _merge_unshared_axis 

210 merge_tensor_map: Merged tensor_map from _merge_unshared_axis 

211 dst_shape: Target shape 

212 x_mesh_shape: Mesh shape 

213 dynamic_can_shard: Whether dynamic shape can be sharded 

214 input_shape: Original input shape 

215 x_map: Input tensor_map 

216 

217 Returns: 

218 list: Output tensor_map 

219 """ 

220 output_tensor_map = [] 

221 cur_axis = len(merged_shape) - 1 

222 cur_size = merged_shape[cur_axis] 

223 

224 for shape in reversed(dst_shape): 

225 if cur_size % shape != 0: 

226 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

227 cur_size = cur_size // shape 

228 

229 if cur_size == 1: 

230 map_val = merge_tensor_map[cur_axis] 

231 if map_val != -1: 

232 self._validate_reshape_shard( 

233 map_val, x_mesh_shape, shape, 

234 dynamic_can_shard, input_shape, x_map, dst_shape 

235 ) 

236 output_tensor_map.insert(0, map_val) 

237 cur_axis -= 1 

238 cur_size = merged_shape[cur_axis] 

239 else: 

240 output_tensor_map.insert(0, -1) 

241 

242 return output_tensor_map 

243 

244 def _validate_reshape_shard(self, map_val, x_mesh_shape, shape, 

245 dynamic_can_shard, input_shape, x_map, dst_shape): 

246 """Validate that a sharded axis can be reshaped to the target shape dimension.""" 

247 if isinstance(map_val, tuple): 

248 shard_size = 1 

249 for axis in map_val: 

250 shard_size *= x_mesh_shape[-axis - 1] 

251 else: 

252 shard_size = x_mesh_shape[-map_val - 1] 

253 

254 if shape < 0: 

255 if not dynamic_can_shard: 

256 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

257 elif shard_size > shape or shape % shard_size != 0: 

258 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

259 

260 def _apply_partial_status(self, x_layout, out_layout): 

261 """Apply partial status from input to output layout.""" 

262 if x_layout.is_partial(): 

263 input_partial = x_layout.partial 

264 for i, partial_op in enumerate(input_partial): 

265 if partial_op is not None and i < len(out_layout.alias_name): 

266 out_layout.set_partial_by_dev_axis(out_layout.alias_name[i], partial_op) 

267 

268 def infer_layout(self, layouts, extra_args=None): 

269 """ 

270 Infer output layout for reshape operator. 

271 

272 For reshape operations, data slice on each device after reshape should be same as data slice before reshape. 

273 

274 Args: 

275 layouts (Layout): Layout of input x 

276 extra_args: 

277 For MindSpore Reshape: (destination shape, original shape) 

278 For PyTorch reshape/view: (shape_arg1, shape_arg2, ..., original shape) or (shape_tuple, original shape) 

279 

280 Returns: 

281 tuple: Layout for output tensor 

282 """ 

283 x_layout = layouts[0] 

284 x_dict = x_layout.to_dict() 

285 

286 dst_shape, input_shape = self._parse_shape_args(extra_args) 

287 dst_shape = self._normalize_shape(dst_shape) 

288 

289 x_map = _filter_none_split_tensor_map(x_dict["tensor_map"], x_dict["mesh_shape"]) 

290 x_mesh_shape = x_dict["mesh_shape"] 

291 

292 input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape) 

293 merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map) 

294 

295 output_tensor_map = self._compute_output_tensor_map( 

296 merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard, input_shape, x_map 

297 ) 

298 

299 output_layout = Layout( 

300 mesh_shape=x_mesh_shape, 

301 alias_name=x_layout.alias_name, 

302 rank_list=x_layout.rank_list 

303 ) 

304 output_map, local_dst_shape = self._cal_output_layout_and_dst_shape(output_tensor_map, dst_shape, x_dict) 

305 out_layout = output_layout(*output_map) 

306 

307 self._apply_partial_status(x_layout, out_layout) 

308 

309 return out_layout, local_dst_shape