Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_chunk_view.py 81.4% 51-53,55-56,58-59,124
hyper_parallel/core/shard/ops/parallel_histc_ext.py 79.5% 29,57-62,109,114
hyper_parallel/core/shard/ops/parallel_inplace_scatter_value.py 69.2% 27,44-45,47,54,59-60,117
hyper_parallel/core/shard/ops/parallel_matmul.py 47.4% 92,110-115,139,145,155,185,203-208,235,242,267,288,360,378-383,411,416,428,441,460-465,495,500,524
hyper_parallel/core/shard/ops/parallel_scatter_update.py 64.0% 47-48,50,56,61-62,89,106,131
hyper_parallel/core/shard/ops/parallel_stack.py 75.0% 92
hyper_parallel/core/shard/ops/parallel_chunk_view.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_chunk_view_args(*args, **kwargs)
        input_tensor, chunks, dim = args
        input_shape = input_tensor.shape

        local_args = (input_tensor.to_local(), chunks, dim)
        local_kwargs = {}

        cache_values = [input_tensor.layout, chunks, dim, input_shape]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layouts for ChunkView operator.
120
121
122
123
124
125
126
127
128
            )

        mapping = alias_map[dim]
        if isinstance(mapping, (list, tuple)):
            is_sharded = any(m != "None" for m in mapping)
        else:
            is_sharded = mapping != "None"

        if is_sharded:
hyper_parallel/core/shard/ops/parallel_histc_ext.py
25
26
27
28
29
30
31
32
33
platform = get_platform()


def _normalize_histc_args(x, bins=100, min=0, max=0):  # pylint: disable=W0622
    return (x, bins, min, max), {}


class HistcExtDistributedOp(DistributedOp):
    """
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_histc_args(*args, **kwargs)
        x, bins, min_val, max_val = args
        local_args = (x.to_local(), bins, min_val, max_val)
        local_kwargs = {}
        cache_values = [x.layout, bins, min_val, max_val]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for HistcExt operator.
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                f"For {self.op_name}, bins should be a positive integer, "
                f"but got {bins}"
            )
        if not isinstance(min_val, (int, float)):
            raise ValueError(
                f"For {self.op_name}, min should be a number, "
                f"but got {type(min_val).__name__}"
            )
        if not isinstance(max_val, (int, float)):
            raise ValueError(
                f"For {self.op_name}, max should be a number, "
                f"but got {type(max_val).__name__}"
            )
        if min_val > max_val:
hyper_parallel/core/shard/ops/parallel_inplace_scatter_value.py
23
24
25
26
27
28
29
30
31


def _normalize_inplace_scatter_value_args(input_tensor, dim, index, value):
    """Normalize InplaceScatterValue arguments to positional args + empty kwargs."""
    return (input_tensor, dim, index, value), {}


class InplaceScatterValueDistributedOp(DistributedOp):
    """Distributed implementation for InplaceScatterValue operator."""
40
41
42
43
44
45
46
47
48
49
50
51

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_inplace_scatter_value_args(*args, **kwargs)
        input_tensor, dim, index, value = args

        local_args = (
            input_tensor.to_local() if hasattr(input_tensor, '_layout') else input_tensor,
            dim,
            index.to_local() if hasattr(index, '_layout') else index,
            value,
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
            index.to_local() if hasattr(index, '_layout') else index,
            value,
        )

        cache_values = [
            input_tensor.layout if hasattr(input_tensor, '_layout') else None,
            index.layout if hasattr(index, '_layout') else None,
            dim,
        ]
        local_kwargs = kwargs
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for InplaceScatterValue operator.
113
114
115
116
117
118
119
120
                f"For {self.op_name}, dim {original_dim} is out of bounds for tensor with {ndim} dims"
            )

        if len(input_map) != len(index_map):
            raise ValueError(
                f"For {self.op_name}, input and index must have the same number of dimensions, "
                f"but got input rank={len(input_map)}, index rank={len(index_map)}"
            )
hyper_parallel/core/shard/ops/parallel_matmul.py
88
89
90
91
92
93
94
95
96
            out_layout.set_partial_by_dev_axis(axis_alias, op)


def _normalize_matmul_ext_args(x, w):
    return (x, w), {}


class MatMulExtDistributedOp(DistributedOp):
    """Distributed implementation for MatMul operator."""
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        Returns:
            tuple: (local_args, local_kwargs, cache_values) where local_args contains
                local tensors for x and w; cache_values contains [x_layout, w_layout].
        """
        args, kwargs = _normalize_matmul_ext_args(*args, **kwargs)
        x_tensor, w_tensor = args[0], args[1]
        local_args = (x_tensor.to_local(), w_tensor.to_local())
        local_kwargs = {}
        cache_values = [x_tensor.layout, w_tensor.layout]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for MatMul operator (output = x @ w).
