Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / llm_trainer.py: 0%

197 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""LLMTrainer — Language Model pretraining and SFT. 

16 

17holds a ``BaseTrainer`` instance and calls 

18its ``_build_*`` methods selectively. Overrides ``_build_model_assets``, 

19``_build_data_transform``, ``_build_dataset``, and ``_build_collate_fn`` 

20for complete real-data training pipeline. 

21 

22""" 

23import logging 

24from typing import Any, Dict, List 

25 

26import torch 

27from torch.utils.data import Dataset 

28 

29from hyper_parallel.trainer.base import BaseTrainer 

30 

31logger = logging.getLogger(__name__) 

32 

33class LLMTrainer: 

34 """Trainer for LM pretraining and SFT. 

35 

36 Composition pattern — calls BaseTrainer's _build_* methods in order, 

37 overriding data pipeline steps for real tokenized data. 

38 

39 Supports: 

40 - ``data.type = "dummy"``: random tokens for quick FSDP validation 

41 - ``data.type = "hf_datasets"``: real HuggingFace datasets with tokenization 

42 

43 Args: 

44 args: Training configuration parsed from YAML. 

45 """ 

46 

47 def __init__(self, args): 

48 self.base = BaseTrainer(args) 

49 

50 # 13 steps — call base's methods, override where needed 

51 self.base._setup() 

52 self.base._build_model() 

53 self.base._freeze_model() 

54 self._build_model_assets() # 覆盖: 加载 tokenizer 

55 self._build_data_transform() # 覆盖: tokenize 函数 

56 self._build_dataset() # 覆盖: 真实数据集加载 + tokenize 

57 self._build_collate_fn() # 覆盖: 支持变长 padding 

58 self.base._build_dataloader() 

59 self.base._build_parallelized_model() 

60 self.base._build_optimizer() 

61 self.base._build_lr_scheduler() 

62 self.base._build_training_context() 

63 self.base._init_callbacks() 

64 # Fire one-shot ``on_init_end`` AFTER every ``_build_*`` — this is 

65 # the canonical "trainer is fully built" lifecycle hook. 

66 self.base.on_init_end() 

67 

68 # ------------------------------------------------------------------ 

69 # Overridden _build_* methods 

70 # ------------------------------------------------------------------ 

71 

72 def _build_model_assets(self): 

73 """Build tokenizer for data processing. 

74 

75 For dummy data, tokenizer is not needed. 

76 For real data, loads HF AutoTokenizer from ``model.weights_path`` 

77 or ``model.tokenizer_path``. 

78 """ 

79 data_type = getattr(self.base.args.data, 'type', 'dummy') 

80 if data_type == 'dummy': 

81 self.base.tokenizer = None 

82 return 

83 

84 # Try tokenizer_path first, fall back to weights_path 

85 model_cfg = self.base.args.model 

86 tokenizer_path = getattr(model_cfg, 'tokenizer_path', None) 

87 if not tokenizer_path: 

88 tokenizer_path = getattr(model_cfg, 'weights_path', None) 

89 

90 if not tokenizer_path: 

91 raise ValueError( 

92 "data.type='hf_datasets' requires model.tokenizer_path or " 

93 "model.weights_path to load tokenizer." 

94 ) 

95 

96 from transformers import AutoTokenizer # pylint: disable=C0415 # optional dep 

97 self.base.tokenizer = AutoTokenizer.from_pretrained( 

98 tokenizer_path, trust_remote_code=True 

99 ) 

100 # Ensure pad token exists 

101 if self.base.tokenizer.pad_token is None: 

102 self.base.tokenizer.pad_token = self.base.tokenizer.eos_token 

103 logger.info("Tokenizer loaded: %s (vocab=%d)", 

104 tokenizer_path, len(self.base.tokenizer)) 

105 

106 def _build_data_transform(self): 

107 """Build tokenization transform. 

108 

109 Creates a function that tokenizes raw text into input_ids + labels. 

110 Labels are a copy of input_ids (causal LM: predict next token). 

111 Prompt tokens can be masked with -100 for SFT. 

112 """ 

113 if self.base.tokenizer is None: 

114 self.base.data_transform = None 

115 return 

116 

117 max_seq_len = getattr(self.base.args.data, 'max_seq_len', 2048) 

118 tokenizer = self.base.tokenizer 

119 text_key = getattr(self.base.args.data, 'text_key', 'text') 

120 data_type = getattr(self.base.args.data, 'type', 'dummy') 

121 template = getattr(self.base.args.data, 'template', 'empty') 

122 

123 def _tokenize_fn(examples): 

124 """Tokenize text and create causal LM labels. 

125 

126 Supports: 

127 - Plain text (text_key field) 

128 - Alpaca format (instruction/input/output) 

