Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / integration / llamafactory / utils.py: 0%

313 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Accelerate-style FSDP2 utilities backed by HyperParallel fully_shard.""" 

16import copy 

17import functools 

18import re 

19import warnings 

20from collections.abc import Iterable 

21from typing import cast 

22 

23import torch 

24import torch.distributed as dist 

25from torch import nn 

26from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy 

27 

28from hyper_parallel import init_device_mesh 

29from hyper_parallel.core.dtensor.dtensor import DTensor, distribute_tensor 

30from hyper_parallel.core.fully_shard.api import HSDPModule, fully_shard 

31from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy 

32from hyper_parallel.platform import get_platform 

33 

34_DTYPE_MAP = { 

35 "float32": torch.float32, 

36 "fp32": torch.float32, 

37 "float16": torch.float16, 

38 "fp16": torch.float16, 

39 "bfloat16": torch.bfloat16, 

40 "bf16": torch.bfloat16, 

41} 

42 

43 

44def _resolve_device_type(hp_args) -> str: 

45 """Resolve the runtime device type for HyperParallel wrapping.""" 

46 if hp_args.device_type != "auto": 

47 return hp_args.device_type 

48 if hasattr(torch, "npu") and torch.npu.is_available(): # pylint: disable=no-member 

49 return "npu" 

50 if torch.cuda.is_available(): 

51 return "cuda" 

52 return "cpu" 

53 

54 

55def _build_device_mesh(accelerator, hp_args): 

56 """Build an FSDP mesh compatible with Accelerate's FSDP2 expectations.""" 

57 mesh = getattr(accelerator, "torch_device_mesh", None) 

58 if mesh is not None: 

59 fsdp_dim_names = getattr(getattr(accelerator, "parallelism_config", None), "fsdp_dim_names", None) 

60 if fsdp_dim_names: 

61 return mesh[tuple(fsdp_dim_names)] 

62 return mesh 

63 

64 device_type = _resolve_device_type(hp_args) 

65 world_size = get_platform().get_world_size() 

66 return init_device_mesh(device_type, (world_size,), mesh_dim_names=("dp",)) 

67 

68 

69def _build_mp_policy(hp_args) -> MixedPrecisionPolicy: 

70 """Build HyperParallel mixed precision policy.""" 

71 return MixedPrecisionPolicy( 

72 param_dtype=_DTYPE_MAP[hp_args.param_dtype] if hp_args.param_dtype is not None else None, 

73 reduce_dtype=_DTYPE_MAP[hp_args.reduce_dtype] if hp_args.reduce_dtype is not None else None, 

74 output_dtype=_DTYPE_MAP[hp_args.param_dtype] if hp_args.param_dtype is not None else None, 

75 cast_forward_inputs=True, 

76 ) 

77 

78 

79def _resolve_offload_policy(fsdp2_plugin) -> OffloadPolicy: 

80 """Translate Accelerate cpu_offload config to HyperParallel offload policy.""" 

81 cpu_offload = getattr(fsdp2_plugin, "cpu_offload", None) 

82 if isinstance(cpu_offload, OffloadPolicy): 

83 return cpu_offload 

84 if cpu_offload is True: 

85 return CPUOffloadPolicy() 

86 if type(cpu_offload).__name__ == "CPUOffloadPolicy": 

87 return CPUOffloadPolicy() 

88 return OffloadPolicy() 

89 

90 

91def _is_cpu_offload_enabled(cpu_offload) -> bool: 

92 """Return whether CPU offload is truly enabled.""" 

93 if cpu_offload is True: 

94 return True 

95 if isinstance(cpu_offload, CPUOffloadPolicy): 

96 return True 

97 return type(cpu_offload).__name__ == "CPUOffloadPolicy" 

98 

99 

100def _resolve_mp_policy(fsdp2_plugin, hp_args) -> MixedPrecisionPolicy: 

101 """Resolve mixed precision with Accelerate defaults and optional HyperParallel overrides.""" 

