Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / pipeline_parallel / backward.py: 9%

350 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""parallel grad helper (dx/dw split)""" 

16from __future__ import absolute_import 

17from collections import deque, defaultdict 

18import warnings 

19import logging 

20from mindspore.utils._pytree import tree_flatten, tree_leaves, tree_unflatten 

21from mindspore.common.api import _pynative_executor, _GradientEdge 

22from mindspore._c_expression import run_backward 

23from mindspore.common.tensor import Tensor 

24from mindspore import ops 

25 

26 

27def _fill_grads(output_tensor): 

28 """Fill gradients with ones for given tensor(s).""" 

29 if isinstance(output_tensor, Tensor): 

30 return ops.ones_like(output_tensor) 

31 if isinstance(output_tensor, (list, tuple)): 

32 return tuple(ops.ones_like(tensor) for tensor in output_tensor) 

33 return None 

34 

35 

36def _validate_grad_config(grad_position, weights): 

37 """Validate gradient configuration.""" 

38 if grad_position is None and not weights: 

39 raise ValueError("grad_position and weights cannot both be None!") 

40 

41 

42def _set_requires_grad(inputs, kwargs, grad_position): 

43 """Set requires_grad flag for specified inputs.""" 

44 if grad_position is None: 

45 return 

46 if grad_position == -1: 

47 flatten_inputs, _ = tree_flatten(inputs, tensors_only_leaf=True) 

48 for inp in flatten_inputs: 

49 inp._requires_grad = True # pylint: disable=protected-access 

50 flatten_kwargs, _ = tree_flatten(kwargs, tensors_only_leaf=True) 

51 for kwarg in flatten_kwargs: 

52 kwarg._requires_grad = True # pylint: disable=protected-access 

53 return 

54 if isinstance(grad_position, int) and isinstance(inputs[grad_position], Tensor): 

55 inputs[grad_position]._requires_grad = True # pylint: disable=protected-access 

56 return 

57 if isinstance(grad_position, tuple): 

58 for idx in grad_position: 

59 if isinstance(inputs[idx], Tensor): 

60 inputs[idx]._requires_grad = True # pylint: disable=protected-access 

61 

62 

63def _get_grad_node(tensor): 

64 """Get the grad_fn or grad accumulator for a tensor.""" 

65 return tensor._grad_node # pylint: disable=protected-access 

66 

67 

68def _get_node_id(node): 

69 """Return stable integer id for BackwardNode (C++ `_unique_id`).""" 

70 return int(node._unique_id()) # pylint: disable=protected-access 

71 

72 

73def _accumulate_grads(target_tensors, grads): 

74 """Accumulate returned grads onto leaf tensors' internal grad storage.""" 

75 if grads is None: 

76 return 

77 for tensor, grad in zip(target_tensors, grads): 

78 if tensor is None or grad is None: 

79 continue 

80 current_grad = getattr(tensor, "_grad", None) 

81 if current_grad is None: 

82 tensor._grad = grad # pylint: disable=protected-access 

83 else: 

84 current_grad += grad 

85 

86 

87def _compute_nodes_out_degree(output_grad_fns): 

88 """Build reverse edges: child_id -> [(parent_node, slot_index)].""" 

89 

90 queue = deque() 

91 visited_roots = set() 

92 # child_node_id -> list[(parent_node, slot_index_in_parent_next_functions)] 

93 backward_edges = defaultdict(list) 

94 

95 # Seed from output grad_fns. 

96 for node in output_grad_fns: 

97 if node is not None: 

98 node_id = _get_node_id(node) 

99 if node_id not in visited_roots: 

100 queue.append(node) 

101 visited_roots.add(node_id) 

102 

103 # Follow next_functions (towards inputs) while recording reverse edges. 

104 while queue: 

105 node = queue.popleft() 

106 next_fns = node.next_functions 

107 for slot, (child_fn, _) in enumerate(next_fns): 

108 if child_fn is not None: 

109 child_id = _get_node_id(child_fn) 

110 if len(backward_edges[child_id]) == 0: 

111 queue.append(child_fn) 

112 backward_edges[child_id].append((node, slot)) 

113 

114 return backward_edges 

115 

116 

