Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / utils / loss.py: 0%
14 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Token-weighted global loss normalization utilities.
17"""
18from typing import Any, Dict
20from hyper_parallel import get_platform
22platform = get_platform()
24def count_loss_token(batch: Dict[str, Any]) -> int:
25 """Count non-padding tokens in a micro-batch.
27 A token is considered valid (non-padding) when its label is not -100,
28 which is the conventional ignore index used in cross-entropy loss.
30 Args:
31 batch: Dictionary containing at least a ``"labels"`` tensor with
32 shape ``(batch_size, seq_len)``.
34 Returns:
35 Integer count of tokens where ``labels != -100``.
36 """
37 labels = batch.get("labels")
38 if labels is None:
39 return 0
40 return int((labels != -100).sum().item())
42def mean_global_loss(
43 loss: Any,
44 micro_batch_tokens: int,
45 total_tokens: int,
46 fsdp_size: int,
47) -> Any:
48 """Compute token-weighted, globally normalised loss for one micro-batch.
50 Each micro-batch contributes a fraction proportional to how many of the
51 total global tokens it contains. Multiplying by ``fsdp_size`` corrects
52 for the fact that FSDP averages gradients across data-parallel ranks,
53 while token counts are *per-rank* (not global).
55 Formula::
57 normalised_loss = raw_loss * (micro_tokens / global_tokens) * fsdp_size
59 Args:
60 loss: Raw loss scalar returned by the model (may be a DTensor partial).
61 micro_batch_tokens: Number of non-padding tokens in this micro-batch.
62 total_tokens: Total non-padding tokens across **all** micro-batches and
63 all data-parallel ranks in this global step.
64 fsdp_size: Number of data-parallel (FSDP) ranks participating in
65 gradient reduction.
67 Returns:
68 Scaled loss with the same type as ``loss``. If ``total_tokens`` is
69 zero, returns ``loss`` unchanged to avoid division by zero.
71 Raises:
72 ValueError: If ``fsdp_size`` is not a positive integer.
73 """
74 if fsdp_size <= 0:
75 raise ValueError(f"fsdp_size must be a positive integer, got {fsdp_size}")
76 if total_tokens <= 0:
77 return loss
78 return loss * (micro_batch_tokens / total_tokens) * fsdp_size