102 policy = getattr(fsdp2_plugin, "mixed_precision_policy", None) 

103 resolved_policy = MixedPrecisionPolicy() 

104 if policy is not None: 

105 resolved_policy = MixedPrecisionPolicy( 

106 param_dtype=getattr(policy, "param_dtype", None), 

107 reduce_dtype=getattr(policy, "reduce_dtype", None), 

108 output_dtype=getattr(policy, "output_dtype", None), 

109 cast_forward_inputs=getattr(policy, "cast_forward_inputs", True), 

110 ) 

111 

112 hp_policy = _build_mp_policy(hp_args) 

113 if hp_args.param_dtype is not None: 

114 resolved_policy.param_dtype = hp_policy.param_dtype 

115 resolved_policy.output_dtype = hp_policy.output_dtype 

116 if hp_args.reduce_dtype is not None: 

117 resolved_policy.reduce_dtype = hp_policy.reduce_dtype 

118 return resolved_policy 

119 

120 

121def _is_compiled_module(model: nn.Module) -> bool: 

122 """Best-effort check for compiled modules.""" 

123 return hasattr(model, "_orig_mod") 

124 

125 

126def _get_module_children_bottom_up(model: nn.Module, return_fqns: bool = False): 

127 """Return model children bottom-up, matching Accelerate helper semantics.""" 

128 modules = [] 

129 

130 def _visit(module: nn.Module, prefix: str = ""): 

131 for child_name, child in module.named_children(): 

132 child_prefix = f"{prefix}.{child_name}" if prefix else child_name 

133 _visit(child, child_prefix) 

134 modules.append((prefix, module) if return_fqns else module) 

135 

136 _visit(model) 

137 return modules 

138 

139 

140def _get_non_persistent_buffers(model: nn.Module, recurse: bool = True, fqns: bool = True): 

141 """Collect non-persistent buffers.""" 

142 buffers = set() 

143 for module_name, module in model.named_modules(): 

144 if not recurse and module is not model: 

145 continue 

146 for buffer_name in getattr(module, "_non_persistent_buffers_set", set()): 

147 if fqns and module_name: 

148 buffers.add(f"{module_name}.{buffer_name}") 

149 else: 

150 buffers.add(buffer_name) 

151 return buffers 

152 

153 

154def _get_module_class_from_name(module: nn.Module, class_name: str): 

155 """Find a module class by name from the model tree.""" 

156 for child in module.modules(): 

157 if child.__class__.__name__ == class_name: 

158 return child.__class__ 

159 return None 

160 

161 

162def _move_model_to_meta(model: nn.Module) -> nn.Module: 

163 """Move the model to meta before fully_shard to match Accelerate FSDP2 loading order.""" 

164 model = model.to(torch.device("meta")) 

165 if hasattr(model, "tie_weights"): 

166 model.tie_weights() 

167 return model 

168 

169 

170 

171def _get_parameters_from_modules(modules: Iterable[nn.Module] | str, model: nn.Module, device) -> set[nn.Parameter]: 

172 """Convert ignored modules to ignored parameters, matching Accelerate behaviour.""" 

173 if modules is None: 

174 return set() 

175 

176 parameters = [] 

177 if isinstance(modules, str): 

178 pattern = re.compile(modules) 

179 matched_modules = [] 

180 for name, module in model.named_modules(): 

181 if pattern.fullmatch(name): 

182 module.to(device) 

183 matched_modules.append(module) 

184 modules = matched_modules 

185 

186 for module in modules: 

187 parameters.extend(list(module.parameters())) 

188 return set(parameters) 

189 

190 

191def _prepare_auto_wrap_policy(fsdp2_plugin, model: nn.Module): 

192 """Prepare auto-wrap policy, copied from Accelerate FSDP2 logic.""" 

193 fn = fsdp2_plugin.auto_wrap_policy 

194 if isinstance(fn, functools.partial): 

195 fn = fn.func 

196 

197 if fn is transformer_auto_wrap_policy: 

