1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed implementation for RotaryPositionEmbedding operator."""
16import copy
17from typing import Optional, Tuple
18
19from .parallel_ops import DistributedOp
20
21
22def _normalize_rpe_args(x, cos, sin, mode=0):
23 """Normalize positional and keyword arguments into a canonical positional tuple.
24
25 Args:
26 x: Input tensor.
27 cos: Cosine position encoding tensor.
28 sin: Sine position encoding tensor.
29 mode: Rotation mode. 0=rotate_half, 1=rotate_interleaved, 2=quarter,
30 3=interleave-half. Defaults to 0.
31
32 Returns:
33 tuple: (positional_args_tuple, empty_kwargs_dict)
34 """
35 return (x, cos, sin, mode), {}
36
37
38class RotaryPositionEmbeddingDistributedOp(DistributedOp):
39 """Distributed operator for RotaryPositionEmbedding.
40
41 Computes rotary position embedding element-wise:
42 y = x * cos + x_rotate * sin
43
44 where x_rotate is obtained by rotating within the last (D) dimension.
45 Output shape equals x shape exactly.
46
47 Sharding constraints:
48 - D (last dim) must be replicated for x, cos, and sin: the kernel rotates
49 within D and the operation is indivisible along that axis.
50 - B, N, S dims are fully independent across positions and can be freely
51 sharded.
52 - cos/sin may have any subset of non-D dims replicated (broadcast case),
53 but if cos/sin is sharded on a dimension, it must match x's sharding
54 on that dimension.
55
56 MODE does not affect layout inference: all modes produce output shape == x
57 shape and leave B/N/S independence unchanged. MODE is passed through as a
58 kernel parameter.
59
60 Output:
61 Single tensor with the same shape and layout as x.
62 """
63
64 @staticmethod
65 def _validate_input_layouts(x_layout, cos_layout, sin_layout) -> None:
66 """Validate sharding constraints for all input tensors.
67
68 Rules (applied to both 4-D BNSD/BSND/SBND and 3-D TND layouts):
69 - x's last dim (D) must be replicated.
70 - cos and sin's last dim (D) must be replicated.
71 - For any non-D dimension d: if cos/sin is sharded there, the mesh
72 axis must equal x's mesh axis on the same dimension.
73
74 Args:
75 x_layout: Layout of the x tensor.
76 cos_layout: Layout of the cos tensor.
77 sin_layout: Layout of the sin tensor.
78
79 Raises:
80 ValueError: If D is sharded for any input, or if cos/sin sharding
81 is inconsistent with x on any non-D dimension.
82 """
83 op = "rotary_position_embedding"
84 x_tm = x_layout.tensor_map
85
86 if x_tm[-1] != -1:
87 raise ValueError(
88 f"For {op}, D (last dim) of x must be replicated, "
89 f"but got tensor_map={x_tm}"
90 )
91
92 for name, layout in [('cos', cos_layout), ('sin', sin_layout)]:
93 tm = layout.tensor_map
94 if tm[-1] != -1:
95 raise ValueError(
96 f"For {op}, D (last dim) of {name} must be replicated, "
97 f"but got tensor_map={tm}"
98 )
99 for d in range(len(tm) - 1):
100 x_d = x_tm[d] if d < len(x_tm) - 1 else -1
101 if tm[d] != -1 and tm[d] != x_d:
102 raise ValueError(
103 f"For {op}, {name} sharding on dim {d} must match x or be replicated, "
104 f"but got x={x_d}, {name}={tm[d]}"
105 )
106
107 def preprocess(self, args: tuple, kwargs: dict) -> Optional[tuple]:
108 """Extract local tensors and build the layout cache.
109
110 All inputs (x, cos, sin) are expected to be DTensors.
111
112 Args:
113 args: Positional arguments, may include DTensors.
114 kwargs: Keyword arguments.
115
116 Returns:
117 tuple: (local_args, local_kwargs, cache_values) where
118 local_args = (x_local, cos_local, sin_local, mode),
119 local_kwargs = {},
120 cache_values = [x_layout, cos_layout, sin_layout].
121 """
122 norm_args, _ = _normalize_rpe_args(*args, **kwargs)
123 x = norm_args[0]
124 cos = norm_args[1]
125 sin = norm_args[2]
126 mode = norm_args[3]
127
128 local_args = (x.to_local(), cos.to_local(), sin.to_local(), mode)
129 cache_values = [x.layout, cos.layout, sin.layout]
130 return local_args, {}, cache_values
131
132 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
133 """Infer output layout for the single output tensor.
134
135 Rules:
136 1. Partial inputs are rejected.
137 2. D (last dim) must be replicated for x, cos, and sin.
138 3. cos/sin sharding on non-D dims must match x or be replicated.
139 4. Output layout = deep copy of x_layout (output shape == x shape).
140
141 Args:
142 cache_values: [x_layout, cos_layout, sin_layout]
143
144 Returns:
145 tuple: ((output_layout,), None)
146
147 Raises:
148 ValueError: If any input has Partial status, D is sharded,
149 or cos/sin sharding is inconsistent with x.
150 """
151 x_layout = cache_values[0]
152 cos_layout = cache_values[1]
153 sin_layout = cache_values[2]
154
155 self._check_partial_inputs([x_layout, cos_layout, sin_layout])
156 self._validate_input_layouts(x_layout, cos_layout, sin_layout)
157 return (copy.deepcopy(x_layout),), None