diff --git a/python/fasttext_module/fasttext/FastText.py b/python/fasttext_module/fasttext/FastText.py index 64b5f4cac..fcb3ba5eb 100644 --- a/python/fasttext_module/fasttext/FastText.py +++ b/python/fasttext_module/fasttext/FastText.py @@ -516,6 +516,7 @@ def train_supervised(*kargs, **kwargs): 'model': "supervised" }) + callback = kwargs.pop("callback", None) arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount', 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket', 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors', @@ -525,7 +526,10 @@ def train_supervised(*kargs, **kwargs): supervised_default) a = _build_args(args, manually_set_args) ft = _FastText(args=a) - fasttext.train(ft.f, a) + if callback: + fasttext.train_with_callback(ft.f, a, callback) + else: + fasttext.train(ft.f, a) ft.set_args(ft.f.getArgs()) return ft @@ -544,6 +548,7 @@ def train_unsupervised(*kargs, **kwargs): dataset pulled by the example script word-vector-example.sh, which is part of the fastText repository. """ + callback = kwargs.pop("callback", None) arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount', 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket', 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors'] @@ -551,7 +556,10 @@ def train_unsupervised(*kargs, **kwargs): unsupervised_default) a = _build_args(args, manually_set_args) ft = _FastText(args=a) - fasttext.train(ft.f, a) + if callback: + fasttext.train_with_callback(ft.f, a, callback) + else: + fasttext.train(ft.f, a) ft.set_args(ft.f.getArgs()) return ft diff --git a/python/fasttext_module/fasttext/pybind/fasttext_pybind.cc b/python/fasttext_module/fasttext/pybind/fasttext_pybind.cc index 4cd1d3728..9cf27a69b 100644 --- a/python/fasttext_module/fasttext/pybind/fasttext_pybind.cc +++ b/python/fasttext_module/fasttext/pybind/fasttext_pybind.cc @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -166,6 +167,13 @@ PYBIND11_MODULE(fasttext_pybind, m) { } }, py::call_guard()); + + m.def( + "train_with_callback", + [](fasttext::FastText& ft, fasttext::Args& a, fasttext::FastText::TrainCallback& c) { + ft.train(a, c); + }, + py::call_guard()); py::class_(m, "Vector", py::buffer_protocol()) .def(py::init())