27 lines
1,022 B
Python
27 lines
1,022 B
Python
import torch
|
|
from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
|
|
|
|
|
class MCLIPConfig(XLMRobertaConfig):
|
|
model_type = "M-CLIP"
|
|
|
|
def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs):
|
|
self.transformerDimensions = transformerDimSize
|
|
self.numDims = imageDimSize
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
class MultilingualCLIP(PreTrainedModel):
|
|
config_class = MCLIPConfig
|
|
|
|
def __init__(self, config, *args, **kwargs):
|
|
super().__init__(config, *args, **kwargs)
|
|
self.transformer = XLMRobertaModel(config)
|
|
self.LinearTransformation = torch.nn.Linear(
|
|
in_features=config.transformerDimensions, out_features=config.numDims
|
|
)
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
|
|
embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
|
|
return self.LinearTransformation(embs2), embs
|