117def _compute_reachable_nodes(roots, boundary_nodes, backward_graph): 

118 """BFS on reverse edges; stop at boundary_nodes and return (reachable_ids, boundary_hits).""" 

119 reachable = set() 

120 boundary_hits = set() 

121 queue = deque() 

122 

123 for node in roots: 

124 if node is not None: 

125 node_id = _get_node_id(node) 

126 if node_id not in reachable: 

127 reachable.add(node_id) 

128 queue.append(node) 

129 

130 while queue: 

131 node = queue.popleft() 

132 node_id = _get_node_id(node) 

133 parent_entries = backward_graph.get(node_id, []) 

134 

135 for parent_fn, _ in parent_entries: 

136 if parent_fn is None: 

137 continue 

138 parent_id = _get_node_id(parent_fn) 

139 if parent_id in reachable: 

140 continue 

141 if parent_id in boundary_nodes: 

142 boundary_hits.add(parent_fn) 

143 continue 

144 reachable.add(parent_id) 

145 queue.append(parent_fn) 

146 

147 return reachable, boundary_hits 

148 

149 

150def _find_boundary_intermediates_with_edge_slots(weight_root, boundary_node_ids, backward_graph): 

151 """For one weight root, collect boundary intermediates and their edge-slot(s) on the boundary node.""" 

152 inter_to_slots = defaultdict(set) 

153 if weight_root is None: 

154 return inter_to_slots 

155 

156 q = deque([weight_root]) 

157 seen = {_get_node_id(weight_root)} 

158 

159 while q: 

160 node = q.popleft() 

161 node_id = _get_node_id(node) 

162 for parent, slot in backward_graph.get(node_id, ()): 

163 if parent is None: 

164 continue 

165 parent_id = _get_node_id(parent) 

166 if parent_id in boundary_node_ids: 

167 # Slot is already carried by reverse graph, no need to rescan parent.next_functions. 

168 inter_to_slots[parent].add(slot) 

169 continue 

170 if parent_id not in seen: 

171 seen.add(parent_id) 

172 q.append(parent) 

173 

174 return inter_to_slots 

175 

176 

177def _group_weights_by_intermediate(input_grad_fns, weight_grad_fns, backward_graph): 

178 """Group weight nodes by boundary intermediates; also merge edge-slot usage.""" 

179 # Step 1: Build input subgraph - all nodes reachable from inputs 

180 input_subgraph_ids, _ = _compute_reachable_nodes(input_grad_fns, set(), backward_graph) 

181 

182 weight_groups = {} 

183 

184 # Step 2: For each weight, find boundary nodes with input subgraph 

185 for weight_fn in weight_grad_fns: 

186 inter_to_slots = _find_boundary_intermediates_with_edge_slots(weight_fn, input_subgraph_ids, backward_graph) 

187 intermediate_nodes = set(inter_to_slots.keys()) 

188 

189 weight_group = { 

190 'params': {weight_fn}, 

191 'intermediates': intermediate_nodes, 

192 'edge_slots': {_get_node_id(node): set(slots) for node, slots in inter_to_slots.items()}, 

193 } 

194 

195 # Merge weights with same intermediate nodes 

196 for intermediate_node in intermediate_nodes: 

197 intermediate_id = _get_node_id(intermediate_node) 

198 existing = weight_groups.get(intermediate_id, None) 

199 if existing is not None: 

200 existing['params'] = existing['params'].union(weight_group['params']) 

201 existing['intermediates'] = existing['intermediates'].union(weight_group['intermediates']) 

202 if existing.get('edge_slots', None) is None: 

203 existing['edge_slots'] = {} 

204 for nid, slots in weight_group.get('edge_slots', {}).items(): 

205 existing['edge_slots'].setdefault(nid, set()).update(slots) 

206 weight_group = existing 

207 else: 

208 weight_groups[intermediate_id] = weight_group 

209 

210 # Return unique weight groups (normalize to deterministic order for hooks/consumption) 

211 seen_groups = set() 

212 unique_groups = [] 

213 for weight_group in weight_groups.values(): 

214 group_id = id(weight_group) 

215 if group_id not in seen_groups: 

216 seen_groups.add(group_id) 

