Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / trainer / vl_trainer.py: 0%
145 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""VLTrainer for native Qwen3-VL multimodal training."""
16import logging
17from typing import Any, Dict, List
19import torch
20from torch.utils.data import Dataset
22from hyper_parallel.trainer.base import BaseTrainer
24logger = logging.getLogger(__name__)
27class VLTrainer:
28 """Trainer for multimodal Qwen3-VL training (text + image/video).
30 Two data paths are supported:
31 - ``data.type = "vl_dummy"``: deterministic synthetic multimodal tensors,
32 useful for quick smoke tests without dataset preparation.
33 - ``data.type = "preset_pt"``: replays pre-tokenized batches that already
34 include ``pixel_values`` and ``image_grid_thw``.
35 """
37 def __init__(self, args):
38 self.base = BaseTrainer(args)
39 self.base._setup()
40 self.base._build_model()
41 self.base._freeze_model()
42 self._build_model_assets()
43 self._build_data_transform()
44 self._build_dataset()
45 self._build_collate_fn()
46 self.base._build_dataloader()
47 self.base._build_parallelized_model()
48 self.base._build_optimizer()
49 self.base._build_lr_scheduler()
50 self.base._build_training_context()
51 self.base._init_callbacks()
52 self.base.on_init_end()
54 def _build_model_assets(self):
55 """Load processor when a real VL dataset is configured."""
56 self.base.processor = None
57 self.base.tokenizer = None
58 data_type = getattr(self.base.args.data, "type", "vl_dummy")
59 if data_type == "vl_dummy":
60 return
61 processor_path = (
62 getattr(self.base.args.data, "processor_path", None)
63 or getattr(self.base.args.model, "tokenizer_path", None)
64 or getattr(self.base.args.model, "weights_path", None)
65 )
66 if not processor_path:
67 raise ValueError("VL real-data mode requires data.processor_path or model.weights_path")
68 from transformers import AutoProcessor # pylint: disable=C0415
70 self.base.processor = AutoProcessor.from_pretrained(
71 processor_path, trust_remote_code=True,
72 )
73 self.base.tokenizer = getattr(self.base.processor, "tokenizer", None)
74 logger.info("Processor loaded from %s", processor_path)
76 def _build_data_transform(self):
77 self.base.data_transform = None
79 def _build_dataset(self):
81 """Build dataset (internal)."""
82 data_type = getattr(self.base.args.data, "type", "vl_dummy")
83 if data_type == "preset_pt":
84 self._build_preset_pt_dataset()
85 return
86 if data_type != "vl_dummy":
87 raise NotImplementedError(
88 f"VL trainer supports data.type 'vl_dummy' or 'preset_pt', got '{data_type}'"
89 )
91 max_steps = getattr(self.base.args.train, "max_steps", 1)
92 global_bs = getattr(self.base.args.train, "global_batch_size", 1)
93 total_samples = max_steps * global_bs
94 data_cfg = self.base.args.data
95 model_cfg = self.base.args.model
96 extra = getattr(model_cfg, "config_overrides", None) or {}
97 vision_extra = extra.get("vision_config", {}) if isinstance(extra, dict) else {}
98 patch_size = int(vision_extra.get("patch_size", 16))
99 temporal_patch_size = int(vision_extra.get("temporal_patch_size", 2))
100 in_channels = int(vision_extra.get("in_channels", 3))
101 spatial_merge = int(vision_extra.get("spatial_merge_size", 2))
102 grid_t = int(getattr(data_cfg, "vl_grid_t", 2))
103 grid_h = int(getattr(data_cfg, "vl_grid_h", 2))
104 grid_w = int(getattr(data_cfg, "vl_grid_w", 2))
105 image_token_id = int(getattr(data_cfg, "image_token_id", 151655))
106 image_tokens = grid_t * grid_h * grid_w // (spatial_merge ** 2)
107 row_width = in_channels * temporal_patch_size * patch_size * patch_size
108 seq_len = max(int(getattr(data_cfg, "max_seq_len", 16)), image_tokens + 4)
109 base_seed = int(getattr(self.base.args, "seed", 42))
111 class DeterministicVLDataset(Dataset):
112 """Deterministic synthetic VL dataset for smoke / FSDP regression."""
114 def __len__(self):
115 return total_samples
117 def __getitem__(self, idx):
118 g = torch.Generator().manual_seed(base_seed + idx)
119 pixel_values = torch.randn(
120 grid_t * grid_h * grid_w, row_width, generator=g,
121 dtype=torch.float32,
122 )
123 input_ids = torch.full((seq_len,), 100, dtype=torch.long)
124 input_ids[0] = 151643
125 input_ids[1: 1 + image_tokens] = image_token_id
126 tail = torch.arange(
127 200 + idx % 17,
128 200 + idx % 17 + seq_len - 1 - image_tokens,
129 dtype=torch.long,
130 )
131 input_ids[1 + image_tokens:] = tail
132 labels = input_ids.clone()
133 mm_token_type_ids = torch.zeros(seq_len, dtype=torch.int32)
134 mm_token_type_ids[1: 1 + image_tokens] = 1
135 return {
136 "input_ids": input_ids,
137 "labels": labels,
138 "attention_mask": torch.ones(seq_len, dtype=torch.long),
139 "mm_token_type_ids": mm_token_type_ids,
140 "pixel_values": pixel_values,
141 "image_grid_thw": torch.tensor([grid_t, grid_h, grid_w], dtype=torch.long),
142 }
144 self.base.train_dataset = DeterministicVLDataset()
145 self.base.state.max_steps = max_steps
146 logger.info_rank0(
147 "VL dummy dataset created: samples=%d seq_len=%d grid=(%d,%d,%d) image_tokens=%d",
148 total_samples, seq_len, grid_t, grid_h, grid_w, image_tokens,
149 )
151 def _build_preset_pt_dataset(self):
152 """Replay pre-tokenized VL batches from a .pt file.
154 Each entry is a dict of tensors (input_ids/labels/attention_mask plus
155 VL fields pixel_values/image_grid_thw/mm_token_type_ids). Sample
156 layout follows the same convention as ``llm_trainer._build_preset_pt_dataset``.
157 """
158 # pylint: disable=C0415
160 train_path = self.base.args.data.train_path
161 if not train_path:
162 raise ValueError("data.train_path is required when data.type='preset_pt'")
163 batches = torch.load(train_path, map_location="cpu", weights_only=False)
164 if not isinstance(batches, list) or not batches:
165 raise ValueError(f"preset_pt expects List, got {type(batches)}")
167 def _expand_dict(b: Dict[str, Any]) -> List[Dict[str, Any]]:
168 ids = b["input_ids"]
169 labels = b["labels"]
170 out: List[Dict[str, Any]] = []
171 for i in range(ids.shape[0]):
172 rec = {"input_ids": ids[i].clone(), "labels": labels[i].clone()}
173 for k in ("attention_mask", "mm_token_type_ids"):
174 v = b.get(k)
175 if v is not None and v.dim() >= 2:
176 rec[k] = v[i].clone()
177 if b.get("pixel_values") is not None and b.get("image_grid_thw") is not None:
178 pv = b["pixel_values"]
179 thw = b["image_grid_thw"]
180 grids_per_sample = thw.shape[0] // ids.shape[0] if thw.dim() == 2 else 0
181 if grids_per_sample > 0:
182 thw_i = thw[i * grids_per_sample:(i + 1) * grids_per_sample].clone()
183 pv_count = int(thw_i.prod(dim=-1).sum().item())
184 offset = sum(
185 int(thw[j].prod(dim=-1).sum().item())
186 for j in range(i * grids_per_sample)
187 )
188 rec["pixel_values"] = pv[offset:offset + pv_count].clone()
189 rec["image_grid_thw"] = thw_i
190 out.append(rec)
191 return out
193 per_sample = []
194 for b in batches:
195 if isinstance(b, list):
196 for br in b:
197 per_sample.extend(_expand_dict(br))
198 else:
199 per_sample.extend(_expand_dict(b))
201 class PresetPtVLDataset(Dataset):
202 def __init__(self, samples):
203 self.samples = samples
204 def __len__(self):
205 return len(self.samples)
206 def __getitem__(self, idx):
207 return self.samples[idx]
209 self.base.train_dataset = PresetPtVLDataset(per_sample)
210 max_steps = getattr(self.base.args.train, "max_steps", None)
211 if max_steps:
212 self.base.state.max_steps = int(max_steps)
213 logger.info_rank0(
214 "preset_pt VL dataset: %d samples loaded from %s", len(per_sample), train_path,
215 )
217 def _build_collate_fn(self):
219 """Build collate fn (internal)."""
220 def _vl_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
221 out = {
222 "input_ids": torch.stack([x["input_ids"] for x in batch]),
223 "labels": torch.stack([x["labels"] for x in batch]),
224 "attention_mask": torch.stack([x["attention_mask"] for x in batch]),
225 }
226 if "mm_token_type_ids" in batch[0]:
227 out["mm_token_type_ids"] = torch.stack([x["mm_token_type_ids"] for x in batch])
228 if "pixel_values" in batch[0] and batch[0].get("pixel_values") is not None:
229 out["pixel_values"] = torch.cat([x["pixel_values"] for x in batch], dim=0)
230 if "image_grid_thw" in batch[0] and batch[0].get("image_grid_thw") is not None:
231 grids = [x["image_grid_thw"] for x in batch]
232 if grids[0].dim() == 1:
233 out["image_grid_thw"] = torch.stack(grids)
234 else:
235 out["image_grid_thw"] = torch.cat(grids, dim=0)
236 return out
238 self.base.collate_fn = _vl_collate
240 def train(self):
241 return self.base.train()