Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / context_parallel / dsa_context_parallel.py: 82%
250 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-21 07:51 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-21 07:51 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Context parallel style for DeepSeek Sparse Attention (DSA).
17These PyNative module-level styles are companions to the DSA distributed
18operators. The distributed operators define the per-op layout rules for
19``lightning_indexer``, ``npu_sparse_flash_attention`` and indexer-loss custom
20ops; these styles prepare module inputs so those rules can be selected by the
21DTensor dispatcher.
23The first implementation intentionally supports only Colossal-style CP:
24query-side tensors are sharded on sequence, while key-side tensors are gathered
25to CP-replicated layouts. Ulysses/head sharding is rejected because the current
26DSA kernels require attention head, index head, head dim and sparse top-k dims to
27stay replicated.
28"""
29from typing import Any, Optional
31from hyper_parallel.core.context_parallel.context_parallel import _ensure_1d, _gather_seq
32from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
33from hyper_parallel.core.dtensor.dtensor import DTensor
34from hyper_parallel.core.dtensor.placement_types import Replicate, Shard
35from hyper_parallel.core.tensor_parallel.style import ParallelStyle
36from hyper_parallel.platform import get_platform
38platform = get_platform()
39Module = platform.Module
42_SUPPORTED_LAYOUTS = ("BSND", "TND")
45def _is_tensor_or_dtensor(value: Any) -> bool:
46 """Return True for framework tensors and HyperParallel DTensors."""
47 return isinstance(value, DTensor) or platform.is_tensor(value)
50def _to_sequence_shard(value: Any, device_mesh: DeviceMesh, seq_dim: int) -> Any:
51 """Annotate ``value`` as sequence-sharded on ``device_mesh``."""
52 if not _is_tensor_or_dtensor(value):
53 return value
54 if isinstance(value, DTensor):
55 return value.redistribute(device_mesh, (Shard(seq_dim),))
56 return DTensor.from_local(value, device_mesh, (Shard(seq_dim),))
59def _to_sequence_replicate(value: Any, device_mesh: DeviceMesh, seq_dim: int) -> Any:
60 """Annotate ``value`` as local sequence shard, then all-gather on CP."""
61 if not _is_tensor_or_dtensor(value):
62 return value
63 if isinstance(value, DTensor):
64 return value.redistribute(device_mesh, (Replicate(),))
65 return _gather_seq(value, device_mesh, seq_dim)
68def _maybe_replace_arg(args: list, index: Optional[int], fn) -> None:
69 """Apply ``fn`` to ``args[index]`` when the index points to an existing arg."""
70 if index is None or index >= len(args):
71 return
72 args[index] = fn(args[index])
75def _maybe_replace_kwarg(kwargs: dict, name: Optional[str], fn) -> None:
76 """Apply ``fn`` to ``kwargs[name]`` when the key exists."""
77 if name is None or name not in kwargs:
78 return
79 kwargs[name] = fn(kwargs[name])
82def _validate_layout_and_mode(style_name: str, layout: str, mode: str) -> tuple[str, int]:
83 """Return normalized layout and sequence dim for the DSA CP style."""
84 layout = layout.upper()
85 if layout not in _SUPPORTED_LAYOUTS:
86 raise ValueError(f"layout must be one of {_SUPPORTED_LAYOUTS}, but got {layout!r}.")
87 if mode != "colossal":
88 raise ValueError(f"{style_name} currently supports only mode='colossal'.")
89 return layout, 1 if layout == "BSND" else 0
92def _finalize_output(value: Any, use_local_output: bool) -> Any:
93 """Convert direct DTensor outputs, or one-level tuple/list outputs, to local tensors."""
94 if isinstance(value, DTensor):
95 return value.to_local() if use_local_output else value
96 if isinstance(value, tuple):
97 return tuple(v.to_local() if use_local_output and isinstance(v, DTensor) else v for v in value)
98 if isinstance(value, list):
99 return [v.to_local() if use_local_output and isinstance(v, DTensor) else v for v in value]
100 return value
103def _dtensor_has_partial(value: DTensor) -> bool:
104 """Return whether ``value`` has any Partial placement."""
105 return any(placement.is_partial() for placement in value.placements)
108def _dtensor_to_local_reducing_partial(value: Any) -> Any:
109 """Convert a DTensor to local, reducing Partial only when communication is needed."""
110 if not isinstance(value, DTensor):
111 return value
112 if _dtensor_has_partial(value) and value.device_mesh.mesh.numel() > 1:
113 value = value.reduce_partial()
114 return value.to_local()
117def _register_boundary_hooks(module: Module, pre_hook, use_local_output: bool) -> None:
118 """Register a DSA boundary pre-hook and its public output conversion hook."""
119 platform.register_forward_pre_hook(module, pre_hook, with_kwargs=True)
120 module.register_forward_hook(
121 lambda _module, _args, outputs: _finalize_output(outputs, use_local_output)
122 )
125def _configure_sparse_attention_boundary( # pylint: disable=too-many-arguments
126 style,
127 *,
128 layout: str,
129 mode: str,
130 query_index: Optional[int],
131 key_index: Optional[int],
132 value_index: Optional[int],
133 topk_index: Optional[int],
134 query_kwarg_name: Optional[str],
135 key_kwarg_name: Optional[str],
136 value_kwarg_name: Optional[str],
137 topk_kwarg_name: Optional[str],
138 query_rope_index: Optional[int],
139 key_rope_index: Optional[int],
140 query_rope_kwarg_name: Optional[str],
141 key_rope_kwarg_name: Optional[str],
142 use_local_output: bool,
143) -> None:
144 """Store sparse-attention boundary configuration on ``style``."""
145 layout, seq_dim = _validate_layout_and_mode(style.__class__.__name__, layout, mode)
146 style.layout = layout
147 style.mode = mode
148 style.seq_dim = seq_dim
149 style.query_index = query_index
150 style.key_index = key_index
151 style.value_index = value_index
152 style.topk_index = topk_index
153 style.query_kwarg_name = query_kwarg_name
154 style.key_kwarg_name = key_kwarg_name
155 style.value_kwarg_name = value_kwarg_name
156 style.topk_kwarg_name = topk_kwarg_name
157 style.query_rope_index = query_rope_index
158 style.key_rope_index = key_rope_index
159 style.query_rope_kwarg_name = query_rope_kwarg_name
160 style.key_rope_kwarg_name = key_rope_kwarg_name
161 style.use_local_output = use_local_output
164def _apply_sparse_attention_boundary(style, module: Module, device_mesh: DeviceMesh) -> Module:
165 """Register low-level DSA sparse-attention boundary hooks for ``style``."""
166 cp_mesh = _ensure_1d(device_mesh)
168 def _shard_query_side(value: Any) -> Any:
169 return _to_sequence_shard(value, cp_mesh, style.seq_dim)
171 def _replicate_key_side(value: Any) -> Any:
172 return _to_sequence_replicate(value, cp_mesh, style.seq_dim)
174 def _pre_hook(hook_module, args, kwargs):
175 del hook_module
176 new_args = list(args)
177 new_kwargs = dict(kwargs)
178 _maybe_replace_arg(new_args, style.query_index, _shard_query_side)
179 _maybe_replace_arg(new_args, style.key_index, _replicate_key_side)
180 _maybe_replace_arg(new_args, style.value_index, _replicate_key_side)
181 _maybe_replace_arg(new_args, style.topk_index, _shard_query_side)
182 _maybe_replace_arg(new_args, style.query_rope_index, _shard_query_side)
183 _maybe_replace_arg(new_args, style.key_rope_index, _replicate_key_side)
184 _maybe_replace_kwarg(new_kwargs, style.query_kwarg_name, _shard_query_side)
185 _maybe_replace_kwarg(new_kwargs, style.key_kwarg_name, _replicate_key_side)
186 _maybe_replace_kwarg(new_kwargs, style.value_kwarg_name, _replicate_key_side)
187 _maybe_replace_kwarg(new_kwargs, style.topk_kwarg_name, _shard_query_side)
188 _maybe_replace_kwarg(new_kwargs, style.query_rope_kwarg_name, _shard_query_side)
189 _maybe_replace_kwarg(new_kwargs, style.key_rope_kwarg_name, _replicate_key_side)
190 return tuple(new_args), new_kwargs
192 _register_boundary_hooks(module, _pre_hook, style.use_local_output)
193 return module
196class DSAIndexerContextParallel(ParallelStyle):
197 """Colossal-style CP hook for a DSA indexer boundary.
199 This style targets a hookable module/cell whose forward signature is shaped
200 like ``(query, key, weights, ...)`` and rewrites only that boundary:
202 - ``query`` and ``weights`` are annotated as ``Shard(seq)``;
203 - ``key`` is all-gathered to ``Replicate()``.
204 """
206 def __init__( # pylint: disable=too-many-arguments
207 self,
208 *,
209 layout: str = "BSND",
210 mode: str = "colossal",
211 query_index: Optional[int] = 0,
212 key_index: Optional[int] = 1,
213 weights_index: Optional[int] = 2,
214 query_kwarg_name: Optional[str] = None,
215 key_kwarg_name: Optional[str] = None,
216 weights_kwarg_name: Optional[str] = None,
217 use_local_output: bool = True,
218 ) -> None:
219 super().__init__()
220 layout, seq_dim = _validate_layout_and_mode(self.__class__.__name__, layout, mode)
221 self.layout = layout
222 self.mode = mode
223 self.seq_dim = seq_dim
224 self.query_index = query_index
225 self.key_index = key_index
226 self.weights_index = weights_index
227 self.query_kwarg_name = query_kwarg_name
228 self.key_kwarg_name = key_kwarg_name
229 self.weights_kwarg_name = weights_kwarg_name
230 self.use_local_output = use_local_output
232 def __repr__(self) -> str:
233 return (
234 f"{self.__class__.__name__}("
235 f"layout={self.layout!r}, mode={self.mode!r}, "
236 f"use_local_output={self.use_local_output})"
237 )
239 def _shard_query_side(self, value: Any, device_mesh: DeviceMesh) -> Any:
240 return _to_sequence_shard(value, device_mesh, self.seq_dim)
242 def _replicate_key_side(self, value: Any, device_mesh: DeviceMesh) -> Any:
243 return _to_sequence_replicate(value, device_mesh, self.seq_dim)
245 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
246 """Register DSA indexer CP hooks on ``module`` and return it."""
247 cp_mesh = _ensure_1d(device_mesh)
249 def _pre_hook(hook_module, args, kwargs):
250 del hook_module
251 new_args = list(args)
252 new_kwargs = dict(kwargs)
253 _maybe_replace_arg(new_args, self.query_index, lambda t: self._shard_query_side(t, cp_mesh))
254 _maybe_replace_arg(new_args, self.key_index, lambda t: self._replicate_key_side(t, cp_mesh))
255 _maybe_replace_arg(new_args, self.weights_index, lambda t: self._shard_query_side(t, cp_mesh))
256 _maybe_replace_kwarg(new_kwargs, self.query_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh))
257 _maybe_replace_kwarg(new_kwargs, self.key_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh))
258 _maybe_replace_kwarg(new_kwargs, self.weights_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh))
259 return tuple(new_args), new_kwargs
261 _register_boundary_hooks(module, _pre_hook, self.use_local_output)
262 return module
265class DSASparseAttentionContextParallel(ParallelStyle):
266 """Colossal-style CP hook for a DSA sparse-attention boundary.
268 This style targets a hookable module/cell whose forward signature is shaped
269 like ``(query, key, value, topk_indices, query_rope, key_rope, ...)`` and
270 rewrites only that boundary:
272 - ``query``, ``topk_indices`` and ``query_rope`` are annotated as ``Shard(seq)``;
273 - ``key``, ``value`` and ``key_rope`` are all-gathered to ``Replicate()``.
274 """
276 def __init__( # pylint: disable=too-many-arguments
277 self,
278 *,
279 layout: str = "BSND",
280 mode: str = "colossal",
281 query_index: Optional[int] = 0,
282 key_index: Optional[int] = 1,
283 value_index: Optional[int] = 2,
284 topk_index: Optional[int] = 3,
285 query_kwarg_name: Optional[str] = None,
286 key_kwarg_name: Optional[str] = None,
287 value_kwarg_name: Optional[str] = None,
288 topk_kwarg_name: Optional[str] = None,
289 query_rope_index: Optional[int] = 4,
290 key_rope_index: Optional[int] = 5,
291 query_rope_kwarg_name: Optional[str] = "query_rope",
292 key_rope_kwarg_name: Optional[str] = "key_rope",
293 use_local_output: bool = True,
294 ) -> None:
295 super().__init__()
296 _configure_sparse_attention_boundary(
297 self,
298 layout=layout,
299 mode=mode,
300 query_index=query_index,
301 key_index=key_index,
302 value_index=value_index,
303 topk_index=topk_index,
304 query_kwarg_name=query_kwarg_name,
305 key_kwarg_name=key_kwarg_name,
306 value_kwarg_name=value_kwarg_name,
307 topk_kwarg_name=topk_kwarg_name,
308 query_rope_index=query_rope_index,
309 key_rope_index=key_rope_index,
310 query_rope_kwarg_name=query_rope_kwarg_name,
311 key_rope_kwarg_name=key_rope_kwarg_name,
312 use_local_output=use_local_output,
313 )
315 def __repr__(self) -> str:
316 return (
317 f"{self.__class__.__name__}("
318 f"layout={self.layout!r}, mode={self.mode!r}, "
319 f"use_local_output={self.use_local_output})"
320 )
322 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
323 """Register DSA sparse-attention CP hooks on ``module`` and return it."""
324 return _apply_sparse_attention_boundary(self, module, device_mesh)
327class DSAIndexerLossContextParallel(ParallelStyle):
328 """Colossal-style CP hook for a DSA indexer-loss kernel boundary.
330 This style targets a hookable module/cell whose forward signature is shaped
331 like ``(query, key, query_index, key_index, weights, topk_indices,
332 softmax_max, softmax_sum, query_rope, key_rope, ...)``. The boundary is
333 expected to start after MF has already done local-only bookkeeping such as
334 ``stop_gradient`` and ``split``.
336 Placements:
337 - ``query``, ``query_index``, ``weights``, ``topk_indices`` and
338 ``query_rope`` are annotated as ``Shard(seq)``;
339 - ``key``, ``key_index`` and ``key_rope`` are all-gathered to
340 ``Replicate()``.
341 """
343 def __init__( # pylint: disable=too-many-arguments
344 self,
345 *,
346 layout: str = "BSND",
347 mode: str = "colossal",
348 query_index: Optional[int] = 0,
349 key_index: Optional[int] = 1,
350 query_indexer_index: Optional[int] = 2,
351 key_indexer_index: Optional[int] = 3,
352 weights_index: Optional[int] = 4,
353 topk_index: Optional[int] = 5,
354 query_rope_index: Optional[int] = 8,
355 key_rope_index: Optional[int] = 9,
356 query_kwarg_name: Optional[str] = None,
357 key_kwarg_name: Optional[str] = None,
358 query_indexer_kwarg_name: Optional[str] = None,
359 key_indexer_kwarg_name: Optional[str] = None,
360 weights_kwarg_name: Optional[str] = None,
361 topk_kwarg_name: Optional[str] = None,
362 query_rope_kwarg_name: Optional[str] = None,
363 key_rope_kwarg_name: Optional[str] = None,
364 use_local_output: bool = True,
365 ) -> None:
366 super().__init__()
367 layout, seq_dim = _validate_layout_and_mode(self.__class__.__name__, layout, mode)
368 self.layout = layout
369 self.mode = mode
370 self.seq_dim = seq_dim
371 self.query_index = query_index
372 self.key_index = key_index
373 self.query_indexer_index = query_indexer_index
374 self.key_indexer_index = key_indexer_index
375 self.weights_index = weights_index
376 self.topk_index = topk_index
377 self.query_rope_index = query_rope_index
378 self.key_rope_index = key_rope_index
379 self.query_kwarg_name = query_kwarg_name
380 self.key_kwarg_name = key_kwarg_name
381 self.query_indexer_kwarg_name = query_indexer_kwarg_name
382 self.key_indexer_kwarg_name = key_indexer_kwarg_name
383 self.weights_kwarg_name = weights_kwarg_name
384 self.topk_kwarg_name = topk_kwarg_name
385 self.query_rope_kwarg_name = query_rope_kwarg_name
386 self.key_rope_kwarg_name = key_rope_kwarg_name
387 self.use_local_output = use_local_output
389 def __repr__(self) -> str:
390 return (
391 f"{self.__class__.__name__}("
392 f"layout={self.layout!r}, mode={self.mode!r}, "
393 f"use_local_output={self.use_local_output})"
394 )
396 def _shard_query_side(self, value: Any, device_mesh: DeviceMesh) -> Any:
397 return _to_sequence_shard(value, device_mesh, self.seq_dim)
399 def _replicate_key_side(self, value: Any, device_mesh: DeviceMesh) -> Any:
400 return _to_sequence_replicate(value, device_mesh, self.seq_dim)
402 @staticmethod
403 def _local_shape(value: Any) -> Optional[tuple]:
404 if isinstance(value, DTensor):
405 return value.local_shape
406 if platform.is_tensor(value):
407 return value.shape
408 return None
410 def _slice_local_key_grad(self, value: Any, module: Module) -> Any:
411 """Convert d_key_index to the original local key-index shard shape."""
412 if not self.use_local_output:
413 return value
414 target_shape = getattr(module, "_hp_dsa_loss_key_index_local_shape", None)
415 if target_shape is None:
416 return _finalize_output(value, use_local_output=True)
418 if isinstance(value, DTensor):
419 value = _dtensor_to_local_reducing_partial(value)
420 if not platform.is_tensor(value):
421 return value
423 target_len = target_shape[self.seq_dim]
424 if value.shape[self.seq_dim] == target_len:
425 return value
427 local_idx = getattr(module, "_hp_dsa_loss_local_idx", 0)
428 start = local_idx * target_len
429 return value.narrow(self.seq_dim, start, target_len)
431 def _process_outputs(self, module: Module, outputs: Any) -> Any:
432 """Finalize indexer-loss outputs, reducing Partial values when needed."""
433 if not self.use_local_output:
434 return outputs
435 if not isinstance(outputs, (tuple, list)) or len(outputs) < 4:
436 return _finalize_output(outputs, use_local_output=True)
438 processed = list(outputs)
439 processed[0] = _dtensor_to_local_reducing_partial(processed[0])
440 processed[1] = self._slice_local_key_grad(processed[1], module)
441 processed[2] = _dtensor_to_local_reducing_partial(processed[2])
442 processed[3] = _dtensor_to_local_reducing_partial(processed[3])
443 return type(outputs)(processed)
445 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
446 """Register DSA indexer-loss CP hooks on ``module`` and return it."""
447 cp_mesh = _ensure_1d(device_mesh)
448 rank_list = list(cp_mesh.rank_list)
449 local_idx = rank_list.index(platform.get_rank()) if platform.get_rank() in rank_list else 0
451 def _pre_hook(hook_module, args, kwargs):
452 new_args = list(args)
453 new_kwargs = dict(kwargs)
454 key_shape = None
455 if self.key_indexer_index is not None and self.key_indexer_index < len(new_args):
456 key_shape = self._local_shape(new_args[self.key_indexer_index])
457 elif self.key_indexer_kwarg_name and self.key_indexer_kwarg_name in new_kwargs:
458 key_shape = self._local_shape(new_kwargs[self.key_indexer_kwarg_name])
459 setattr(hook_module, "_hp_dsa_loss_key_index_local_shape", key_shape)
460 setattr(hook_module, "_hp_dsa_loss_local_idx", local_idx)
462 _maybe_replace_arg(new_args, self.query_index, lambda t: self._shard_query_side(t, cp_mesh))
463 _maybe_replace_arg(new_args, self.key_index, lambda t: self._replicate_key_side(t, cp_mesh))
464 _maybe_replace_arg(new_args, self.query_indexer_index, lambda t: self._shard_query_side(t, cp_mesh))
465 _maybe_replace_arg(new_args, self.key_indexer_index, lambda t: self._replicate_key_side(t, cp_mesh))
466 _maybe_replace_arg(new_args, self.weights_index, lambda t: self._shard_query_side(t, cp_mesh))
467 _maybe_replace_arg(new_args, self.topk_index, lambda t: self._shard_query_side(t, cp_mesh))
468 _maybe_replace_arg(new_args, self.query_rope_index, lambda t: self._shard_query_side(t, cp_mesh))
469 _maybe_replace_arg(new_args, self.key_rope_index, lambda t: self._replicate_key_side(t, cp_mesh))
470 _maybe_replace_kwarg(new_kwargs, self.query_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh))
471 _maybe_replace_kwarg(new_kwargs, self.key_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh))
472 _maybe_replace_kwarg(
473 new_kwargs, self.query_indexer_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)
474 )
475 _maybe_replace_kwarg(
476 new_kwargs, self.key_indexer_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh)
477 )
478 _maybe_replace_kwarg(new_kwargs, self.weights_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh))
479 _maybe_replace_kwarg(new_kwargs, self.topk_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh))
480 _maybe_replace_kwarg(
481 new_kwargs, self.query_rope_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)
482 )
483 _maybe_replace_kwarg(
484 new_kwargs, self.key_rope_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh)
485 )
486 return tuple(new_args), new_kwargs
488 platform.register_forward_pre_hook(module, _pre_hook, with_kwargs=True)
489 module.register_forward_hook(lambda _module, _args, outputs: self._process_outputs(_module, outputs))
490 return module
493class DSAContextParallel(ParallelStyle):
494 """Colossal-style CP hook for a low-level DSA attention boundary.
496 This compatibility style intentionally handles only a direct boundary whose
497 forward signature is shaped like
498 ``(query, key, value, topk_indices, query_rope, key_rope, ...)``. Callers
499 that own a higher-level attention module should locate its indexer and
500 sparse-attention submodules themselves and apply
501 :class:`DSAIndexerContextParallel` and
502 :class:`DSASparseAttentionContextParallel` explicitly.
503 """
505 def __init__( # pylint: disable=too-many-arguments
506 self,
507 *,
508 layout: str = "BSND",
509 mode: str = "colossal",
510 query_index: Optional[int] = 0,
511 key_index: Optional[int] = 1,
512 value_index: Optional[int] = 2,
513 topk_index: Optional[int] = 3,
514 query_kwarg_name: Optional[str] = None,
515 key_kwarg_name: Optional[str] = None,
516 value_kwarg_name: Optional[str] = None,
517 topk_kwarg_name: Optional[str] = None,
518 query_rope_index: Optional[int] = 4,
519 key_rope_index: Optional[int] = 5,
520 query_rope_kwarg_name: Optional[str] = "query_rope",
521 key_rope_kwarg_name: Optional[str] = "key_rope",
522 use_local_output: bool = True,
523 ) -> None:
524 super().__init__()
525 _configure_sparse_attention_boundary(
526 self,
527 layout=layout,
528 mode=mode,
529 query_index=query_index,
530 key_index=key_index,
531 value_index=value_index,
532 topk_index=topk_index,
533 query_kwarg_name=query_kwarg_name,
534 key_kwarg_name=key_kwarg_name,
535 value_kwarg_name=value_kwarg_name,
536 topk_kwarg_name=topk_kwarg_name,
537 query_rope_index=query_rope_index,
538 key_rope_index=key_rope_index,
539 query_rope_kwarg_name=query_rope_kwarg_name,
540 key_rope_kwarg_name=key_rope_kwarg_name,
541 use_local_output=use_local_output,
542 )
544 def __repr__(self) -> str:
545 return (
546 f"{self.__class__.__name__}("
547 f"layout={self.layout!r}, mode={self.mode!r}, "
548 f"use_local_output={self.use_local_output})"
549 )
551 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
552 """Register low-level DSA attention CP hooks on ``module`` and return it."""
553 return _apply_sparse_attention_boundary(self, module, device_mesh)