217 weight_group['intermediates'] = tuple(sorted(weight_group.get('intermediates', ()), key=_get_node_id)) 

218 weight_group['params'] = tuple(sorted(weight_group.get('params', ()), key=_get_node_id)) 

219 edge_slots = {} 

220 for nid, slots in weight_group.get('edge_slots', {}).items(): 

221 edge_slots[nid] = tuple(sorted(slots)) 

222 weight_group['edge_slots'] = edge_slots 

223 unique_groups.append(weight_group) 

224 

225 return unique_groups 

226 

227 

228class GradFunction: 

229 """ 

230 A wrapper class used to build forward output and gradient functions. 

231 This class supports separate computation of input gradients (dx) and weight gradients (dw), 

232 which is essential for pipeline parallelism. 

233 

234 Args: 

235 output (Any): Forward output value of the network. 

236 inputs (tuple[Tensor, ...] | Tensor | Any): Original inputs used for forward computation. 

237 kwargs (dict): Keyword arguments used in forward computation. 

238 weights (tuple[Parameter, ...] | None): Parameters used for weight gradient calculation. 

239 has_aux (bool): Whether the forward output contains auxiliary values. 

240 grad_position (tuple[int, ...]): Positions of inputs to compute gradients for. 

241 

242 Raises: 

243 TypeError: If `output` does not match the expected structure when `has_aux` is ``True``. 

244 

245 Supported Platforms: 

246 ``Ascend`` ``GPU`` ``CPU`` 

247 """ 

248 

249 def __init__(self, output, inputs, kwargs, weights, has_aux, grad_position): 

250 self.output = output 

251 self.inputs = inputs 

252 self.flatten_input_size = 0 

253 self.kwargs = kwargs 

254 self.weights = weights 

255 self.has_aux = has_aux 

256 self.grad_position = grad_position 

257 # Storage for intermediate gradients captured during dx computation 

258 self._saved_intermediates = [] 

259 self.aux_inputs_data = None 

260 self.aux_weights_data = None 

261 

262 def _clear_res(self): 

263 self.output = None 

264 self.inputs = None 

265 self.weights = None 

266 self._saved_intermediates = [] 

267 

268 def _collect_weight_tensors(self): 

269 """Collect weight tensors into a list.""" 

270 if self.weights is None: 

271 return [] 

272 if isinstance(self.weights, tuple): 

273 return list(self.weights) 

274 return [self.weights] 

275 

276 def _setup_intermediate_hooks(self, output_tensors, input_tensors, weight_tensors): 

277 """Setup prehooks on intermediate nodes to capture gradients for dw computation.""" 

278 hook_handles = [] 

279 output_grad_fns = [_get_grad_node(t) for t in output_tensors if isinstance(t, Tensor)] 

280 input_grad_fns = [_get_grad_node(t) for t in input_tensors] 

281 weight_grad_fns = [_get_grad_node(w) for w in weight_tensors] 

282 

283 # Filter out None values 

284 output_grad_fns = [fn for fn in output_grad_fns if fn is not None] 

285 input_grad_fns = [fn for fn in input_grad_fns if fn is not None] 

286 weight_grad_fns = [fn for fn in weight_grad_fns if fn is not None] 

287 

288 if not output_grad_fns or not weight_grad_fns: 

289 return hook_handles 

290 

291 backward_graph = _compute_nodes_out_degree(output_grad_fns) 

292 weight_groups = _group_weights_by_intermediate(input_grad_fns, weight_grad_fns, backward_graph) 

293 

294 for weight_group in weight_groups: 

295 for i, intermediate in enumerate(weight_group['intermediates']): 

296 def make_hook(wg, idx): 

297 def prehook_fn(grad_inputs): 

298 if wg.get('grads', None) is None: 

299 wg['grads'] = [None] * len(wg['intermediates']) 

300 wg['grads'][idx] = grad_inputs 

301 return grad_inputs 

302 return prehook_fn 

303 handle = intermediate.register_prehook(make_hook(weight_group, i)) 

304 hook_handles.append(handle) 

305 

306 self._saved_intermediates = weight_groups 

307 return hook_handles 

308 

309 def _format_input_grads(self, input_grads, input_size): 

310 """Format input gradients based on grad_position configuration.""" 