198 no_split_modules = getattr(model, "_no_split_modules", None) or [] 

199 transformer_cls_names_to_wrap = list(no_split_modules) 

200 if fsdp2_plugin.transformer_cls_names_to_wrap is not None: 

201 transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap 

202 transformer_cls_to_wrap = set() 

203 

204 for layer_class in transformer_cls_names_to_wrap: 

205 transformer_cls = _get_module_class_from_name(model, layer_class) 

206 if transformer_cls is None: 

207 raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") 

208 transformer_cls_to_wrap.add(transformer_cls) 

209 

210 def policy(module: nn.Module) -> bool: 

211 if fsdp2_plugin.transformer_cls_names_to_wrap is None: 

212 return False 

213 return isinstance(module, tuple(transformer_cls_to_wrap)) 

214 

215 elif fn is size_based_auto_wrap_policy: 

216 

217 def policy(module: nn.Module) -> bool: 

218 return sum(param.numel() for param in module.parameters()) > fsdp2_plugin.min_num_params 

219 

220 else: 

221 return None 

222 

223 return policy 

224 

225 

226def fsdp2_load_full_state_dict(accelerator, model: nn.Module, full_sd: dict, cpu_offload: bool = False): 

227 """Load full state dict into a HyperParallel-sharded model following Accelerate semantics.""" 

228 meta_sharded_sd = model.state_dict() 

229 local_sd = {} 

230 

231 def _infer_parameter_dtype(target_model: nn.Module, param_name: str, empty_param: torch.Tensor): 

232 try: 

233 old_param = target_model.get_parameter(param_name) 

234 except Exception: # pylint: disable=broad-except 

235 old_param = None 

236 if old_param is None: 

237 try: 

238 old_param = target_model.get_buffer(param_name) 

239 except Exception: # pylint: disable=broad-except 

240 old_param = None 

241 if old_param is None: 

242 base_name, local_name = param_name.rsplit(".", 1) 

243 old_param = getattr(target_model.get_submodule(base_name), local_name) 

244 

245 is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") 

246 casting_dtype = None 

247 is_param_float8 = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn 

248 if empty_param.dtype.is_floating_point and not is_param_float8: 

249 casting_dtype = old_param.dtype 

250 if isinstance(old_param, DTensor): 

251 local_param = old_param.to_local() 

252 return local_param is not None and local_param.is_contiguous(), casting_dtype 

253 return old_param is not None and old_param.is_contiguous(), casting_dtype 

254 

255 def _cast_and_contiguous(tensor: torch.Tensor, to_contiguous: bool, dtype): 

256 if isinstance(tensor, DTensor): 

257 local_tensor = tensor.to_local() 

258 if dtype is not None: 

259 local_tensor = local_tensor.to(dtype=dtype) 

260 if to_contiguous: 

261 local_tensor = local_tensor.contiguous() 

262 return DTensor.from_local(local_tensor, tensor.device_mesh, tensor.placements) 

263 if dtype is not None: 

264 tensor = tensor.to(dtype=dtype) 

265 if to_contiguous: 

266 tensor = tensor.contiguous() 

267 return tensor 

268 

269 if accelerator.is_main_process: 

270 iterable = full_sd.items() 

271 else: 

272 iterable = meta_sharded_sd.items() 

273 

274 for item in iterable: 

275 if accelerator.is_main_process: 

276 param_name, full_param = item 

277 sharded_param = meta_sharded_sd[param_name] 

278 else: 

279 param_name, sharded_param = item 

280 full_param = torch.empty(sharded_param.size(), device=accelerator.device, dtype=sharded_param.dtype) 

281 

282 if isinstance(full_param, DTensor): 

283 full_param = full_param.to_local() 

284 

285 full_param = full_param.detach().to(accelerator.device) 

286 dist.broadcast(full_param, src=0, group=dist.group.WORLD) 

287 

288 if isinstance(sharded_param, DTensor): 

