Skip to content

Commit

Permalink
[bug] Prevent to from being ignored (#2351)
Browse files Browse the repository at this point in the history
* Update _target_device on `to` call

+ test

* Fully replace `_target_device` with `device`

But try to preserve backwards compatibility

* Update test phrasing
  • Loading branch information
tomaarsen authored Dec 12, 2023
1 parent 331549c commit 135753b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
23 changes: 17 additions & 6 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, model_name_or_path: Optional[str] = None,
device = get_device_name()
logger.info("Use pytorch device_name: {}".format(device))

self._target_device = torch.device(device)
self.to(device)

def encode(self, sentences: Union[str, List[str]],
batch_size: int = 32,
Expand Down Expand Up @@ -167,7 +167,7 @@ def encode(self, sentences: Union[str, List[str]],
input_was_string = True

if device is None:
device = self._target_device
device = self.device

self.to(device)

Expand Down Expand Up @@ -658,7 +658,7 @@ def fit(self,
from torch.cuda.amp import autocast
scaler = torch.cuda.amp.GradScaler()

self.to(self._target_device)
self.to(self.device)

dataloaders = [dataloader for dataloader, _ in train_objectives]

Expand All @@ -668,7 +668,7 @@ def fit(self,

loss_models = [loss for _, loss in train_objectives]
for loss_model in loss_models:
loss_model.to(self._target_device)
loss_model.to(self.device)

self.best_score = -9999999

Expand Down Expand Up @@ -724,8 +724,8 @@ def fit(self,
data = next(data_iterator)

features, labels = data
labels = labels.to(self._target_device)
features = list(map(lambda batch: batch_to_device(batch, self._target_device), features))
labels = labels.to(self.device)
features = list(map(lambda batch: batch_to_device(batch, self.device), features))

if use_amp:
with autocast():
Expand Down Expand Up @@ -949,3 +949,14 @@ def max_seq_length(self, value):
Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
"""
self._first_module().max_seq_length = value

@property
def _target_device(self) -> torch.device:
logger.warning(
"`SentenceTransformer._target_device` has been removed, please use `SentenceTransformer.device` instead.",
)
return self.device

@_target_device.setter
def _target_device(self, device: Optional[Union[int, str, torch.device]] = None) -> None:
self.to(device)
19 changes: 19 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests general behaviour of the SentenceTransformer class
"""


from pathlib import Path
import tempfile

Expand Down Expand Up @@ -45,3 +46,21 @@ def test_load_with_safetensors(self):
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)),
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings",
)

@unittest.skipUnless(torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_to(self):
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")

test_device = torch.device("cuda")
self.assertEqual(model.device.type, "cpu")
self.assertEqual(test_device.type, "cuda")

model.to(test_device)
self.assertEqual(model.device.type, "cuda", msg="The model device should have updated")

model.encode("Test sentence")
self.assertEqual(model.device.type, "cuda", msg="Encoding shouldn't change the device")

self.assertEqual(model._target_device, model.device, msg="Prevent backwards compatibility failure for _target_device")
model._target_device = "cpu"
self.assertEqual(model.device.type, "cpu", msg="Ensure that setting `_target_device` doesn't crash.")

0 comments on commit 135753b

Please sign in to comment.