311 if isinstance(self.grad_position, int) and self.grad_position == -1: 

312 if self.flatten_input_size == input_size: 

313 return tree_unflatten(self.aux_inputs_data, input_grads[:input_size]) 

314 return (tree_unflatten(self.aux_inputs_data, input_grads[:self.flatten_input_size]), 

315 tree_unflatten(self.aux_kwargs_data, input_grads[self.flatten_input_size:input_size])) 

316 return input_grads[0] if len(input_grads) == 1 else input_grads 

317 

318 def _prune_intermediate_edges(self): 

319 """Prune intermediate next_edges based on recorded edge slots.""" 

320 for weight_group in self._saved_intermediates: 

321 edge_slots = weight_group.get('edge_slots', {}) 

322 for intermediate in weight_group.get('intermediates', ()): 

323 if intermediate is None: 

324 continue 

325 keep_slots = set(edge_slots.get(_get_node_id(intermediate), ())) 

326 for slot in range(len(intermediate.next_functions)): 

327 if slot not in keep_slots: 

328 intermediate._set_next_edge(slot, None) # pylint: disable=protected-access 

329 

330 def _process_weight_group(self, weight_group, grad_node_to_weight, keep_graph): 

331 """Process a single weight group and compute gradients.""" 

332 grad_edges = [] 

333 grad_outputs = [] 

334 

335 for captured_grads, intermediate in zip(weight_group['grads'], weight_group['intermediates']): 

336 if captured_grads is None: 

337 continue 

338 if isinstance(captured_grads, (tuple, list)): 

339 for slot_idx, grad_item in enumerate(captured_grads): 

340 if grad_item is not None: 

341 grad_edges.append(_GradientEdge( 

342 grad_node=intermediate, output_index=slot_idx, keep_alive_token=None)) 

343 grad_outputs.append(grad_item) 

344 else: 

345 grad_edges.append(_GradientEdge( 

346 grad_node=intermediate, output_index=0, keep_alive_token=None)) 

347 grad_outputs.append(captured_grads) 

348 

349 if not keep_graph: 

350 del weight_group['intermediates'] 

351 

352 if not grad_edges: 

353 return {} 

354 

355 group_weights = [grad_node_to_weight.get(_get_node_id(weight_fn)) 

356 for weight_fn in weight_group.get('params', set())] 

357 group_weights = [w for w in group_weights if w is not None] 

358 

359 if not group_weights: 

360 return {} 

361 

362 weight_grads = run_backward( 

363 tuple(grad_edges), tuple(grad_outputs), 

364 keep_graph, keep_graph, 

365 tuple(group_weights), allow_unreachable=True, accumulate_grad=False 

366 ) 

367 _accumulate_grads(group_weights, weight_grads) 

368 if not keep_graph: 

369 del weight_group['grads'] 

370 

371 collected = {} 

372 for weight, grad in zip(group_weights, weight_grads): 

373 collected[weight] = grad if grad is not None else ops.zeros_like(weight) 

374 return collected 

375 

376 def _prepare_output_and_sens(self, sens): 

377 """ 

378 Prepare output tensor and sensitivity based on has_aux flag. 

379 

380 Args: 

381 sens: Gradient of output tensor for gradient computation. 

382 

383 Returns: 

384 Tuple of (output_tensor, processed_sens) 

385 """ 

386 output_tensor = self.output 

387 if self.has_aux: 

388 if not isinstance(self.output, (tuple, list)): 

389 raise TypeError( 

390 f"The output of fn should be list or tuple when has_aux=True, " 

391 f"but got {type(self.output)}" 

392 ) 

393 output_tensor = output_tensor[0] 

394 if isinstance(sens, (tuple, list)): 

395 sens = sens[0] 

396 return (output_tensor,), (sens,) 

397 flatten_outputs = tree_leaves(output_tensor, tensors_only_leaf=True) 

398 if sens is None: 

399 sens = _fill_grads(flatten_outputs) 

400 else: 

401 sens = tree_leaves(sens, tensors_only_leaf=True) 

402 return tuple(flatten_outputs), tuple(sens) 

403 

404 def _collect_input_tensors(self, collect_weights=True): 

405 """Collect input tensors based on grad_position and weights.""" 

