Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_matmul_reduce_scatter.py: 82%
45 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for MatmulReduceScatter (MC2 fusion) operator."""
16import copy
17from typing import Tuple
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23def _normalize_mrs_args(
24 x1,
25 x2,
26 group,
27 world_size,
28 reduce_op='sum',
29 bias=None,
30 comm_turn=0,
31 trans_input=False,
32 trans_x2=False,
33) -> Tuple[tuple, dict]:
34 """Normalize positional and keyword arguments into a canonical positional tuple.
36 Args:
37 x1: Input tensor. Physical shape (m, k); trans_input=False only (MS constraint).
38 x2: Weight tensor with shape (k, n) if trans_x2=False, or (n, k) if trans_x2=True.
39 group: Communication group name string.
40 world_size: Number of ranks in the communication group.
41 reduce_op: Reduce operation string (only 'sum' is supported).
42 bias: Must be None (MS constraint; bias is not supported).
43 comm_turn: Communication turn (only 0 is supported).
44 trans_input: Must be False (MS constraint).
45 trans_x2: If True, x2 physical shape is (n, k); MindSpore transposes x2 before CANN call.
47 Returns:
48 tuple: (positional_args_tuple, empty_kwargs_dict)
49 """
50 return (x1, x2, group, world_size, reduce_op, bias, comm_turn, trans_input, trans_x2), {}
53class MatmulReduceScatterDistributedOp(DistributedOp):
54 """Distributed operator for mindspore.ops.matmul_reduce_scatter (MC2 fusion).
56 The CANN MatmulReduceScatter kernel handles communication (ReduceScatter)
57 internally. HyperParallel's role is to:
58 1. Extract local tensors from DTensor inputs via to_local().
59 2. Infer output DTensor layouts so downstream operators can correctly
60 understand the distribution state of the output.
62 No Partial state is needed because CANN internally completes the ReduceScatter;
63 each rank receives the correctly reduced and scattered result slice.
65 Shape transformation (logical, after any transpose):
66 x1 (m, k_local), x2 (k_local, n) —[matmul]→ partial (m, n)
67 —[CANN ReduceScatter on m]→ output (m_local = m / world_size, n)
69 Sharding constraints:
70 - x1 physical (m, k): dim 1 (k) must be Shard (TP/comm axis);
71 dim 0 (m) may be Replicate or Shard (DP axis).
72 - x2's k-dim must be Shard and match x1's k-dim placement exactly.
73 - trans_x2=False: x2 k-dim is dim 0; trans_x2=True: x2 k-dim is dim 1.
74 - trans_input=False, bias=None only (current MS constraints).
75 - Partial inputs are not allowed.
76 """
78 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
79 """Extract local tensors and build the layout cache.
81 Args:
82 args: Positional arguments (DTensors for x1 and x2).
83 kwargs: Keyword arguments.
85 Returns:
86 tuple: (local_args, local_kwargs, cache_values) where
87 local_args contains extracted local tensors,
88 local_kwargs contains keyword arguments,
89 cache_values = [x1_layout, x2_layout, trans_x2].
90 """
91 norm_args, _ = _normalize_mrs_args(*args, **kwargs)
92 x1 = norm_args[0]
93 x2 = norm_args[1]
94 trans_x2 = norm_args[8]
96 # MindSpore matmul_reduce_scatter: only (input, x2, group, world_size) are
97 # positional; reduce_op, bias, comm_turn, trans_input, trans_x2 are
98 # keyword-only (after the '*' separator).
99 local_args = (
100 x1.to_local(),
101 x2.to_local(),
102 norm_args[2], # group
103 norm_args[3], # world_size
104 )
105 local_kwargs = {
106 'reduce_op': norm_args[4],
107 'bias': norm_args[5],
108 'comm_turn': norm_args[6],
109 'trans_input': norm_args[7],
110 'trans_x2': trans_x2,
111 }
113 cache_values = [
114 x1.layout,
115 x2.layout,
116 trans_x2,
117 ]
118 return local_args, local_kwargs, cache_values
120 @staticmethod
121 def _validate_input_layouts(
122 x1_layout: Layout,
123 x2_layout: Layout,
124 trans_x2: bool,
125 ) -> None:
126 """Validate sharding constraints for MatmulReduceScatter inputs.
128 Args:
129 x1_layout: Layout of x1. Physical (m, k); trans_input=False only.
130 x2_layout: Layout of x2 (k, n) if trans_x2=False, or (n, k) if trans_x2=True.
131 trans_x2: Whether x2 is transposed.
133 Raises:
134 ValueError: If x1's k dimension is Replicate, x2's k dimension layout does not
135 match x1's k dimension layout, or any input has Partial status.
136 """
137 op = "matmul_reduce_scatter"
138 x1_tm = x1_layout.tensor_map
139 x2_tm = x2_layout.tensor_map
141 # trans_input=False only: x1 physical (m, k) — k is dim 1.
142 x1_k_dim = 1
144 if x1_tm[x1_k_dim] == -1:
145 raise ValueError(
146 f"For {op}, x1 k-dim (dim {x1_k_dim}) must be "
147 f"Shard (not Replicate), because ReduceScatter requires k to be sharded. "
148 f"Got tensor_map={x1_tm}"
149 )
151 x2_k_dim = 1 if trans_x2 else 0
152 if x2_tm[x2_k_dim] != x1_tm[x1_k_dim]:
153 raise ValueError(
154 f"For {op}, x2 dim {x2_k_dim} (k) layout must match x1 k-dim (dim {x1_k_dim}) "
155 f"layout (trans_x2={trans_x2}), "
156 f"but got x1_k={x1_tm[x1_k_dim]}, x2_k={x2_tm[x2_k_dim]}"
157 )
159 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
160 """Infer output layout for MatmulReduceScatter.
162 ReduceScatter converts the k-dim sharding (TP comm axis) into m-dim sharding.
163 trans_input=False only: x1 physical (m, k).
165 comm_mesh_dim = x1_tm[1] (k → TP / ReduceScatter scatter axis)
166 dp_mesh_dim = x1_tm[0] (m → DP, or -1 if Replicate)
168 If dp_mesh_dim == -1 (pure TP):
169 output_tm[0] = comm_mesh_dim
170 Else (TP + DP joint sharding):
171 output_tm[0] = (comm_mesh_dim, dp_mesh_dim)
173 output_tm[1] = x2's n dimension placement
174 - trans_x2=False: x2 dim 1 (n) → x2_tm[1]
175 - trans_x2=True: x2 dim 0 (n) → x2_tm[0]
177 Args:
178 cache_values: [x1_layout, x2_layout, trans_x2]
180 Returns:
181 tuple: (output_layout, None)
183 Raises:
184 ValueError: If any input has Partial status or sharding constraints are violated.
185 """
186 x1_layout = cache_values[0]
187 x2_layout = cache_values[1]
188 trans_x2 = cache_values[2]
190 self._check_partial_inputs([x1_layout, x2_layout])
191 self._validate_input_layouts(x1_layout, x2_layout, trans_x2)
193 x1_tm = x1_layout.tensor_map
194 x2_tm = x2_layout.tensor_map
196 # trans_input=False: k is dim 1 (comm axis), m is dim 0 (DP axis or Replicate).
197 comm_mesh_dim = x1_tm[1]
198 dp_mesh_dim = x1_tm[0]
200 if dp_mesh_dim == -1:
201 output_m = comm_mesh_dim
202 else:
203 output_m = (comm_mesh_dim, dp_mesh_dim)
205 n_placement = x2_tm[0] if trans_x2 else x2_tm[1]
207 output_layout = Layout.from_device_mesh(x1_layout.mesh)
208 output_layout.set_tensor_map((output_m, n_placement))
209 output_layout.tensor_map_to_placement()
211 return copy.deepcopy(output_layout), None