diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 50115f6b5..3fd20eb1f 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable import torch +from packaging.version import parse as parse_version from torch import nn from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback +from transformers import __version__ as transformers_version from transformers.data.data_collator import DataCollator from transformers.integrations import WandbCallback from transformers.trainer import TRAINING_ARGS_NAME @@ -202,19 +204,24 @@ def __init__( train_dataset = DatasetDict(train_dataset) if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict): eval_dataset = DatasetDict(eval_dataset) - super().__init__( - model=None if self.model_init else model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) + super_kwargs = { + "model": None if self.model_init else model, + "args": args, + "data_collator": data_collator, + "train_dataset": train_dataset, + "eval_dataset": eval_dataset, + "model_init": model_init, + "compute_metrics": compute_metrics, + "callbacks": callbacks, + "optimizers": optimizers, + "preprocess_logits_for_metrics": preprocess_logits_for_metrics, + } + # Transformers v4.46.0 changed the `tokenizer` argument to a more general `processing_class` argument + if parse_version(transformers_version) >= parse_version("4.46.0"): + super_kwargs["processing_class"] = tokenizer + else: + super_kwargs["tokenizer"] = tokenizer + super().__init__(**super_kwargs) # Every Sentence Transformer model can always return a loss, so we set this to True # to avoid having to specify it in the data collator or model's forward self.can_return_loss = True @@ -311,6 +318,7 @@ def compute_loss( model: SentenceTransformer, inputs: dict[str, torch.Tensor | Any], return_outputs: bool = False, + num_items_in_batch=None, ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: """ Computes the loss for the SentenceTransformer model. @@ -325,6 +333,7 @@ def compute_loss( model (SentenceTransformer): The SentenceTransformer model. inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model. return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False. + num_items_in_batch (int, optional): The number of items in the batch. Defaults to None. Unused, but required by the transformers Trainer. Returns: Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.