289 local_param = distribute_tensor(full_param, sharded_param.device_mesh, sharded_param.placements).to_local() 

290 else: 

291 local_param = full_param 

292 

293 to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, local_param) 

294 local_param = _cast_and_contiguous(local_param, to_contiguous, casting_dtype) 

295 if isinstance(local_param, DTensor): 

296 local_param = local_param.to_local() 

297 local_param = local_param.detach().clone() 

298 if not local_param.is_contiguous(): 

299 local_param = local_param.contiguous() 

300 if cpu_offload: 

301 local_param = local_param.to("cpu") 

302 

303 local_sd[param_name] = local_param 

304 

305 cast(nn.Module, model).load_state_dict(local_sd, assign=True) 

306 return model 

307 

308 

309def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: nn.Module): 

310 """Prepare auto-wrap policy, matching Accelerate helper naming and behavior.""" 

311 return _prepare_auto_wrap_policy(fsdp2_plugin, model) 

312 

313 

314def get_parameters_from_modules(modules: Iterable[nn.Module] | str, model: nn.Module, device) -> set[nn.Parameter]: 

315 """Convert ignored modules to ignored parameters.""" 

316 return _get_parameters_from_modules(modules, model, device) 

317 

318 

319def _is_fsdp2_wrapped_model(model: nn.Module) -> bool: 

320 """Return whether the model is already wrapped by HyperParallel FSDP2.""" 

321 return isinstance(model, HSDPModule) or ( 

322 _is_compiled_module(model) and isinstance(model._orig_mod, HSDPModule) # pylint: disable=protected-access 

323 ) 

324 

325 

326def _resolve_shard_size(mesh) -> int: 

327 """Return the FSDP shard-dim size for a 1D FSDP or 2D HSDP mesh. 

328 

329 HP ``fully_shard`` builds ``FSDPMeshInfo(shard_mesh_dim=0)`` for a 1D mesh 

330 and ``HSDPMeshInfo(shard_mesh_dim=1, replicate_mesh_dim=0)`` for a 2D mesh 

331 (see ``platform/*/fully_shard/scheduler.py``). In both cases the shard 

332 dim is the last mesh dim, so ``mesh.mesh_shape[-1]`` gives the actual 

333 per-param shard count regardless of HSDP layout. 

334 """ 

335 if mesh is None: 

336 return get_platform().get_world_size() 

337 shape = getattr(mesh, "mesh_shape", None) 

338 if shape: 

339 return int(shape[-1]) 

340 return mesh.size() if hasattr(mesh, "size") else get_platform().get_world_size() 

341 

342 

343def _collect_replicate_params(model: nn.Module, shard_size: int) -> set: 

344 """Collect params whose dim-0 isn't divisible by ``shard_size``. 

345 

346 HP ``fully_shard`` raises ``Uneven sharding on dim 0`` for such params 

347 (e.g. ``shared_expert_gate.weight`` of shape ``(1, hidden)`` on 

348 ``shard_size > 1``). Routing them through ``replicate_params`` makes 

349 them DDP-replicated along the shard dim instead. 

350 """ 

351 replicate = set() 

352 if shard_size <= 1: 

353 return replicate 

354 for _, param in model.named_parameters(): 

355 if param.dim() == 0: 

356 continue 

357 if param.size(0) % shard_size != 0: 

358 replicate.add(param) 

359 return replicate 

360 

361 

362def _build_fsdp2_kwargs(accelerator, model: nn.Module, hp_args, fsdp2_plugin) -> dict: 

363 """Build fully_shard kwargs from accelerator and plugin settings.""" 

364 mesh = _build_device_mesh(accelerator, hp_args) 

365 reshard_after_forward = fsdp2_plugin.reshard_after_forward 

366 if hp_args.reshard_after_forward is not None: 

367 reshard_after_forward = hp_args.reshard_after_forward 