135
136
137
138
139
140
141
142
143
        """
        x_layout = cache_values[0]
        w_layout = cache_values[1]
        if not x_layout or not w_layout:
            raise ValueError(
                f"For {self.op_name}, x_layout: {x_layout}, w_layout: {w_layout}"
            )
        x_mesh_shape = x_layout.mesh_shape
        w_mesh_shape = w_layout.mesh_shape
141
142
143
144
145
146
147
148
            )
        x_mesh_shape = x_layout.mesh_shape
        w_mesh_shape = w_layout.mesh_shape
        if x_mesh_shape != w_mesh_shape:
            raise ValueError(
                f"For {self.op_name}, inputs must have same mesh_shape, "
                f"but got x: {x_mesh_shape} and w: {w_mesh_shape}"
            )
151
152
153
154
155
156
157
158
        w_map = w_layout.alias_tensor_map
        contract_dim = len(x_map) - 1
        w_contract_dim = len(w_map) - 2
        if x_map[contract_dim] != w_map[w_contract_dim]:
            raise ValueError(
                f"For {self.op_name}, contracting dimensions must have same layout, "
                f"but got x: {x_map[contract_dim]} and w: {w_map[w_contract_dim]}"
            )
181
182
183
184
185
186
187
188
189
        return ((out_layout,), None)


def _normalize_matmul_args(x, w, transpose_a=False, transpose_b=False):
    return (x, w, transpose_a, transpose_b), {}


class MatMulDistributedOp(DistributedOp):
    """Distributed implementation for MatMul operator."""
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        Returns:
            tuple: (local_args, local_kwargs, cache_values) where local_args contains
                local tensors for x and w; cache_values contains [x_layout, w_layout, transpose_a, transpose_b].
        """
        args, kwargs = _normalize_matmul_args(*args, **kwargs)
        x_tensor, w_tensor, transpose_a, transpose_b = args
        local_args = (x_tensor.to_local(), w_tensor.to_local())
        local_kwargs = {}
        cache_values = [x_tensor.layout, w_tensor.layout, transpose_a, transpose_b]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for MatMul operator (output = x @ w, with possible transpose).
231
232
233
234
235
236
237
238
239
        transpose_a = cache_values[2]
        transpose_b = cache_values[3]

        if not x_layout or not w_layout:
            raise ValueError(
                f"For {self.op_name}, x_layout: {x_layout}, w_layout: {w_layout}"
            )

        x_mesh_shape = x_layout.mesh_shape
238
239
240
241
242
243
244
245

        x_mesh_shape = x_layout.mesh_shape
        w_mesh_shape = w_layout.mesh_shape
        if x_mesh_shape != w_mesh_shape:
            raise ValueError(
                f"For {self.op_name}, inputs must have same mesh_shape, "
                f"but got x: {x_mesh_shape} and w: {w_mesh_shape}"
            )
263
264
265
266
267
268
269
270
            w_contract_dim = len(w_map) - 2  # Second to last dimension

        # Validate contracting dimensions
        if x_map[x_contract_dim] != w_map[w_contract_dim]:
            raise ValueError(
                f"For {self.op_name}, contracting dimensions must have same layout, "
                f"but got x: {x_map[x_contract_dim]} and w: {w_map[w_contract_dim]}"
            )
284
285
286
287
288
289
290
291
292
        # Set partial status
        if x_map[x_contract_dim] != "None":
            if isinstance(x_map[x_contract_dim], tuple):
                for axis in x_map[x_contract_dim]:
                    out_layout.set_partial_by_dev_axis(axis, 'sum')
            else:
                out_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum')

        return ((out_layout,), None)
356
357
358
359
360
361
362
363
364
        return output_layout


def _normalize_batch_matmul_ext_args(x, w):
    return (x, w), {}


