Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / pipeline_parallel / backward.py: 9%
350 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""parallel grad helper (dx/dw split)"""
16from __future__ import absolute_import
17from collections import deque, defaultdict
18import warnings
19import logging
20from mindspore.utils._pytree import tree_flatten, tree_leaves, tree_unflatten
21from mindspore.common.api import _pynative_executor, _GradientEdge
22from mindspore._c_expression import run_backward
23from mindspore.common.tensor import Tensor
24from mindspore import ops
27def _fill_grads(output_tensor):
28 """Fill gradients with ones for given tensor(s)."""
29 if isinstance(output_tensor, Tensor):
30 return ops.ones_like(output_tensor)
31 if isinstance(output_tensor, (list, tuple)):
32 return tuple(ops.ones_like(tensor) for tensor in output_tensor)
33 return None
36def _validate_grad_config(grad_position, weights):
37 """Validate gradient configuration."""
38 if grad_position is None and not weights:
39 raise ValueError("grad_position and weights cannot both be None!")
42def _set_requires_grad(inputs, kwargs, grad_position):
43 """Set requires_grad flag for specified inputs."""
44 if grad_position is None:
45 return
46 if grad_position == -1:
47 flatten_inputs, _ = tree_flatten(inputs, tensors_only_leaf=True)
48 for inp in flatten_inputs:
49 inp._requires_grad = True # pylint: disable=protected-access
50 flatten_kwargs, _ = tree_flatten(kwargs, tensors_only_leaf=True)
51 for kwarg in flatten_kwargs:
52 kwarg._requires_grad = True # pylint: disable=protected-access
53 return
54 if isinstance(grad_position, int) and isinstance(inputs[grad_position], Tensor):
55 inputs[grad_position]._requires_grad = True # pylint: disable=protected-access
56 return
57 if isinstance(grad_position, tuple):
58 for idx in grad_position:
59 if isinstance(inputs[idx], Tensor):
60 inputs[idx]._requires_grad = True # pylint: disable=protected-access
63def _get_grad_node(tensor):
64 """Get the grad_fn or grad accumulator for a tensor."""
65 return tensor._grad_node # pylint: disable=protected-access
68def _get_node_id(node):
69 """Return stable integer id for BackwardNode (C++ `_unique_id`)."""
70 return int(node._unique_id()) # pylint: disable=protected-access
73def _accumulate_grads(target_tensors, grads):
74 """Accumulate returned grads onto leaf tensors' internal grad storage."""
75 if grads is None:
76 return
77 for tensor, grad in zip(target_tensors, grads):
78 if tensor is None or grad is None:
79 continue
80 current_grad = getattr(tensor, "_grad", None)
81 if current_grad is None:
82 tensor._grad = grad # pylint: disable=protected-access
83 else:
84 current_grad += grad
87def _compute_nodes_out_degree(output_grad_fns):
88 """Build reverse edges: child_id -> [(parent_node, slot_index)]."""
90 queue = deque()
91 visited_roots = set()
92 # child_node_id -> list[(parent_node, slot_index_in_parent_next_functions)]
93 backward_edges = defaultdict(list)
95 # Seed from output grad_fns.
96 for node in output_grad_fns:
97 if node is not None:
98 node_id = _get_node_id(node)
99 if node_id not in visited_roots:
100 queue.append(node)
101 visited_roots.add(node_id)
103 # Follow next_functions (towards inputs) while recording reverse edges.
104 while queue:
105 node = queue.popleft()
106 next_fns = node.next_functions
107 for slot, (child_fn, _) in enumerate(next_fns):
108 if child_fn is not None:
109 child_id = _get_node_id(child_fn)
110 if len(backward_edges[child_id]) == 0:
111 queue.append(child_fn)
112 backward_edges[child_id].append((node, slot))
114 return backward_edges
117def _compute_reachable_nodes(roots, boundary_nodes, backward_graph):
118 """BFS on reverse edges; stop at boundary_nodes and return (reachable_ids, boundary_hits)."""
119 reachable = set()
120 boundary_hits = set()
121 queue = deque()
123 for node in roots:
124 if node is not None:
125 node_id = _get_node_id(node)
126 if node_id not in reachable:
127 reachable.add(node_id)
128 queue.append(node)
130 while queue:
131 node = queue.popleft()
132 node_id = _get_node_id(node)
133 parent_entries = backward_graph.get(node_id, [])
135 for parent_fn, _ in parent_entries:
136 if parent_fn is None:
137 continue
138 parent_id = _get_node_id(parent_fn)
139 if parent_id in reachable:
140 continue
141 if parent_id in boundary_nodes:
142 boundary_hits.add(parent_fn)
143 continue
144 reachable.add(parent_id)
145 queue.append(parent_fn)
147 return reachable, boundary_hits
150def _find_boundary_intermediates_with_edge_slots(weight_root, boundary_node_ids, backward_graph):
151 """For one weight root, collect boundary intermediates and their edge-slot(s) on the boundary node."""
152 inter_to_slots = defaultdict(set)
153 if weight_root is None:
154 return inter_to_slots
156 q = deque([weight_root])
157 seen = {_get_node_id(weight_root)}
159 while q:
160 node = q.popleft()
161 node_id = _get_node_id(node)
162 for parent, slot in backward_graph.get(node_id, ()):
163 if parent is None:
164 continue
165 parent_id = _get_node_id(parent)
166 if parent_id in boundary_node_ids:
167 # Slot is already carried by reverse graph, no need to rescan parent.next_functions.
168 inter_to_slots[parent].add(slot)
169 continue
170 if parent_id not in seen:
171 seen.add(parent_id)
172 q.append(parent)
174 return inter_to_slots
177def _group_weights_by_intermediate(input_grad_fns, weight_grad_fns, backward_graph):
178 """Group weight nodes by boundary intermediates; also merge edge-slot usage."""
179 # Step 1: Build input subgraph - all nodes reachable from inputs
180 input_subgraph_ids, _ = _compute_reachable_nodes(input_grad_fns, set(), backward_graph)
182 weight_groups = {}
184 # Step 2: For each weight, find boundary nodes with input subgraph
185 for weight_fn in weight_grad_fns:
186 inter_to_slots = _find_boundary_intermediates_with_edge_slots(weight_fn, input_subgraph_ids, backward_graph)
187 intermediate_nodes = set(inter_to_slots.keys())
189 weight_group = {
190 'params': {weight_fn},
191 'intermediates': intermediate_nodes,
192 'edge_slots': {_get_node_id(node): set(slots) for node, slots in inter_to_slots.items()},
193 }
195 # Merge weights with same intermediate nodes
196 for intermediate_node in intermediate_nodes:
197 intermediate_id = _get_node_id(intermediate_node)
198 existing = weight_groups.get(intermediate_id, None)
199 if existing is not None:
200 existing['params'] = existing['params'].union(weight_group['params'])
201 existing['intermediates'] = existing['intermediates'].union(weight_group['intermediates'])
202 if existing.get('edge_slots', None) is None:
203 existing['edge_slots'] = {}
204 for nid, slots in weight_group.get('edge_slots', {}).items():
205 existing['edge_slots'].setdefault(nid, set()).update(slots)
206 weight_group = existing
207 else:
208 weight_groups[intermediate_id] = weight_group
210 # Return unique weight groups (normalize to deterministic order for hooks/consumption)
211 seen_groups = set()
212 unique_groups = []
213 for weight_group in weight_groups.values():
214 group_id = id(weight_group)
215 if group_id not in seen_groups:
216 seen_groups.add(group_id)
217 weight_group['intermediates'] = tuple(sorted(weight_group.get('intermediates', ()), key=_get_node_id))
218 weight_group['params'] = tuple(sorted(weight_group.get('params', ()), key=_get_node_id))
219 edge_slots = {}
220 for nid, slots in weight_group.get('edge_slots', {}).items():
221 edge_slots[nid] = tuple(sorted(slots))
222 weight_group['edge_slots'] = edge_slots
223 unique_groups.append(weight_group)
225 return unique_groups
228class GradFunction:
229 """
230 A wrapper class used to build forward output and gradient functions.
231 This class supports separate computation of input gradients (dx) and weight gradients (dw),
232 which is essential for pipeline parallelism.
234 Args:
235 output (Any): Forward output value of the network.
236 inputs (tuple[Tensor, ...] | Tensor | Any): Original inputs used for forward computation.
237 kwargs (dict): Keyword arguments used in forward computation.
238 weights (tuple[Parameter, ...] | None): Parameters used for weight gradient calculation.
239 has_aux (bool): Whether the forward output contains auxiliary values.
240 grad_position (tuple[int, ...]): Positions of inputs to compute gradients for.
242 Raises:
243 TypeError: If `output` does not match the expected structure when `has_aux` is ``True``.
245 Supported Platforms:
246 ``Ascend`` ``GPU`` ``CPU``
247 """
249 def __init__(self, output, inputs, kwargs, weights, has_aux, grad_position):
250 self.output = output
251 self.inputs = inputs
252 self.flatten_input_size = 0
253 self.kwargs = kwargs
254 self.weights = weights
255 self.has_aux = has_aux
256 self.grad_position = grad_position
257 # Storage for intermediate gradients captured during dx computation
258 self._saved_intermediates = []
259 self.aux_inputs_data = None
260 self.aux_weights_data = None
262 def _clear_res(self):
263 self.output = None
264 self.inputs = None
265 self.weights = None
266 self._saved_intermediates = []
268 def _collect_weight_tensors(self):
269 """Collect weight tensors into a list."""
270 if self.weights is None:
271 return []
272 if isinstance(self.weights, tuple):
273 return list(self.weights)
274 return [self.weights]
276 def _setup_intermediate_hooks(self, output_tensors, input_tensors, weight_tensors):
277 """Setup prehooks on intermediate nodes to capture gradients for dw computation."""
278 hook_handles = []
279 output_grad_fns = [_get_grad_node(t) for t in output_tensors if isinstance(t, Tensor)]
280 input_grad_fns = [_get_grad_node(t) for t in input_tensors]
281 weight_grad_fns = [_get_grad_node(w) for w in weight_tensors]
283 # Filter out None values
284 output_grad_fns = [fn for fn in output_grad_fns if fn is not None]
285 input_grad_fns = [fn for fn in input_grad_fns if fn is not None]
286 weight_grad_fns = [fn for fn in weight_grad_fns if fn is not None]
288 if not output_grad_fns or not weight_grad_fns:
289 return hook_handles
291 backward_graph = _compute_nodes_out_degree(output_grad_fns)
292 weight_groups = _group_weights_by_intermediate(input_grad_fns, weight_grad_fns, backward_graph)
294 for weight_group in weight_groups:
295 for i, intermediate in enumerate(weight_group['intermediates']):
296 def make_hook(wg, idx):
297 def prehook_fn(grad_inputs):
298 if wg.get('grads', None) is None:
299 wg['grads'] = [None] * len(wg['intermediates'])
300 wg['grads'][idx] = grad_inputs
301 return grad_inputs
302 return prehook_fn
303 handle = intermediate.register_prehook(make_hook(weight_group, i))
304 hook_handles.append(handle)
306 self._saved_intermediates = weight_groups
307 return hook_handles
309 def _format_input_grads(self, input_grads, input_size):
310 """Format input gradients based on grad_position configuration."""
311 if isinstance(self.grad_position, int) and self.grad_position == -1:
312 if self.flatten_input_size == input_size:
313 return tree_unflatten(self.aux_inputs_data, input_grads[:input_size])
314 return (tree_unflatten(self.aux_inputs_data, input_grads[:self.flatten_input_size]),
315 tree_unflatten(self.aux_kwargs_data, input_grads[self.flatten_input_size:input_size]))
316 return input_grads[0] if len(input_grads) == 1 else input_grads
318 def _prune_intermediate_edges(self):
319 """Prune intermediate next_edges based on recorded edge slots."""
320 for weight_group in self._saved_intermediates:
321 edge_slots = weight_group.get('edge_slots', {})
322 for intermediate in weight_group.get('intermediates', ()):
323 if intermediate is None:
324 continue
325 keep_slots = set(edge_slots.get(_get_node_id(intermediate), ()))
326 for slot in range(len(intermediate.next_functions)):
327 if slot not in keep_slots:
328 intermediate._set_next_edge(slot, None) # pylint: disable=protected-access
330 def _process_weight_group(self, weight_group, grad_node_to_weight, keep_graph):
331 """Process a single weight group and compute gradients."""
332 grad_edges = []
333 grad_outputs = []
335 for captured_grads, intermediate in zip(weight_group['grads'], weight_group['intermediates']):
336 if captured_grads is None:
337 continue
338 if isinstance(captured_grads, (tuple, list)):
339 for slot_idx, grad_item in enumerate(captured_grads):
340 if grad_item is not None:
341 grad_edges.append(_GradientEdge(
342 grad_node=intermediate, output_index=slot_idx, keep_alive_token=None))
343 grad_outputs.append(grad_item)
344 else:
345 grad_edges.append(_GradientEdge(
346 grad_node=intermediate, output_index=0, keep_alive_token=None))
347 grad_outputs.append(captured_grads)
349 if not keep_graph:
350 del weight_group['intermediates']
352 if not grad_edges:
353 return {}
355 group_weights = [grad_node_to_weight.get(_get_node_id(weight_fn))
356 for weight_fn in weight_group.get('params', set())]
357 group_weights = [w for w in group_weights if w is not None]
359 if not group_weights:
360 return {}
362 weight_grads = run_backward(
363 tuple(grad_edges), tuple(grad_outputs),
364 keep_graph, keep_graph,
365 tuple(group_weights), allow_unreachable=True, accumulate_grad=False
366 )
367 _accumulate_grads(group_weights, weight_grads)
368 if not keep_graph:
369 del weight_group['grads']
371 collected = {}
372 for weight, grad in zip(group_weights, weight_grads):
373 collected[weight] = grad if grad is not None else ops.zeros_like(weight)
374 return collected
376 def _prepare_output_and_sens(self, sens):
377 """
378 Prepare output tensor and sensitivity based on has_aux flag.
380 Args:
381 sens: Gradient of output tensor for gradient computation.
383 Returns:
384 Tuple of (output_tensor, processed_sens)
385 """
386 output_tensor = self.output
387 if self.has_aux:
388 if not isinstance(self.output, (tuple, list)):
389 raise TypeError(
390 f"The output of fn should be list or tuple when has_aux=True, "
391 f"but got {type(self.output)}"
392 )
393 output_tensor = output_tensor[0]
394 if isinstance(sens, (tuple, list)):
395 sens = sens[0]
396 return (output_tensor,), (sens,)
397 flatten_outputs = tree_leaves(output_tensor, tensors_only_leaf=True)
398 if sens is None:
399 sens = _fill_grads(flatten_outputs)
400 else:
401 sens = tree_leaves(sens, tensors_only_leaf=True)
402 return tuple(flatten_outputs), tuple(sens)
404 def _collect_input_tensors(self, collect_weights=True):
405 """Collect input tensors based on grad_position and weights."""
406 input_tensors = []
407 if isinstance(self.grad_position, int) and self.grad_position == -1:
408 flatten_inputs, self.aux_inputs_data = tree_flatten(self.inputs, tensors_only_leaf=True)
409 flatten_kwargs, self.aux_kwargs_data = tree_flatten(self.kwargs, tensors_only_leaf=True)
410 self.flatten_input_size = len(flatten_inputs)
411 input_tensors.extend(flatten_inputs)
412 input_tensors.extend(flatten_kwargs)
413 elif isinstance(self.grad_position, int) and isinstance(self.inputs[self.grad_position], Tensor):
414 input_tensors.append(self.inputs[self.grad_position])
415 elif isinstance(self.grad_position, (list, tuple)):
416 input_tensors.extend(self.inputs[idx] for idx in self.grad_position if isinstance(self.inputs[idx], Tensor))
418 input_size = len(input_tensors)
419 if self.weights is not None and collect_weights:
420 if isinstance(self.weights, (list, tuple)):
421 input_tensors.extend(self.weights)
422 else:
423 input_tensors.append(self.weights)
425 return tuple(input_tensors), input_size
427 def __call__(self, sens=None, keep_graph=False):
428 """
429 Compute gradients with respect to both inputs and weights.
431 Args:
432 sens: gradient of output tensor for gradient computation.
433 keep_graph: Whether to keep the computation graph.
435 Returns:
436 Gradients with respect to inputs and/or weights.
437 """
438 weights = self.weights
439 input_tensors, input_size = self._collect_input_tensors()
440 output_tensors, sens = self._prepare_output_and_sens(sens)
442 grads = run_backward(
443 output_tensors, sens, keep_graph, keep_graph,
444 input_tensors, allow_unreachable=True, accumulate_grad=False
445 )
446 if input_size > 0:
447 _accumulate_grads(input_tensors[:input_size], grads[:input_size])
448 weight_tensors = self._collect_weight_tensors()
449 if weight_tensors:
450 _accumulate_grads(weight_tensors, grads[input_size:])
451 if not keep_graph:
452 self._clear_res()
453 if input_size == 0:
454 return grads
455 if weights is None:
456 if isinstance(self.grad_position, int) and self.grad_position == -1:
457 if self.flatten_input_size == input_size:
458 return tree_unflatten(self.aux_inputs_data, grads)
459 return (tree_unflatten(self.aux_inputs_data, grads[:self.flatten_input_size]),
460 tree_unflatten(self.aux_kwargs_data, grads[self.flatten_input_size:input_size]))
461 return grads[0] if len(grads) == 1 else grads
462 if isinstance(self.grad_position, int) and self.grad_position == -1:
463 if self.flatten_input_size == input_size:
464 return tree_unflatten(self.aux_inputs_data, grads[:input_size]), grads[input_size:]
465 return (tree_unflatten(self.aux_inputs_data, grads[:self.flatten_input_size]),
466 tree_unflatten(self.aux_kwargs_data, grads[self.flatten_input_size:input_size]),
467 grads[input_size:])
468 return grads[:input_size], grads[input_size:]
470 def compute_input_grad(self, sens=None):
471 """
472 Compute gradients with respect to inputs only (dx).
474 This is the first stage of dx/dw split computation. It computes input gradients
475 while capturing intermediate gradients at the boundaries between input and weight subgraphs.
477 Implementation Strategy:
478 1. Compute graph out degree from output grad_fns
479 2. Find intermediate nodes (boundaries between input/weight subgraphs)
480 3. Register prehooks on these intermediate nodes to capture gradients
481 4. Save captured intermediate gradients for later dw computation
483 Args:
484 sens: gradient of output tensor for gradient computation. If None, will use ones_like.
486 Returns:
487 Gradients with respect to inputs. Returns single tensor if one input,
488 tuple of tensors if multiple inputs.
490 When grad_position=-1 (default), the return type matches the input structure,
491 supporting complex input types (tuple, dict, etc.). The gradients are automatically
492 unflattened to preserve the original input structure.
494 Raises:
495 ValueError: If grad_position is None (no inputs specified for differentiation).
496 """
497 if self.grad_position is None:
498 raise ValueError(
499 "compute_input_grad requires grad_position to be specified. "
500 "Cannot compute input gradients when grad_position is None."
501 )
502 input_tensors, input_size = self._collect_input_tensors(collect_weights=False)
503 if not input_tensors:
504 logging.info("No valid input tensors found for gradient computation.")
505 return ()
507 weight_tensors = self._collect_weight_tensors()
508 self._saved_intermediates = []
509 output_tensors, sens = self._prepare_output_and_sens(sens)
511 hook_handles = []
512 if weight_tensors:
513 hook_handles = self._setup_intermediate_hooks(output_tensors, input_tensors, weight_tensors)
515 input_grads = run_backward(
516 output_tensors, sens, True, True,
517 tuple(input_tensors), allow_unreachable=True, accumulate_grad=False
518 )
519 _accumulate_grads(input_tensors, input_grads)
521 for handle in hook_handles:
522 handle.remove()
524 return self._format_input_grads(input_grads, input_size)
526 def compute_weight_grad(self, keep_graph=False):
527 """
528 Compute gradients with respect to weights only (dw).
530 This is the second stage of dx/dw split computation. It uses the intermediate
531 gradients captured during compute_input_grad to compute weight gradients efficiently,
532 starting from intermediate nodes rather than recomputing from the output.
534 Implementation Strategy:
535 1. Use saved intermediate gradients and GradientEdges from compute_input_grad
536 2. Start backward computation from these intermediate points (not from output)
537 3. This avoids recomputing the portion of the graph already computed in dx
539 Args:
540 keep_graph: Whether to keep the computation graph after this computation.
541 Default is False as this is typically the final gradient computation.
543 Returns:
544 Gradients with respect to weights. Returns single tensor if one weight,
545 tuple of tensors if multiple weights.
547 Raises:
548 ValueError: If weights is None (no weights specified for differentiation).
549 RuntimeError: If compute_input_grad was not called before (no saved intermediates).
551 Note:
552 This function must be called after compute_input_grad, as it relies on
553 the intermediate gradients captured during that computation. The computation
554 graph must still be available (retained by keep_graph=True in compute_input_grad).
555 """
557 if self.weights is None:
558 warnings.warn(
559 "compute_weight_grad requires weights to be specified. "
560 "Cannot compute weight gradients when weights is None."
561 )
562 self._clear_res()
563 return ()
564 if not self._saved_intermediates:
565 raise RuntimeError("Before calling compute_weight_grad, you need first call compute_input_grad!")
567 weight_tensors = self._collect_weight_tensors()
568 if not weight_tensors:
569 raise ValueError("No valid weight tensors found for gradient computation.")
571 grad_node_to_weight = {_get_node_id(_get_grad_node(w)): w
572 for w in weight_tensors if _get_grad_node(w) is not None}
574 self._prune_intermediate_edges()
576 collected_grads = {}
577 for weight_group in self._saved_intermediates:
578 group_grads = self._process_weight_group(weight_group, grad_node_to_weight, keep_graph)
579 collected_grads.update(group_grads)
581 if not keep_graph:
582 self._clear_res()
584 result_grads = [collected_grads.get(weight) for weight in weight_tensors]
585 return tuple(result_grads)
588def forward_and_gradfn(fn, *inputs, weights=None, has_aux=False, grad_position=-1, **kwargs):
589 """
590 A wrapper function to generate the function to calculate forward output and gradient function.
592 As for gradient, three typical cases are included:
594 1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is ``None``.
595 2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not ``None``.
596 3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not ``None``.
598 Args:
599 fn (Union[Cell, Function]): Function to do GradOperation.
600 *inputs: Variable length argument list of inputs to the function `fn`.
601 weights (Union[ParameterTuple, Parameter, list[Parameter]], optional):
602 The parameters of the training network that need to
603 calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
604 Default: ``None`` .
605 has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`,
606 while the other outputs will be returned straightly. It means the `fn` must return more than one outputs
607 in this case.
608 Default: ``False`` .
609 grad_position (Union[NoneType, int, tuple[int]], optional): Index to specify which inputs
610 to be differentiated. Default: ``-1`` means all inputs are differentiated.
612 - If int, get the gradient with respect to single input.
613 - If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
614 - If None, none derivative of any input will be solved, and in this case, `weights` is required.
616 **kwargs: Variable length keyword argument dictionary. Additional keyword arguments passed to the function `fn`.
618 Returns:
619 Tuple of (output, grad_fn):
621 - output: The output value of function `fn`.
622 - grad_fn: A :class:`GradFunction` instance used to compute gradients.
624 The :class:`GradFunction` class provides methods for gradient computation:
626 - :meth:`__call__`: Compute gradients with respect to both inputs and weights at once.
627 - :meth:`compute_input_grad`: Compute gradients with respect to inputs only (dx).
628 This is the first stage of dx/dw split computation, which captures intermediate
629 gradients at weight nodes for efficient dw computation.
630 - :meth:`compute_weight_grad`: Compute gradients with respect to weights only (dw).
631 This is the second stage of dx/dw split computation, which uses the intermediate
632 gradients captured during :meth:`compute_input_grad` to compute weight gradients
633 efficiently without recomputing from the output.
635 Examples:
636 When grad_position=-1 (default), the gradient return type matches the input structure,
637 supporting complex input types (tuple, dict, etc.):
639 >>> def fn(x, tuple_input, scale=None):
640 ... a, b = tuple_input
641 ... return x * a + x * b + scale * x
642 >>> x = Tensor([2.0])
643 >>> a, b = Tensor([3.0]), Tensor([4.0])
644 >>> scale = Tensor([0.5])
645 >>> _, grad_fn = forward_and_gradfn(fn, x, (a, b), grad_position=-1, scale=scale)
646 >>> dx_grads = grad_fn.compute_input_grad()
647 >>> # When both args and kwargs have tensors, returns (args_grads, kwargs_grads)
648 >>> # args_grads structure matches (x, (a, b)) -> (Tensor, (Tensor, Tensor))
649 >>> # kwargs_grads structure matches {'scale': scale} -> {'scale': Tensor}
650 >>> print(type(dx_grads))
651 <class 'tuple'>
652 >>> print(len(dx_grads))
653 2
654 >>> print(type(dx_grads[0]))
655 <class 'tuple'>
656 >>> print(type(dx_grads[0][1]))
657 <class 'tuple'>
658 >>> print(type(dx_grads[1]))
659 <class 'dict'>
661 Raises:
662 ValueError: If both `grad_position` and `weights` are ``None``.
663 TypeError: If type of Args does not belong to required ones.
665 Supported Platforms:
666 ``Ascend`` ``GPU`` ``CPU``
667 """
668 _validate_grad_config(grad_position, weights)
669 _set_requires_grad(inputs, kwargs, grad_position)
670 prev_grad_flag = _pynative_executor.grad_flag()
671 _pynative_executor.set_grad_flag(True)
672 try:
673 res = fn(*inputs, **kwargs)
674 except Exception as e:
675 _pynative_executor.clear_res()
676 raise e
677 finally:
678 _pynative_executor.set_grad_flag(prev_grad_flag)
679 grad_fn = GradFunction(res, inputs, kwargs, weights, has_aux, grad_position)
680 return res, grad_fn