Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_rotary_position_embedding.py: 78%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed implementation for RotaryPositionEmbedding operator.""" 

16import copy 

17from typing import Optional, Tuple 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22def _normalize_rpe_args(x, cos, sin, mode=0): 

23 """Normalize positional and keyword arguments into a canonical positional tuple. 

24 

25 Args: 

26 x: Input tensor. 

27 cos: Cosine position encoding tensor. 

28 sin: Sine position encoding tensor. 

29 mode: Rotation mode. 0=rotate_half, 1=rotate_interleaved, 2=quarter, 

30 3=interleave-half. Defaults to 0. 

31 

32 Returns: 

33 tuple: (positional_args_tuple, empty_kwargs_dict) 

34 """ 

35 return (x, cos, sin, mode), {} 

36 

37 

38class RotaryPositionEmbeddingDistributedOp(DistributedOp): 

39 """Distributed operator for RotaryPositionEmbedding. 

40 

41 Computes rotary position embedding element-wise: 

42 y = x * cos + x_rotate * sin 

43 

44 where x_rotate is obtained by rotating within the last (D) dimension. 

45 Output shape equals x shape exactly. 

46 

47 Sharding constraints: 

48 - D (last dim) must be replicated for x, cos, and sin: the kernel rotates 

49 within D and the operation is indivisible along that axis. 

50 - B, N, S dims are fully independent across positions and can be freely 

51 sharded. 

52 - cos/sin may have any subset of non-D dims replicated (broadcast case), 

53 but if cos/sin is sharded on a dimension, it must match x's sharding 

54 on that dimension. 

55 

56 MODE does not affect layout inference: all modes produce output shape == x 

57 shape and leave B/N/S independence unchanged. MODE is passed through as a 

58 kernel parameter. 

59 

60 Output: 

61 Single tensor with the same shape and layout as x. 

62 """ 

63 

64 @staticmethod 

65 def _validate_input_layouts(x_layout, cos_layout, sin_layout) -> None: 

66 """Validate sharding constraints for all input tensors. 

67 

68 Rules (applied to both 4-D BNSD/BSND/SBND and 3-D TND layouts): 

69 - x's last dim (D) must be replicated. 

70 - cos and sin's last dim (D) must be replicated. 

71 - For any non-D dimension d: if cos/sin is sharded there, the mesh 

72 axis must equal x's mesh axis on the same dimension. 

73 

74 Args: 

75 x_layout: Layout of the x tensor. 

76 cos_layout: Layout of the cos tensor. 

77 sin_layout: Layout of the sin tensor. 

78 

79 Raises: 

80 ValueError: If D is sharded for any input, or if cos/sin sharding 

81 is inconsistent with x on any non-D dimension. 

82 """ 

83 op = "rotary_position_embedding" 

84 x_tm = x_layout.tensor_map 

85 

86 if x_tm[-1] != -1: 

87 raise ValueError( 

88 f"For {op}, D (last dim) of x must be replicated, " 

89 f"but got tensor_map={x_tm}" 

90 ) 

91 

92 for name, layout in [('cos', cos_layout), ('sin', sin_layout)]: 

93 tm = layout.tensor_map 

94 if tm[-1] != -1: 

95 raise ValueError( 

96 f"For {op}, D (last dim) of {name} must be replicated, " 

97 f"but got tensor_map={tm}" 

98 ) 

99 for d in range(len(tm) - 1): 

100 x_d = x_tm[d] if d < len(x_tm) - 1 else -1 

101 if tm[d] != -1 and tm[d] != x_d: 

102 raise ValueError( 

103 f"For {op}, {name} sharding on dim {d} must match x or be replicated, " 

104 f"but got x={x_d}, {name}={tm[d]}" 

105 ) 

106 

107 def preprocess(self, args: tuple, kwargs: dict) -> Optional[tuple]: 

108 """Extract local tensors and build the layout cache. 

109 

110 All inputs (x, cos, sin) are expected to be DTensors. 

111 

112 Args: 

113 args: Positional arguments, may include DTensors. 

114 kwargs: Keyword arguments. 

115 

116 Returns: 

117 tuple: (local_args, local_kwargs, cache_values) where 

118 local_args = (x_local, cos_local, sin_local, mode), 

119 local_kwargs = {}, 

120 cache_values = [x_layout, cos_layout, sin_layout]. 

121 """ 

122 norm_args, _ = _normalize_rpe_args(*args, **kwargs) 

123 x = norm_args[0] 

124 cos = norm_args[1] 

125 sin = norm_args[2] 

126 mode = norm_args[3] 

127 

128 local_args = (x.to_local(), cos.to_local(), sin.to_local(), mode) 

129 cache_values = [x.layout, cos.layout, sin.layout] 

130 return local_args, {}, cache_values 

131 

132 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

133 """Infer output layout for the single output tensor. 

134 

135 Rules: 

136 1. Partial inputs are rejected. 

137 2. D (last dim) must be replicated for x, cos, and sin. 

138 3. cos/sin sharding on non-D dims must match x or be replicated. 

139 4. Output layout = deep copy of x_layout (output shape == x shape). 

140 

141 Args: 

142 cache_values: [x_layout, cos_layout, sin_layout] 

143 

144 Returns: 

145 tuple: ((output_layout,), None) 

146 

147 Raises: 

148 ValueError: If any input has Partial status, D is sharded, 

149 or cos/sin sharding is inconsistent with x. 

150 """ 

151 x_layout = cache_values[0] 

152 cos_layout = cache_values[1] 

153 sin_layout = cache_values[2] 

154 

155 self._check_partial_inputs([x_layout, cos_layout, sin_layout]) 

156 self._validate_input_layouts(x_layout, cos_layout, sin_layout) 

157 return (copy.deepcopy(x_layout),), None