Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / utils / loss.py: 0%

14 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Token-weighted global loss normalization utilities. 

16 

17""" 

18from typing import Any, Dict 

19 

20from hyper_parallel import get_platform 

21 

22platform = get_platform() 

23 

24def count_loss_token(batch: Dict[str, Any]) -> int: 

25 """Count non-padding tokens in a micro-batch. 

26 

27 A token is considered valid (non-padding) when its label is not -100, 

28 which is the conventional ignore index used in cross-entropy loss. 

29 

30 Args: 

31 batch: Dictionary containing at least a ``"labels"`` tensor with 

32 shape ``(batch_size, seq_len)``. 

33 

34 Returns: 

35 Integer count of tokens where ``labels != -100``. 

36 """ 

37 labels = batch.get("labels") 

38 if labels is None: 

39 return 0 

40 return int((labels != -100).sum().item()) 

41 

42def mean_global_loss( 

43 loss: Any, 

44 micro_batch_tokens: int, 

45 total_tokens: int, 

46 fsdp_size: int, 

47) -> Any: 

48 """Compute token-weighted, globally normalised loss for one micro-batch. 

49 

50 Each micro-batch contributes a fraction proportional to how many of the 

51 total global tokens it contains. Multiplying by ``fsdp_size`` corrects 

52 for the fact that FSDP averages gradients across data-parallel ranks, 

53 while token counts are *per-rank* (not global). 

54 

55 Formula:: 

56 

57 normalised_loss = raw_loss * (micro_tokens / global_tokens) * fsdp_size 

58 

59 Args: 

60 loss: Raw loss scalar returned by the model (may be a DTensor partial). 

61 micro_batch_tokens: Number of non-padding tokens in this micro-batch. 

62 total_tokens: Total non-padding tokens across **all** micro-batches and 

63 all data-parallel ranks in this global step. 

64 fsdp_size: Number of data-parallel (FSDP) ranks participating in 

65 gradient reduction. 

66 

67 Returns: 

68 Scaled loss with the same type as ``loss``. If ``total_tokens`` is 

69 zero, returns ``loss`` unchanged to avoid division by zero. 

70 

71 Raises: 

72 ValueError: If ``fsdp_size`` is not a positive integer. 

73 """ 

74 if fsdp_size <= 0: 

75 raise ValueError(f"fsdp_size must be a positive integer, got {fsdp_size}") 

76 if total_tokens <= 0: 

77 return loss 

78 return loss * (micro_batch_tokens / total_tokens) * fsdp_size