Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / vl_trainer.py: 0%

145 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""VLTrainer for native Qwen3-VL multimodal training.""" 

16import logging 

17from typing import Any, Dict, List 

18 

19import torch 

20from torch.utils.data import Dataset 

21 

22from hyper_parallel.trainer.base import BaseTrainer 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class VLTrainer: 

28 """Trainer for multimodal Qwen3-VL training (text + image/video). 

29 

30 Two data paths are supported: 

31 - ``data.type = "vl_dummy"``: deterministic synthetic multimodal tensors, 

32 useful for quick smoke tests without dataset preparation. 

33 - ``data.type = "preset_pt"``: replays pre-tokenized batches that already 

34 include ``pixel_values`` and ``image_grid_thw``. 

35 """ 

36 

37 def __init__(self, args): 

38 self.base = BaseTrainer(args) 

39 self.base._setup() 

40 self.base._build_model() 

41 self.base._freeze_model() 

42 self._build_model_assets() 

43 self._build_data_transform() 

44 self._build_dataset() 

45 self._build_collate_fn() 

46 self.base._build_dataloader() 

47 self.base._build_parallelized_model() 

48 self.base._build_optimizer() 

49 self.base._build_lr_scheduler() 

50 self.base._build_training_context() 

51 self.base._init_callbacks() 

52 self.base.on_init_end() 

53 

54 def _build_model_assets(self): 

55 """Load processor when a real VL dataset is configured.""" 

56 self.base.processor = None 

57 self.base.tokenizer = None 

58 data_type = getattr(self.base.args.data, "type", "vl_dummy") 

59 if data_type == "vl_dummy": 

60 return 

61 processor_path = ( 

62 getattr(self.base.args.data, "processor_path", None) 

63 or getattr(self.base.args.model, "tokenizer_path", None) 

64 or getattr(self.base.args.model, "weights_path", None) 

65 ) 

66 if not processor_path: 

67 raise ValueError("VL real-data mode requires data.processor_path or model.weights_path") 

68 from transformers import AutoProcessor # pylint: disable=C0415 

69 

70 self.base.processor = AutoProcessor.from_pretrained( 

71 processor_path, trust_remote_code=True, 

72 ) 

73 self.base.tokenizer = getattr(self.base.processor, "tokenizer", None) 

74 logger.info("Processor loaded from %s", processor_path) 

75 

76 def _build_data_transform(self): 

77 self.base.data_transform = None 

78 

79 def _build_dataset(self): 

80 

81 """Build dataset (internal).""" 

82 data_type = getattr(self.base.args.data, "type", "vl_dummy") 

83 if data_type == "preset_pt": 

84 self._build_preset_pt_dataset() 

85 return 

86 if data_type != "vl_dummy": 

87 raise NotImplementedError( 

88 f"VL trainer supports data.type 'vl_dummy' or 'preset_pt', got '{data_type}'" 

89 ) 

90 

91 max_steps = getattr(self.base.args.train, "max_steps", 1) 

92 global_bs = getattr(self.base.args.train, "global_batch_size", 1) 

93 total_samples = max_steps * global_bs 

94 data_cfg = self.base.args.data 

95 model_cfg = self.base.args.model 

96 extra = getattr(model_cfg, "config_overrides", None) or {} 

97 vision_extra = extra.get("vision_config", {}) if isinstance(extra, dict) else {} 

98 patch_size = int(vision_extra.get("patch_size", 16)) 

99 temporal_patch_size = int(vision_extra.get("temporal_patch_size", 2)) 

100 in_channels = int(vision_extra.get("in_channels", 3)) 

101 spatial_merge = int(vision_extra.get("spatial_merge_size", 2)) 

102 grid_t = int(getattr(data_cfg, "vl_grid_t", 2)) 

103 grid_h = int(getattr(data_cfg, "vl_grid_h", 2)) 

104 grid_w = int(getattr(data_cfg, "vl_grid_w", 2)) 

105 image_token_id = int(getattr(data_cfg, "image_token_id", 151655)) 

106 image_tokens = grid_t * grid_h * grid_w // (spatial_merge ** 2) 

107 row_width = in_channels * temporal_patch_size * patch_size * patch_size 

108 seq_len = max(int(getattr(data_cfg, "max_seq_len", 16)), image_tokens + 4) 

109 base_seed = int(getattr(self.base.args, "seed", 42)) 

110 

111 class DeterministicVLDataset(Dataset): 

112 """Deterministic synthetic VL dataset for smoke / FSDP regression.""" 

113 

114 def __len__(self): 

115 return total_samples 

116 

117 def __getitem__(self, idx): 

118 g = torch.Generator().manual_seed(base_seed + idx) 

119 pixel_values = torch.randn( 

120 grid_t * grid_h * grid_w, row_width, generator=g, 

121 dtype=torch.float32, 

122 ) 

123 input_ids = torch.full((seq_len,), 100, dtype=torch.long) 

124 input_ids[0] = 151643 

125 input_ids[1: 1 + image_tokens] = image_token_id 

126 tail = torch.arange( 

127 200 + idx % 17, 

128 200 + idx % 17 + seq_len - 1 - image_tokens, 

129 dtype=torch.long, 

130 ) 

131 input_ids[1 + image_tokens:] = tail 

132 labels = input_ids.clone() 

133 mm_token_type_ids = torch.zeros(seq_len, dtype=torch.int32) 

134 mm_token_type_ids[1: 1 + image_tokens] = 1 

135 return { 

136 "input_ids": input_ids, 

137 "labels": labels, 

138 "attention_mask": torch.ones(seq_len, dtype=torch.long), 

139 "mm_token_type_ids": mm_token_type_ids, 

140 "pixel_values": pixel_values, 

141 "image_grid_thw": torch.tensor([grid_t, grid_h, grid_w], dtype=torch.long), 

142 } 

143 

144 self.base.train_dataset = DeterministicVLDataset() 

145 self.base.state.max_steps = max_steps 

146 logger.info_rank0( 

147 "VL dummy dataset created: samples=%d seq_len=%d grid=(%d,%d,%d) image_tokens=%d", 

148 total_samples, seq_len, grid_t, grid_h, grid_w, image_tokens, 

149 ) 

150 

151 def _build_preset_pt_dataset(self): 

152 """Replay pre-tokenized VL batches from a .pt file. 

