Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_reshape.py: 82%
166 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Reshape operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from hyper_parallel.platform import get_platform
21from .parallel_ops import DistributedOp
22platform = get_platform()
23Tensor = platform.Tensor
26def _filter_none_split_tensor_map(tensor_map, mesh_shape):
27 """
28 Filter out the elements in tensor_map where the size of the corresponding dimension in device_matrix is 1.
30 Args:
31 tensor_map (list): A list of tensor mappings, which may contain integers or tuples.
32 device_matrix (list): A device matrix representing the device distribution across each dimension.
34 Returns:
35 list: The filtered list of tensor mappings, where invalid mappings are replaced with -1 or valid mappings are
36 retained.
37 """
38 filtered_tensor_map = []
39 for item in tensor_map:
40 if isinstance(item, tuple):
41 filtered = []
42 for i in item:
43 if mesh_shape[-1 - i] != 1:
44 filtered.append(i)
45 if len(filtered) == 0:
46 filtered_tensor_map.append(-1)
47 elif len(filtered) == 1:
48 filtered_tensor_map.append(filtered[0])
49 else:
50 filtered_tensor_map.append(tuple(filtered))
51 else:
52 filtered_tensor_map.append(item if mesh_shape[-1 - item] != 1 else -1)
53 return filtered_tensor_map
56class ReshapeDistributedOp(DistributedOp):
57 """Distributed implementation for Reshape operator."""
59 def __init__(self, op_name):
60 super().__init__(op_name)
61 self._allow_partial_inputs = True
63 def _get_dynamic_shape_info(self, shape):
64 total_size = 1
65 dynamic_axis = -1
66 for axis, s in enumerate(shape):
67 total_size *= s
68 if s < 0:
69 dynamic_axis = axis
70 return total_size < 0, dynamic_axis, total_size
72 def _handle_dynamic_shape(self, input_shape, output_shape):
73 """
74 Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown,
75 calculate the relative multiple.
76 [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8]
77 """
78 input_shape = list(input_shape)
79 output_shape = list(output_shape)
80 is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape)
81 is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape)
82 dynamic_can_shard = False
83 if not is_input_dynamic and not is_output_dynamic:
84 if input_total_size != output_total_size:
85 raise ValueError(f"The total elements number of input shape {input_shape} and output shape "
86 f"{output_shape} are different.")
87 return input_shape, output_shape, dynamic_can_shard
89 if not is_input_dynamic:
90 accurate_output_shape = output_shape
91 accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size
92 return input_shape, accurate_output_shape, dynamic_can_shard
94 if not is_output_dynamic:
95 accurate_input_shape = input_shape
96 accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size
97 return accurate_input_shape, output_shape, dynamic_can_shard
99 if output_total_size >= input_total_size:
100 output_shape[output_dynamic_axis] = -(input_total_size // output_total_size)
101 dynamic_can_shard = True
102 else:
103 input_shape[input_dynamic_axis] = -(output_total_size // input_total_size)
104 return input_shape, output_shape, dynamic_can_shard
106 def _merge_unshared_axis(self, global_shape, tensor_map):
107 """
108 Merge those axes that are not sharded to the high dimension which is shared.
109 shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48]
111 Returns:
112 tuple: (merged_shape, merge_tensor_map).
113 merge_tensor_map may contain -1 for merged unsharded axis groups.
114 """
115 merged_size = 1
116 merged_shape = []
117 merged_tensor_map = []
118 for axis in range(len(global_shape) - 1, -1, -1):
119 merged_size *= global_shape[axis]
120 if tensor_map[axis] != -1:
121 merged_shape.insert(0, merged_size)
122 merged_tensor_map.insert(0, tensor_map[axis])
123 merged_size = 1
124 if tensor_map[0] == -1:
125 merged_shape.insert(0, merged_size)
126 merged_tensor_map.insert(0, -1)
127 return merged_shape, merged_tensor_map
130 def _cal_output_layout_and_dst_shape(self, output_tensor_map, dst_shape, x_dict):
131 """
132 calculate output layout tensor map and local dst shape.
133 """
134 x_mesh_shape = x_dict["mesh_shape"]
135 output_map = []
136 local_dst_shape = []
137 for idx, map_id in enumerate(output_tensor_map):
138 if isinstance(map_id, tuple):
139 shard_size = 1
140 map_idx = []
141 for shard_id in map_id:
142 map_idx.append(x_dict["alias_name"][-1 - shard_id])
143 shard_size *= x_mesh_shape[-1 - shard_id]
144 output_map.append(tuple(map_idx))
145 local_dst_shape.append(dst_shape[idx] // shard_size if dst_shape[idx] > 0 else -1)
146 continue
147 if map_id < 0:
148 output_map.append("None")
149 local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1)
150 else:
151 output_map.append(x_dict["alias_name"][-1 - map_id])
152 local_dst_shape.append(dst_shape[idx] // x_mesh_shape[-1 - map_id] if dst_shape[idx] > 0 else -1)
153 return output_map, local_dst_shape
155 def _parse_shape_args(self, extra_args):
156 """Parse shape arguments from extra_args.
158 Args:
159 extra_args: Extra arguments containing shape info
161 Returns:
162 tuple: (dst_shape, input_shape)
163 """
164 if self.op_name in ["reshape", "view"]:
165 return self._parse_torch_shape_args(extra_args)
166 return self._parse_mindspore_shape_args(extra_args)
168 def _parse_torch_shape_args(self, extra_args):
169 """Parse PyTorch style shape arguments."""
170 if len(extra_args) < 2:
171 raise ValueError(f"{self.op_name} requires output shape and input shape.")
173 input_shape = extra_args[-1]
174 shape_args = extra_args[:-1]
176 if len(shape_args) == 1:
177 first_arg = shape_args[0]
178 if isinstance(first_arg, (list, tuple)):
179 dst_shape = first_arg
180 elif isinstance(first_arg, Tensor):
181 dst_shape = first_arg.tolist()
182 else:
183 dst_shape = shape_args
184 else:
185 dst_shape = shape_args
187 return dst_shape, input_shape
189 def _parse_mindspore_shape_args(self, extra_args):
190 """Parse MindSpore style shape arguments."""
191 if len(extra_args) != 2:
192 raise ValueError("Reshape requires output shape and input shape.")
194 return extra_args[0], extra_args[1]
196 def _normalize_shape(self, dst_shape):
197 """Normalize dst_shape to list format."""
198 if isinstance(dst_shape, Tensor):
199 dst_shape = dst_shape.tolist()
200 if not isinstance(dst_shape, (list, tuple)):
201 raise ValueError("Shape should be a tensor or a tuple or a list.")
202 return dst_shape
204 def _compute_output_tensor_map(self, merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard,
205 input_shape, x_map):
206 """Compute output tensor_map from merged information.
208 Args:
209 merged_shape: Merged shape from _merge_unshared_axis
210 merge_tensor_map: Merged tensor_map from _merge_unshared_axis
211 dst_shape: Target shape
212 x_mesh_shape: Mesh shape
213 dynamic_can_shard: Whether dynamic shape can be sharded
214 input_shape: Original input shape
215 x_map: Input tensor_map
217 Returns:
218 list: Output tensor_map
219 """
220 output_tensor_map = []
221 cur_axis = len(merged_shape) - 1
222 cur_size = merged_shape[cur_axis]
224 for shape in reversed(dst_shape):
225 if cur_size % shape != 0:
226 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
227 cur_size = cur_size // shape
229 if cur_size == 1:
230 map_val = merge_tensor_map[cur_axis]
231 if map_val != -1:
232 self._validate_reshape_shard(
233 map_val, x_mesh_shape, shape,
234 dynamic_can_shard, input_shape, x_map, dst_shape
235 )
236 output_tensor_map.insert(0, map_val)
237 cur_axis -= 1
238 cur_size = merged_shape[cur_axis]
239 else:
240 output_tensor_map.insert(0, -1)
242 return output_tensor_map
244 def _validate_reshape_shard(self, map_val, x_mesh_shape, shape,
245 dynamic_can_shard, input_shape, x_map, dst_shape):
246 """Validate that a sharded axis can be reshaped to the target shape dimension."""
247 if isinstance(map_val, tuple):
248 shard_size = 1
249 for axis in map_val:
250 shard_size *= x_mesh_shape[-axis - 1]
251 else:
252 shard_size = x_mesh_shape[-map_val - 1]
254 if shape < 0:
255 if not dynamic_can_shard:
256 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
257 elif shard_size > shape or shape % shard_size != 0:
258 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
260 def _apply_partial_status(self, x_layout, out_layout):
261 """Apply partial status from input to output layout."""
262 if x_layout.is_partial():
263 input_partial = x_layout.partial
264 for i, partial_op in enumerate(input_partial):
265 if partial_op is not None and i < len(out_layout.alias_name):
266 out_layout.set_partial_by_dev_axis(out_layout.alias_name[i], partial_op)
268 def infer_layout(self, layouts, extra_args=None):
269 """
270 Infer output layout for reshape operator.
272 For reshape operations, data slice on each device after reshape should be same as data slice before reshape.
274 Args:
275 layouts (Layout): Layout of input x
276 extra_args:
277 For MindSpore Reshape: (destination shape, original shape)
278 For PyTorch reshape/view: (shape_arg1, shape_arg2, ..., original shape) or (shape_tuple, original shape)
280 Returns:
281 tuple: Layout for output tensor
282 """
283 x_layout = layouts[0]
284 x_dict = x_layout.to_dict()
286 dst_shape, input_shape = self._parse_shape_args(extra_args)
287 dst_shape = self._normalize_shape(dst_shape)
289 x_map = _filter_none_split_tensor_map(x_dict["tensor_map"], x_dict["mesh_shape"])
290 x_mesh_shape = x_dict["mesh_shape"]
292 input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape)
293 merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map)
295 output_tensor_map = self._compute_output_tensor_map(
296 merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard, input_shape, x_map
297 )
299 output_layout = Layout(
300 mesh_shape=x_mesh_shape,
301 alias_name=x_layout.alias_name,
302 rank_list=x_layout.rank_list
303 )
304 output_map, local_dst_shape = self._cal_output_layout_and_dst_shape(output_tensor_map, dst_shape, x_dict)
305 out_layout = output_layout(*output_map)
307 self._apply_partial_status(x_layout, out_layout)
309 return out_layout, local_dst_shape