368 kwargs = { 

369 "reshard_after_forward": reshard_after_forward, 

370 "offload_policy": _resolve_offload_policy(fsdp2_plugin), 

371 "mp_policy": _resolve_mp_policy(fsdp2_plugin, hp_args), 

372 "mesh": mesh if mesh is not None else None, 

373 "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device), 

374 "comm_fusion": True, 

375 } 

376 replicate_params = _collect_replicate_params(model, _resolve_shard_size(mesh)) 

377 if replicate_params: 

378 kwargs["replicate_params"] = replicate_params 

379 return kwargs 

380 

381 

382def _model_has_4bit_params(model: nn.Module) -> bool: 

383 """Return whether the model contains bitsandbytes 4-bit parameters.""" 

384 return any(param.__class__.__name__ == "Params4bit" for _, param in model.named_parameters()) 

385 

386 

387def _prepare_cpu_ram_efficient_loading(model: nn.Module, enabled: bool) -> dict[str, torch.Tensor]: 

388 """Capture non-persistent buffers before cpu_ram_efficient_loading rematerializes the model.""" 

389 if not enabled: 

390 return {} 

391 

392 non_persistent_buffer_fqns = _get_non_persistent_buffers(model, recurse=True, fqns=True) 

393 original_non_persistent_buffers = copy.deepcopy( 

394 {name: buffer for name, buffer in model.named_buffers() if name in non_persistent_buffer_fqns} 

395 ) 

396 return original_non_persistent_buffers 

397 

398 

399def _apply_auto_wrap_policy(model: nn.Module, fsdp2_plugin, fsdp2_kwargs: dict) -> None: 

400 """Apply fully_shard to matching child modules before wrapping the root module.""" 

401 auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) 

402 if auto_wrap_policy_func is None: 

403 return 

404 

405 for module in _get_module_children_bottom_up(model)[:-1]: 

406 if auto_wrap_policy_func(module) and not isinstance(module, HSDPModule): 

407 fully_shard(module, **fsdp2_kwargs) 

408 

409 

410def _setup_prefetch(model: nn.Module) -> None: 

411 """Set up forward and backward prefetch for HSDP-wrapped child modules. 

412 

413 Each wrapped layer prefetches the next layer's allgather during forward, 

414 and the previous layer's allgather during backward, to overlap communication 

415 with computation. 

416 

417 Backward prefetch uses reversed module order because backward execution 

418 proceeds from the last layer to the first. 

419 """ 

420 wrapped_modules = [m for m in model.modules() if isinstance(m, HSDPModule) and m is not model] 

421 num_to_forward_prefetch = 1 

422 num_to_backward_prefetch = 1 

423 

424 # Forward prefetch: each layer prefetches the next layer(s) 

425 for i, layer in enumerate(wrapped_modules): 

426 j_end = min(len(wrapped_modules), i + 1 + num_to_forward_prefetch) 

427 forward_targets = wrapped_modules[i + 1:j_end] 

428 if forward_targets: 

429 layer.set_modules_to_forward_prefetch(forward_targets) 

430 

431 # Backward prefetch: reverse order since backward runs last-to-first 

432 wrapped_modules.reverse() 

433 for i, layer in enumerate(wrapped_modules): 

434 j_end = min(len(wrapped_modules), i + 1 + num_to_backward_prefetch) 

435 backward_targets = wrapped_modules[i + 1:j_end] 

436 if backward_targets: 

437 layer.set_modules_to_backward_prefetch(backward_targets) 

438 

439 

440def _restore_non_persistent_buffers(model: nn.Module, buffers: dict[str, torch.Tensor], device) -> None: 

441 """Restore non-persistent buffers after cpu_ram_efficient_loading finishes.""" 

442 if not buffers: 

443 return 

444 

445 for fqn, buffer_tensor in buffers.items(): 

446 buffer_tensor = buffer_tensor.to(device) 

447 if "." in fqn: 

448 parent_fqn, local_buffer_name = fqn.rsplit(".", 1) 

449 parent_module = model.get_submodule(parent_fqn) 

450 else: 

451 local_buffer_name = fqn 

