|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING, Any |
| 4 | + |
| 5 | +import torch |
| 6 | +from tqdm.auto import tqdm |
| 7 | +from transformers import AutoModel, AutoProcessor |
| 8 | + |
| 9 | +from mteb._requires_package import requires_image_dependencies |
| 10 | +from mteb.models.abs_encoder import AbsEncoder |
| 11 | +from mteb.models.model_meta import ModelMeta, ScoringFunction |
| 12 | +from mteb.types import OutputDType |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from torch.utils.data import DataLoader |
| 16 | + |
| 17 | + from mteb.abstasks.task_metadata import TaskMetadata |
| 18 | + from mteb.types import Array, BatchedInput, PromptType |
| 19 | + |
| 20 | + |
| 21 | +class ColVec1Wrapper(AbsEncoder): |
| 22 | + """ |
| 23 | + MTEB wrapper for ColVec1 (ColQwen3.5-based) retrieval models. |
| 24 | +
|
| 25 | + Loads via AutoModel/AutoProcessor with trust_remote_code=True so no |
| 26 | + external library beyond transformers is required. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + model_name: str, |
| 32 | + revision: str | None = None, |
| 33 | + device: str | None = None, |
| 34 | + trust_remote_code: bool = True, |
| 35 | + torch_dtype: torch.dtype | None = torch.bfloat16, |
| 36 | + **kwargs, |
| 37 | + ): |
| 38 | + requires_image_dependencies() |
| 39 | + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| 40 | + |
| 41 | + self.model = AutoModel.from_pretrained( |
| 42 | + model_name, |
| 43 | + revision=revision, |
| 44 | + dtype=torch_dtype, |
| 45 | + trust_remote_code=trust_remote_code, |
| 46 | + **kwargs, |
| 47 | + ).to(self.device) |
| 48 | + self.model.eval() |
| 49 | + |
| 50 | + self.processor = AutoProcessor.from_pretrained( |
| 51 | + model_name, |
| 52 | + revision=revision, |
| 53 | + trust_remote_code=trust_remote_code, |
| 54 | + ) |
| 55 | + |
| 56 | + def encode( |
| 57 | + self, |
| 58 | + inputs: DataLoader[BatchedInput], |
| 59 | + *, |
| 60 | + task_metadata: TaskMetadata, |
| 61 | + hf_split: str, |
| 62 | + hf_subset: str, |
| 63 | + prompt_type: PromptType | None = None, |
| 64 | + **kwargs: Any, |
| 65 | + ) -> Array: |
| 66 | + text_embeddings = None |
| 67 | + image_embeddings = None |
| 68 | + |
| 69 | + if "text" in inputs.dataset.features: |
| 70 | + text_embeddings = self.get_text_embeddings(inputs, **kwargs) |
| 71 | + if "image" in inputs.dataset.features: |
| 72 | + image_embeddings = self.get_image_embeddings(inputs, **kwargs) |
| 73 | + |
| 74 | + if text_embeddings is not None and image_embeddings is not None: |
| 75 | + if len(text_embeddings) != len(image_embeddings): |
| 76 | + raise ValueError( |
| 77 | + "The number of texts and images must have the same length" |
| 78 | + ) |
| 79 | + return text_embeddings + image_embeddings |
| 80 | + elif text_embeddings is not None: |
| 81 | + return text_embeddings |
| 82 | + elif image_embeddings is not None: |
| 83 | + return image_embeddings |
| 84 | + raise ValueError("No text or image features found in inputs.") |
| 85 | + |
| 86 | + def _encode_inputs(self, encoded_inputs: dict[str, torch.Tensor]) -> torch.Tensor: |
| 87 | + vlm = getattr(self.model, "vlm", None) |
| 88 | + if vlm is not None: |
| 89 | + base = getattr(vlm, "model", vlm) |
| 90 | + if hasattr(base, "rope_deltas"): |
| 91 | + base.rope_deltas = None |
| 92 | + return self.model(**encoded_inputs) |
| 93 | + |
| 94 | + def get_image_embeddings( |
| 95 | + self, images, batch_size=32, show_progress_bar=True, **kwargs |
| 96 | + ): |
| 97 | + import torchvision.transforms.functional as F |
| 98 | + from PIL import Image |
| 99 | + |
| 100 | + all_embeds = [] |
| 101 | + with torch.no_grad(): |
| 102 | + for batch in tqdm( |
| 103 | + images, disable=not show_progress_bar, desc="Encoding images" |
| 104 | + ): |
| 105 | + imgs = [ |
| 106 | + F.to_pil_image(b.to(self.device)) |
| 107 | + if not isinstance(b, Image.Image) |
| 108 | + else b |
| 109 | + for b in batch["image"] |
| 110 | + ] |
| 111 | + imgs = [img.convert("RGB") for img in imgs] |
| 112 | + inputs = self.processor.process_images(imgs) |
| 113 | + inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| 114 | + outs = self._encode_inputs(inputs) |
| 115 | + all_embeds.extend(outs.cpu().to(torch.float32)) |
| 116 | + return torch.nn.utils.rnn.pad_sequence( |
| 117 | + all_embeds, batch_first=True, padding_value=0 |
| 118 | + ) |
| 119 | + |
| 120 | + def get_text_embeddings( |
| 121 | + self, texts, batch_size=32, show_progress_bar=True, **kwargs |
| 122 | + ): |
| 123 | + all_embeds = [] |
| 124 | + with torch.no_grad(): |
| 125 | + for batch in tqdm( |
| 126 | + texts, disable=not show_progress_bar, desc="Encoding texts" |
| 127 | + ): |
| 128 | + inputs = self.processor.process_queries(batch["text"]) |
| 129 | + inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| 130 | + outs = self._encode_inputs(inputs) |
| 131 | + all_embeds.extend(outs.cpu().to(torch.float32)) |
| 132 | + return torch.nn.utils.rnn.pad_sequence( |
| 133 | + all_embeds, batch_first=True, padding_value=0 |
| 134 | + ) |
| 135 | + |
| 136 | + def similarity(self, a, b): |
| 137 | + a = [torch.as_tensor(x) for x in a] |
| 138 | + b = [torch.as_tensor(x) for x in b] |
| 139 | + return self.processor.score_multi_vector(a, b, device=self.device) |
| 140 | + |
| 141 | + |
| 142 | +COLWEBAI_TRAINING_DATA = { |
| 143 | + "VidoreDocVQARetrieval", |
| 144 | + "VidoreInfoVQARetrieval", |
| 145 | + "VidoreTatdqaRetrieval", |
| 146 | + "VidoreArxivQARetrieval", |
| 147 | + "VisRAGRetArxivQA", |
| 148 | + "VisRAGRetChartQA", |
| 149 | + "VisRAGRetInfoVQA", |
| 150 | + "VisRAGRetPlotQA", |
| 151 | + "VisRAGRetMPDocVQA", |
| 152 | + "VisRAGRetSlideVQA", |
| 153 | + "VDRMultilingualRetrieval", |
| 154 | + "VidoreTabfquadRetrieval", |
| 155 | +} |
| 156 | + |
| 157 | +COLWEBAI_CITATION = """ |
| 158 | +@misc{webAI-ColVec1, |
| 159 | + title={webAI-ColVec1: Late-Interaction Multi-Vector Embedding Model for Visual Document Retrieval}, |
| 160 | + author={webAI}, |
| 161 | + year={2026}, |
| 162 | + url={https://huggingface.co/webAI-Official/webAI-ColVec1-4b} |
| 163 | +} |
| 164 | +""" |
| 165 | + |
| 166 | + |
| 167 | +colvec1_4b = ModelMeta( |
| 168 | + loader=ColVec1Wrapper, |
| 169 | + loader_kwargs=dict(torch_dtype=torch.bfloat16), |
| 170 | + name="webAI-Official/webAI-ColVec1-4b", |
| 171 | + revision="dce73882e6b89a01e702891a593f775dc5711929", |
| 172 | + release_date="2026-04-05", |
| 173 | + model_type=["late-interaction"], |
| 174 | + languages=["eng-Latn", "fra-Latn"], |
| 175 | + modalities=["image", "text"], |
| 176 | + n_parameters=4540904576, |
| 177 | + n_embedding_parameters=1639040, |
| 178 | + n_active_parameters_override=None, |
| 179 | + memory_usage_mb=8661, |
| 180 | + max_tokens=262144, |
| 181 | + embed_dim=640, |
| 182 | + license="multiple", |
| 183 | + open_weights=True, |
| 184 | + public_training_code=None, |
| 185 | + public_training_data=None, |
| 186 | + framework=["PyTorch", "Transformers", "safetensors"], |
| 187 | + reference="https://huggingface.co/webAI-Official/webAI-ColVec1-4b", |
| 188 | + similarity_fn_name=ScoringFunction.MAX_SIM, |
| 189 | + use_instructions=False, |
| 190 | + training_datasets=COLWEBAI_TRAINING_DATA, |
| 191 | + adapted_from="Qwen/Qwen3.5-4B", |
| 192 | + superseded_by=None, |
| 193 | + citation=COLWEBAI_CITATION, |
| 194 | + contacts=["psam-ai"], |
| 195 | + output_dtypes=OutputDType.FLOAT16, |
| 196 | +) |
| 197 | + |
| 198 | + |
| 199 | +colvec1_9b = ModelMeta( |
| 200 | + loader=ColVec1Wrapper, |
| 201 | + loader_kwargs=dict(torch_dtype=torch.bfloat16), |
| 202 | + name="webAI-Official/webAI-ColVec1-9b", |
| 203 | + revision="3767539920b9132abb24cef2c88d42d81817e50b", |
| 204 | + release_date="2026-04-05", |
| 205 | + model_type=["late-interaction"], |
| 206 | + languages=["eng-Latn", "fra-Latn"], |
| 207 | + modalities=["image", "text"], |
| 208 | + n_parameters=9420302064, |
| 209 | + n_embedding_parameters=10488320, |
| 210 | + n_active_parameters_override=None, |
| 211 | + memory_usage_mb=17968, |
| 212 | + max_tokens=262144, |
| 213 | + embed_dim=2560, |
| 214 | + license="multiple", |
| 215 | + open_weights=True, |
| 216 | + public_training_code=None, |
| 217 | + public_training_data=None, |
| 218 | + framework=["PyTorch", "Transformers", "safetensors"], |
| 219 | + reference="https://huggingface.co/webAI-Official/webAI-ColVec1-9b", |
| 220 | + similarity_fn_name=ScoringFunction.MAX_SIM, |
| 221 | + use_instructions=False, |
| 222 | + training_datasets=COLWEBAI_TRAINING_DATA, |
| 223 | + adapted_from="Qwen/Qwen3.5-9B", |
| 224 | + superseded_by=None, |
| 225 | + citation=COLWEBAI_CITATION, |
| 226 | + contacts=["psam-ai"], |
| 227 | + output_dtypes=OutputDType.FLOAT16, |
| 228 | +) |
0 commit comments