406 input_tensors = [] 

407 if isinstance(self.grad_position, int) and self.grad_position == -1: 

408 flatten_inputs, self.aux_inputs_data = tree_flatten(self.inputs, tensors_only_leaf=True) 

409 flatten_kwargs, self.aux_kwargs_data = tree_flatten(self.kwargs, tensors_only_leaf=True) 

410 self.flatten_input_size = len(flatten_inputs) 

411 input_tensors.extend(flatten_inputs) 

412 input_tensors.extend(flatten_kwargs) 

413 elif isinstance(self.grad_position, int) and isinstance(self.inputs[self.grad_position], Tensor): 

414 input_tensors.append(self.inputs[self.grad_position]) 

415 elif isinstance(self.grad_position, (list, tuple)): 

416 input_tensors.extend(self.inputs[idx] for idx in self.grad_position if isinstance(self.inputs[idx], Tensor)) 

417 

418 input_size = len(input_tensors) 

419 if self.weights is not None and collect_weights: 

420 if isinstance(self.weights, (list, tuple)): 

421 input_tensors.extend(self.weights) 

422 else: 

423 input_tensors.append(self.weights) 

424 

425 return tuple(input_tensors), input_size 

426 

427 def __call__(self, sens=None, keep_graph=False): 

428 """ 

429 Compute gradients with respect to both inputs and weights. 

430 

431 Args: 

432 sens: gradient of output tensor for gradient computation. 

433 keep_graph: Whether to keep the computation graph. 

434 

435 Returns: 

436 Gradients with respect to inputs and/or weights. 

437 """ 

438 weights = self.weights 

439 input_tensors, input_size = self._collect_input_tensors() 

440 output_tensors, sens = self._prepare_output_and_sens(sens) 

441 

442 grads = run_backward( 

443 output_tensors, sens, keep_graph, keep_graph, 

444 input_tensors, allow_unreachable=True, accumulate_grad=False 

445 ) 

446 if input_size > 0: 

447 _accumulate_grads(input_tensors[:input_size], grads[:input_size]) 

448 weight_tensors = self._collect_weight_tensors() 

449 if weight_tensors: 

450 _accumulate_grads(weight_tensors, grads[input_size:]) 

451 if not keep_graph: 

452 self._clear_res() 

453 if input_size == 0: 

454 return grads 

455 if weights is None: 

456 if isinstance(self.grad_position, int) and self.grad_position == -1: 

457 if self.flatten_input_size == input_size: 

458 return tree_unflatten(self.aux_inputs_data, grads) 

459 return (tree_unflatten(self.aux_inputs_data, grads[:self.flatten_input_size]), 

460 tree_unflatten(self.aux_kwargs_data, grads[self.flatten_input_size:input_size])) 

461 return grads[0] if len(grads) == 1 else grads 

462 if isinstance(self.grad_position, int) and self.grad_position == -1: 

463 if self.flatten_input_size == input_size: 

464 return tree_unflatten(self.aux_inputs_data, grads[:input_size]), grads[input_size:] 

465 return (tree_unflatten(self.aux_inputs_data, grads[:self.flatten_input_size]), 

466 tree_unflatten(self.aux_kwargs_data, grads[self.flatten_input_size:input_size]), 

467 grads[input_size:]) 

468 return grads[:input_size], grads[input_size:] 

469 

470 def compute_input_grad(self, sens=None): 

471 """ 

472 Compute gradients with respect to inputs only (dx). 

473 

474 This is the first stage of dx/dw split computation. It computes input gradients 

475 while capturing intermediate gradients at the boundaries between input and weight subgraphs. 

476 

477 Implementation Strategy: 

478 1. Compute graph out degree from output grad_fns 

479 2. Find intermediate nodes (boundaries between input/weight subgraphs) 

480 3. Register prehooks on these intermediate nodes to capture gradients 

481 4. Save captured intermediate gradients for later dw computation 

482 

483 Args: 

484 sens: gradient of output tensor for gradient computation. If None, will use ones_like. 

485 

486 Returns: 

487 Gradients with respect to inputs. Returns single tensor if one input, 

488 tuple of tensors if multiple inputs. 

489 

490 When grad_position=-1 (default), the return type matches the input structure, 

491 supporting complex input types (tuple, dict, etc.). The gradients are automatically 

492 unflattened to preserve the original input structure. 

493 

494 Raises: 

495 ValueError: If grad_position is None (no inputs specified for differentiation). 

496 """ 

