Skip to content

Commit 5eb2a1b

Browse files
authored
[feat] Update the trackio default project if not already defined (#3467)
* Update the trackio default project if not already defined * Specify the exact min. version
1 parent 267a5f6 commit 5eb2a1b

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

sentence_transformers/trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
if TYPE_CHECKING:
4949
from sentence_transformers.SentenceTransformer import SentenceTransformer
5050

51+
# The TrackioCallback is only available in the v4.54+ of transformers, but I'd like to keep Sentence Transformers
52+
# compatible with older versions of transformers as well, so we import it conditionally
53+
try:
54+
from transformers.integrations import TrackioCallback
55+
except ImportError:
56+
TrackioCallback = None
57+
5158

5259
class SentenceTransformerTrainer(Trainer):
5360
"""
@@ -273,9 +280,13 @@ def __init__(
273280
self.model: SentenceTransformer
274281
self.args: SentenceTransformerTrainingArguments
275282
self.data_collator: SentenceTransformerDataCollator
276-
# Set the W&B project via environment variables if it's not already set
283+
# Set the W&B or Trackio project via environment variables if it's not already set
277284
if any([isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks]):
278285
os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")
286+
if TrackioCallback is not None and any(
287+
[isinstance(callback, TrackioCallback) for callback in self.callback_handler.callbacks]
288+
):
289+
os.environ.setdefault("TRACKIO_PROJECT", "sentence-transformers")
279290

280291
if loss is None:
281292
logger.info("No `loss` passed, using `losses.CoSENTLoss` as a default option.")

0 commit comments

Comments
 (0)