452 parent_module = model 

453 parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False) 

454 

455 if hasattr(model, "tie_weights"): 

456 model.tie_weights() 

457 

458 

459def _maybe_upcast_trainable_params(accelerator, model: nn.Module) -> None: 

460 """Upcast model parameters to fp32 when mixed precision requires Accelerate-compatible behavior. 

461 

462 ``model.to(torch.float32)`` creates new fp32 parameters in the module tree. 

463 Refresh HSDP's cached sharded parameter references and mixed-precision dtypes 

464 so comm_fusion uses the new fp32 parameter dtype as well. 

465 """ 

466 model_dtype = getattr(model, "dtype", None) 

467 should_upcast = accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32) 

468 if not should_upcast: 

469 return 

470 

471 model.to(torch.float32) 

472 

473 for module in model.modules(): 

474 if isinstance(module, HSDPModule): 

475 state = module.hsdp_scheduler.hsdp_state # pylint: disable=protected-access 

476 for hsdp_param in state.hsdp_params: 

477 if hsdp_param.is_sharded: 

478 hsdp_param.reset_sharded_param() 

479 param_group = getattr(state, "param_group", None) 

480 if param_group is not None: 

481 param_group._init_mp_dtypes() # pylint: disable=protected-access 

482 

483 if accelerator.is_main_process: 

484 warnings.warn( 

485 "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') " 

486 "may affect the precision of model checkpoints." 

487 ) 

488 

489 

490 

491def fsdp2_prepare_model(accelerator, model: nn.Module, hp_args) -> nn.Module: 

492 """ 

493 Prepare model following Accelerate FSDP2 flow, using HyperParallel fully_shard. 

494 

495 This function is designed to be called with the runtime `accelerator` 

496 instance already created by `transformers.Trainer` / `accelerate`. 

497 

498 Required accelerator attributes: 

499 state.fsdp_plugin: FSDP plugin configuration used to derive wrapping and 

500 state-dict behaviour. 

501 torch_device_mesh: Optional device mesh prepared by Accelerate. 

502 parallelism_config.fsdp_dim_names: Optional FSDP mesh dimension names 

503 used when `torch_device_mesh` is available. 

504 device: Current process device, used for ignored module parameter 

505 materialization and buffer restoration. 

506 is_main_process: Whether the current rank is the main process during 

507 full state-dict distribution. 

508 mixed_precision: Mixed precision mode string, used for the final 

509 parameter upcast behavior. 

510 """ 

511 if _is_fsdp2_wrapped_model(model): 

512 return model 

513 

514 fsdp2_plugin = accelerator.state.fsdp_plugin 

515 fsdp2_plugin.set_auto_wrap_policy(model) 

516 

517 model_has_params4bit = _model_has_4bit_params(model) 

518 original_sd = model.state_dict() 

519 should_restore_non_persistent_buffers = fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit 

520 original_non_persistent_buffers = _prepare_cpu_ram_efficient_loading(model, should_restore_non_persistent_buffers) 

521 if should_restore_non_persistent_buffers: 

522 model = _move_model_to_meta(model) 

523 

524 fsdp2_kwargs = _build_fsdp2_kwargs(accelerator, model, hp_args, fsdp2_plugin) 

525 

526 _apply_auto_wrap_policy(model, fsdp2_plugin, fsdp2_kwargs) 

527 if not isinstance(model, HSDPModule): 

528 fully_shard(model, **fsdp2_kwargs) 

529 

530 _setup_prefetch(model) 

531 

532 if fsdp2_plugin.cpu_ram_efficient_loading: 

533 fsdp2_load_full_state_dict( 

534 accelerator, 

535 model, 

536 original_sd, 

537 cpu_offload=_is_cpu_offload_enabled(fsdp2_plugin.cpu_offload), 

538 ) 

539 

540 _restore_non_persistent_buffers(model, original_non_persistent_buffers, accelerator.device) 

541 _maybe_upcast_trainable_params(accelerator, model) 

542 return model