Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_mhc_pre_sinkhorn.py: 98%
55 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for npu_mhc_pre_sinkhorn operator."""
16from typing import Tuple, Dict, Any
18from hyper_parallel.core.dtensor.layout import Layout
19from hyper_parallel.platform import get_platform
20from hyper_parallel.platform.platform import PlatformType
21from .parallel_ops import DistributedOp
23platform = get_platform()
25_HC_MULT_DEFAULT = 4
26_NUM_ITERS_DEFAULT = 20
27_HC_EPS_DEFAULT = 1e-6
28_NORM_EPS_DEFAULT = 1e-6
31def _normalize_mhc_pre_sinkhorn_args(
32 x,
33 phi,
34 alpha,
35 bias,
36 hc_mult=_HC_MULT_DEFAULT,
37 num_iters=_NUM_ITERS_DEFAULT,
38 hc_eps=_HC_EPS_DEFAULT,
39 norm_eps=_NORM_EPS_DEFAULT,
40 out_flag=True):
41 """Normalize positional and keyword arguments into a canonical positional tuple.
43 Args:
44 x: Input tensor [B,S,N,C] or [T,N,C].
45 phi: mHC parameter matrix [N*N+2*N, N*C].
46 alpha: mHC scaling parameters [3].
47 bias: mHC bias parameters [N*N+2*N].
48 hc_mult: HC dimension size (currently only 4 supported).
49 num_iters: Sinkhorn iteration count.
50 hc_eps: H_pre sigmoid eps parameter.
51 norm_eps: RmsNorm eps parameter.
52 out_flag: Whether to output intermediate gradients.
54 Returns:
55 tuple: (positional_args_tuple, empty_kwargs_dict)
56 """
57 return (
58 x, phi, alpha, bias,
59 hc_mult, num_iters, hc_eps, norm_eps, out_flag,
60 ), {}
63# Validation rules table for npu_mhc_pre_sinkhorn
64# Key: tensor_map length (format identifier)
65# Value: validation rules for that format
66_MHC_PRE_SINKHORN_VALIDATION_RULES: Dict[int, Dict[str, Any]] = {
67 4: {
68 "op_name": "npu_mhc_pre_sinkhorn",
69 "forbidden_dims": {2: "N"},
70 "phi_forbidden_dims": {0: "dim0", 1: "dim1"},
71 "alpha_forbidden_dims": {0: "dim0"},
72 "bias_forbidden_dims": {0: "dim0"},
73 },
74 3: {
75 "op_name": "npu_mhc_pre_sinkhorn",
76 "forbidden_dims": {1: "N"},
77 "phi_forbidden_dims": {0: "dim0", 1: "dim1"},
78 "alpha_forbidden_dims": {0: "dim0"},
79 "bias_forbidden_dims": {0: "dim0"},
80 },
81}
84def _validate_tensor_map_dims(
85 tensor_map: tuple,
86 op_name: str,
87 forbidden_dims: Dict[int, str],
88) -> None:
89 """Check that specified dimensions are not sharded (replicated).
91 Args:
92 tensor_map: The tensor_map to check.
93 op_name: Operator name for error message.
94 forbidden_dims: Dict mapping dim index to dim name.
96 Raises:
97 ValueError: If any forbidden dimension is sharded.
98 """
99 for dim_idx, dim_name in forbidden_dims.items():
100 dim_value = tensor_map[dim_idx]
101 if dim_value != -1:
102 raise ValueError(
103 f"For {op_name}, {dim_name} dimension (dim {dim_idx}) of x "
104 f"should be replicated, but got {dim_value}"
105 )
108class NpuMhcPreSinkhornDistributedOp(DistributedOp):
109 """DistributedOp for npu_mhc_pre_sinkhorn operator.
111 Implements layout inference for the MHC pre-processing with Sinkhorn operation.
112 Outputs 8 tensors: hin, h_post, h_res, h_pre, hc_before_norm, inv_rms, sum_out, norm_out.
113 """
115 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
116 norm_args, _ = _normalize_mhc_pre_sinkhorn_args(*args, **kwargs)
117 dtensor_x = norm_args[0]
118 dtensor_phi = norm_args[1]
119 dtensor_alpha = norm_args[2]
120 dtensor_bias = norm_args[3]
122 if platform.platform_type == PlatformType.MINDSPORE:
123 local_args = (
124 dtensor_x.to_local(),
125 dtensor_phi.to_local(),
126 dtensor_alpha.to_local(),
127 dtensor_bias.to_local(),
128 norm_args[4],
129 norm_args[5],
130 norm_args[6],
131 norm_args[7],
132 norm_args[8],
133 )
134 local_kwargs = {}
135 else:
136 local_args = (
137 dtensor_x.to_local(),
138 dtensor_phi.to_local(),
139 dtensor_alpha.to_local(),
140 dtensor_bias.to_local(),
141 )
142 local_kwargs = {
143 'hc_mult': norm_args[4],
144 'num_iters': norm_args[5],
145 'hc_eps': norm_args[6],
146 'norm_eps': norm_args[7],
147 'out_flag': norm_args[8],
148 }
150 cache_values = [
151 dtensor_x.layout,
152 dtensor_phi.layout,
153 dtensor_alpha.layout,
154 dtensor_bias.layout,
155 ]
156 return local_args, local_kwargs, cache_values
158 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
159 x_layout, phi_layout, alpha_layout, bias_layout = cache_values
161 self._check_partial_inputs([x_layout, phi_layout, alpha_layout, bias_layout])
163 self._validate_input_layouts_mhc_pre_sinkhorn(
164 x_layout, phi_layout, alpha_layout, bias_layout
165 )
167 out_layouts = self._infer_output_layouts(x_layout)
168 return out_layouts, None
170 @staticmethod
171 def _validate_input_layouts_mhc_pre_sinkhorn(
172 x_layout: Layout,
173 phi_layout: Layout,
174 alpha_layout: Layout,
175 bias_layout: Layout,
176 ) -> None:
177 """Validate input layouts for npu_mhc_pre_sinkhorn operator."""
178 x_tm = x_layout.tensor_map
179 x_tm_len = len(x_tm)
181 # Get validation rules from table based on tensor_map length
182 rules = _MHC_PRE_SINKHORN_VALIDATION_RULES.get(x_tm_len)
183 if rules is None:
184 raise ValueError(
185 f"For npu_mhc_pre_sinkhorn, tensor_map length should be 4 or 3, but got {x_tm_len}"
186 )
188 # Validate forbidden dimensions (N must be replicated)
189 _validate_tensor_map_dims(x_tm, rules["op_name"], rules["forbidden_dims"])
191 # Validate phi, alpha, bias must be fully replicated
192 _validate_tensor_map_dims(phi_layout.tensor_map, rules["op_name"], rules["phi_forbidden_dims"])
193 _validate_tensor_map_dims(alpha_layout.tensor_map, rules["op_name"], rules["alpha_forbidden_dims"])
194 _validate_tensor_map_dims(bias_layout.tensor_map, rules["op_name"], rules["bias_forbidden_dims"])
196 @staticmethod
197 def _infer_output_layouts(
198 x_layout: Layout,
199 ) -> Tuple[Layout, Layout, Layout, Layout, Layout, Layout, Layout, Layout]:
200 out_layout = Layout.from_device_mesh(x_layout.mesh)
201 out_layout.set_tensor_map(x_layout.tensor_map)
202 out_layout.tensor_map_to_placement()
204 return (
205 out_layout, out_layout, out_layout, out_layout,
206 out_layout, out_layout, out_layout, out_layout,
207 )