129 """ 

130 # SFT label masking: prompt tokens → IGNORE_INDEX, response 

131 # tokens kept. Truncation prioritises the response side. 

132 ignore_index = -100 

133 

134 def _infer_seqlen(s_len, t_len, cutoff): 

135 if t_len * 2 < cutoff: 

136 max_t = cutoff 

137 elif s_len * 2 < cutoff: 

138 max_t = cutoff - s_len 

139 else: 

140 max_t = int(cutoff * (t_len / (s_len + t_len))) 

141 new_t = min(max_t, t_len) 

142 max_s = max(cutoff - new_t, 0) 

143 new_s = min(max_s, s_len) 

144 return new_s, new_t 

145 

146 if "instruction" in examples and data_type == "json_file" and template == "empty": 

147 instructions = examples["instruction"] 

148 inputs = examples.get("input", [""] * len(instructions)) 

149 outputs = examples["output"] 

150 result = {"input_ids": [], "labels": []} 

151 for inst, inp, out in zip(instructions, inputs, outputs): 

152 prompt_text = inst + (("\n" + inp) if inp else "") 

153 prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] 

154 response_ids = tokenizer(out, add_special_tokens=False)["input_ids"] 

155 s_len, t_len = _infer_seqlen(len(prompt_ids), len(response_ids), max_seq_len) 

156 prompt_ids = prompt_ids[:s_len] 

157 response_ids = response_ids[:t_len] 

158 ids = prompt_ids + response_ids 

159 labels = [ignore_index] * len(prompt_ids) + list(response_ids) 

160 if len(ids) > 0: 

161 result["input_ids"].append(ids) 

162 result["labels"].append(labels) 

163 return result 

164 

165 if "instruction" in examples and data_type == "json_file": 

166 # Alpaca format with chat-style template (legacy default) 

167 instructions = examples["instruction"] 

168 inputs = examples.get("input", [""] * len(instructions)) 

169 outputs = examples["output"] 

170 texts = [] 

171 for inst, inp, out in zip(instructions, inputs, outputs): 

172 if inp: 

173 texts.append(f"Human: {inst}\n{inp}\n\nAssistant: {out}") 

174 else: 

175 texts.append(f"Human: {inst}\n\nAssistant: {out}") 

176 else: 

177 # Plain text format 

178 texts = examples[text_key] 

179 if isinstance(texts, str): 

180 texts = [texts] 

181 

182 tokenized = tokenizer( 

183 texts, 

184 truncation=True, 

185 max_length=max_seq_len, 

186 padding=False, 

187 return_attention_mask=False, 

188 ) 

189 

190 result = {"input_ids": [], "labels": []} 

191 for ids in tokenized["input_ids"]: 

192 if len(ids) > 0: 

193 result["input_ids"].append(ids) 

194 result["labels"].append(ids.copy()) 

195 

196 return result 

197 

198 self.base.data_transform = _tokenize_fn 

199 logger.info("Data transform: tokenize max_seq_len=%d, format=%s", 

200 max_seq_len, "alpaca" if data_type == "json_file" else text_key) 

201 

202 def _build_dataset(self): 

203 """Build training dataset with full tokenization pipeline. 

204 

205 For dummy data, delegates to BaseTrainer. 

206 For hf_datasets, loads + tokenizes + filters empty examples. 

207 """ 

208 data_type = getattr(self.base.args.data, 'type', 'dummy') 

209 

210 if data_type == 'dummy': 

211 self.base._build_dataset() # pylint: disable=protected-access 

212 return 

213 

214 if data_type == 'preset_pt': 

215 self._build_preset_pt_dataset() 

216 return 

217 

218 if data_type not in ('hf_datasets', 'json_file'): 

219 raise ValueError( 

220 f"LLMTrainer supports data.type 'dummy', 'hf_datasets', 'json_file', or 'preset_pt', " 

221 f"got '{data_type}'" 

222 ) 

223 

224 # pylint: disable=C0415 

225 from datasets import load_dataset # pylint: disable=C0415 # optional dep 

226 

227 train_path = self.base.args.data.train_path 

228 data_subset = getattr(self.base.args.data, 'subset', None) 

229 

230 logger.info("Loading dataset: type=%s, path=%s", data_type, train_path) 

231 

232 if data_type == 'json_file': 

233 # Load local JSON file (alpaca format: instruction/input/output) 

234 ds = load_dataset("json", data_files=train_path, split="train") 

235 elif data_subset: 

236 ds = load_dataset(train_path, data_subset, split="train") 

237 else: 

238 ds = load_dataset(train_path, split="train") 

239 

240 # Limit dataset size if specified 

241 train_size = getattr(self.base.args.data, 'train_size', None) 

242 if train_size and train_size < len(ds): 

243 ds = ds.select(range(train_size)) 

244 logger.info("Dataset truncated to %d samples", train_size) 

245 

246 # Tokenize 

247 if self.base.data_transform: 

248 ds = ds.map( 

249 self.base.data_transform, 

250 batched=True, 

251 remove_columns=ds.column_names, 

252 desc="Tokenizing", 

253 ) 

254 

255 # Filter empty sequences 

256 ds = ds.filter(lambda x: len(x["input_ids"]) > 0) 

257 

258 # Convert to torch tensors 

259 class TokenizedDataset(torch.utils.data.Dataset): 

260 """Wrap HF dataset for torch DataLoader.""" 

261 def __init__(self, hf_ds): 

262 self.data = hf_ds 

263 def __len__(self): 

264 return len(self.data) 

265 def __getitem__(self, idx): 

266 item = self.data[idx] 

267 return { 

268 "input_ids": torch.tensor(item["input_ids"], dtype=torch.long), 

269 "labels": torch.tensor(item["labels"], dtype=torch.long), 

270 } 

271 

272 self.base.train_dataset = TokenizedDataset(ds) 

273 self.base.state.max_steps = min( 

274 self.base.args.train.max_steps, 

275 len(self.base.train_dataset) // max( 

276 self.base.args.train.global_batch_size, 1 

277 ), 

278 ) 

279 logger.info("Dataset ready: %d samples, max_steps=%d", 

280 len(self.base.train_dataset), self.base.state.max_steps) 

281 

282 def _build_preset_pt_dataset(self): 

283 """Load pre-tokenized batches from a .pt file (List[Dict[str, Tensor]]). 