class BatchMatMulExtDistributedOp(BaseBatchMatMulDistributedOp):
    """Distributed implementation for BatchMatMulExt operator."""
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        Returns:
            tuple: (local_args, local_kwargs, cache_values) where local_args contains
                local tensors for x and w; cache_values contains [x_layout, w_layout].
        """
        args, kwargs = _normalize_batch_matmul_ext_args(*args, **kwargs)
        x_tensor, w_tensor = args[0], args[1]
        local_args = (x_tensor.to_local(), w_tensor.to_local())
        local_kwargs = {}
        cache_values = [x_tensor.layout, w_tensor.layout]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for BatchMatMulExt operator (output = x @ w).
407
408
409
410
411
412
413
414
415
416
417
418
419
        x_layout = cache_values[0]
        w_layout = cache_values[1]

        if not x_layout or not w_layout:
            raise ValueError(
                f"For {self.op_name}, x_layout: {x_layout}, w_layout: {w_layout}"
            )

        if x_layout.mesh_shape != w_layout.mesh_shape:
            raise ValueError(
                f"For {self.op_name}, inputs must have same mesh_shape, "
                f"but got x: {x_layout.mesh_shape} and w: {w_layout.mesh_shape}"
            )
424
425
426
427
428
429
430
431
        # contracting dims
        x_contract = x_map[-1]
        w_contract = w_map[-2]
        if x_contract != w_contract:
            raise ValueError(
                f"For {self.op_name}, contracting (M) dim layouts must match, "
                f"but got x: {x_contract} and w: {w_contract}"
            )
437
438
439
440
441
442
443
444
445
        return ((self._build_output_layout(x_layout, w_layout, merged_batch, x_n, w_p, x_contract),), None)


def _normalize_batch_matmul_args(x, w, transpose_a=False, transpose_b=False):
    return (x, w, transpose_a, transpose_b), {}


class BatchMatMulDistributedOp(BaseBatchMatMulDistributedOp):
    """Distributed implementation for BatchMatMul operator."""
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            tuple: (local_args, local_kwargs, cache_values) where local_args contains
                local tensors for x and w; cache_values contains
                [x_layout, w_layout, transpose_a, transpose_b].
        """
        args, kwargs = _normalize_batch_matmul_args(*args, **kwargs)
        x_tensor, w_tensor, transpose_a, transpose_b = args
        local_args = (x_tensor.to_local(), w_tensor.to_local())
        local_kwargs = {}
        cache_values = [x_tensor.layout, w_tensor.layout, transpose_a, transpose_b]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for BatchMatMul operator (output = x @ w, with possible transpose).
491
492
493
494
495
496
497
498
499
500
501
502
503
        transpose_a = cache_values[2]
        transpose_b = cache_values[3]

        if not x_layout or not w_layout:
            raise ValueError(
                f"For {self.op_name}, x_layout: {x_layout}, w_layout: {w_layout}"
            )

        if x_layout.mesh_shape != w_layout.mesh_shape:
            raise ValueError(
                f"For {self.op_name}, inputs must have same mesh_shape, "
                f"but got x: {x_layout.mesh_shape} and w: {w_layout.mesh_shape}"
            )
520
521
522
523
524
525
526
527
            w_contract = w_map[-2]
            w_p = w_map[-1]

        if x_contract != w_contract:
            raise ValueError(
                f"For {self.op_name}, contracting (M) dim layouts must match, "
                f"but got x: {x_contract} and w: {w_contract}"
            )
hyper_parallel/core/shard/ops/parallel_scatter_update.py
43
44
45
46
47
48
49
50
51
52
53
54

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_scatter_update_args(*args, **kwargs)
        x, indices, updates = args

        local_args = (
            x.to_local() if hasattr(x, '_layout') else x,
            indices.to_local() if hasattr(indices, '_layout') else indices,
            updates.to_local() if hasattr(updates, '_layout') else updates,
        )
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
            indices.to_local() if hasattr(indices, '_layout') else indices,
            updates.to_local() if hasattr(updates, '_layout') else updates,
        )

        cache_values = [
            x.layout if hasattr(x, '_layout') else None,
            indices.layout if hasattr(indices, '_layout') else None,
            updates.layout if hasattr(updates, '_layout') else None,
        ]
        local_kwargs = kwargs
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for ScatterUpdate operator.
85
86
87
88
89
90
91
92
93
        """
        input_layout, indices_layout, updates_layout = cache_values

        if input_layout is None:
            raise ValueError(
                f"For {self.op_name}, input layout should not be None"
            )

        if indices_layout is None:
102
103
104
105
106
107
108
109
110

        # Partial inputs are intentionally allowed. The output inherits Partial status
        # from the input layout (see lines below), making this a Partial-preserving op.
        if not self._allow_partial_inputs:
            self._check_partial_inputs([input_layout])

        self._validate_strategy(input_layout, indices_layout, updates_layout)

        output_layout = Layout(
127
128
129
130
131
132
133
134
135
        indices_map = indices_layout.alias_tensor_map
        updates_map = updates_layout.alias_tensor_map

        if not input_map:
            raise ValueError(
                f"For {self.op_name}, input tensor map should not be empty"
            )

        if input_map[0] != "None":
hyper_parallel/core/shard/ops/parallel_stack.py
88
89
90
91
92
93
94
95
96

        valid_layouts = [lyt for lyt in layouts if lyt is not None]

        if not valid_layouts:
            raise ValueError(
                f"For {self.op_name}, stack requires at least one input DTensor."
            )

        # Reference layout to validate consistency across all input tensors