Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_one_hot_ext.py: 54%
147 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for OneHotExt operator.
17"""
19# pylint: disable=import-outside-toplevel
20from hyper_parallel.core.dtensor.layout import Layout
21from hyper_parallel.platform import get_platform
22from .parallel_ops import DistributedOp
24platform = get_platform()
27class OneHotExtDistributedOp(DistributedOp):
28 """Distributed implementation for OneHotExt operator."""
30 def infer_layout(self, layouts, extra_args=None):
31 """
32 Infer output layout for OneHotExt.
34 Args:
35 layouts (tuple): Tuple containing input layouts.
36 extra_args (tuple): Additional arguments containing [num_classes, on_value, off_value, axis].
38 Returns:
39 Layout: Output layout with one-hot dimension inserted at specified axis.
40 """
41 if not layouts:
42 return None
44 indices_layout = layouts[0]
45 if indices_layout is None or indices_layout.mesh_shape is None:
46 raise ValueError(f"{self.op_name}: indices layout cannot be None")
48 if indices_layout.is_partial():
49 raise ValueError(
50 f"{self.op_name}: indices cannot be in partial state. "
51 f"Indices must contain complete index values for OneHot operation."
52 )
54 num_classes = self._get_num_classes(extra_args)
55 self._validate_num_classes(num_classes)
57 axis = self._get_axis(extra_args)
59 in_tensor_map = indices_layout.tensor_map
60 if not in_tensor_map:
61 raise ValueError(f"{self.op_name}: indices tensor_map is empty")
63 self._validate_multi_dim_restriction(in_tensor_map, axis, indices_layout)
64 self._validate_inputs_layouts(layouts)
66 out_tensor_map = self._infer_output_tensor_map(in_tensor_map, axis)
67 out_layout = self._create_layout_from_tensor_map(indices_layout, out_tensor_map)
69 out_layout.tensor_map_to_placement()
71 return out_layout
73 def get_expand_impl(self, func, infer_result, layouts, extra_args=None):
74 """Get expanded implementation for OneHotExt operator."""
75 import mindspore as ms
76 from mindspore import ops, Tensor
78 del infer_result
80 indices_layout = layouts[0] if layouts else None
81 if indices_layout is None:
82 return None
84 sharded_axes = self._get_sharded_axes(indices_layout)
85 if not sharded_axes:
86 return None
88 original_op = func
89 reduce_max = ops.ReduceMax(keep_dims=False)
91 def expanded_one_hot(indices, num_classes, on_value, off_value, axis):
92 self._validate_num_classes(num_classes)
93 self._validate_indices_dtype(indices)
95 if num_classes != -1:
96 return original_op(indices, num_classes, on_value, off_value, axis)
98 local_max = reduce_max(indices, ())
99 if not isinstance(local_max, Tensor):
100 local_max = Tensor(local_max, ms.int64)
102 local_max_host = int(local_max.asnumpy())
103 if local_max_host > 2147483647:
104 raise ValueError(
105 f"{self.op_name}: indices max value {local_max_host} exceeds int32 range"
106 )
108 zero_dim = local_max.ndim == 0
109 local_max_i32 = ops.cast(local_max, ms.int32)
111 if zero_dim:
112 local_max_i32 = ops.expand_dims(local_max_i32, 0)
114 global_max_i32 = local_max_i32
115 for axis_name in sharded_axes:
116 group = indices_layout.get_comm_group_by_axis(axis_name)
117 global_max_i32 = platform.differentiable_all_reduce(
118 global_max_i32, "max", group
119 )
121 if zero_dim:
122 global_max_i32 = ops.squeeze(global_max_i32, 0)
124 depth = int(global_max_i32.asnumpy()) + 1
125 return original_op(indices, depth, on_value, off_value, axis)
127 return expanded_one_hot
129 def _get_num_classes(self, extra_args):
130 """Extract num_classes from extra arguments."""
131 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 1:
132 num_classes = extra_args[0]
133 if isinstance(num_classes, int):
134 return num_classes
135 return -1
137 def _validate_num_classes(self, num_classes):
138 """Validate num_classes parameter."""
139 if not isinstance(num_classes, int):
140 raise TypeError(
141 f"{self.op_name}: num_classes must be int, but got {type(num_classes).__name__}"
142 )
143 if num_classes < -1:
144 raise ValueError(
145 f"{self.op_name}: num_classes must be >= -1, but got {num_classes}"
146 )
148 def _validate_indices_dtype(self, indices):
149 """Validate indices dtype."""
150 import mindspore as ms
152 if indices.dtype != ms.int64:
153 raise TypeError(
154 f"{self.op_name}: indices dtype must be int64, but got {indices.dtype}"
155 )
157 def _get_sharded_axes(self, layout):
158 """Get all device axes that are used for sharding."""
159 sharded_axes = set()
161 if layout is None or layout.alias_tensor_map is None:
162 return []
164 for dim_alias in layout.alias_tensor_map:
165 if dim_alias == "None":
166 continue
168 if isinstance(dim_alias, tuple):
169 for axis_name in dim_alias:
170 if axis_name != "None":
171 sharded_axes.add(axis_name)
172 else:
173 sharded_axes.add(dim_alias)
175 return list(sharded_axes)
177 def _get_axis(self, extra_args):
178 """Extract axis parameter from extra arguments."""
179 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 4:
180 axis = extra_args[3]
181 if isinstance(axis, int):
182 return self._validate_axis(axis)
183 return -1
185 def _validate_axis(self, axis):
186 """Validate axis parameter."""
187 if not isinstance(axis, int):
188 raise TypeError(
189 f"{self.op_name}: axis must be int, but got {type(axis).__name__}"
190 )
192 if axis > 1 or axis < -1:
193 raise ValueError(f"{self.op_name}: axis {axis} is out of range[-1, 1]")
195 return axis
197 def _validate_multi_dim_restriction(self, in_tensor_map, axis, indices_layout):
198 """Validate restriction for multi-dimensional inputs."""
199 in_rank = len(in_tensor_map)
200 if in_rank <= 1:
201 return
203 if axis != -1:
204 raise ValueError(
205 f"{self.op_name}: when input dimension is > 1, axis must be -1, but got {axis}"
206 )
208 alias_map = indices_layout.alias_tensor_map
209 for i in range(1, len(alias_map)):
210 if alias_map[i] != "None":
211 raise ValueError(
212 f"{self.op_name}: when input dimension is > 1, strategy must be data parallel, "
213 f"but dimension {i} is sharded on '{alias_map[i]}'"
214 )
216 def _validate_inputs_layouts(self, layouts):
217 """Validate that non-indices inputs are fully replicated."""
218 for layout in layouts[1:]:
219 if layout is None:
220 continue
221 alias_map = layout.alias_tensor_map
222 if alias_map and any(x != "None" for x in alias_map):
223 raise ValueError(
224 f"{self.op_name}: non-indices inputs must be replicated, but got {alias_map}"
225 )
227 def _infer_output_tensor_map(self, in_tensor_map, axis):
228 """Infer output tensor map by inserting one-hot dimension at specified axis."""
229 in_rank = len(in_tensor_map)
231 if axis in (-1, in_rank):
232 insert_pos = in_rank
233 else:
234 insert_pos = axis
236 if insert_pos < 0 or insert_pos > in_rank:
237 raise ValueError(
238 f"{self.op_name}: axis {axis} is out of range for input with rank {in_rank}"
239 )
241 out_tensor_map = list(in_tensor_map)
242 out_tensor_map.insert(insert_pos, -1)
243 return tuple(out_tensor_map)
245 def _create_layout_from_tensor_map(self, base_layout, out_tensor_map):
246 """Create output layout from tensor map."""
247 out_layout = Layout(
248 mesh_shape=base_layout.mesh_shape,
249 alias_name=base_layout.alias_name,
250 rank_list=base_layout.rank_list,
251 )
253 out_layout.set_tensor_map(out_tensor_map)
254 out_layout.set_alias_tensor_map(
255 self._tensor_map_to_alias_tensor_map(base_layout, out_tensor_map)
256 )
257 out_layout.update_compact_str()
258 return out_layout
260 def _tensor_map_to_alias_tensor_map(self, base_layout, tensor_map):
261 """Convert numeric tensor map to alias tensor map."""
262 alias_name = base_layout.alias_name
263 alias_tensor_map = []
265 for dim in tensor_map:
266 if dim == -1:
267 alias_tensor_map.append("None")
268 continue
270 if isinstance(dim, tuple):
271 names = tuple(
272 alias_name[len(alias_name) - 1 - d] for d in dim if d != -1
273 )
274 alias_tensor_map.append(names if names else "None")
275 continue
277 alias_tensor_map.append(alias_name[len(alias_name) - 1 - dim])
279 return tuple(alias_tensor_map)