Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / llm_trainer.py: 0%
197 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""LLMTrainer — Language Model pretraining and SFT.
17holds a ``BaseTrainer`` instance and calls
18its ``_build_*`` methods selectively. Overrides ``_build_model_assets``,
19``_build_data_transform``, ``_build_dataset``, and ``_build_collate_fn``
20for complete real-data training pipeline.
22"""
23import logging
24from typing import Any, Dict, List
26import torch
27from torch.utils.data import Dataset
29from hyper_parallel.trainer.base import BaseTrainer
31logger = logging.getLogger(__name__)
33class LLMTrainer:
34 """Trainer for LM pretraining and SFT.
36 Composition pattern — calls BaseTrainer's _build_* methods in order,
37 overriding data pipeline steps for real tokenized data.
39 Supports:
40 - ``data.type = "dummy"``: random tokens for quick FSDP validation
41 - ``data.type = "hf_datasets"``: real HuggingFace datasets with tokenization
43 Args:
44 args: Training configuration parsed from YAML.
45 """
47 def __init__(self, args):
48 self.base = BaseTrainer(args)
50 # 13 steps — call base's methods, override where needed
51 self.base._setup()
52 self.base._build_model()
53 self.base._freeze_model()
54 self._build_model_assets() # 覆盖: 加载 tokenizer
55 self._build_data_transform() # 覆盖: tokenize 函数
56 self._build_dataset() # 覆盖: 真实数据集加载 + tokenize
57 self._build_collate_fn() # 覆盖: 支持变长 padding
58 self.base._build_dataloader()
59 self.base._build_parallelized_model()
60 self.base._build_optimizer()
61 self.base._build_lr_scheduler()
62 self.base._build_training_context()
63 self.base._init_callbacks()
64 # Fire one-shot ``on_init_end`` AFTER every ``_build_*`` — this is
65 # the canonical "trainer is fully built" lifecycle hook.
66 self.base.on_init_end()
68 # ------------------------------------------------------------------
69 # Overridden _build_* methods
70 # ------------------------------------------------------------------
72 def _build_model_assets(self):
73 """Build tokenizer for data processing.
75 For dummy data, tokenizer is not needed.
76 For real data, loads HF AutoTokenizer from ``model.weights_path``
77 or ``model.tokenizer_path``.
78 """
79 data_type = getattr(self.base.args.data, 'type', 'dummy')
80 if data_type == 'dummy':
81 self.base.tokenizer = None
82 return
84 # Try tokenizer_path first, fall back to weights_path
85 model_cfg = self.base.args.model
86 tokenizer_path = getattr(model_cfg, 'tokenizer_path', None)
87 if not tokenizer_path:
88 tokenizer_path = getattr(model_cfg, 'weights_path', None)
90 if not tokenizer_path:
91 raise ValueError(
92 "data.type='hf_datasets' requires model.tokenizer_path or "
93 "model.weights_path to load tokenizer."
94 )
96 from transformers import AutoTokenizer # pylint: disable=C0415 # optional dep
97 self.base.tokenizer = AutoTokenizer.from_pretrained(
98 tokenizer_path, trust_remote_code=True
99 )
100 # Ensure pad token exists
101 if self.base.tokenizer.pad_token is None:
102 self.base.tokenizer.pad_token = self.base.tokenizer.eos_token
103 logger.info("Tokenizer loaded: %s (vocab=%d)",
104 tokenizer_path, len(self.base.tokenizer))
106 def _build_data_transform(self):
107 """Build tokenization transform.
109 Creates a function that tokenizes raw text into input_ids + labels.
110 Labels are a copy of input_ids (causal LM: predict next token).
111 Prompt tokens can be masked with -100 for SFT.
112 """
113 if self.base.tokenizer is None:
114 self.base.data_transform = None
115 return
117 max_seq_len = getattr(self.base.args.data, 'max_seq_len', 2048)
118 tokenizer = self.base.tokenizer
119 text_key = getattr(self.base.args.data, 'text_key', 'text')
120 data_type = getattr(self.base.args.data, 'type', 'dummy')
121 template = getattr(self.base.args.data, 'template', 'empty')
123 def _tokenize_fn(examples):
124 """Tokenize text and create causal LM labels.
126 Supports:
127 - Plain text (text_key field)
128 - Alpaca format (instruction/input/output)
129 """
130 # SFT label masking: prompt tokens → IGNORE_INDEX, response
131 # tokens kept. Truncation prioritises the response side.
132 ignore_index = -100
134 def _infer_seqlen(s_len, t_len, cutoff):
135 if t_len * 2 < cutoff:
136 max_t = cutoff
137 elif s_len * 2 < cutoff:
138 max_t = cutoff - s_len
139 else:
140 max_t = int(cutoff * (t_len / (s_len + t_len)))
141 new_t = min(max_t, t_len)
142 max_s = max(cutoff - new_t, 0)
143 new_s = min(max_s, s_len)
144 return new_s, new_t
146 if "instruction" in examples and data_type == "json_file" and template == "empty":
147 instructions = examples["instruction"]
148 inputs = examples.get("input", [""] * len(instructions))
149 outputs = examples["output"]
150 result = {"input_ids": [], "labels": []}
151 for inst, inp, out in zip(instructions, inputs, outputs):
152 prompt_text = inst + (("\n" + inp) if inp else "")
153 prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
154 response_ids = tokenizer(out, add_special_tokens=False)["input_ids"]
155 s_len, t_len = _infer_seqlen(len(prompt_ids), len(response_ids), max_seq_len)
156 prompt_ids = prompt_ids[:s_len]
157 response_ids = response_ids[:t_len]
158 ids = prompt_ids + response_ids
159 labels = [ignore_index] * len(prompt_ids) + list(response_ids)
160 if len(ids) > 0:
161 result["input_ids"].append(ids)
162 result["labels"].append(labels)
163 return result
165 if "instruction" in examples and data_type == "json_file":
166 # Alpaca format with chat-style template (legacy default)
167 instructions = examples["instruction"]
168 inputs = examples.get("input", [""] * len(instructions))
169 outputs = examples["output"]
170 texts = []
171 for inst, inp, out in zip(instructions, inputs, outputs):
172 if inp:
173 texts.append(f"Human: {inst}\n{inp}\n\nAssistant: {out}")
174 else:
175 texts.append(f"Human: {inst}\n\nAssistant: {out}")
176 else:
177 # Plain text format
178 texts = examples[text_key]
179 if isinstance(texts, str):
180 texts = [texts]
182 tokenized = tokenizer(
183 texts,
184 truncation=True,
185 max_length=max_seq_len,
186 padding=False,
187 return_attention_mask=False,
188 )
190 result = {"input_ids": [], "labels": []}
191 for ids in tokenized["input_ids"]:
192 if len(ids) > 0:
193 result["input_ids"].append(ids)
194 result["labels"].append(ids.copy())
196 return result
198 self.base.data_transform = _tokenize_fn
199 logger.info("Data transform: tokenize max_seq_len=%d, format=%s",
200 max_seq_len, "alpaca" if data_type == "json_file" else text_key)
202 def _build_dataset(self):
203 """Build training dataset with full tokenization pipeline.
205 For dummy data, delegates to BaseTrainer.
206 For hf_datasets, loads + tokenizes + filters empty examples.
207 """
208 data_type = getattr(self.base.args.data, 'type', 'dummy')
210 if data_type == 'dummy':
211 self.base._build_dataset() # pylint: disable=protected-access
212 return
214 if data_type == 'preset_pt':
215 self._build_preset_pt_dataset()
216 return
218 if data_type not in ('hf_datasets', 'json_file'):
219 raise ValueError(
220 f"LLMTrainer supports data.type 'dummy', 'hf_datasets', 'json_file', or 'preset_pt', "
221 f"got '{data_type}'"
222 )
224 # pylint: disable=C0415
225 from datasets import load_dataset # pylint: disable=C0415 # optional dep
227 train_path = self.base.args.data.train_path
228 data_subset = getattr(self.base.args.data, 'subset', None)
230 logger.info("Loading dataset: type=%s, path=%s", data_type, train_path)
232 if data_type == 'json_file':
233 # Load local JSON file (alpaca format: instruction/input/output)
234 ds = load_dataset("json", data_files=train_path, split="train")
235 elif data_subset:
236 ds = load_dataset(train_path, data_subset, split="train")
237 else:
238 ds = load_dataset(train_path, split="train")
240 # Limit dataset size if specified
241 train_size = getattr(self.base.args.data, 'train_size', None)
242 if train_size and train_size < len(ds):
243 ds = ds.select(range(train_size))
244 logger.info("Dataset truncated to %d samples", train_size)
246 # Tokenize
247 if self.base.data_transform:
248 ds = ds.map(
249 self.base.data_transform,
250 batched=True,
251 remove_columns=ds.column_names,
252 desc="Tokenizing",
253 )
255 # Filter empty sequences
256 ds = ds.filter(lambda x: len(x["input_ids"]) > 0)
258 # Convert to torch tensors
259 class TokenizedDataset(torch.utils.data.Dataset):
260 """Wrap HF dataset for torch DataLoader."""
261 def __init__(self, hf_ds):
262 self.data = hf_ds
263 def __len__(self):
264 return len(self.data)
265 def __getitem__(self, idx):
266 item = self.data[idx]
267 return {
268 "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
269 "labels": torch.tensor(item["labels"], dtype=torch.long),
270 }
272 self.base.train_dataset = TokenizedDataset(ds)
273 self.base.state.max_steps = min(
274 self.base.args.train.max_steps,
275 len(self.base.train_dataset) // max(
276 self.base.args.train.global_batch_size, 1
277 ),
278 )
279 logger.info("Dataset ready: %d samples, max_steps=%d",
280 len(self.base.train_dataset), self.base.state.max_steps)
282 def _build_preset_pt_dataset(self):
283 """Load pre-tokenized batches from a .pt file (List[Dict[str, Tensor]]).
285 ``data.train_path`` is the .pt file. Each entry is a dict of tensors
286 with shape ``(global_batch, seq_len)``. The dataset returns a flat
287 sequence of per-sample dicts so the standard DataLoader can batch them.
288 Use this when the dataset has already been tokenized offline and the
289 token stream should be replayed deterministically.
290 """
291 # pylint: disable=C0415
293 train_path = self.base.args.data.train_path
294 if not train_path:
295 raise ValueError("data.train_path is required when data.type='preset_pt'")
296 batches = torch.load(train_path, map_location="cpu", weights_only=False)
297 if not isinstance(batches, list) or not batches:
298 raise ValueError(f"preset_pt expects List, got {type(batches)}")
299 per_sample = []
300 for b in batches:
301 # Two formats: stacked dict ``{input_ids: (B,S), labels: (B,S)}``,
302 # or list of per-rank dicts (preserves per-rank dynamic seq_len).
303 if isinstance(b, list):
304 for br in b:
305 ids = br["input_ids"]
306 labels = br["labels"]
307 attn = br.get("attention_mask")
308 for i in range(ids.shape[0]):
309 rec = {
310 "input_ids": ids[i].clone(),
311 "labels": labels[i].clone(),
312 }
313 if attn is not None and attn.dim() == 2:
314 rec["attention_mask"] = attn[i].clone()
315 per_sample.append(rec)
316 else:
317 ids = b["input_ids"]
318 labels = b["labels"]
319 attn = b.get("attention_mask")
320 for i in range(ids.shape[0]):
321 rec = {
322 "input_ids": ids[i].clone(),
323 "labels": labels[i].clone(),
324 }
325 if attn is not None and attn.dim() == 2:
326 rec["attention_mask"] = attn[i].clone()
327 per_sample.append(rec)
329 class PresetPtDataset(Dataset):
330 def __init__(self, samples):
331 self.samples = samples
332 def __len__(self):
333 return len(self.samples)
334 def __getitem__(self, idx):
335 return self.samples[idx]
337 self.base.train_dataset = PresetPtDataset(per_sample)
338 max_steps = getattr(self.base.args.train, "max_steps", None)
339 if max_steps:
340 self.base.state.max_steps = int(max_steps)
341 logger.info("preset_pt dataset: %d samples loaded from %s", len(per_sample), train_path)
343 def _build_collate_fn(self):
344 """Build collator with proper padding.
346 Pads input_ids with pad_token_id (or 0) and labels with -100.
347 """
349 pad_id = 0
350 if self.base.tokenizer and self.base.tokenizer.pad_token_id is not None:
351 pad_id = self.base.tokenizer.pad_token_id
353 def _lm_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
354 """Pad sequences to max length in batch."""
355 max_len = max(item["input_ids"].size(0) for item in batch)
356 input_ids_list = []
357 labels_list = []
359 for item in batch:
360 seq_len = item["input_ids"].size(0)
361 pad_len = max_len - seq_len
363 if pad_len > 0:
364 input_ids_list.append(
365 torch.cat([item["input_ids"],
366 torch.full((pad_len,), pad_id, dtype=torch.long)])
367 )
368 labels_list.append(
369 torch.cat([item["labels"],
370 torch.full((pad_len,), -100, dtype=torch.long)])
371 )
372 else:
373 input_ids_list.append(item["input_ids"])
374 labels_list.append(item["labels"])
376 return {
377 "input_ids": torch.stack(input_ids_list),
378 "labels": torch.stack(labels_list),
379 }
381 self.base.collate_fn = _lm_collate
383 # ------------------------------------------------------------------
384 # Delegated methods
385 # ------------------------------------------------------------------
387 def train(self):
388 """Delegate to BaseTrainer.train()."""
389 self.base.train()
391 def train_step(self, data_iterator):
392 """Delegate to BaseTrainer.train_step()."""
393 return self.base.train_step(data_iterator)