153 

154 Each entry is a dict of tensors (input_ids/labels/attention_mask plus 

155 VL fields pixel_values/image_grid_thw/mm_token_type_ids). Sample 

156 layout follows the same convention as ``llm_trainer._build_preset_pt_dataset``. 

157 """ 

158 # pylint: disable=C0415 

159 

160 train_path = self.base.args.data.train_path 

161 if not train_path: 

162 raise ValueError("data.train_path is required when data.type='preset_pt'") 

163 batches = torch.load(train_path, map_location="cpu", weights_only=False) 

164 if not isinstance(batches, list) or not batches: 

165 raise ValueError(f"preset_pt expects List, got {type(batches)}") 

166 

167 def _expand_dict(b: Dict[str, Any]) -> List[Dict[str, Any]]: 

168 ids = b["input_ids"] 

169 labels = b["labels"] 

170 out: List[Dict[str, Any]] = [] 

171 for i in range(ids.shape[0]): 

172 rec = {"input_ids": ids[i].clone(), "labels": labels[i].clone()} 

173 for k in ("attention_mask", "mm_token_type_ids"): 

174 v = b.get(k) 

175 if v is not None and v.dim() >= 2: 

176 rec[k] = v[i].clone() 

177 if b.get("pixel_values") is not None and b.get("image_grid_thw") is not None: 

178 pv = b["pixel_values"] 

179 thw = b["image_grid_thw"] 

180 grids_per_sample = thw.shape[0] // ids.shape[0] if thw.dim() == 2 else 0 

181 if grids_per_sample > 0: 

182 thw_i = thw[i * grids_per_sample:(i + 1) * grids_per_sample].clone() 

183 pv_count = int(thw_i.prod(dim=-1).sum().item()) 

184 offset = sum( 

185 int(thw[j].prod(dim=-1).sum().item()) 

186 for j in range(i * grids_per_sample) 

187 ) 

188 rec["pixel_values"] = pv[offset:offset + pv_count].clone() 

189 rec["image_grid_thw"] = thw_i 

190 out.append(rec) 

191 return out 

192 

193 per_sample = [] 

194 for b in batches: 

195 if isinstance(b, list): 

196 for br in b: 

197 per_sample.extend(_expand_dict(br)) 

198 else: 

199 per_sample.extend(_expand_dict(b)) 

200 

201 class PresetPtVLDataset(Dataset): 

202 def __init__(self, samples): 

203 self.samples = samples 

204 def __len__(self): 

205 return len(self.samples) 

206 def __getitem__(self, idx): 

207 return self.samples[idx] 

208 

209 self.base.train_dataset = PresetPtVLDataset(per_sample) 

210 max_steps = getattr(self.base.args.train, "max_steps", None) 

211 if max_steps: 

212 self.base.state.max_steps = int(max_steps) 

213 logger.info_rank0( 

214 "preset_pt VL dataset: %d samples loaded from %s", len(per_sample), train_path, 

215 ) 

216 

217 def _build_collate_fn(self): 

218 

219 """Build collate fn (internal).""" 

220 def _vl_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]: 

221 out = { 

222 "input_ids": torch.stack([x["input_ids"] for x in batch]), 

223 "labels": torch.stack([x["labels"] for x in batch]), 

224 "attention_mask": torch.stack([x["attention_mask"] for x in batch]), 

225 } 

226 if "mm_token_type_ids" in batch[0]: 

227 out["mm_token_type_ids"] = torch.stack([x["mm_token_type_ids"] for x in batch]) 

228 if "pixel_values" in batch[0] and batch[0].get("pixel_values") is not None: 

229 out["pixel_values"] = torch.cat([x["pixel_values"] for x in batch], dim=0) 

230 if "image_grid_thw" in batch[0] and batch[0].get("image_grid_thw") is not None: 

231 grids = [x["image_grid_thw"] for x in batch] 

232 if grids[0].dim() == 1: 

233 out["image_grid_thw"] = torch.stack(grids) 

234 else: 

235 out["image_grid_thw"] = torch.cat(grids, dim=0) 

236 return out 

237 

238 self.base.collate_fn = _vl_collate 

239 

240 def train(self): 

241 return self.base.train()