497 if self.grad_position is None: 

498 raise ValueError( 

499 "compute_input_grad requires grad_position to be specified. " 

500 "Cannot compute input gradients when grad_position is None." 

501 ) 

502 input_tensors, input_size = self._collect_input_tensors(collect_weights=False) 

503 if not input_tensors: 

504 logging.info("No valid input tensors found for gradient computation.") 

505 return () 

506 

507 weight_tensors = self._collect_weight_tensors() 

508 self._saved_intermediates = [] 

509 output_tensors, sens = self._prepare_output_and_sens(sens) 

510 

511 hook_handles = [] 

512 if weight_tensors: 

513 hook_handles = self._setup_intermediate_hooks(output_tensors, input_tensors, weight_tensors) 

514 

515 input_grads = run_backward( 

516 output_tensors, sens, True, True, 

517 tuple(input_tensors), allow_unreachable=True, accumulate_grad=False 

518 ) 

519 _accumulate_grads(input_tensors, input_grads) 

520 

521 for handle in hook_handles: 

522 handle.remove() 

523 

524 return self._format_input_grads(input_grads, input_size) 

525 

526 def compute_weight_grad(self, keep_graph=False): 

527 """ 

528 Compute gradients with respect to weights only (dw). 

529 

530 This is the second stage of dx/dw split computation. It uses the intermediate 

531 gradients captured during compute_input_grad to compute weight gradients efficiently, 

532 starting from intermediate nodes rather than recomputing from the output. 

533 

534 Implementation Strategy: 

535 1. Use saved intermediate gradients and GradientEdges from compute_input_grad 

536 2. Start backward computation from these intermediate points (not from output) 

537 3. This avoids recomputing the portion of the graph already computed in dx 

538 

539 Args: 

540 keep_graph: Whether to keep the computation graph after this computation. 

541 Default is False as this is typically the final gradient computation. 

542 

543 Returns: 

544 Gradients with respect to weights. Returns single tensor if one weight, 

545 tuple of tensors if multiple weights. 

546 

547 Raises: 

548 ValueError: If weights is None (no weights specified for differentiation). 

549 RuntimeError: If compute_input_grad was not called before (no saved intermediates). 

550 

551 Note: 

552 This function must be called after compute_input_grad, as it relies on 

553 the intermediate gradients captured during that computation. The computation 

554 graph must still be available (retained by keep_graph=True in compute_input_grad). 

555 """ 

556 

557 if self.weights is None: 

558 warnings.warn( 

559 "compute_weight_grad requires weights to be specified. " 

560 "Cannot compute weight gradients when weights is None." 

561 ) 

562 self._clear_res() 

563 return () 

564 if not self._saved_intermediates: 

565 raise RuntimeError("Before calling compute_weight_grad, you need first call compute_input_grad!") 

566 

567 weight_tensors = self._collect_weight_tensors() 

568 if not weight_tensors: 

569 raise ValueError("No valid weight tensors found for gradient computation.") 

570 

571 grad_node_to_weight = {_get_node_id(_get_grad_node(w)): w 

572 for w in weight_tensors if _get_grad_node(w) is not None} 

573 

574 self._prune_intermediate_edges() 

575 

576 collected_grads = {} 

577 for weight_group in self._saved_intermediates: 

578 group_grads = self._process_weight_group(weight_group, grad_node_to_weight, keep_graph) 

579 collected_grads.update(group_grads) 

580 

581 if not keep_graph: 

582 self._clear_res() 

583 

584 result_grads = [collected_grads.get(weight) for weight in weight_tensors] 

585 return tuple(result_grads) 

586 

587 

588def forward_and_gradfn(fn, *inputs, weights=None, has_aux=False, grad_position=-1, **kwargs): 

