Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_ms_flash_attention_score.py 52.0% 336-337,343,350,354,358,821-822,942-948,952,955-956,965,987,995,1004,1028-1032,1034-1036,1038-1042,1044-1045,1047-1048,1052-1055,1057,1062,1071-1072,1089
hyper_parallel/core/shard/ops/parallel_npu_flash_attention_score.py 60.9% 110,157-159,161,166,168-170,172-173,366-367,373,380,384,388,545,591,611,906-907,1013,1121-1122
hyper_parallel/core/shard/ops/parallel_outer.py 96.3% 32
hyper_parallel/core/shard/ops/parallel_scaled_dot_product_attention.py 82.9% 36,70-72,74-75,79,88,90,95,345,550
hyper_parallel/core/shard/ops/parallel_scatter.py 100%  
hyper_parallel/core/shard/ops/parallel_transpose.py 91.2% 68,114,130
hyper_parallel/core/shard/ops/parallel_ms_flash_attention_score.py
332
333
334
335
336
337
338
339
340
341
        #
        # kv_seq_split_num > 1 is blocked by a guard in
        # _compute_adjusted_sparse_params before reaching this function,
        # so local_kv_len == global_kv_len is guaranteed here.
        local_q_len = query.shape[seq_dim_idx]
        local_kv_len = key.shape[seq_dim_idx]

        if sparse_mode in (SPARSE_DEFAULT_MASK, SPARSE_BAND):
            new_pre_tokens = pre_tokens
            new_next_tokens = next_tokens
339
340
341
342
343
344
345
346
347
        if sparse_mode in (SPARSE_DEFAULT_MASK, SPARSE_BAND):
            new_pre_tokens = pre_tokens
            new_next_tokens = next_tokens
        else:
            new_pre_tokens = local_kv_len
            new_next_tokens = 0

        new_sparse_mode = SPARSE_BAND if sparse_mode != SPARSE_DEFAULT_MASK else sparse_mode
        update_mode = SPARSE_MODE_UPDATE_MAP[sparse_mode]
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        new_sparse_mode = SPARSE_BAND if sparse_mode != SPARSE_DEFAULT_MASK else sparse_mode
        update_mode = SPARSE_MODE_UPDATE_MAP[sparse_mode]

        if update_mode == LEFT_UP_TO_LEFT_UP:
            offset = -split_id * local_q_len
            new_pre_tokens = new_pre_tokens + offset
            new_next_tokens = new_next_tokens - offset
        elif update_mode == LEFT_UP_TO_RIGHT_DOWN:
            offset = local_kv_len - (split_id + 1) * local_q_len
            new_pre_tokens = new_pre_tokens + offset
            new_next_tokens = new_next_tokens - offset
        elif update_mode == RIGHT_DOWN_TO_RIGHT_DOWN:
            offset = (split_num - split_id - 1) * local_q_len
            new_pre_tokens = new_pre_tokens + offset
            new_next_tokens = new_next_tokens - offset

        return new_sparse_mode, new_pre_tokens, new_next_tokens
