Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_all_gather_matmul.py: 81%
57 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for AllGatherMatmul (MC2 fusion) operator."""
16import copy
17from typing import Tuple
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23def _normalize_agm_args(
24 x1,
25 x2,
26 group,
27 world_size,
28 bias=None,
29 gather_index=0,
30 gather_output=True,
31 comm_turn=0,
32 trans_input=False,
33 trans_x2=False,
34) -> Tuple[tuple, dict]:
35 """Normalize positional and keyword arguments into a canonical positional tuple.
37 Args:
38 x1: Input tensor. Physical shape (m, k); trans_input=False only (MS constraint).
39 x2: Weight tensor with shape (k, n) if trans_x2=False, or (n, k) if trans_x2=True.
40 group: Communication group name string.
41 world_size: Number of ranks in the communication group.
42 bias: Must be None (MS constraint; bias is not supported).
43 gather_index: Index for gather operation (only 0 is supported).
44 gather_output: Whether to return the gathered intermediate tensor.
45 comm_turn: Communication turn (only 0 is supported).
46 trans_input: Must be False (MS constraint).
47 trans_x2: If True, x2 physical shape is (n, k); MindSpore transposes x2 before CANN call.
49 Returns:
50 tuple: (positional_args_tuple, empty_kwargs_dict)
51 """
52 return (x1, x2, group, world_size, bias, gather_index, gather_output, comm_turn, trans_input, trans_x2), {}
55class AllGatherMatmulDistributedOp(DistributedOp):
56 """Distributed operator for mindspore.ops.all_gather_matmul (MC2 fusion).
58 The CANN AllGatherMatmul kernel handles communication (AllGather) internally.
59 HyperParallel's role is to:
60 1. Extract local tensors from DTensor inputs via to_local().
61 2. Infer output DTensor layouts so downstream operators can correctly
62 understand the distribution state of the output.
64 Shape transformation (logical, after any transpose):
65 x1 (m_local, k_local) —[CANN AllGather on m]→ (m_global, k_local) —[matmul]→ output (m_global, n_local)
66 gather_out (m_global, k_local) — valid only when gather_output=True
68 Sharding constraints:
69 - x1 physical (m, k): dim 0 is m (AllGather consumes m); dim 1 (k) may be Replicate or Shard.
70 - x1's m-dim tensor_map must not be a tuple (joint sharding across multiple mesh dims unsupported).
71 - x1 k-dim and x2 k-dim must share the same placement (both Replicate, or both sharded on the
72 same mesh axis). trans_x2=False: x2 k is dim 0; trans_x2=True: x2 k is dim 1.
73 - When k is sharded, output carries Partial(sum) status on the k-dim mesh axis; the caller is
74 responsible for applying AllReduce to obtain the correct full result.
75 - gather_out k-dim follows x1's k-dim placement.
76 - Partial inputs are not allowed.
77 - gather_index=0, trans_input=False, bias=None only (current MS constraints).
78 """
80 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
81 """Extract local tensors and build the layout cache.
83 Args:
84 args: Positional arguments (DTensors for x1 and x2).
85 kwargs: Keyword arguments.
87 Returns:
88 tuple: (local_args, local_kwargs, cache_values) where
89 local_args contains extracted local tensors,
90 local_kwargs contains keyword arguments,
91 cache_values = [x1_layout, x2_layout, trans_x2, gather_output].
92 """
93 norm_args, _ = _normalize_agm_args(*args, **kwargs)
94 x1 = norm_args[0]
95 x2 = norm_args[1]
96 gather_output = norm_args[6]
97 trans_x2 = norm_args[9]
99 # MindSpore all_gather_matmul: only (input, x2, group, world_size) are
100 # positional; bias, gather_index, gather_output, comm_turn, trans_input,
101 # trans_x2 are keyword-only (after the '*' separator).
102 local_args = (
103 x1.to_local(),
104 x2.to_local(),
105 norm_args[2], # group
106 norm_args[3], # world_size
107 )
108 local_kwargs = {
109 'bias': norm_args[4],
110 'gather_index': norm_args[5],
111 'gather_output': gather_output,
112 'comm_turn': norm_args[7],
113 'trans_input': norm_args[8],
114 'trans_x2': trans_x2,
115 }
117 cache_values = [
118 x1.layout,
119 x2.layout,
120 trans_x2,
121 gather_output,
122 ]
123 return local_args, local_kwargs, cache_values
125 @staticmethod
126 def _set_partial_from_k(output_layout: Layout, k_placement, op: str = 'sum') -> None:
127 """Set Partial on output_layout for the mesh axes corresponding to k_placement.
129 Args:
130 output_layout: Layout to mark as Partial.
131 k_placement: Tensor_map value for the k dimension (integer or tuple of integers).
132 op: Reduction operation, default 'sum'.
133 """
134 alias = output_layout.alias_name
135 n = len(alias)
136 if isinstance(k_placement, tuple):
137 for v in k_placement:
138 output_layout.set_partial_by_dev_axis(alias[n - 1 - v], op)
139 else:
140 output_layout.set_partial_by_dev_axis(alias[n - 1 - k_placement], op)
142 @staticmethod
143 def _validate_input_layouts(
144 x1_layout: Layout,
145 x2_layout: Layout,
146 trans_x2: bool,
147 ) -> None:
148 """Validate sharding constraints for AllGatherMatmul inputs.
150 Args:
151 x1_layout: Layout of x1. Physical (m, k); trans_input=False only.
152 x2_layout: Layout of x2 (k, n) if trans_x2=False, or (n, k) if trans_x2=True.
153 trans_x2: Whether x2 is transposed.
155 Raises:
156 ValueError: If x1's m-dim tensor_map is a tuple, the k-dim placements of x1 and x2
157 do not match, or any input has Partial status.
158 """
159 op = "all_gather_matmul"
160 x1_tm = x1_layout.tensor_map
161 x2_tm = x2_layout.tensor_map
163 # trans_input=False only: x1 physical (m, k) — k is dim 1, m is dim 0.
164 x1_m_dim = 0
166 if isinstance(x1_tm[x1_m_dim], tuple):
167 raise ValueError(
168 f"For {op}, x1 m-dim (dim {x1_m_dim}) "
169 f"with tensor_map={x1_tm[x1_m_dim]} is jointly sharded across multiple "
170 f"mesh dims, which is not supported in this version."
171 )
173 # k-dim placement must match between x1 and x2.
174 x2_k_dim = 1 if trans_x2 else 0
175 if x1_tm[1] != x2_tm[x2_k_dim]:
176 raise ValueError(
177 f"For {op}, x1 k-dim (dim 1) placement {x1_tm[1]} must match "
178 f"x2 k-dim (dim {x2_k_dim}) placement {x2_tm[x2_k_dim]} "
179 f"(trans_x2={trans_x2})."
180 )
182 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
183 """Infer output layouts for (output, gather_out).
185 AllGather on m dim: output dim 0 is always -1 (Replicate), because
186 AllGather unconditionally makes the m dimension global.
188 n dim: follows x2's n placement.
189 - trans_x2=False: n is x2 dim 1 → output_tm[1] = x2_tm[1]
190 - trans_x2=True: n is x2 dim 0 → output_tm[1] = x2_tm[0]
192 k dim (contraction): when k is sharded, output carries Partial(sum) on the
193 k-dim mesh axis; the caller must apply AllReduce to get the correct result.
195 gather_out layout: m is Replicate (-1); k follows x1's k-dim placement.
197 Args:
198 cache_values: [x1_layout, x2_layout, trans_x2, gather_output]
200 Returns:
201 tuple: ((output_layout, gather_out_layout), None)
203 Raises:
204 ValueError: If any input has Partial status or sharding constraints are violated.
205 """
206 x1_layout = cache_values[0]
207 x2_layout = cache_values[1]
208 trans_x2 = cache_values[2]
209 gather_output = cache_values[3]
211 self._check_partial_inputs([x1_layout, x2_layout])
212 self._validate_input_layouts(x1_layout, x2_layout, trans_x2)
214 x1_tm = x1_layout.tensor_map
215 x2_tm = x2_layout.tensor_map
216 k_placement = x1_tm[1]
217 n_placement = x2_tm[0] if trans_x2 else x2_tm[1]
219 # output: m is Replicate (-1) because AllGather consumed the m sharding;
220 # n inherits from x2's n dim.
221 output_layout = Layout.from_device_mesh(x1_layout.mesh)
222 output_layout.set_tensor_map((-1, n_placement))
223 output_layout.tensor_map_to_placement()
225 # When k is sharded, output is a partial sum; mark Partial so the framework
226 # can insert AllReduce downstream.
227 if k_placement != -1:
228 self._set_partial_from_k(output_layout, k_placement)
230 # gather_out: gather_output=True → m Replicate (-1), k follows x1's k placement.
231 # gather_output=False → CANN returns a 1-D empty tensor; force all-Replicate so
232 # the layout is compatible with any tensor rank returned by the kernel.
233 gather_out_layout = Layout.from_device_mesh(x1_layout.mesh)
234 gather_k = k_placement if gather_output else -1
235 gather_out_layout.set_tensor_map((-1, gather_k))
236 gather_out_layout.tensor_map_to_placement()
238 return (copy.deepcopy(output_layout), copy.deepcopy(gather_out_layout)), None