589 """ 

590 A wrapper function to generate the function to calculate forward output and gradient function. 

591 

592 As for gradient, three typical cases are included: 

593 

594 1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is ``None``. 

595 2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not ``None``. 

596 3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not ``None``. 

597 

598 Args: 

599 fn (Union[Cell, Function]): Function to do GradOperation. 

600 *inputs: Variable length argument list of inputs to the function `fn`. 

601 weights (Union[ParameterTuple, Parameter, list[Parameter]], optional): 

602 The parameters of the training network that need to 

603 calculate the gradient. `weights` can be got through `weights = net.trainable_params()` . 

604 Default: ``None`` . 

605 has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, 

606 while the other outputs will be returned straightly. It means the `fn` must return more than one outputs 

607 in this case. 

608 Default: ``False`` . 

609 grad_position (Union[NoneType, int, tuple[int]], optional): Index to specify which inputs 

610 to be differentiated. Default: ``-1`` means all inputs are differentiated. 

611 

612 - If int, get the gradient with respect to single input. 

613 - If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0. 

614 - If None, none derivative of any input will be solved, and in this case, `weights` is required. 

615 

616 **kwargs: Variable length keyword argument dictionary. Additional keyword arguments passed to the function `fn`. 

617 

618 Returns: 

619 Tuple of (output, grad_fn): 

620 

621 - output: The output value of function `fn`. 

622 - grad_fn: A :class:`GradFunction` instance used to compute gradients. 

623 

624 The :class:`GradFunction` class provides methods for gradient computation: 

625 

626 - :meth:`__call__`: Compute gradients with respect to both inputs and weights at once. 

627 - :meth:`compute_input_grad`: Compute gradients with respect to inputs only (dx). 

628 This is the first stage of dx/dw split computation, which captures intermediate 

629 gradients at weight nodes for efficient dw computation. 

630 - :meth:`compute_weight_grad`: Compute gradients with respect to weights only (dw). 

631 This is the second stage of dx/dw split computation, which uses the intermediate 

632 gradients captured during :meth:`compute_input_grad` to compute weight gradients 

633 efficiently without recomputing from the output. 

634 

635 Examples: 

636 When grad_position=-1 (default), the gradient return type matches the input structure, 

637 supporting complex input types (tuple, dict, etc.): 

638 

639 >>> def fn(x, tuple_input, scale=None): 

640 ... a, b = tuple_input 

641 ... return x * a + x * b + scale * x 

642 >>> x = Tensor([2.0]) 

643 >>> a, b = Tensor([3.0]), Tensor([4.0]) 

644 >>> scale = Tensor([0.5]) 

645 >>> _, grad_fn = forward_and_gradfn(fn, x, (a, b), grad_position=-1, scale=scale) 

646 >>> dx_grads = grad_fn.compute_input_grad() 

647 >>> # When both args and kwargs have tensors, returns (args_grads, kwargs_grads) 

648 >>> # args_grads structure matches (x, (a, b)) -> (Tensor, (Tensor, Tensor)) 

649 >>> # kwargs_grads structure matches {'scale': scale} -> {'scale': Tensor} 

650 >>> print(type(dx_grads)) 

651 <class 'tuple'> 

652 >>> print(len(dx_grads)) 

653 2 

654 >>> print(type(dx_grads[0])) 

655 <class 'tuple'> 

656 >>> print(type(dx_grads[0][1])) 

657 <class 'tuple'> 

658 >>> print(type(dx_grads[1])) 

659 <class 'dict'> 

660 

661 Raises: 

662 ValueError: If both `grad_position` and `weights` are ``None``. 

663 TypeError: If type of Args does not belong to required ones. 

664 

665 Supported Platforms: 

666 ``Ascend`` ``GPU`` ``CPU`` 

667 """ 

668 _validate_grad_config(grad_position, weights) 

669 _set_requires_grad(inputs, kwargs, grad_position) 

670 prev_grad_flag = _pynative_executor.grad_flag() 

671 _pynative_executor.set_grad_flag(True) 

672 try: 

673 res = fn(*inputs, **kwargs) 

674 except Exception as e: 

675 _pynative_executor.clear_res() 

676 raise e 

677 finally: 

678 _pynative_executor.set_grad_flag(prev_grad_flag) 

679 grad_fn = GradFunction(res, inputs, kwargs, weights, has_aux, grad_position) 

680 return res, grad_fn