284 

285 ``data.train_path`` is the .pt file. Each entry is a dict of tensors 

286 with shape ``(global_batch, seq_len)``. The dataset returns a flat 

287 sequence of per-sample dicts so the standard DataLoader can batch them. 

288 Use this when the dataset has already been tokenized offline and the 

289 token stream should be replayed deterministically. 

290 """ 

291 # pylint: disable=C0415 

292 

293 train_path = self.base.args.data.train_path 

294 if not train_path: 

295 raise ValueError("data.train_path is required when data.type='preset_pt'") 

296 batches = torch.load(train_path, map_location="cpu", weights_only=False) 

297 if not isinstance(batches, list) or not batches: 

298 raise ValueError(f"preset_pt expects List, got {type(batches)}") 

299 per_sample = [] 

300 for b in batches: 

301 # Two formats: stacked dict ``{input_ids: (B,S), labels: (B,S)}``, 

302 # or list of per-rank dicts (preserves per-rank dynamic seq_len). 

303 if isinstance(b, list): 

304 for br in b: 

305 ids = br["input_ids"] 

306 labels = br["labels"] 

307 attn = br.get("attention_mask") 

308 for i in range(ids.shape[0]): 

309 rec = { 

310 "input_ids": ids[i].clone(), 

311 "labels": labels[i].clone(), 

312 } 

313 if attn is not None and attn.dim() == 2: 

314 rec["attention_mask"] = attn[i].clone() 

315 per_sample.append(rec) 

316 else: 

317 ids = b["input_ids"] 

318 labels = b["labels"] 

319 attn = b.get("attention_mask") 

320 for i in range(ids.shape[0]): 

321 rec = { 

322 "input_ids": ids[i].clone(), 

323 "labels": labels[i].clone(), 

324 } 

325 if attn is not None and attn.dim() == 2: 

326 rec["attention_mask"] = attn[i].clone() 

327 per_sample.append(rec) 

328 

329 class PresetPtDataset(Dataset): 

330 def __init__(self, samples): 

331 self.samples = samples 

332 def __len__(self): 

333 return len(self.samples) 

334 def __getitem__(self, idx): 

335 return self.samples[idx] 

336 

337 self.base.train_dataset = PresetPtDataset(per_sample) 

338 max_steps = getattr(self.base.args.train, "max_steps", None) 

339 if max_steps: 

340 self.base.state.max_steps = int(max_steps) 

341 logger.info("preset_pt dataset: %d samples loaded from %s", len(per_sample), train_path) 

342 

343 def _build_collate_fn(self): 

344 """Build collator with proper padding. 

345 

346 Pads input_ids with pad_token_id (or 0) and labels with -100. 

347 """ 

348 

349 pad_id = 0 

350 if self.base.tokenizer and self.base.tokenizer.pad_token_id is not None: 

351 pad_id = self.base.tokenizer.pad_token_id 

352 

353 def _lm_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]: 

354 """Pad sequences to max length in batch.""" 

355 max_len = max(item["input_ids"].size(0) for item in batch) 

356 input_ids_list = [] 

357 labels_list = [] 

358 

359 for item in batch: 

360 seq_len = item["input_ids"].size(0) 

361 pad_len = max_len - seq_len 

362 

363 if pad_len > 0: 

364 input_ids_list.append( 

365 torch.cat([item["input_ids"], 

366 torch.full((pad_len,), pad_id, dtype=torch.long)]) 

367 ) 

368 labels_list.append( 

369 torch.cat([item["labels"], 

370 torch.full((pad_len,), -100, dtype=torch.long)]) 

371 ) 

372 else: 

373 input_ids_list.append(item["input_ids"]) 

374 labels_list.append(item["labels"]) 

375 

376 return { 

377 "input_ids": torch.stack(input_ids_list), 

378 "labels": torch.stack(labels_list), 

379 } 

380 

381 self.base.collate_fn = _lm_collate 

382 

383 # ------------------------------------------------------------------ 

384 # Delegated methods 

385 # ------------------------------------------------------------------ 

386 

387 def train(self): 

388 """Delegate to BaseTrainer.train().""" 

389 self.base.train() 

390 

391 def train_step(self, data_iterator): 

392 """Delegate to BaseTrainer.train_step().""" 

393 return self.base.train_step(data_iterator)