Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_one_hot_ext.py: 54%

147 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for OneHotExt operator. 

17""" 

18 

19# pylint: disable=import-outside-toplevel 

20from hyper_parallel.core.dtensor.layout import Layout 

21from hyper_parallel.platform import get_platform 

22from .parallel_ops import DistributedOp 

23 

24platform = get_platform() 

25 

26 

27class OneHotExtDistributedOp(DistributedOp): 

28 """Distributed implementation for OneHotExt operator.""" 

29 

30 def infer_layout(self, layouts, extra_args=None): 

31 """ 

32 Infer output layout for OneHotExt. 

33 

34 Args: 

35 layouts (tuple): Tuple containing input layouts. 

36 extra_args (tuple): Additional arguments containing [num_classes, on_value, off_value, axis]. 

37 

38 Returns: 

39 Layout: Output layout with one-hot dimension inserted at specified axis. 

40 """ 

41 if not layouts: 

42 return None 

43 

44 indices_layout = layouts[0] 

45 if indices_layout is None or indices_layout.mesh_shape is None: 

46 raise ValueError(f"{self.op_name}: indices layout cannot be None") 

47 

48 if indices_layout.is_partial(): 

49 raise ValueError( 

50 f"{self.op_name}: indices cannot be in partial state. " 

51 f"Indices must contain complete index values for OneHot operation." 

52 ) 

53 

54 num_classes = self._get_num_classes(extra_args) 

55 self._validate_num_classes(num_classes) 

56 

57 axis = self._get_axis(extra_args) 

58 

59 in_tensor_map = indices_layout.tensor_map 

60 if not in_tensor_map: 

61 raise ValueError(f"{self.op_name}: indices tensor_map is empty") 

62 

63 self._validate_multi_dim_restriction(in_tensor_map, axis, indices_layout) 

64 self._validate_inputs_layouts(layouts) 

65 

66 out_tensor_map = self._infer_output_tensor_map(in_tensor_map, axis) 

67 out_layout = self._create_layout_from_tensor_map(indices_layout, out_tensor_map) 

68 

69 out_layout.tensor_map_to_placement() 

70 

71 return out_layout 

72 

73 def get_expand_impl(self, func, infer_result, layouts, extra_args=None): 

74 """Get expanded implementation for OneHotExt operator.""" 

75 import mindspore as ms 

76 from mindspore import ops, Tensor 

77 

78 del infer_result 

79 

80 indices_layout = layouts[0] if layouts else None 

81 if indices_layout is None: 

82 return None 

83 

84 sharded_axes = self._get_sharded_axes(indices_layout) 

85 if not sharded_axes: 

86 return None 

87 

88 original_op = func 

89 reduce_max = ops.ReduceMax(keep_dims=False) 

90 

91 def expanded_one_hot(indices, num_classes, on_value, off_value, axis): 

92 self._validate_num_classes(num_classes) 

93 self._validate_indices_dtype(indices) 

94 

95 if num_classes != -1: 

96 return original_op(indices, num_classes, on_value, off_value, axis) 

97 

98 local_max = reduce_max(indices, ()) 

99 if not isinstance(local_max, Tensor): 

100 local_max = Tensor(local_max, ms.int64) 

101 

102 local_max_host = int(local_max.asnumpy()) 

103 if local_max_host > 2147483647: 

104 raise ValueError( 

105 f"{self.op_name}: indices max value {local_max_host} exceeds int32 range" 

106 ) 

107 

108 zero_dim = local_max.ndim == 0 

109 local_max_i32 = ops.cast(local_max, ms.int32) 

110 

111 if zero_dim: 

112 local_max_i32 = ops.expand_dims(local_max_i32, 0) 

113 

114 global_max_i32 = local_max_i32 

115 for axis_name in sharded_axes: 

116 group = indices_layout.get_comm_group_by_axis(axis_name) 

117 global_max_i32 = platform.differentiable_all_reduce( 

118 global_max_i32, "max", group 

119 ) 

120 

121 if zero_dim: 

122 global_max_i32 = ops.squeeze(global_max_i32, 0) 

123 

124 depth = int(global_max_i32.asnumpy()) + 1 

125 return original_op(indices, depth, on_value, off_value, axis) 

126 

127 return expanded_one_hot 

128 

129 def _get_num_classes(self, extra_args): 

130 """Extract num_classes from extra arguments.""" 

131 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 1: 

132 num_classes = extra_args[0] 

133 if isinstance(num_classes, int): 

134 return num_classes 

135 return -1 

136 

137 def _validate_num_classes(self, num_classes): 

138 """Validate num_classes parameter.""" 

139 if not isinstance(num_classes, int): 

140 raise TypeError( 

141 f"{self.op_name}: num_classes must be int, but got {type(num_classes).__name__}" 

142 ) 

143 if num_classes < -1: 

144 raise ValueError( 

145 f"{self.op_name}: num_classes must be >= -1, but got {num_classes}" 

146 ) 

147 

148 def _validate_indices_dtype(self, indices): 

149 """Validate indices dtype.""" 

150 import mindspore as ms 

151 

152 if indices.dtype != ms.int64: 

153 raise TypeError( 

154 f"{self.op_name}: indices dtype must be int64, but got {indices.dtype}" 

155 ) 

156 

157 def _get_sharded_axes(self, layout): 

158 """Get all device axes that are used for sharding.""" 

159 sharded_axes = set() 

160 

161 if layout is None or layout.alias_tensor_map is None: 

162 return [] 

163 

164 for dim_alias in layout.alias_tensor_map: 

165 if dim_alias == "None": 

166 continue 

167 

168 if isinstance(dim_alias, tuple): 

169 for axis_name in dim_alias: 

170 if axis_name != "None": 

171 sharded_axes.add(axis_name) 

172 else: 

173 sharded_axes.add(dim_alias) 

174 

175 return list(sharded_axes) 

176 

177 def _get_axis(self, extra_args): 

178 """Extract axis parameter from extra arguments.""" 

179 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 4: 

180 axis = extra_args[3] 

181 if isinstance(axis, int): 

182 return self._validate_axis(axis) 

183 return -1 

184 

185 def _validate_axis(self, axis): 

186 """Validate axis parameter.""" 

187 if not isinstance(axis, int): 

188 raise TypeError( 

189 f"{self.op_name}: axis must be int, but got {type(axis).__name__}" 

190 ) 

191 

192 if axis > 1 or axis < -1: 

193 raise ValueError(f"{self.op_name}: axis {axis} is out of range[-1, 1]") 

194 

195 return axis 

196 

197 def _validate_multi_dim_restriction(self, in_tensor_map, axis, indices_layout): 

198 """Validate restriction for multi-dimensional inputs.""" 

199 in_rank = len(in_tensor_map) 

200 if in_rank <= 1: 

201 return 

202 

203 if axis != -1: 

204 raise ValueError( 

205 f"{self.op_name}: when input dimension is > 1, axis must be -1, but got {axis}" 

206 ) 

207 

208 alias_map = indices_layout.alias_tensor_map 

209 for i in range(1, len(alias_map)): 

210 if alias_map[i] != "None": 

211 raise ValueError( 

212 f"{self.op_name}: when input dimension is > 1, strategy must be data parallel, " 

213 f"but dimension {i} is sharded on '{alias_map[i]}'" 

214 ) 

215 

216 def _validate_inputs_layouts(self, layouts): 

217 """Validate that non-indices inputs are fully replicated.""" 

218 for layout in layouts[1:]: 

219 if layout is None: 

220 continue 

221 alias_map = layout.alias_tensor_map 

222 if alias_map and any(x != "None" for x in alias_map): 

223 raise ValueError( 

224 f"{self.op_name}: non-indices inputs must be replicated, but got {alias_map}" 

225 ) 

226 

227 def _infer_output_tensor_map(self, in_tensor_map, axis): 

228 """Infer output tensor map by inserting one-hot dimension at specified axis.""" 

229 in_rank = len(in_tensor_map) 

230 

231 if axis in (-1, in_rank): 

232 insert_pos = in_rank 

233 else: 

234 insert_pos = axis 

235 

236 if insert_pos < 0 or insert_pos > in_rank: 

237 raise ValueError( 

238 f"{self.op_name}: axis {axis} is out of range for input with rank {in_rank}" 

239 ) 

240 

241 out_tensor_map = list(in_tensor_map) 

242 out_tensor_map.insert(insert_pos, -1) 

243 return tuple(out_tensor_map) 

244 

245 def _create_layout_from_tensor_map(self, base_layout, out_tensor_map): 

246 """Create output layout from tensor map.""" 

247 out_layout = Layout( 

248 mesh_shape=base_layout.mesh_shape, 

249 alias_name=base_layout.alias_name, 

250 rank_list=base_layout.rank_list, 

251 ) 

252 

253 out_layout.set_tensor_map(out_tensor_map) 

254 out_layout.set_alias_tensor_map( 

255 self._tensor_map_to_alias_tensor_map(base_layout, out_tensor_map) 

256 ) 

257 out_layout.update_compact_str() 

258 return out_layout 

259 

260 def _tensor_map_to_alias_tensor_map(self, base_layout, tensor_map): 

261 """Convert numeric tensor map to alias tensor map.""" 

262 alias_name = base_layout.alias_name 

263 alias_tensor_map = [] 

264 

265 for dim in tensor_map: 

266 if dim == -1: 

267 alias_tensor_map.append("None") 

268 continue 

269 

270 if isinstance(dim, tuple): 

271 names = tuple( 

272 alias_name[len(alias_name) - 1 - d] for d in dim if d != -1 

273 ) 

274 alias_tensor_map.append(names if names else "None") 

275 continue 

276 

277 alias_tensor_map.append(alias_name[len(alias_name) - 1 - dim]) 

278 

279 return tuple(alias_tensor_map)