|
21 | 21 | "convert_sharegpt_data", |
22 | 22 | "convert_ultrachat_data", |
23 | 23 | "DataCollatorWithPadding", |
| 24 | + "VLMDataCollatorWithPadding", |
24 | 25 | ] |
25 | 26 |
|
26 | 27 |
|
@@ -100,38 +101,92 @@ def process_token_dict_to_mappings( |
100 | 101 | return d2t, t2d |
101 | 102 |
|
102 | 103 |
|
| 104 | +def paddingtensor(intensors, N): |
| 105 | + B, n, S = intensors.shape |
| 106 | + # padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype) |
| 107 | + padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype) |
| 108 | + outtensors = torch.cat((intensors, padding_tensor), dim=1) |
| 109 | + return outtensors |
| 110 | + |
| 111 | + |
| 112 | +def paddingtensor2D(intensors, N): |
| 113 | + B, n = intensors.shape |
| 114 | + padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype) |
| 115 | + outtensors = torch.cat((intensors, padding_tensor), dim=1) |
| 116 | + return outtensors |
| 117 | + |
| 118 | + |
| 119 | +def paddingtensor3D(tensor_list): |
| 120 | + max_h = max(tensor.shape[-2] for tensor in tensor_list) |
| 121 | + max_w = max(tensor.shape[-1] for tensor in tensor_list) |
| 122 | + out_tensor_list = [] |
| 123 | + for tensor in tensor_list: |
| 124 | + if tensor.ndim == 2: |
| 125 | + tensor = tensor.unsqueeze(0) |
| 126 | + b, h, w = tensor.shape |
| 127 | + outtensor = torch.zeros(b, max_h, max_w, dtype=tensor.dtype) |
| 128 | + outtensor[:, :h, :w] = tensor |
| 129 | + out_tensor_list.append(outtensor) |
| 130 | + return torch.cat(out_tensor_list) |
| 131 | + |
| 132 | + |
103 | 133 | class DataCollatorWithPadding: |
104 | | - def paddingtensor(self, intensors, N): |
105 | | - B, n, S = intensors.shape |
106 | | - # padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype) |
107 | | - padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype) |
108 | | - outtensors = torch.cat((intensors, padding_tensor), dim=1) |
109 | | - return outtensors |
110 | | - |
111 | | - def paddingtensor2D(self, intensors, N): |
112 | | - B, n = intensors.shape |
113 | | - padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype) |
114 | | - outtensors = torch.cat((intensors, padding_tensor), dim=1) |
115 | | - return outtensors |
116 | 134 |
|
117 | 135 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
118 | 136 | max_length = max(item["input_ids"].shape[1] for item in features) |
119 | 137 | batch_input_ids = torch.cat( |
120 | | - [self.paddingtensor2D(item["input_ids"], max_length) for item in features] |
| 138 | + [paddingtensor2D(item["input_ids"], max_length) for item in features] |
| 139 | + ) |
| 140 | + batch_attention_mask = torch.cat( |
| 141 | + [paddingtensor2D(item["attention_mask"], max_length) for item in features] |
| 142 | + ) |
| 143 | + batch_loss_mask = torch.cat( |
| 144 | + [paddingtensor2D(item["loss_mask"], max_length) for item in features] |
| 145 | + ) |
| 146 | + |
| 147 | + batch = { |
| 148 | + "input_ids": batch_input_ids, |
| 149 | + "attention_mask": batch_attention_mask, |
| 150 | + "loss_mask": batch_loss_mask, |
| 151 | + } |
| 152 | + return batch |
| 153 | + |
| 154 | + |
| 155 | +class VLMDataCollatorWithPadding: |
| 156 | + |
| 157 | + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| 158 | + max_length = max(item["input_ids"].shape[1] for item in features) |
| 159 | + batch_input_ids = torch.cat( |
| 160 | + [paddingtensor2D(item["input_ids"], max_length) for item in features] |
121 | 161 | ) |
122 | 162 | batch_attention_mask = torch.cat( |
123 | | - [ |
124 | | - self.paddingtensor2D(item["attention_mask"], max_length) |
125 | | - for item in features |
126 | | - ] |
| 163 | + [paddingtensor2D(item["attention_mask"], max_length) for item in features] |
127 | 164 | ) |
128 | 165 | batch_loss_mask = torch.cat( |
129 | | - [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] |
| 166 | + [paddingtensor2D(item["loss_mask"], max_length) for item in features] |
130 | 167 | ) |
131 | 168 |
|
132 | 169 | batch = { |
133 | 170 | "input_ids": batch_input_ids, |
134 | 171 | "attention_mask": batch_attention_mask, |
135 | 172 | "loss_mask": batch_loss_mask, |
136 | 173 | } |
| 174 | + |
| 175 | + if "pixel_values" in features[0]: |
| 176 | + batch["pixel_values"] = paddingtensor3D( |
| 177 | + [item["pixel_values"] for item in features] |
| 178 | + ) |
| 179 | + if "video_pixel_values" in features[0]: |
| 180 | + batch["video_pixel_values"] = paddingtensor3D( |
| 181 | + [item["video_pixel_values"] for item in features] |
| 182 | + ) |
| 183 | + if "image_grid_thw" in features[0]: |
| 184 | + batch["image_grid_thw"] = paddingtensor3D( |
| 185 | + [item["image_grid_thw"] for item in features] |
| 186 | + ) |
| 187 | + if "video_grid_thw" in features[0]: |
| 188 | + batch["video_grid_thw"] = paddingtensor3D( |
| 189 | + [item["video_grid_thw"] for item in features] |
| 190 | + ) |
| 191 | + |
137 | 192 | return batch |
0 commit comments