Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_mhc_post.py: 98%
46 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for npu_mhc_post operator."""
16from typing import Tuple, Dict, Any
18from hyper_parallel.core.dtensor.layout import Layout
19from .parallel_ops import DistributedOp
22def _normalize_mhc_post_args(
23 x,
24 h_res,
25 h_out,
26 h_post):
27 """Normalize positional and keyword arguments into a canonical positional tuple.
29 Args:
30 x: Input tensor [B,S,N,D] or [T,N,D].
31 h_res: mHC h_res transformation matrix [B,S,N,N] or [T,N,N].
32 h_out: Attention/MLP output [B,S,D] or [T,D].
33 h_post: mHC h_post transformation matrix [B,S,N] or [T,N].
35 Returns:
36 tuple: (positional_args_tuple, empty_kwargs_dict)
37 """
38 return (x, h_res, h_out, h_post), {}
41# Validation rules table for npu_mhc_post
42# Key: tensor_map length (format identifier)
43# Value: validation rules for that format
44_MHC_POST_VALIDATION_RULES: Dict[int, Dict[str, Any]] = {
45 4: {
46 "op_name": "npu_mhc_post",
47 "forbidden_dims": {2: "N", 3: "D"},
48 "dim_requirements": {
49 "h_res": 4,
50 "h_out": 3,
51 "h_post": 3,
52 },
53 },
54 3: {
55 "op_name": "npu_mhc_post",
56 "forbidden_dims": {1: "N", 2: "D"},
57 "dim_requirements": {
58 "h_res": 3,
59 "h_out": 2,
60 "h_post": 2,
61 },
62 },
63}
66def _validate_tensor_map_dims(
67 tensor_map: tuple,
68 op_name: str,
69 forbidden_dims: Dict[int, str],
70) -> None:
71 """Check that specified dimensions are not sharded (replicated).
73 Args:
74 tensor_map: The tensor_map to check.
75 op_name: Operator name for error message.
76 forbidden_dims: Dict mapping dim index to dim name.
78 Raises:
79 ValueError: If any forbidden dimension is sharded.
80 """
81 for dim_idx, dim_name in forbidden_dims.items():
82 dim_value = tensor_map[dim_idx]
83 if dim_value != -1:
84 raise ValueError(
85 f"For {op_name}, {dim_name} dimension (dim {dim_idx}) of x "
86 f"should be replicated, but got {dim_value}"
87 )
90def _validate_tensor_dimensions(
91 actual_dims: Dict[str, int],
92 required_dims: Dict[str, int],
93 op_name: str,
94) -> None:
95 """Validate that each input tensor has the expected number of dimensions.
97 Args:
98 actual_dims: Dict mapping input name to actual dimension count.
99 required_dims: Dict mapping input name to expected dimension count.
100 op_name: Operator name for error message.
102 Raises:
103 ValueError: If any input has wrong dimension count.
104 """
105 for input_name, required_dim in required_dims.items():
106 actual_dim = actual_dims.get(input_name)
107 if actual_dim != required_dim:
108 raise ValueError(
109 f"For {op_name}, {input_name} tensor should have {required_dim} dimensions, "
110 f"but got {actual_dim}"
111 )
114class NpuMhcPostDistributedOp(DistributedOp):
115 """DistributedOp for npu_mhc_post operator.
117 Implements layout inference for the MHC post-processing operation:
118 x_{l+1} = (H_l^res)^T × x_l + h_l^out ⊗ H_t^post
119 """
121 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
122 norm_args, _ = _normalize_mhc_post_args(*args, **kwargs)
123 dtensor_x, dtensor_h_res, dtensor_h_out, dtensor_h_post = (
124 norm_args[0], norm_args[1], norm_args[2], norm_args[3]
125 )
127 local_args = (
128 dtensor_x.to_local(),
129 dtensor_h_res.to_local(),
130 dtensor_h_out.to_local(),
131 dtensor_h_post.to_local(),
132 )
133 local_kwargs = {}
135 cache_values = [
136 dtensor_x.layout,
137 dtensor_h_res.layout,
138 dtensor_h_out.layout,
139 dtensor_h_post.layout,
140 ]
141 return local_args, local_kwargs, cache_values
143 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
144 x_layout, h_res_layout, h_out_layout, h_post_layout = cache_values
146 self._check_partial_inputs([x_layout, h_res_layout, h_out_layout, h_post_layout])
148 self._validate_input_layouts_mhc_post(
149 x_layout, h_res_layout, h_out_layout, h_post_layout
150 )
152 out_layout = self._infer_output_layout(x_layout)
153 return (out_layout,), None
155 @staticmethod
156 def _validate_input_layouts_mhc_post(
157 x_layout: Layout,
158 h_res_layout: Layout,
159 h_out_layout: Layout,
160 h_post_layout: Layout,
161 ) -> None:
162 """Validate input layouts for npu_mhc_post operator."""
163 x_tm = x_layout.tensor_map
164 x_tm_len = len(x_tm)
166 # Get validation rules from table based on tensor_map length
167 rules = _MHC_POST_VALIDATION_RULES.get(x_tm_len)
168 if rules is None:
169 raise ValueError(
170 f"For npu_mhc_post, tensor_map length should be 4 or 3, but got {x_tm_len}"
171 )
173 # Validate forbidden dimensions (N, D must be replicated)
174 _validate_tensor_map_dims(x_tm, rules["op_name"], rules["forbidden_dims"])
176 # Validate dimension counts for each input
177 actual_dims = {
178 "h_res": len(h_res_layout.tensor_map),
179 "h_out": len(h_out_layout.tensor_map),
180 "h_post": len(h_post_layout.tensor_map),
181 }
182 _validate_tensor_dimensions(actual_dims, rules["dim_requirements"], rules["op_name"])
184 @staticmethod
185 def _infer_output_layout(x_layout: Layout) -> Layout:
186 out_layout = Layout.from_device_mesh(x_layout.mesh)
187 out_layout.set_tensor_map(x_layout.tensor_map)
188 out_layout.tensor_map_to_placement()
189 return out_layout