Skip to content

Commit

Permalink
[integration] Add support for Transformers v4.46.0 (#3026)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Oct 29, 2024
1 parent 1912788 commit 96052ad
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import TYPE_CHECKING, Any, Callable

import torch
from packaging.version import parse as parse_version
from torch import nn
from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler
from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback
from transformers import __version__ as transformers_version
from transformers.data.data_collator import DataCollator
from transformers.integrations import WandbCallback
from transformers.trainer import TRAINING_ARGS_NAME
Expand Down Expand Up @@ -202,19 +204,24 @@ def __init__(
train_dataset = DatasetDict(train_dataset)
if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict):
eval_dataset = DatasetDict(eval_dataset)
super().__init__(
model=None if self.model_init else model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
super_kwargs = {
"model": None if self.model_init else model,
"args": args,
"data_collator": data_collator,
"train_dataset": train_dataset,
"eval_dataset": eval_dataset,
"model_init": model_init,
"compute_metrics": compute_metrics,
"callbacks": callbacks,
"optimizers": optimizers,
"preprocess_logits_for_metrics": preprocess_logits_for_metrics,
}
# Transformers v4.46.0 changed the `tokenizer` argument to a more general `processing_class` argument
if parse_version(transformers_version) >= parse_version("4.46.0"):
super_kwargs["processing_class"] = tokenizer
else:
super_kwargs["tokenizer"] = tokenizer
super().__init__(**super_kwargs)
# Every Sentence Transformer model can always return a loss, so we set this to True
# to avoid having to specify it in the data collator or model's forward
self.can_return_loss = True
Expand Down Expand Up @@ -311,6 +318,7 @@ def compute_loss(
model: SentenceTransformer,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
"""
Computes the loss for the SentenceTransformer model.
Expand All @@ -325,6 +333,7 @@ def compute_loss(
model (SentenceTransformer): The SentenceTransformer model.
inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model.
return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False.
num_items_in_batch (int, optional): The number of items in the batch. Defaults to None. Unused, but required by the transformers Trainer.
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.
Expand Down

0 comments on commit 96052ad

Please sign in to comment.