817
818
819
820
821
822
823
824
825
826
        is_dynamic: bool,
    ) -> Tuple[int, int, int]:
        """Compute adjusted sparse parameters based on dynamic or static shape."""
        if is_dynamic:
            if kv_seq_split_num > 1:
                raise NotImplementedError(
                    f"For {self.op_name}, dynamic shape with KV sequence sharding "
                    f"(kv_seq_split_num={kv_seq_split_num}) is not yet supported. "
                    f"The dynamic path currently uses local KV length directly, "
                    f"while the static path multiplies by kv_seq_split_num to obtain "
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
            # Scalar parameters from pyboost may arrive as Tensor objects.
            # Arithmetic and comparison operations on such Tensor scalars
            # would trigger device kernel calls, so convert them to Python
            # native types early.
            head_num = int(self._to_python_scalar(p_head_num))
            keep_prob = self._to_python_scalar(p_keep_prob)
            scale_value = self._to_python_scalar(p_scale_value)
            pre_tokens = int(self._to_python_scalar(p_pre_tokens))
            next_tokens = int(self._to_python_scalar(p_next_tokens))
            inner_precise = int(self._to_python_scalar(p_inner_precise))
            sparse_mode = int(self._to_python_scalar(p_sparse_mode))

            # Ensure the runtime input_layout matches the cached value used
            # for sharding derivation and validation.
            runtime_input_layout = _resolve_input_layout(
                self._to_python_scalar(p_input_layout)
            )
            if runtime_input_layout != input_layout:
                raise ValueError(
                    f"For {self.op_name}, runtime input_layout {runtime_input_layout!r} "
                    f"does not match the cached input_layout {input_layout!r} "
                    f"used for sharding inference. This may indicate an incorrect "
                    f"dispatcher cache key."
961
962
963
964
965
966
967
968
969
                )

            is_varlen = input_layout == "TND" and actual_seq_qlen is not None
            self._validate_attn_mask(attn_mask, sparse_mode, input_layout, is_varlen)
            FlashAttentionScoreDistributedOp._validate_real_shift_configuration(
                real_shift, sparse_mode)

            split_info = self._get_split_info(query_layout, input_layout)
            head_split_num = split_info["head"]
983
984
985
986
987
988
989
990
991
                return FlashAttentionScoreDistributedOp._truncate_result(result)

            adjusted_head_num = self._adjust_head_num(head_num, head_split_num)

            (adjusted_sparse_mode, adjusted_pre_tokens, adjusted_next_tokens,
             adjusted_actual_seq_qlen, adjusted_actual_seq_kvlen) = self._apply_seq_split_adjustments(
                query, key, query_layout, key_layout, input_layout,
                sparse_mode, pre_tokens, next_tokens,
                actual_seq_qlen, actual_seq_kvlen,
991
992
993
994
995
996
997
998
999
                actual_seq_qlen, actual_seq_kvlen,
                seq_split_num, lb_split_id, lb_split_num,
            )

            result = func(
                query, key, value,
                real_shift, drop_mask, padding_mask, attn_mask, prefix,
                adjusted_actual_seq_qlen, adjusted_actual_seq_kvlen,
                int(adjusted_head_num), keep_prob, scale_value,
1000
1001
1002
1003
1004
1005
1006
1007
1008
                int(adjusted_pre_tokens), int(adjusted_next_tokens), inner_precise,
                p_input_layout, int(adjusted_sparse_mode),
            )

            return FlashAttentionScoreDistributedOp._truncate_result(result)

        return _expanded_impl

    def _apply_seq_split_adjustments(  # pylint: disable=too-many-arguments,too-many-locals
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        Returns:
            Tuple of (adjusted_sparse_mode, adjusted_pre_tokens, adjusted_next_tokens,
                      adjusted_actual_seq_qlen, adjusted_actual_seq_kvlen).
        """
        adjusted_sparse_mode = sparse_mode
        adjusted_pre_tokens = pre_tokens
        adjusted_next_tokens = next_tokens
        adjusted_actual_seq_qlen = actual_seq_qlen
        adjusted_actual_seq_kvlen = actual_seq_kvlen

        if seq_split_num > 1 or lb_split_id is not None:
            dynamic_info = self._get_dynamic_shape_info(query, key, input_layout)
            is_dynamic = dynamic_info.get('is_dynamic', False)

            if lb_split_id is not None:
                if lb_split_num is None:
                    raise ValueError("lb_split_num must not be None when lb_split_id is set")
                split_id = lb_split_id
                seq_split_num = lb_split_num
            else:
                split_id = self._get_split_id(query_layout, input_layout)
            seq_dim_idx = self._get_seq_dim_idx(self._layout_dims.get(input_layout, {}))

            if seq_dim_idx is None:
                raise ValueError(
                    f"Cannot infer seq/total dim for input_layout={input_layout}"
                )

            kv_seq_split_num = 1
            if key_layout is not None:
                kv_split_info = self._get_split_info(key_layout, input_layout)
                kv_seq_split_num = kv_split_info["seq"]

            self._check_seq_sharding_compatibility(
                query_layout, key_layout, input_layout,
                seq_dim_idx, seq_split_num, kv_seq_split_num
            )

            (adjusted_sparse_mode,
             adjusted_pre_tokens,
             adjusted_next_tokens) = self._compute_adjusted_sparse_params(
                query, key,
                sparse_mode, pre_tokens, next_tokens,
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
                split_id, seq_split_num, seq_dim_idx,
                kv_seq_split_num, is_dynamic,
            )

            if input_layout == "TND":
                (adjusted_sparse_mode,
                 adjusted_pre_tokens,
                 adjusted_next_tokens,
                 adjusted_actual_seq_qlen,
                 adjusted_actual_seq_kvlen) = self._adjust_tnd_layout_params(
1085
1086
1087
1088
1089
1090
1091
1092
1093
                        kv_seq_split_num=kv_seq_split_num, is_dynamic=is_dynamic,
                    ),
                )

        return (adjusted_sparse_mode, adjusted_pre_tokens, adjusted_next_tokens,
                adjusted_actual_seq_qlen, adjusted_actual_seq_kvlen)

    def _get_seq_dim_idx(self, dims: dict) -> Optional[int]:
        """Get the sequence dimension index."""
hyper_parallel/core/shard/ops/parallel_npu_flash_attention_score.py
106
107
108
109
110
111
112
113
114

    Returns:
        tuple: (positional_args_tuple, empty_kwargs_dict)
    """
    return (
        query, key, value, head_num, input_layout,
        pse, padding_mask, atten_mask,
        scale, keep_prob, pre_tockens, next_tockens,
        inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen,
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_npu_fusion_attention_args(*args, **kwargs)
        query, key, value = args[0], args[1], args[2]
        input_layout = args[4]

        local_args = (
            query.to_local() if hasattr(query, '_layout') else query,
            key.to_local() if hasattr(key, '_layout') else key,
            value.to_local() if hasattr(value, '_layout') else value,
        ) + args[3:]
        local_kwargs = {}

        query_layout = query.layout if hasattr(query, "_layout") else None
        key_layout = key.layout if hasattr(key, "_layout") else None
        value_layout = value.layout if hasattr(value, "_layout") else None

        cache_values = [query_layout, key_layout, value_layout, input_layout]
        return local_args, local_kwargs, cache_values

    def _is_dynamic_shape(self, tensor: Tensor, dim: int) -> bool:
        """Check if tensor has dynamic shape at given dimension."""
        try:
362
363
364
365
366
367
368
369
370
371
        #
        # kv_seq_split_num > 1 is blocked by a guard in
        # _compute_adjusted_sparse_params before reaching this function,
        # so local_kv_len == global_kv_len is guaranteed here.
        local_q_len = query.shape[seq_dim_idx]
        local_kv_len = key.shape[seq_dim_idx]

        if sparse_mode in (SPARSE_DEFAULT_MASK, SPARSE_BAND):
            new_pre_tockens = pre_tockens
            new_next_tockens = next_tockens
369
370
371
372
373
374
375
376
377
        if sparse_mode in (SPARSE_DEFAULT_MASK, SPARSE_BAND):
            new_pre_tockens = pre_tockens
            new_next_tockens = next_tockens
        else:
            new_pre_tockens = local_kv_len
            new_next_tockens = 0

        new_sparse_mode = SPARSE_BAND if sparse_mode != SPARSE_DEFAULT_MASK else sparse_mode
        update_mode = SPARSE_MODE_UPDATE_MAP[sparse_mode]
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        new_sparse_mode = SPARSE_BAND if sparse_mode != SPARSE_DEFAULT_MASK else sparse_mode
        update_mode = SPARSE_MODE_UPDATE_MAP[sparse_mode]

        if update_mode == LEFT_UP_TO_LEFT_UP:
            offset = -split_id * local_q_len
            new_pre_tockens = new_pre_tockens + offset
            new_next_tockens = new_next_tockens - offset
        elif update_mode == LEFT_UP_TO_RIGHT_DOWN:
            offset = local_kv_len - (split_id + 1) * local_q_len
            new_pre_tockens = new_pre_tockens + offset
            new_next_tockens = new_next_tockens - offset
        elif update_mode == RIGHT_DOWN_TO_RIGHT_DOWN:
            offset = (split_num - split_id - 1) * local_q_len
            new_pre_tockens = new_pre_tockens + offset
            new_next_tockens = new_next_tockens - offset

        return new_sparse_mode, new_pre_tockens, new_next_tockens
541
542
543
544
545
546
547
548
549
        Raises:
            ValueError: If any validation rule is violated.
        """
        if query_layout is None:
            raise ValueError(
                f"For {op_name}, query layout cannot be None"
            )

        NPUFlashAttentionScoreDistributedOp._validate_sharding_consistency(
587
588
589
590
591
592
593
594
        value_layout = cache_values[2]
        input_layout_str = cache_values[3]

        if not isinstance(input_layout_str, str):
            raise ValueError(
                f"For {self.op_name}, input_layout should be a string, "
                f"but got {type(input_layout_str)}"
            )
607
608
609
610
611
612
613
614
615
        )

        attention_out_layout = copy.deepcopy(query_layout)
        if attention_out_layout.placements is None and attention_out_layout.tensor_map is not None:
            attention_out_layout.tensor_map_to_placement()

        softmax_layout = self._infer_softmax_layout_by_input_layout(
            query_layout, input_layout_str, ""
        )
902
903
904
905
906
907
908
909
910
911
        is_dynamic: bool,
    ) -> Tuple[int, int, int]:
        """Compute adjusted sparse parameters based on dynamic or static shape."""
        if is_dynamic:
            if kv_seq_split_num > 1:
                raise NotImplementedError(
                    f"For {self.op_name}, dynamic shape with KV sequence sharding "
                    f"(kv_seq_split_num={kv_seq_split_num}) is not yet supported. "
                    f"The dynamic path currently uses local KV length directly, "
                    f"while the static path multiplies by kv_seq_split_num to obtain "
1009
1010
1011
1012
1013
1014
1015
1016
1017
        query_layout = cache_values[0]
        key_layout = cache_values[1]

        if query_layout is None:
            return None

        def _expanded_impl(  # pylint: disable=R0913
            query,
            key,
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
            dynamic_info = self._get_dynamic_shape_info(query, key, input_layout)
            is_dynamic = dynamic_info.get('is_dynamic', False)

            if lb_split_id is not None:
                if lb_split_num is None:
                    raise ValueError(
                        "lb_split_num must not be None when lb_split_id is set"
                    )
                split_id = lb_split_id
                seq_split_num = lb_split_num
hyper_parallel/core/shard/ops/parallel_outer.py
28
29
30
31
32
33
34
35
36

def _get_alias_shard_set(dim_alias):
    if isinstance(dim_alias, str):
        return {dim_alias} if dim_alias != "None" else set()
    return set(dim_alias)


class OuterDistributedOp(DistributedOp):
    """Distributed implementation for torch.outer."""
hyper_parallel/core/shard/ops/parallel_scaled_dot_product_attention.py
32
33
34
35
36
37
38
39
40


def _normalize_sdpa_args(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
                         enable_gqa=False):
    return (query, key, value, attn_mask, dropout_p, is_causal, scale), {'enable_gqa': enable_gqa}


class ScaledDotProductAttentionDistributedOp(DistributedOp):
    """Distributed operator for torch.nn.functional.scaled_dot_product_attention.
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        Returns:
            tuple: (local_args, local_kwargs, cache_values) where local_args contains
                local tensors and runtime scalars, and cache_values contains Layout objects.
        """
        args, kwargs = _normalize_sdpa_args(*args, **kwargs)
        query, key, value, attn_mask, dropout_p, is_causal, scale = args
        enable_gqa = kwargs['enable_gqa']

        if hasattr(attn_mask, '_layout'):
            raise NotImplementedError(
                f"For {self.op_name}, DTensor attn_mask is not supported yet."
            )

        local_args = (
            query.to_local() if hasattr(query, '_layout') else query,
            key.to_local() if hasattr(key, '_layout') else key,
            value.to_local() if hasattr(value, '_layout') else value,
            attn_mask,
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            dropout_p,
            is_causal,
            scale,
        )
        local_kwargs = {'enable_gqa': enable_gqa}

        cache_values = [
            query.layout if hasattr(query, '_layout') else None,
            key.layout if hasattr(key, '_layout') else None,
            value.layout if hasattr(value, '_layout') else None,
        ]
        return local_args, local_kwargs, cache_values

    @staticmethod
    def _normalize_dim_map(dim_map):
        """Normalize dim_map to string representation."""
341
342
343
344
345
346
347
348
349
                f"Query ndim: {query_ndim}\n"
                f"Key ndim: {len(key_layout.alias_tensor_map)}"
            )
        if value_layout is not None and len(value_layout.alias_tensor_map) != query_ndim:
            raise ValueError(
                f"For {op_name}, Query, Key and Value must have the same rank.\n"
                f"Query ndim: {query_ndim}\n"
                f"Value ndim: {len(value_layout.alias_tensor_map)}"
            )
546
547
548
549
550
551
552
553
554
                if lb_split_id is not None:
                    if lb_split_num is None:
                        raise ValueError("lb_split_num must not be None when lb_split_id is set")
                    split_id = lb_split_id
                    seq_split_num = lb_split_num
                else:
                    split_id = self._get_split_id(query_layout, dims)
                local_q_len = query.shape[dims["seq"]]
                global_kv_len = key.shape[dims["seq"]]
hyper_parallel/core/shard/ops/parallel_transpose.py
64
65
66
67
68
69
70
71
72
            local_args = (input_tensor.to_local(), dim0, dim1)
            local_kwargs = {}
            cache_values = [input_tensor.layout, dim0, dim1]
        else:
            raise ValueError(
                f"For TransposeDistributedOp, unsupported op_name: {self.op_name}. "
                f"Expected 'Transpose', 'transpose', 'permute', "
                f"'TransposeView', or 'TransposeExtView'."
            )
110
111
112
113
114
115
116
117
        if self.op_name in ("Transpose", "permute", "TransposeView"):
            axis = cache_values[1]

            if not isinstance(axis, (list, tuple)):
                raise ValueError(
                    f"For {self.op_name}, axis should be a list or tuple, "
                    f"but got {type(axis)}."
                )
126
127
128
129
130
131
132
133
134
            # check if axis is a permutation
            seen = set()
            for v in axis:
                if not isinstance(v, int):
                    raise ValueError(
                        f"For {self.op_name}, axis elements must be integers, "
                        f"but got {type(v)}."
                    )
                if v < 0 or v >= ndim or v in seen: