Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / context_parallel / dsa_context_parallel.py: 82%

250 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-21 07:51 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Context parallel style for DeepSeek Sparse Attention (DSA). 

16 

17These PyNative module-level styles are companions to the DSA distributed 

18operators. The distributed operators define the per-op layout rules for 

19``lightning_indexer``, ``npu_sparse_flash_attention`` and indexer-loss custom 

20ops; these styles prepare module inputs so those rules can be selected by the 

21DTensor dispatcher. 

22 

23The first implementation intentionally supports only Colossal-style CP: 

24query-side tensors are sharded on sequence, while key-side tensors are gathered 

25to CP-replicated layouts. Ulysses/head sharding is rejected because the current 

26DSA kernels require attention head, index head, head dim and sparse top-k dims to 

27stay replicated. 

28""" 

29from typing import Any, Optional 

30 

31from hyper_parallel.core.context_parallel.context_parallel import _ensure_1d, _gather_seq 

32from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

33from hyper_parallel.core.dtensor.dtensor import DTensor 

34from hyper_parallel.core.dtensor.placement_types import Replicate, Shard 

35from hyper_parallel.core.tensor_parallel.style import ParallelStyle 

36from hyper_parallel.platform import get_platform 

37 

38platform = get_platform() 

39Module = platform.Module 

40 

41 

42_SUPPORTED_LAYOUTS = ("BSND", "TND") 

43 

44 

45def _is_tensor_or_dtensor(value: Any) -> bool: 

46 """Return True for framework tensors and HyperParallel DTensors.""" 

47 return isinstance(value, DTensor) or platform.is_tensor(value) 

48 

49 

50def _to_sequence_shard(value: Any, device_mesh: DeviceMesh, seq_dim: int) -> Any: 

51 """Annotate ``value`` as sequence-sharded on ``device_mesh``.""" 

52 if not _is_tensor_or_dtensor(value): 

53 return value 

54 if isinstance(value, DTensor): 

55 return value.redistribute(device_mesh, (Shard(seq_dim),)) 

56 return DTensor.from_local(value, device_mesh, (Shard(seq_dim),)) 

57 

58 

59def _to_sequence_replicate(value: Any, device_mesh: DeviceMesh, seq_dim: int) -> Any: 

60 """Annotate ``value`` as local sequence shard, then all-gather on CP.""" 

61 if not _is_tensor_or_dtensor(value): 

62 return value 

63 if isinstance(value, DTensor): 

64 return value.redistribute(device_mesh, (Replicate(),)) 

65 return _gather_seq(value, device_mesh, seq_dim) 

66 

67 

68def _maybe_replace_arg(args: list, index: Optional[int], fn) -> None: 

69 """Apply ``fn`` to ``args[index]`` when the index points to an existing arg.""" 

70 if index is None or index >= len(args): 

71 return 

72 args[index] = fn(args[index]) 

73 

74 

75def _maybe_replace_kwarg(kwargs: dict, name: Optional[str], fn) -> None: 

76 """Apply ``fn`` to ``kwargs[name]`` when the key exists.""" 

77 if name is None or name not in kwargs: 

78 return 

79 kwargs[name] = fn(kwargs[name]) 

80 

81 

82def _validate_layout_and_mode(style_name: str, layout: str, mode: str) -> tuple[str, int]: 

83 """Return normalized layout and sequence dim for the DSA CP style.""" 

84 layout = layout.upper() 

85 if layout not in _SUPPORTED_LAYOUTS: 

86 raise ValueError(f"layout must be one of {_SUPPORTED_LAYOUTS}, but got {layout!r}.") 

87 if mode != "colossal": 

88 raise ValueError(f"{style_name} currently supports only mode='colossal'.") 

89 return layout, 1 if layout == "BSND" else 0 

90 

91 

92def _finalize_output(value: Any, use_local_output: bool) -> Any: 

93 """Convert direct DTensor outputs, or one-level tuple/list outputs, to local tensors.""" 

94 if isinstance(value, DTensor): 

95 return value.to_local() if use_local_output else value 

96 if isinstance(value, tuple): 

97 return tuple(v.to_local() if use_local_output and isinstance(v, DTensor) else v for v in value) 

98 if isinstance(value, list): 

99 return [v.to_local() if use_local_output and isinstance(v, DTensor) else v for v in value] 

100 return value 

101 

102 

103def _dtensor_has_partial(value: DTensor) -> bool: 

104 """Return whether ``value`` has any Partial placement.""" 

105 return any(placement.is_partial() for placement in value.placements) 

106 

107 

108def _dtensor_to_local_reducing_partial(value: Any) -> Any: 

109 """Convert a DTensor to local, reducing Partial only when communication is needed.""" 

110 if not isinstance(value, DTensor): 

111 return value 

112 if _dtensor_has_partial(value) and value.device_mesh.mesh.numel() > 1: 

113 value = value.reduce_partial() 

114 return value.to_local() 

115 

116 

117def _register_boundary_hooks(module: Module, pre_hook, use_local_output: bool) -> None: 

118 """Register a DSA boundary pre-hook and its public output conversion hook.""" 

119 platform.register_forward_pre_hook(module, pre_hook, with_kwargs=True) 

120 module.register_forward_hook( 

121 lambda _module, _args, outputs: _finalize_output(outputs, use_local_output) 

122 ) 

123 

124 

125def _configure_sparse_attention_boundary( # pylint: disable=too-many-arguments 

126 style, 

127 *, 

128 layout: str, 

129 mode: str, 

130 query_index: Optional[int], 

131 key_index: Optional[int], 

132 value_index: Optional[int], 

133 topk_index: Optional[int], 

134 query_kwarg_name: Optional[str], 

135 key_kwarg_name: Optional[str], 

136 value_kwarg_name: Optional[str], 

137 topk_kwarg_name: Optional[str], 

138 query_rope_index: Optional[int], 

139 key_rope_index: Optional[int], 

140 query_rope_kwarg_name: Optional[str], 

141 key_rope_kwarg_name: Optional[str], 

142 use_local_output: bool, 

143) -> None: 

144 """Store sparse-attention boundary configuration on ``style``.""" 

145 layout, seq_dim = _validate_layout_and_mode(style.__class__.__name__, layout, mode) 

146 style.layout = layout 

147 style.mode = mode 

148 style.seq_dim = seq_dim 

149 style.query_index = query_index 

150 style.key_index = key_index 

151 style.value_index = value_index 

152 style.topk_index = topk_index 

153 style.query_kwarg_name = query_kwarg_name 

154 style.key_kwarg_name = key_kwarg_name 

155 style.value_kwarg_name = value_kwarg_name 

156 style.topk_kwarg_name = topk_kwarg_name 

157 style.query_rope_index = query_rope_index 

158 style.key_rope_index = key_rope_index 

159 style.query_rope_kwarg_name = query_rope_kwarg_name 

160 style.key_rope_kwarg_name = key_rope_kwarg_name 

161 style.use_local_output = use_local_output 

162 

163 

164def _apply_sparse_attention_boundary(style, module: Module, device_mesh: DeviceMesh) -> Module: 

165 """Register low-level DSA sparse-attention boundary hooks for ``style``.""" 

166 cp_mesh = _ensure_1d(device_mesh) 

167 

168 def _shard_query_side(value: Any) -> Any: 

169 return _to_sequence_shard(value, cp_mesh, style.seq_dim) 

170 

171 def _replicate_key_side(value: Any) -> Any: 

172 return _to_sequence_replicate(value, cp_mesh, style.seq_dim) 

173 

174 def _pre_hook(hook_module, args, kwargs): 

175 del hook_module 

176 new_args = list(args) 

177 new_kwargs = dict(kwargs) 

178 _maybe_replace_arg(new_args, style.query_index, _shard_query_side) 

179 _maybe_replace_arg(new_args, style.key_index, _replicate_key_side) 

180 _maybe_replace_arg(new_args, style.value_index, _replicate_key_side) 

181 _maybe_replace_arg(new_args, style.topk_index, _shard_query_side) 

182 _maybe_replace_arg(new_args, style.query_rope_index, _shard_query_side) 

183 _maybe_replace_arg(new_args, style.key_rope_index, _replicate_key_side) 

184 _maybe_replace_kwarg(new_kwargs, style.query_kwarg_name, _shard_query_side) 

185 _maybe_replace_kwarg(new_kwargs, style.key_kwarg_name, _replicate_key_side) 

186 _maybe_replace_kwarg(new_kwargs, style.value_kwarg_name, _replicate_key_side) 

187 _maybe_replace_kwarg(new_kwargs, style.topk_kwarg_name, _shard_query_side) 

188 _maybe_replace_kwarg(new_kwargs, style.query_rope_kwarg_name, _shard_query_side) 

189 _maybe_replace_kwarg(new_kwargs, style.key_rope_kwarg_name, _replicate_key_side) 

190 return tuple(new_args), new_kwargs 

191 

192 _register_boundary_hooks(module, _pre_hook, style.use_local_output) 

193 return module 

194 

195 

196class DSAIndexerContextParallel(ParallelStyle): 

197 """Colossal-style CP hook for a DSA indexer boundary. 

198 

199 This style targets a hookable module/cell whose forward signature is shaped 

200 like ``(query, key, weights, ...)`` and rewrites only that boundary: 

201 

202 - ``query`` and ``weights`` are annotated as ``Shard(seq)``; 

203 - ``key`` is all-gathered to ``Replicate()``. 

204 """ 

205 

206 def __init__( # pylint: disable=too-many-arguments 

207 self, 

208 *, 

209 layout: str = "BSND", 

210 mode: str = "colossal", 

211 query_index: Optional[int] = 0, 

212 key_index: Optional[int] = 1, 

213 weights_index: Optional[int] = 2, 

214 query_kwarg_name: Optional[str] = None, 

215 key_kwarg_name: Optional[str] = None, 

216 weights_kwarg_name: Optional[str] = None, 

217 use_local_output: bool = True, 

218 ) -> None: 

219 super().__init__() 

220 layout, seq_dim = _validate_layout_and_mode(self.__class__.__name__, layout, mode) 

221 self.layout = layout 

222 self.mode = mode 

223 self.seq_dim = seq_dim 

224 self.query_index = query_index 

225 self.key_index = key_index 

226 self.weights_index = weights_index 

227 self.query_kwarg_name = query_kwarg_name 

228 self.key_kwarg_name = key_kwarg_name 

229 self.weights_kwarg_name = weights_kwarg_name 

230 self.use_local_output = use_local_output 

231 

232 def __repr__(self) -> str: 

233 return ( 

234 f"{self.__class__.__name__}(" 

235 f"layout={self.layout!r}, mode={self.mode!r}, " 

236 f"use_local_output={self.use_local_output})" 

237 ) 

238 

239 def _shard_query_side(self, value: Any, device_mesh: DeviceMesh) -> Any: 

240 return _to_sequence_shard(value, device_mesh, self.seq_dim) 

241 

242 def _replicate_key_side(self, value: Any, device_mesh: DeviceMesh) -> Any: 

243 return _to_sequence_replicate(value, device_mesh, self.seq_dim) 

244 

245 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

246 """Register DSA indexer CP hooks on ``module`` and return it.""" 

247 cp_mesh = _ensure_1d(device_mesh) 

248 

249 def _pre_hook(hook_module, args, kwargs): 

250 del hook_module 

251 new_args = list(args) 

252 new_kwargs = dict(kwargs) 

253 _maybe_replace_arg(new_args, self.query_index, lambda t: self._shard_query_side(t, cp_mesh)) 

254 _maybe_replace_arg(new_args, self.key_index, lambda t: self._replicate_key_side(t, cp_mesh)) 

255 _maybe_replace_arg(new_args, self.weights_index, lambda t: self._shard_query_side(t, cp_mesh)) 

256 _maybe_replace_kwarg(new_kwargs, self.query_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)) 

257 _maybe_replace_kwarg(new_kwargs, self.key_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh)) 

258 _maybe_replace_kwarg(new_kwargs, self.weights_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)) 

259 return tuple(new_args), new_kwargs 

260 

261 _register_boundary_hooks(module, _pre_hook, self.use_local_output) 

262 return module 

263 

264 

265class DSASparseAttentionContextParallel(ParallelStyle): 

266 """Colossal-style CP hook for a DSA sparse-attention boundary. 

267 

268 This style targets a hookable module/cell whose forward signature is shaped 

269 like ``(query, key, value, topk_indices, query_rope, key_rope, ...)`` and 

270 rewrites only that boundary: 

271 

272 - ``query``, ``topk_indices`` and ``query_rope`` are annotated as ``Shard(seq)``; 

273 - ``key``, ``value`` and ``key_rope`` are all-gathered to ``Replicate()``. 

274 """ 

275 

276 def __init__( # pylint: disable=too-many-arguments 

277 self, 

278 *, 

279 layout: str = "BSND", 

280 mode: str = "colossal", 

281 query_index: Optional[int] = 0, 

282 key_index: Optional[int] = 1, 

283 value_index: Optional[int] = 2, 

284 topk_index: Optional[int] = 3, 

285 query_kwarg_name: Optional[str] = None, 

286 key_kwarg_name: Optional[str] = None, 

287 value_kwarg_name: Optional[str] = None, 

288 topk_kwarg_name: Optional[str] = None, 

289 query_rope_index: Optional[int] = 4, 

290 key_rope_index: Optional[int] = 5, 

291 query_rope_kwarg_name: Optional[str] = "query_rope", 

292 key_rope_kwarg_name: Optional[str] = "key_rope", 

293 use_local_output: bool = True, 

294 ) -> None: 

295 super().__init__() 

296 _configure_sparse_attention_boundary( 

297 self, 

298 layout=layout, 

299 mode=mode, 

300 query_index=query_index, 

301 key_index=key_index, 

302 value_index=value_index, 

303 topk_index=topk_index, 

304 query_kwarg_name=query_kwarg_name, 

305 key_kwarg_name=key_kwarg_name, 

306 value_kwarg_name=value_kwarg_name, 

307 topk_kwarg_name=topk_kwarg_name, 

308 query_rope_index=query_rope_index, 

309 key_rope_index=key_rope_index, 

310 query_rope_kwarg_name=query_rope_kwarg_name, 

311 key_rope_kwarg_name=key_rope_kwarg_name, 

312 use_local_output=use_local_output, 

313 ) 

314 

315 def __repr__(self) -> str: 

316 return ( 

317 f"{self.__class__.__name__}(" 

318 f"layout={self.layout!r}, mode={self.mode!r}, " 

319 f"use_local_output={self.use_local_output})" 

320 ) 

321 

322 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

323 """Register DSA sparse-attention CP hooks on ``module`` and return it.""" 

324 return _apply_sparse_attention_boundary(self, module, device_mesh) 

325 

326 

327class DSAIndexerLossContextParallel(ParallelStyle): 

328 """Colossal-style CP hook for a DSA indexer-loss kernel boundary. 

329 

330 This style targets a hookable module/cell whose forward signature is shaped 

331 like ``(query, key, query_index, key_index, weights, topk_indices, 

332 softmax_max, softmax_sum, query_rope, key_rope, ...)``. The boundary is 

333 expected to start after MF has already done local-only bookkeeping such as 

334 ``stop_gradient`` and ``split``. 

335 

336 Placements: 

337 - ``query``, ``query_index``, ``weights``, ``topk_indices`` and 

338 ``query_rope`` are annotated as ``Shard(seq)``; 

339 - ``key``, ``key_index`` and ``key_rope`` are all-gathered to 

340 ``Replicate()``. 

341 """ 

342 

343 def __init__( # pylint: disable=too-many-arguments 

344 self, 

345 *, 

346 layout: str = "BSND", 

347 mode: str = "colossal", 

348 query_index: Optional[int] = 0, 

349 key_index: Optional[int] = 1, 

350 query_indexer_index: Optional[int] = 2, 

351 key_indexer_index: Optional[int] = 3, 

352 weights_index: Optional[int] = 4, 

353 topk_index: Optional[int] = 5, 

354 query_rope_index: Optional[int] = 8, 

355 key_rope_index: Optional[int] = 9, 

356 query_kwarg_name: Optional[str] = None, 

357 key_kwarg_name: Optional[str] = None, 

358 query_indexer_kwarg_name: Optional[str] = None, 

359 key_indexer_kwarg_name: Optional[str] = None, 

360 weights_kwarg_name: Optional[str] = None, 

361 topk_kwarg_name: Optional[str] = None, 

362 query_rope_kwarg_name: Optional[str] = None, 

363 key_rope_kwarg_name: Optional[str] = None, 

364 use_local_output: bool = True, 

365 ) -> None: 

366 super().__init__() 

367 layout, seq_dim = _validate_layout_and_mode(self.__class__.__name__, layout, mode) 

368 self.layout = layout 

369 self.mode = mode 

370 self.seq_dim = seq_dim 

371 self.query_index = query_index 

372 self.key_index = key_index 

373 self.query_indexer_index = query_indexer_index 

374 self.key_indexer_index = key_indexer_index 

375 self.weights_index = weights_index 

376 self.topk_index = topk_index 

377 self.query_rope_index = query_rope_index 

378 self.key_rope_index = key_rope_index 

379 self.query_kwarg_name = query_kwarg_name 

380 self.key_kwarg_name = key_kwarg_name 

381 self.query_indexer_kwarg_name = query_indexer_kwarg_name 

382 self.key_indexer_kwarg_name = key_indexer_kwarg_name 

383 self.weights_kwarg_name = weights_kwarg_name 

384 self.topk_kwarg_name = topk_kwarg_name 

385 self.query_rope_kwarg_name = query_rope_kwarg_name 

386 self.key_rope_kwarg_name = key_rope_kwarg_name 

387 self.use_local_output = use_local_output 

388 

389 def __repr__(self) -> str: 

390 return ( 

391 f"{self.__class__.__name__}(" 

392 f"layout={self.layout!r}, mode={self.mode!r}, " 

393 f"use_local_output={self.use_local_output})" 

394 ) 

395 

396 def _shard_query_side(self, value: Any, device_mesh: DeviceMesh) -> Any: 

397 return _to_sequence_shard(value, device_mesh, self.seq_dim) 

398 

399 def _replicate_key_side(self, value: Any, device_mesh: DeviceMesh) -> Any: 

400 return _to_sequence_replicate(value, device_mesh, self.seq_dim) 

401 

402 @staticmethod 

403 def _local_shape(value: Any) -> Optional[tuple]: 

404 if isinstance(value, DTensor): 

405 return value.local_shape 

406 if platform.is_tensor(value): 

407 return value.shape 

408 return None 

409 

410 def _slice_local_key_grad(self, value: Any, module: Module) -> Any: 

411 """Convert d_key_index to the original local key-index shard shape.""" 

412 if not self.use_local_output: 

413 return value 

414 target_shape = getattr(module, "_hp_dsa_loss_key_index_local_shape", None) 

415 if target_shape is None: 

416 return _finalize_output(value, use_local_output=True) 

417 

418 if isinstance(value, DTensor): 

419 value = _dtensor_to_local_reducing_partial(value) 

420 if not platform.is_tensor(value): 

421 return value 

422 

423 target_len = target_shape[self.seq_dim] 

424 if value.shape[self.seq_dim] == target_len: 

425 return value 

426 

427 local_idx = getattr(module, "_hp_dsa_loss_local_idx", 0) 

428 start = local_idx * target_len 

429 return value.narrow(self.seq_dim, start, target_len) 

430 

431 def _process_outputs(self, module: Module, outputs: Any) -> Any: 

432 """Finalize indexer-loss outputs, reducing Partial values when needed.""" 

433 if not self.use_local_output: 

434 return outputs 

435 if not isinstance(outputs, (tuple, list)) or len(outputs) < 4: 

436 return _finalize_output(outputs, use_local_output=True) 

437 

438 processed = list(outputs) 

439 processed[0] = _dtensor_to_local_reducing_partial(processed[0]) 

440 processed[1] = self._slice_local_key_grad(processed[1], module) 

441 processed[2] = _dtensor_to_local_reducing_partial(processed[2]) 

442 processed[3] = _dtensor_to_local_reducing_partial(processed[3]) 

443 return type(outputs)(processed) 

444 

445 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

446 """Register DSA indexer-loss CP hooks on ``module`` and return it.""" 

447 cp_mesh = _ensure_1d(device_mesh) 

448 rank_list = list(cp_mesh.rank_list) 

449 local_idx = rank_list.index(platform.get_rank()) if platform.get_rank() in rank_list else 0 

450 

451 def _pre_hook(hook_module, args, kwargs): 

452 new_args = list(args) 

453 new_kwargs = dict(kwargs) 

454 key_shape = None 

455 if self.key_indexer_index is not None and self.key_indexer_index < len(new_args): 

456 key_shape = self._local_shape(new_args[self.key_indexer_index]) 

457 elif self.key_indexer_kwarg_name and self.key_indexer_kwarg_name in new_kwargs: 

458 key_shape = self._local_shape(new_kwargs[self.key_indexer_kwarg_name]) 

459 setattr(hook_module, "_hp_dsa_loss_key_index_local_shape", key_shape) 

460 setattr(hook_module, "_hp_dsa_loss_local_idx", local_idx) 

461 

462 _maybe_replace_arg(new_args, self.query_index, lambda t: self._shard_query_side(t, cp_mesh)) 

463 _maybe_replace_arg(new_args, self.key_index, lambda t: self._replicate_key_side(t, cp_mesh)) 

464 _maybe_replace_arg(new_args, self.query_indexer_index, lambda t: self._shard_query_side(t, cp_mesh)) 

465 _maybe_replace_arg(new_args, self.key_indexer_index, lambda t: self._replicate_key_side(t, cp_mesh)) 

466 _maybe_replace_arg(new_args, self.weights_index, lambda t: self._shard_query_side(t, cp_mesh)) 

467 _maybe_replace_arg(new_args, self.topk_index, lambda t: self._shard_query_side(t, cp_mesh)) 

468 _maybe_replace_arg(new_args, self.query_rope_index, lambda t: self._shard_query_side(t, cp_mesh)) 

469 _maybe_replace_arg(new_args, self.key_rope_index, lambda t: self._replicate_key_side(t, cp_mesh)) 

470 _maybe_replace_kwarg(new_kwargs, self.query_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)) 

471 _maybe_replace_kwarg(new_kwargs, self.key_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh)) 

472 _maybe_replace_kwarg( 

473 new_kwargs, self.query_indexer_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh) 

474 ) 

475 _maybe_replace_kwarg( 

476 new_kwargs, self.key_indexer_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh) 

477 ) 

478 _maybe_replace_kwarg(new_kwargs, self.weights_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)) 

479 _maybe_replace_kwarg(new_kwargs, self.topk_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh)) 

480 _maybe_replace_kwarg( 

481 new_kwargs, self.query_rope_kwarg_name, lambda t: self._shard_query_side(t, cp_mesh) 

482 ) 

483 _maybe_replace_kwarg( 

484 new_kwargs, self.key_rope_kwarg_name, lambda t: self._replicate_key_side(t, cp_mesh) 

485 ) 

486 return tuple(new_args), new_kwargs 

487 

488 platform.register_forward_pre_hook(module, _pre_hook, with_kwargs=True) 

489 module.register_forward_hook(lambda _module, _args, outputs: self._process_outputs(_module, outputs)) 

490 return module 

491 

492 

493class DSAContextParallel(ParallelStyle): 

494 """Colossal-style CP hook for a low-level DSA attention boundary. 

495 

496 This compatibility style intentionally handles only a direct boundary whose 

497 forward signature is shaped like 

498 ``(query, key, value, topk_indices, query_rope, key_rope, ...)``. Callers 

499 that own a higher-level attention module should locate its indexer and 

500 sparse-attention submodules themselves and apply 

501 :class:`DSAIndexerContextParallel` and 

502 :class:`DSASparseAttentionContextParallel` explicitly. 

503 """ 

504 

505 def __init__( # pylint: disable=too-many-arguments 

506 self, 

507 *, 

508 layout: str = "BSND", 

509 mode: str = "colossal", 

510 query_index: Optional[int] = 0, 

511 key_index: Optional[int] = 1, 

512 value_index: Optional[int] = 2, 

513 topk_index: Optional[int] = 3, 

514 query_kwarg_name: Optional[str] = None, 

515 key_kwarg_name: Optional[str] = None, 

516 value_kwarg_name: Optional[str] = None, 

517 topk_kwarg_name: Optional[str] = None, 

518 query_rope_index: Optional[int] = 4, 

519 key_rope_index: Optional[int] = 5, 

520 query_rope_kwarg_name: Optional[str] = "query_rope", 

521 key_rope_kwarg_name: Optional[str] = "key_rope", 

522 use_local_output: bool = True, 

523 ) -> None: 

524 super().__init__() 

525 _configure_sparse_attention_boundary( 

526 self, 

527 layout=layout, 

528 mode=mode, 

529 query_index=query_index, 

530 key_index=key_index, 

531 value_index=value_index, 

532 topk_index=topk_index, 

533 query_kwarg_name=query_kwarg_name, 

534 key_kwarg_name=key_kwarg_name, 

535 value_kwarg_name=value_kwarg_name, 

536 topk_kwarg_name=topk_kwarg_name, 

537 query_rope_index=query_rope_index, 

538 key_rope_index=key_rope_index, 

539 query_rope_kwarg_name=query_rope_kwarg_name, 

540 key_rope_kwarg_name=key_rope_kwarg_name, 

541 use_local_output=use_local_output, 

542 ) 

543 

544 def __repr__(self) -> str: 

545 return ( 

546 f"{self.__class__.__name__}(" 

547 f"layout={self.layout!r}, mode={self.mode!r}, " 

548 f"use_local_output={self.use_local_output})" 

549 ) 

550 

551 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

552 """Register low-level DSA attention CP hooks on ``module`` and return it.""" 

553 return _apply_sparse_attention_boundary(self, module, device_mesh)