diff --git a/CHANGELOG.md b/CHANGELOG.md index 721ffd06a..e343fd828 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Adds `marian-adaptive` and `marian-adaptive-server` executables to enable self-adaptive translation (a.k.a, runtime domain adaptation). ### Fixed - Scripts using PyYAML now use `safe_load`; see https://msg.pyyaml.org/load diff --git a/CMakeLists.txt b/CMakeLists.txt index dbad75cb5..9eae892f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ option(COMPILE_CPU "Compile CPU version" ON) option(COMPILE_CUDA "Compile GPU version" ON) option(COMPILE_EXAMPLES "Compile examples" OFF) option(COMPILE_SERVER "Compile marian-server" OFF) +option(COMPILE_ADAPTIVE "Compile marian-adaptive. Set COMPILE_SERVER=ON to enable the server mode." OFF) option(COMPILE_TESTS "Compile tests" OFF) if(APPLE) option(USE_APPLE_ACCELERATE "Compile with Apple Accelerate" ON) diff --git a/scripts/self-adaptive/client_example.py b/scripts/self-adaptive/client_example.py new file mode 100644 index 000000000..e1fa52d37 --- /dev/null +++ b/scripts/self-adaptive/client_example.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# This is an example for using self-adaptive translation in server mode. +# +# To run: +# 1. Start self-adaptive Marian in server mode, e.g.: +# ./build/marian-adaptive-server -p 8080 -m model.npz -v vocap.yaml vocab.yaml \ +# --after-batches 10 --after-epochs 10 --learn-rate 0.1 --mini-batch 15 # other options +# 2. In a new shell, run this script: +# python3 ./scripts/self-adaptive/client_exmaple.py -p 8080 +# +# For a more extensive example, see https://github.com/marian-cef/marian-examples/tree/master/adaptive +# or https://github.com/tilde-nlp/runtime-domain-adaptation-tutorial + +from __future__ import print_function, unicode_literals, division + +import sys +import time +import argparse +import json + +from websocket import create_connection + + +def translate(batch, port=8080): + ws = create_connection("ws://localhost:{}/translate".format(port)) + ws.send(batch) + result = ws.recv() + ws.close() + return result.rstrip() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--port", type=int, default=8080) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + # List of input sentences separated by a new line character + inputs = "this is an example\nthe second sentence\nno context provided" + # For each input sentence a list of parallel sentences can be provided as a + # list of source and target sentences. + contexts = [ + # Source-side context for the first input sentence + ["this is a test\nthese are examples", + # Target-side context for the first input sentence + "das ist ein test\ndies sind Beispiele"], + # Only one example is given as a context for the second input sentence + ["the next sentence", + "der nächste Satz"], + # No context for the third input sentence + [] + ] + + input_data = {'input': inputs, 'context': contexts} + input_json = json.dumps(input_data) + + output_json = translate(input_json, port=args.port) + output_data = json.loads(output_json) + print(output_data['output']) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3718807a5..cc9a8345b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -104,6 +104,7 @@ set(MARIAN_SOURCES translator/nth_element.cpp translator/helpers.cpp translator/scorers.cpp + translator/swappable.cpp training/graph_group_async.cpp training/graph_group_sync.cpp @@ -129,6 +130,12 @@ set(MARIAN_SOURCES $ ) +if(COMPILE_ADAPTIVE) + set(MARIAN_SOURCES ${MARIAN_SOURCES} + data/adaptive_context.cpp + ) +endif(COMPILE_ADAPTIVE) + add_library(marian STATIC ${MARIAN_SOURCES}) target_compile_options(marian PRIVATE ${ALL_WARNINGS}) @@ -188,6 +195,7 @@ if(CUDA_FOUND) tensors/gpu/add_all.cu tensors/gpu/tensor_operators.cu tensors/gpu/cudnn_wrappers.cu + tensors/gpu/swap.cu translator/nth_element.cu translator/helpers.cu STATIC) @@ -274,6 +282,18 @@ if (NOT COMPILE_LIBRARY_ONLY) set(EXECUTABLES ${EXECUTABLES} marian_server) endif(COMPILE_SERVER) + if(COMPILE_ADAPTIVE) + add_executable(marian_adaptive command/marian_adaptive.cpp) + set_target_properties(marian_adaptive PROPERTIES OUTPUT_NAME marian-adaptive) + set(EXECUTABLES ${EXECUTABLES} marian_adaptive) + + if(COMPILE_SERVER) + add_executable(marian_adaptive_server command/marian_adaptive_server.cpp) + set_target_properties(marian_adaptive_server PROPERTIES OUTPUT_NAME marian-adaptive-server) + set(EXECUTABLES ${EXECUTABLES} marian_adaptive_server) + endif(COMPILE_SERVER) + endif(COMPILE_ADAPTIVE) + foreach(exec ${EXECUTABLES}) target_link_libraries(${exec} marian) if(CUDA_FOUND) diff --git a/src/command/marian_adaptive.cpp b/src/command/marian_adaptive.cpp new file mode 100644 index 000000000..a21d04a7d --- /dev/null +++ b/src/command/marian_adaptive.cpp @@ -0,0 +1,19 @@ +#include "marian.h" + +#include "common/timer.h" +#include "common/utils.h" +#include "training/training.h" +#include "translator/self_adaptive.h" + +using namespace marian; + +int main(int argc, char **argv) { + auto options = parseOptions(argc, argv, cli::mode::selfadaptive); + auto task = New(options); + + timer::Timer timer; + task->run(); + LOG(info, "Total time: {:.5f}s", timer.elapsed()); + + return 0; +} diff --git a/src/command/marian_adaptive_server.cpp b/src/command/marian_adaptive_server.cpp new file mode 100644 index 000000000..e2f03d999 --- /dev/null +++ b/src/command/marian_adaptive_server.cpp @@ -0,0 +1,11 @@ +#include "translator/self_adaptive.h" +#include "translator/server_common.h" + +int main(int argc, char **argv) { + using namespace marian; + + auto options = parseOptions(argc, argv, cli::mode::selfadaptiveServer); + auto task = New(options); + + return runServer(task, options); +} diff --git a/src/command/marian_server.cpp b/src/command/marian_server.cpp index d712e8389..ef62320b8 100644 --- a/src/command/marian_server.cpp +++ b/src/command/marian_server.cpp @@ -1,62 +1,11 @@ -#include "marian.h" -#include "translator/beam_search.h" +#include "translator/server_common.h" #include "translator/translator.h" -#include "common/timer.h" -#include "common/utils.h" - -#include "3rd_party/simple-websocket-server/server_ws.hpp" - -typedef SimpleWeb::SocketServer WSServer; int main(int argc, char **argv) { using namespace marian; - // Initialize translation task auto options = parseOptions(argc, argv, cli::mode::server, true); auto task = New>(options); - auto quiet = options->get("quiet-translation"); - - // Initialize web server - WSServer server; - server.config.port = (short)options->get("port", 8080); - - auto &translate = server.endpoint["^/translate/?$"]; - - translate.on_message = [&task, quiet](Ptr connection, - Ptr message) { - // Get input text - auto inputText = message->string(); - auto sendStream = std::make_shared(); - - // Translate - timer::Timer timer; - auto outputText = task->run(inputText); - *sendStream << outputText << std::endl; - if(!quiet) - LOG(info, "Translation took: {:.5f}s", timer.elapsed()); - - // Send translation back - connection->send(sendStream, [](const SimpleWeb::error_code &ec) { - if(ec) - LOG(error, "Error sending message: ({}) {}", ec.value(), ec.message()); - }); - }; - - // Error Codes for error code meanings - // http://www.boost.org/doc/libs/1_55_0/doc/html/boost_asio/reference.html - translate.on_error = [](Ptr /*connection*/, - const SimpleWeb::error_code &ec) { - LOG(error, "Connection error: ({}) {}", ec.value(), ec.message()); - }; - - // Start server thread - std::thread serverThread([&server]() { - server.start([](unsigned short port) { - LOG(info, "Server is listening on port {}", port); - }); - }); - - serverThread.join(); - return 0; + return runServer(task, options); } diff --git a/src/common/config.cpp b/src/common/config.cpp index 9878c70b0..3e03f8a6d 100644 --- a/src/common/config.cpp +++ b/src/common/config.cpp @@ -73,7 +73,7 @@ void Config::initialize(ConfigParser const& cp) { } // guess --tsv-fields, i.e. the number of fields in a TSV input, if not set - if(get("tsv") && get("tsv-fields") == 0) { + if(get("tsv", false) && get("tsv-fields") == 0) { size_t tsvFields = 0; // use the length of --input-types if given diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 0d9564953..aed9593a1 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -69,13 +69,29 @@ std::string const& ConfigParser::cmdLine() const { return cmdLine_; } -ConfigParser::ConfigParser(cli::mode mode) - : cli_(config_,"Marian: Fast Neural Machine Translation in C++", - "General options", "", 40), - mode_(mode == cli::mode::server ? cli::mode::translation : mode) { +/** + * Convert some special modes (currently, server-like modes) to their non-special counterparts. + */ +cli::mode convertSpecialModes(cli::mode mode) { + switch(mode) { + case cli::mode::server: + return cli::mode::translation; + case cli::mode::selfadaptiveServer: + return cli::mode::selfadaptive; + default: + return mode; + } +} +ConfigParser::ConfigParser(cli::mode mode) + : cli_(config_, "Marian: Fast Neural Machine Translation in C++", "General options", "", 40), + // Server-like modes should mostly act like their non-server counterparts + // when parsing options. We keep all special handling in the constructor + // but in the rest of the parsing code we just pretend that we have a + // non-server mode. + mode_(convertSpecialModes(mode)) { addOptionsGeneral(cli_); - if (mode == cli::mode::server) + if (mode == cli::mode::server || mode == cli::mode::selfadaptiveServer) addOptionsServer(cli_); addOptionsModel(cli_); @@ -94,6 +110,10 @@ ConfigParser::ConfigParser(cli::mode mode) case cli::mode::embedding: addOptionsEmbedding(cli_); break; + case cli::mode::selfadaptive: + addOptionsTraining(cli_); + addOptionsTranslation(cli_); + break; default: ABORT("wrong CLI mode"); break; @@ -121,6 +141,15 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) { cli.add("--workspace,-w", "Preallocate arg MB of work space", defaultWorkspace); + // Self-adaptive translation uses a training graph and a translation graph. We + // want to be able to prealocate different amounts of memory for both (because + // translation usually needs less) so we add a dedicated opiton for + // translation if self-adaptive translation is used. + if (mode_ == cli::mode::selfadaptive) { + cli.add("--workspace-translate", + "Preallocate arg MB of work space for translation", + 512); + } cli.add("--log", "Log training process information to file given by arg"); cli.add("--log-level", @@ -159,9 +188,7 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) { void ConfigParser::addOptionsServer(cli::CLIWrapper& cli) { // clang-format off auto previous_group = cli.switchGroup("Server options"); - cli.add("--port,-p", - "Port number for web socket server", - 8080); + cli.add("--port,-p", "Port number for web socket server", 8080); cli.switchGroup(previous_group); // clang-format on } @@ -336,7 +363,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { {1, 2, 3, 4, 5, 6, 7, 8}); #endif - if(mode_ == cli::mode::training) { + if(mode_ == cli::mode::training || mode_ == cli::mode::selfadaptive) { // TODO: add ->range(0,1); cli.add("--dropout-rnn", "Scaling dropout along rnn layers and time (0 = no dropout)"); @@ -388,9 +415,13 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { #endif // scheduling options + // In self-adaptive mode users would typically want less updates to happen than in regular training + size_t defaultAfterEpochs = (mode_ == cli::mode::selfadaptive) ? 2 : 0; + std::string defaultDispFreq = (mode_ == cli::mode::selfadaptive) ? "1u" : "1000u"; + // @TODO: these should be re-defined as aliases for `--after` but the current frame work matches on value, so not doable. cli.add("--after-epochs,-e", - "Finish after this many epochs, 0 is infinity (deprecated, '--after-epochs N' corresponds to '--after Ne')"); // @TODO: replace with alias + "Finish after this many epochs, 0 is infinity (deprecated, '--after-epochs N' corresponds to '--after Ne')", defaultAfterEpochs); // @TODO: replace with alias cli.add("--after-batches", "Finish after this many batch updates, 0 is infinity (deprecated, '--after-batches N' corresponds to '--after Nu')"); // @TODO: replace with alias @@ -399,7 +430,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "0e"); cli.add("--disp-freq", "Display information every arg updates (append 't' for every arg target labels)", - "1000u"); + defaultDispFreq); cli.add("--disp-first", "Display information for the first arg updates"); cli.add("--disp-label-counts", @@ -416,34 +447,49 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { {"1e", "0"}); addSuboptionsInputLength(cli); - addSuboptionsTSV(cli); + // TSV inputs aren't currently supported for self-adaptive translation because + // self-adaptive translation uses a custom training data reader + // (`AdaptiveContextReader`) which doesn't yet support TSV. + if (mode_ != cli::mode::selfadaptive) + addSuboptionsTSV(cli); // data management options - cli.add("--shuffle", - "How to shuffle input data (data: shuffles data and sorted batches; batches: " - "data is read in order into batches, but batches are shuffled; none: no shuffling). " - "Use with '--maxi-batch-sort none' in order to achieve exact reading order", "data"); - cli.add("--no-shuffle", - "Shortcut for backwards compatiblity, equivalent to --shuffle none (deprecated)"); - cli.add("--no-restore-corpus", - "Skip restoring corpus state after training is restarted"); - cli.add("--tempdir,-T", - "Directory for temporary (shuffled) files and database", - "/tmp"); - cli.add("--sqlite", - "Use disk-based sqlite3 database for training corpus storage, default" - " is temporary with path creates persistent storage") - ->implicit_val("temporary"); - cli.add("--sqlite-drop", - "Drop existing tables in sqlite3 database"); + // + // These options are disabled for self-adaptive translation because they seem + // to not make much sense in that context, except for --shuffle, because they + // deal with the storage of training data, but, in self-adaptive translation, + // training data sets are small and they typically change for each input + // sentence. --shuffle isn't currently supported because we use `TextInput` + // for training data and shuffle is a no-op in that class. This might get + // implemented in the future. + if (mode_ != cli::mode::selfadaptive) { + cli.add("--shuffle", + "How to shuffle input data (data: shuffles data and sorted batches; batches: " + "data is read in order into batches, but batches are shuffled; none: no shuffling). " + "Use with '--maxi-batch-sort none' in order to achieve exact reading order", "data"); + cli.add("--no-shuffle", + "Shortcut for backwards compatiblity, equivalent to --shuffle none (deprecated)"); + cli.add("--no-restore-corpus", + "Skip restoring corpus state after training is restarted"); + cli.add("--tempdir,-T", + "Directory for temporary (shuffled) files and database", + "/tmp"); + cli.add("--sqlite", + "Use disk-based sqlite3 database for training corpus storage, default" + " is temporary with path creates persistent storage") + ->implicit_val("temporary"); + cli.add("--sqlite-drop", + "Drop existing tables in sqlite3 database"); + } addSuboptionsDevices(cli); addSuboptionsBatching(cli); // optimizer options - cli.add("--optimizer,-o", + auto defaultOptimizer = (mode_ == cli::mode::selfadaptive) ? "sgd" : "adam"; + cli.add("--optimizer", "Optimization algorithm: sgd, adagrad, adam", - "adam"); + defaultOptimizer); cli.add>("--optimizer-params", "Parameters for optimization algorithm, e.g. betas for Adam. " "Auto-adjusted to --mini-batch-words-ref if given"); @@ -658,8 +704,11 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { cli.add("--output,-o", "Path to output file, stdout by default", "stdout"); - cli.add>("--vocabs,-v", - "Paths to vocabulary files have to correspond to --input"); + // for self-adaptive mode these are already added via the training options + if(mode_ != cli::mode::selfadaptive) { + cli.add>("--vocabs,-v", + "Paths to vocabulary files have to correspond to --input"); + } // decoding options cli.add("--beam-size,-b", "Beam size used during search with validating translator", @@ -691,16 +740,22 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { "Keep the output segmented into SentencePiece subwords"); #endif - addSuboptionsInputLength(cli); - addSuboptionsTSV(cli); - addSuboptionsDevices(cli); - addSuboptionsBatching(cli); + // For self-adaptive translation these options are already added in + // `addOptionsTraining` + if(mode_ != cli::mode::selfadaptive) { + addSuboptionsInputLength(cli); + addSuboptionsTSV(cli); + addSuboptionsDevices(cli); + addSuboptionsBatching(cli); + } - cli.add("--fp16", - "Shortcut for mixed precision inference with float16, corresponds to: --precision float16"); - cli.add>("--precision", - "Mixed precision for inference, set parameter type in expression graph", - {"float32"}); + if(mode_ != cli::mode::selfadaptive) { + cli.add("--fp16", + "Shortcut for mixed precision inference with float16, corresponds to: --precision float16"); + cli.add>("--precision", + "Mixed precision for inference, set parameter type in expression graph", + {"float32"}); + } cli.add("--skip-cost", "Ignore model cost during translation, not recommended for beam-size > 1"); @@ -727,7 +782,8 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { #if 0 // @TODO: Ask Hany if there are any decoding-time options // add ULR settings - addSuboptionsULR(cli); + if(mode_ != cli::mode::selfadaptive) + addSuboptionsULR(cli); #endif cli.switchGroup(previous_group); @@ -860,8 +916,9 @@ void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) { } void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) { - int defaultMiniBatch = (mode_ == cli::mode::translation) ? 1 : 64; - int defaultMaxiBatch = (mode_ == cli::mode::translation) ? 1 : 100; + bool transMode = mode_ == cli::mode::translation || mode_ == cli::mode::selfadaptive; + int defaultMiniBatch = transMode ? 1 : 64; + int defaultMaxiBatch = transMode ? 1 : 100; std::string defaultMaxiBatchSort = (mode_ == cli::mode::translation) ? "none" : "trg"; // clang-format off @@ -893,7 +950,7 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) { "Sorting strategy for maxi-batch: none, src, trg (not available for decoder)", defaultMaxiBatchSort); - if(mode_ == cli::mode::training) { + if(mode_ == cli::mode::training || mode_ == cli::mode::selfadaptive) { cli.add("--shuffle-in-ram", "Keep shuffled corpus in RAM, do not write to temp file"); @@ -936,13 +993,25 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) { } void ConfigParser::addSuboptionsInputLength(cli::CLIWrapper& cli) { - size_t defaultMaxLength = (mode_ == cli::mode::training) ? 50 : 1000; + size_t defaultMaxLength = + (mode_ == cli::mode::training || mode_ == cli::mode::selfadaptive) + ? 50 + : 1000; // clang-format off cli.add("--max-length", "Maximum length of a sentence in a training sentence pair", defaultMaxLength); cli.add("--max-length-crop", "Crop a sentence to max-length instead of omitting it if longer than max-length"); + // In self-adaptive translation, the user might want to be able to set + // different max lengths for training and translation. In that case, + // --max-length is assumed to be meant for training (as per the help message) + // and we add a --max-length-translate parameter for translation. + if (mode_ == cli::mode::selfadaptive) { + cli.add("--max-length-translate", + "Maximum input sentence length for translation", + 1000); + } // clang-format on } @@ -1088,7 +1157,7 @@ Ptr ConfigParser::parseOptions(int argc, char** argv, bool doValidate) // (or --data-weighting and 'weight'). // // Note: this may modify the config, so it is safer to do it after --dump-config. - if(mode_ == cli::mode::training || get("tsv")) { + if(mode_ == cli::mode::training || get("tsv", false)) { auto inputTypes = get>("input-types"); if(!inputTypes.empty()) { bool seenAligns = false; diff --git a/src/common/config_parser.h b/src/common/config_parser.h index 18b6eccb7..b0b4f9386 100644 --- a/src/common/config_parser.h +++ b/src/common/config_parser.h @@ -14,7 +14,7 @@ namespace marian { namespace cli { -enum struct mode { training, translation, scoring, server, embedding }; + enum struct mode { training, translation, scoring, server, embedding, selfadaptive, selfadaptiveServer }; } // namespace cli /** @@ -122,6 +122,16 @@ class ConfigParser { return config_[key].as(); } + // Return value for given option key cast to given type. Return the supplied + // default value if option is not set. + template + T get(const std::string& key, T defaultValue) const { + if(has(key)) + return config_[key].as(); + else + return defaultValue; + } + void addOptionsGeneral(cli::CLIWrapper&); void addOptionsServer(cli::CLIWrapper&); void addOptionsModel(cli::CLIWrapper&); diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp index b0230da99..254de51c5 100644 --- a/src/common/config_validator.cpp +++ b/src/common/config_validator.cpp @@ -37,6 +37,14 @@ void ConfigValidator::validateOptions(cli::mode mode) const { validateOptionsParallelData(); validateOptionsTraining(); break; + case cli::mode::selfadaptive: + validateOptionsVocabularies(); + // Check that we're not running in server mode. In server mode, training + // data are passed in via the request not CLI options + if (!has("port")) + validateOptionsParallelData(); + validateOptionsTraining(); + break; default: ABORT("wrong CLI mode"); break; @@ -62,6 +70,12 @@ void ConfigValidator::validateOptionsTranslation() const { ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile); } + validateOptionsVocabularies(); +} + +// Other validation methods already do vocabulary validation but we need this +// functionality separately for self-adaptive translation option validation +void ConfigValidator::validateOptionsVocabularies() const { auto vocabs = get>("vocabs"); ABORT_IF(vocabs.empty(), "Translating, but vocabularies are not given"); @@ -80,11 +94,14 @@ void ConfigValidator::validateOptionsParallelData() const { ABORT_IF(trainSets.empty(), "No train sets given in config file or on command line"); auto numVocabs = get>("vocabs").size(); - ABORT_IF(!get("tsv") && numVocabs > 0 && numVocabs != trainSets.size(), + // The "tsv" option isn't present in self-adaptive translation options so we + // have to explicitly default to false for the option + auto tsv = get("tsv", false); + ABORT_IF(!tsv && numVocabs > 0 && numVocabs != trainSets.size(), "There should be as many vocabularies as training files"); // disallow, for example --tsv --train-sets file1.tsv file2.tsv - ABORT_IF(get("tsv") && trainSets.size() != 1, + ABORT_IF(tsv && trainSets.size() != 1, "A single file must be provided with --train-sets (or stdin) for a tab-separated input"); // disallow, for example --train-sets stdin stdin or --train-sets stdin file.tsv @@ -126,7 +143,9 @@ void ConfigValidator::validateOptionsTraining() const { "Model directory does not exist"); std::string errorMsg = "There should be as many validation files as training files"; - if(get("tsv")) + // The "tsv" option isn't present in self-adaptive translation options so we + // have to explicitly default to false for the option + if(get("tsv", false)) errorMsg += ". If the training set is in the TSV format, validation sets have to also be a single TSV file"; ABORT_IF(has("valid-sets") @@ -134,10 +153,13 @@ void ConfigValidator::validateOptionsTraining() const { && !get>("valid-sets").empty(), errorMsg); - // check if --early-stopping-on has proper value - std::set supportedStops = {"first", "all", "any"}; - ABORT_IF(supportedStops.find(get("early-stopping-on")) == supportedStops.end(), - "Supported options for --early-stopping-on are: first, all, any"); + // "early-stopping" also isn't present for self-adaptive translation + if (has("early-stopping")) { + // check if --early-stopping-on has proper value + std::set supportedStops = {"first", "all", "any"}; + ABORT_IF(supportedStops.find(get("early-stopping-on")) == supportedStops.end(), + "Supported options for --early-stopping-on are: first, all, any"); + } // validations for learning rate decaying ABORT_IF(get("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual"); diff --git a/src/common/config_validator.h b/src/common/config_validator.h index 0e73a9e39..e31188532 100644 --- a/src/common/config_validator.h +++ b/src/common/config_validator.h @@ -14,12 +14,22 @@ class ConfigValidator { T get(const std::string& key) const { return config_[key].as(); } + // Return value for given option key cast to given type. Return the supplied + // default value if option is not set. + template + T get(const std::string& key, T defaultValue) const { + if(has(key)) + return config_[key].as(); + else + return defaultValue; + } // The option --dump-config is used, so alleviate some constraints, e.g. we don't want to require // --train-sets or --vocabs bool dumpConfigOnly_{false}; void validateOptionsTranslation() const; + void validateOptionsVocabularies() const; void validateOptionsParallelData() const; void validateOptionsScoring() const; void validateOptionsTraining() const; diff --git a/src/data/adaptive_context.cpp b/src/data/adaptive_context.cpp new file mode 100644 index 000000000..4e626ae88 --- /dev/null +++ b/src/data/adaptive_context.cpp @@ -0,0 +1,96 @@ +#include "data/adaptive_context.h" + +namespace marian { +namespace data { + +AdaptiveContextIterator::AdaptiveContextIterator(AdaptiveContextReader* trainSetReader) + : trainSetReader_(trainSetReader) { + if(trainSetReader) { + currentSamples_ = trainSetReader_->getSamples(); + } +} + +bool AdaptiveContextIterator::equal(const AdaptiveContextIterator& other) const { + return other.trainSetReader_ == trainSetReader_; +} + +const std::vector& AdaptiveContextIterator::dereference() const { + return currentSamples_; +} + +void AdaptiveContextIterator::increment() { + // If the previous increment has exhausted the file, we must indicate that the we've reached + // the iterator's end + if(trainSetReader_->eof() && trainSetReader_ != nullptr) { + trainSetReader_ = nullptr; + return; + } + // If we're at the end of the iterator and increment has been called yet another time, there's + // a bug in the calling code + ABORT_IF(trainSetReader_ == nullptr, "Incrementing past the end of the iterator isn't allowed"); + + currentSamples_ = trainSetReader_->getSamples(); +} + + +AdaptiveContextReader::AdaptiveContextReader(std::vector paths) { + for(auto& path : paths) + files_.emplace_back(new io::InputFileStream(path)); +} + +AdaptiveContextIterator AdaptiveContextReader::begin() { + return AdaptiveContextIterator(this); +} + +AdaptiveContextIterator AdaptiveContextReader::end() { + return AdaptiveContextIterator(nullptr); +} + +bool AdaptiveContextReader::eof() { + return eof_; +} + +std::vector AdaptiveContextReader::getSamples() { + // extracted lines for source and target corpora + std::vector samples; + // counters of number of lines extracted for source and target + std::vector counts; + + // Early exit if input files are exhausted + if (eof_) return samples; + + for(auto const& file : files_) { + size_t currCount = 0; + std::string lines; + std::string line; + bool fileEnded = true; + while(io::getline(*file, line)) { + if(line.empty()) { + fileEnded = false; + break; + } + + if(currCount) + lines += "\n"; + lines += line; + currCount += 1; + } + eof_ = fileEnded; + + if(!lines.empty()) + samples.emplace_back(lines); + counts.push_back(currCount); + + // check if the same number of lines is extracted for source and target + size_t prevCount = counts[0]; + for(size_t i = 1; i < counts.size(); ++i) { + ABORT_IF(prevCount != counts[i], + "An empty source or target sentence has been encountered!"); + prevCount = counts[i]; + } + } + + return samples; +} +} // namespace data +} // namespace marian diff --git a/src/data/adaptive_context.h b/src/data/adaptive_context.h new file mode 100644 index 000000000..dc7ebee5a --- /dev/null +++ b/src/data/adaptive_context.h @@ -0,0 +1,93 @@ +#pragma once + +#include "common/file_stream.h" +#include "data/iterator_facade.h" + +namespace marian { +namespace data { + + +class AdaptiveContextReader; + + +/** + * An iterator for easier access of the context sentences produced by + * `AdaptiveContextReader::getSamples()` + */ +class AdaptiveContextIterator + : public IteratorFacade> { + + AdaptiveContextReader* trainSetReader_; + std::vector currentSamples_; + +public: + // TODO: should we use a smart pointer here instead? The TrainSetReader::begin() method + // would make it difficult + AdaptiveContextIterator(AdaptiveContextReader* trainSetReader); + + bool equal(const AdaptiveContextIterator& other) const override; + + const std::vector& dereference() const override; + + void increment() override; +}; + + +/** + * Reads the context sentences, that are used for on-the-fly training in + * the self-adaptive translation mode, from files. + */ +class AdaptiveContextReader { + + std::vector> files_; + /// Indicates whether the input files have been exhausted. + bool eof_ = false; + +public: + /** + * Initializes a new reader by supplying paths to the files with + * context sentences + * + * @param paths paths to the input files. The input files contain + * newline-separated parallel sentence pairs (as usual for MT). Sentences are + * grouped by the translatable sentences (which are provided in a different + * file). Each group is delimited by a single empty line. The sentence group + * can be empty (no context is provided for the respective translatable + * sentence) in which case it is also represented by a single empty line. + */ + AdaptiveContextReader(std::vector paths); + + /** + * Returns an iterator over the sets of context sentences produced by + * `getSamples()` + * + * @return the beginning of the iterator. + */ + AdaptiveContextIterator begin(); + + AdaptiveContextIterator end(); + + bool eof(); + + /** + * Reads the next set of samples -- the contaxt sentences -- for + * on-the-fly training in the self-adaptive translation mode. + * + * @details The input files contain newline-separated parallel sentence pairs + * (as usual for MT). Sentences are grouped by the translatable sentences + * (which are provided in a different file). Each group is delimited by a + * single empty line. The sentence group can be empty (no context is provided + * for the respective translatable sentence) in which case it is also + * represented by a single empty line. + * + * @return a vector representing a single group of context sentences. Each + * element in the vector contains newline separated input lines comming from a + * single file, e.g., [0] could contain 3 newline separated sentences in + * English and [1] would contain their 3 respective translations in Latvian. + */ + std::vector getSamples(); +}; + + +} // namespace data +} // namespace marian diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 71c9f9908..a6fd4e3fc 100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -47,9 +47,9 @@ const SentenceTuple& CorpusIterator::dereference() const { return tup_; } -// These types of corpus constructors are used in in-training validators -// (only?), so do not load additional files for guided alignment or data -// weighting. +// These types of corpus constructors are used in in-training validators (only? +// (also in self-adaptive translation)), so do not load additional files for +// guided alignment or data weighting. CorpusBase::CorpusBase(const std::vector& paths, const std::vector>& vocabs, Ptr options, @@ -71,9 +71,15 @@ CorpusBase::CorpusBase(const std::vector& paths, } for(auto path : paths_) { - UPtr strm(new io::InputFileStream(path)); - ABORT_IF(strm->empty(), "File '{}' is empty", path); - files_.emplace_back(std::move(strm)); + // This constructor is also used in self-adaptive translation and it needs + // support for reading translation inputs from stdin + if(path == "stdin" || path == "-") + files_.emplace_back(new std::istream(std::cin.rdbuf())); + else { + UPtr strm(new io::InputFileStream(path)); + ABORT_IF(strm->empty(), "File '{}' is empty", path); + files_.emplace_back(std::move(strm)); + } } initEOS(/*training=*/true); diff --git a/src/data/text_input.h b/src/data/text_input.h index 98d991bcb..e3ebc42eb 100644 --- a/src/data/text_input.h +++ b/src/data/text_input.h @@ -43,7 +43,12 @@ class TextInput : public DatasetBase { SentenceTuple next() override; void shuffle() override {} - void reset() override {} + void reset() override { + for (auto& file : files_) { + file->clear(); + file->seekg(0); + } + } iterator begin() override { return iterator(*this); } iterator end() override { return iterator(); } diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index 7e2a57040..5c3632b6f 100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -184,6 +184,15 @@ class ExpressionGraph : public std::enable_shared_from_this { kvParams.second->clear(); } + + /** + * Call `clear()` on each of the parameters in the graph + */ + void clearParams() { + for(auto kvParams : paramsByElementType_) + kvParams.second->clear(); + } + /** * Set device options used to run the graph. * @param deviceId a struct type which stores device no. (size_t) @@ -230,6 +239,22 @@ class ExpressionGraph : public std::enable_shared_from_this { namespace_ = newNamespace; } + /** + * Extract graph parameters into a named map. + * @return A map with parameter names as keys and the corresponding graph elements as values + */ + const std::unordered_map & getParamsNamedMap() const { + if (paramsByElementType_.size() != 1) { + ABORT("Expected exactly one parameter datatype, got", paramsByElementType_.size()); + } + for(auto&& kvParams : paramsByElementType_) { + auto cur_param = kvParams.second; + return cur_param->getMap(); + } + ABORT("We should never get here"); // Just to satisfy compiler warnings; + return paramsByElementType_.find(Type::float32)->second->getMap(); + } + /** * Copy all parameter objects from one graph to current graph. * @param graph a pointer to a graph object @@ -738,11 +763,22 @@ class ExpressionGraph : public std::enable_shared_from_this { bool getThrowNaN() { return throwNaN_; } public: - /** Load model (mainly parameter objects) from array of io::Items */ - void load(const std::vector& ioItems, bool markReloaded = true) { + /** + * Load model (mainly parameter objects) from array of io::Items + * + * @param dropF0prefix modify the `io::Item` names upon loading by removing + * "F0::" prefixes. "F*::" prefixes are used to distinguish parameters from + * different scorers in the translation graph. This option is used by + * self-adaptive translation to support loading these `io::Item`s for + * training. + */ + void load(const std::vector& ioItems, bool markReloaded = true, bool dropF0prefix = false) { setReloaded(false); for(auto& item : ioItems) { std::string pName = item.name; + if (dropF0prefix && pName.substr(0, 4) == "F0::") { + pName = pName.substr(4); + } // skip over special parameters starting with "special:" if(pName.substr(0, 8) == "special:") continue; diff --git a/src/graph/parameters.h b/src/graph/parameters.h index 8b4af9dd5..1d7808c92 100644 --- a/src/graph/parameters.h +++ b/src/graph/parameters.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "common/definitions.h" @@ -20,9 +21,9 @@ class Parameters { protected: Type acceptedElementType_; // this parameter object only takes paramters of this type - /** @brief List of all parameter nodes of this expression graph. */ + /** List of all parameter nodes of this expression graph. */ std::vector params_; - std::map named_; + std::unordered_map named_; Ptr vals_; Ptr grads_; @@ -44,6 +45,22 @@ class Parameters { LOG(debug, "Destroyed parameter object of type {}", acceptedElementType_); } + /** + * Retrieves the memory corresponding to the parameter values. + * @return A vector of memorypieces each corresponding to a single parameter + */ + std::vector toMemoryPieces() { + std::vector res; + res.reserve(params_.size()); + auto read_it = begin(); + int i = 0; + for(; read_it != end(); ++read_it) { + i++; + res.push_back((*read_it)->val()->memory()); + } + return res; + } + auto begin() -> decltype(params_.begin()) { return params_.begin(); } auto end() -> decltype(params_.begin()) { return params_.end(); } diff --git a/src/models/amun.h b/src/models/amun.h index 135ce3597..fe9fcc670 100644 --- a/src/models/amun.h +++ b/src/models/amun.h @@ -35,9 +35,7 @@ class Amun : public EncoderDecoder { "use --type s2s"); } - void load(Ptr graph, - const std::vector& items, - bool /*markedReloaded*/ = true) override { + static void remapIoItems(std::vector &ioItems, bool tiedEmbeddinsSrcOrAll) { std::map nameMap = {{"decoder_U", "decoder_cell1_U"}, {"decoder_Ux", "decoder_cell1_Ux"}, @@ -86,10 +84,9 @@ class Amun : public EncoderDecoder { {"encoder_r_gamma1", "encoder_bi_r_gamma1"}, {"encoder_r_gamma2", "encoder_bi_r_gamma2"}}; - if(opt("tied-embeddings-src") || opt("tied-embeddings-all")) + if (tiedEmbeddinsSrcOrAll) nameMap["Wemb"] = "Wemb"; - auto ioItems = items; // map names and remove a dummy matrices for(auto it = ioItems.begin(); it != ioItems.end();) { // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size @@ -114,6 +111,14 @@ class Amun : public EncoderDecoder { it++; } } + } + + void load(Ptr graph, + const std::vector& items, + bool /*markedReloaded*/ = true) override { + auto ioItems = items; + // remap item names and remove dummy matrices + remapIoItems(ioItems, opt("tied-embeddings-src") || opt("tied-embeddings-all")); // load items into the graph graph->load(ioItems); } diff --git a/src/models/nematus.h b/src/models/nematus.h index aee8e3b04..0f3455db8 100644 --- a/src/models/nematus.h +++ b/src/models/nematus.h @@ -8,7 +8,7 @@ namespace marian { class Nematus : public EncoderDecoder { public: - Nematus(Ptr graph, Ptr options) : EncoderDecoder(graph, options), nameMap_(createNameMap()) { + Nematus(Ptr graph, Ptr options) : EncoderDecoder(graph, options), nameMap_(createNameMap(options)) { ABORT_IF(options_->get("enc-type") != "bidirectional", "--type nematus does not support other encoder " "type than bidirectional, use --type s2s"); @@ -25,34 +25,18 @@ class Nematus : public EncoderDecoder { "--dec-cell-high-depth > 1, use --type s2s"); } + static void remapIoItems(std::vector& ioItems, Ptr options) { + remapIoItems(ioItems, createNameMap(options), options); + } + void load(Ptr graph, const std::vector& items, bool /*markReloaded*/ = true) override { auto ioItems = items; - // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node - for(auto it = ioItems.begin(); it != ioItems.end();) { - // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size - // @TODO: consider dropping support for Nematus models - if(it->shape.size() == 1) { - int dim = it->shape[-1]; - it->shape.resize(2); - it->shape.set(0, 1); - it->shape.set(1, dim); - } - if(it->name == "decoder_c_tt") { - it = ioItems.erase(it); - } else if(it->name == "uidx") { - it = ioItems.erase(it); - } else if(it->name == "history_errs") { - it = ioItems.erase(it); - } else { - auto pair = nameMap_.find(it->name); - if(pair != nameMap_.end()) - it->name = pair->second; - it++; - } - } + // map names and remove a dummy matrices + remapIoItems(ioItems, nameMap_, options_); + // load items into the graph graph->load(ioItems); } @@ -65,9 +49,9 @@ class Nematus : public EncoderDecoder { load(graph, ioItems); } - void save(Ptr graph, - const std::string& name, - bool saveTranslatorConfig = false) override { + void save( + Ptr graph, const std::string& name, bool saveTranslatorConfig = false) + override { LOG(info, "Saving model to {}", name); // prepare reversed map @@ -86,7 +70,7 @@ class Nematus : public EncoderDecoder { } // add a dummy matrix 'decoder_c_tt' required for Amun and Nematus ioItems.emplace_back(); - ioItems.back().name = "decoder_c_tt"; + ioItems.back().name = "decoder_c_tt"; ioItems.back().shape = Shape({1, 0}); ioItems.back().bytes.emplace_back((char)0); @@ -103,56 +87,83 @@ class Nematus : public EncoderDecoder { std::map nameMap_; std::map nameMapRev_; - std::map createNameMap() { + static void remapIoItems(std::vector& ioItems, std::map nameMap, Ptr options) { + // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node + for(auto it = ioItems.begin(); it != ioItems.end();) { + // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size + // @TODO: consider dropping support for Nematus models + if(it->shape.size() == 1) { + int dim = it->shape[-1]; + it->shape.resize(2); + it->shape.set(0, 1); + it->shape.set(1, dim); + } + + if(it->name == "decoder_c_tt") { + it = ioItems.erase(it); + } else if(it->name == "uidx") { + it = ioItems.erase(it); + } else if(it->name == "history_errs") { + it = ioItems.erase(it); + } else { + auto pair = nameMap.find(it->name); + if(pair != nameMap.end()) + it->name = pair->second; + it++; + } + } + } + + static std::map createNameMap(Ptr options) { std::map nameMap = {{"decoder_U", "decoder_cell1_U"}, - {"decoder_Ux", "decoder_cell1_Ux"}, - {"decoder_W", "decoder_cell1_W"}, - {"decoder_Wx", "decoder_cell1_Wx"}, - {"decoder_b", "decoder_cell1_b"}, - {"decoder_bx", "decoder_cell1_bx"}, - {"decoder_U_nl", "decoder_cell2_U"}, - {"decoder_Ux_nl", "decoder_cell2_Ux"}, - {"decoder_Wc", "decoder_cell2_W"}, - {"decoder_Wcx", "decoder_cell2_Wx"}, - {"decoder_b_nl", "decoder_cell2_b"}, - {"decoder_bx_nl", "decoder_cell2_bx"}, - {"ff_logit_prev_W", "decoder_ff_logit_l1_W0"}, - {"ff_logit_lstm_W", "decoder_ff_logit_l1_W1"}, - {"ff_logit_ctx_W", "decoder_ff_logit_l1_W2"}, - {"ff_logit_prev_b", "decoder_ff_logit_l1_b0"}, - {"ff_logit_lstm_b", "decoder_ff_logit_l1_b1"}, - {"ff_logit_ctx_b", "decoder_ff_logit_l1_b2"}, - {"ff_logit_W", "decoder_ff_logit_l2_W"}, - {"ff_logit_b", "decoder_ff_logit_l2_b"}, - {"ff_state_W", "decoder_ff_state_W"}, - {"ff_state_b", "decoder_ff_state_b"}, - {"Wemb_dec", "decoder_Wemb"}, - {"Wemb", "encoder_Wemb"}, - {"encoder_U", "encoder_bi_U"}, - {"encoder_Ux", "encoder_bi_Ux"}, - {"encoder_W", "encoder_bi_W"}, - {"encoder_Wx", "encoder_bi_Wx"}, - {"encoder_b", "encoder_bi_b"}, - {"encoder_bx", "encoder_bi_bx"}, - {"encoder_r_U", "encoder_bi_r_U"}, - {"encoder_r_Ux", "encoder_bi_r_Ux"}, - {"encoder_r_W", "encoder_bi_r_W"}, - {"encoder_r_Wx", "encoder_bi_r_Wx"}, - {"encoder_r_b", "encoder_bi_r_b"}, - {"encoder_r_bx", "encoder_bi_r_bx"}, - {"ff_state_ln_s", "decoder_ff_state_ln_s"}, - {"ff_state_ln_b", "decoder_ff_state_ln_b"}, - {"ff_logit_prev_ln_s", "decoder_ff_logit_l1_ln_s0"}, - {"ff_logit_lstm_ln_s", "decoder_ff_logit_l1_ln_s1"}, - {"ff_logit_ctx_ln_s", "decoder_ff_logit_l1_ln_s2"}, - {"ff_logit_prev_ln_b", "decoder_ff_logit_l1_ln_b0"}, - {"ff_logit_lstm_ln_b", "decoder_ff_logit_l1_ln_b1"}, - {"ff_logit_ctx_ln_b", "decoder_ff_logit_l1_ln_b2"}}; + {"decoder_Ux", "decoder_cell1_Ux"}, + {"decoder_W", "decoder_cell1_W"}, + {"decoder_Wx", "decoder_cell1_Wx"}, + {"decoder_b", "decoder_cell1_b"}, + {"decoder_bx", "decoder_cell1_bx"}, + {"decoder_U_nl", "decoder_cell2_U"}, + {"decoder_Ux_nl", "decoder_cell2_Ux"}, + {"decoder_Wc", "decoder_cell2_W"}, + {"decoder_Wcx", "decoder_cell2_Wx"}, + {"decoder_b_nl", "decoder_cell2_b"}, + {"decoder_bx_nl", "decoder_cell2_bx"}, + {"ff_logit_prev_W", "decoder_ff_logit_l1_W0"}, + {"ff_logit_lstm_W", "decoder_ff_logit_l1_W1"}, + {"ff_logit_ctx_W", "decoder_ff_logit_l1_W2"}, + {"ff_logit_prev_b", "decoder_ff_logit_l1_b0"}, + {"ff_logit_lstm_b", "decoder_ff_logit_l1_b1"}, + {"ff_logit_ctx_b", "decoder_ff_logit_l1_b2"}, + {"ff_logit_W", "decoder_ff_logit_l2_W"}, + {"ff_logit_b", "decoder_ff_logit_l2_b"}, + {"ff_state_W", "decoder_ff_state_W"}, + {"ff_state_b", "decoder_ff_state_b"}, + {"Wemb_dec", "decoder_Wemb"}, + {"Wemb", "encoder_Wemb"}, + {"encoder_U", "encoder_bi_U"}, + {"encoder_Ux", "encoder_bi_Ux"}, + {"encoder_W", "encoder_bi_W"}, + {"encoder_Wx", "encoder_bi_Wx"}, + {"encoder_b", "encoder_bi_b"}, + {"encoder_bx", "encoder_bi_bx"}, + {"encoder_r_U", "encoder_bi_r_U"}, + {"encoder_r_Ux", "encoder_bi_r_Ux"}, + {"encoder_r_W", "encoder_bi_r_W"}, + {"encoder_r_Wx", "encoder_bi_r_Wx"}, + {"encoder_r_b", "encoder_bi_r_b"}, + {"encoder_r_bx", "encoder_bi_r_bx"}, + {"ff_state_ln_s", "decoder_ff_state_ln_s"}, + {"ff_state_ln_b", "decoder_ff_state_ln_b"}, + {"ff_logit_prev_ln_s", "decoder_ff_logit_l1_ln_s0"}, + {"ff_logit_lstm_ln_s", "decoder_ff_logit_l1_ln_s1"}, + {"ff_logit_ctx_ln_s", "decoder_ff_logit_l1_ln_s2"}, + {"ff_logit_prev_ln_b", "decoder_ff_logit_l1_ln_b0"}, + {"ff_logit_lstm_ln_b", "decoder_ff_logit_l1_ln_b1"}, + {"ff_logit_ctx_ln_b", "decoder_ff_logit_l1_ln_b2"}}; // add mapping for deep encoder cells std::vector suffixes = {"_U", "_Ux", "_b", "_bx"}; - for(int i = 1; i < options_->get("enc-cell-depth"); ++i) { + for(int i = 1; i < options->get("enc-cell-depth"); ++i) { std::string num1 = std::to_string(i); std::string num2 = std::to_string(i + 1); for(auto suf : suffixes) { @@ -161,7 +172,7 @@ class Nematus : public EncoderDecoder { } } // add mapping for deep decoder cells - for(int i = 3; i <= options_->get("dec-cell-base-depth"); ++i) { + for(int i = 3; i <= options->get("dec-cell-base-depth"); ++i) { std::string num1 = std::to_string(i - 2); std::string num2 = std::to_string(i); for(auto suf : suffixes) @@ -186,20 +197,20 @@ class Nematus : public EncoderDecoder { // Amun has only CPU decoder for deep Nematus models amun["cpu-threads"] = 16; amun["gpu-threads"] = 0; - amun["maxi-batch"] = 1; - amun["mini-batch"] = 1; - - auto vocabs = options_->get>("vocabs"); - amun["source-vocab"] = vocabs[0]; - amun["target-vocab"] = vocabs[1]; - amun["devices"] = options_->get>("devices"); - amun["normalize"] = true; - amun["beam-size"] = 5; + amun["maxi-batch"] = 1; + amun["mini-batch"] = 1; + + auto vocabs = options_->get>("vocabs"); + amun["source-vocab"] = vocabs[0]; + amun["target-vocab"] = vocabs[1]; + amun["devices"] = options_->get>("devices"); + amun["normalize"] = true; + amun["beam-size"] = 5; amun["relative-paths"] = false; amun["scorers"]["F0"]["path"] = name; amun["scorers"]["F0"]["type"] = "nematus2"; - amun["weights"]["F0"] = 1.0f; + amun["weights"]["F0"] = 1.0f; io::OutputFileStream out(name + ".amun.yml"); out << amun; diff --git a/src/tensors/gpu/swap.cu b/src/tensors/gpu/swap.cu new file mode 100644 index 000000000..1528f3860 --- /dev/null +++ b/src/tensors/gpu/swap.cu @@ -0,0 +1,15 @@ +#include "cuda_helpers.h" +#include "swap.h" +void copyCpuToGpu(const char * in, char * gpuOut); +void copyGpuToGpu(const char * in, char * gpuOut); + +namespace marian { +namespace swapper { + +void copyCpuToGpu(char * gpuOut, const char * in, size_t count, const marian::DeviceId& deviceId) { + CUDA_CHECK(cudaSetDevice(deviceId.no)); + CUDA_CHECK(cudaMemcpy(gpuOut, in, count, cudaMemcpyHostToDevice)); +} + +} +} diff --git a/src/tensors/gpu/swap.h b/src/tensors/gpu/swap.h new file mode 100644 index 000000000..9de46e9e9 --- /dev/null +++ b/src/tensors/gpu/swap.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include "common/definitions.h" +#include "common/logging.h" + +namespace marian { +namespace swapper { + +#ifdef CUDA_FOUND +void copyCpuToGpu(char * gpuOut, const char * in, size_t count, const marian::DeviceId& deviceId); +#else +inline void copyCpuToGpu(char * gpuOut, const char * in, size_t count, const marian::DeviceId& deviceId) { + ABORT("Copy from CPU to GPU memory is only available with CUDA."); +} +#endif + +} +} diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 3cc3b2076..96dd31467 100644 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -534,10 +534,18 @@ class Scheduler : public TrainingObserver { } void actAfterEpoch(TrainingState& state) override { - // stop if data streaming from STDIN is stopped for a TSV input - std::string firstPath = options_->get>("train-sets")[0]; - if(options_->get("tsv", false) && (firstPath == "stdin" || firstPath == "-")) - endOfStdin_ = true; + // When running self-adaptive marian in server mode the "train-sets" + // option isn't present because the training sentences are passed in via the + // request body + if (options_->has("train-sets")) { + // Stop if data streaming from STDIN is stopped for a TSV input. + auto trainingSets = options_->get>("train-sets"); + if (trainingSets.size() > 0) { + std::string firstPath = options_->get>("train-sets")[0]; + if(options_->get("tsv", false) && (firstPath == "stdin" || firstPath == "-")) + endOfStdin_ = true; + } + } float factor = options_->get("lr-decay"); diff --git a/src/translator/output_collector.cpp b/src/translator/output_collector.cpp index 078be232b..eec63ff8a 100644 --- a/src/translator/output_collector.cpp +++ b/src/translator/output_collector.cpp @@ -6,6 +6,8 @@ namespace marian { +CollectorBase::~CollectorBase(){}; + OutputCollector::OutputCollector() : nextId_(0), printing_(new DefaultPrinting()) {} @@ -81,6 +83,10 @@ void OutputCollector::Write(long sourceId, StringCollector::StringCollector(bool quiet /*=false*/) : maxId_(-1), quiet_(quiet) {} +void StringCollector::Write(long sourceId, const std::string &best1, const std::string &bestn, bool) { + StringCollector::add(sourceId, best1, bestn); +} + void StringCollector::add(long sourceId, const std::string& best1, const std::string& bestn) { diff --git a/src/translator/output_collector.h b/src/translator/output_collector.h index 0e6bfc9f8..106ecbf26 100644 --- a/src/translator/output_collector.h +++ b/src/translator/output_collector.h @@ -44,7 +44,13 @@ class GeometricPrinting : public PrintingStrategy { long next_{10}; }; -class OutputCollector { +struct CollectorBase { + virtual ~CollectorBase() = 0; + virtual void Write(long sourceId, const std::string& best1, const std::string& bestn, bool nbest) + = 0; +}; + +class OutputCollector : public CollectorBase { public: OutputCollector(); OutputCollector(std::string outFile); @@ -57,7 +63,7 @@ class OutputCollector { void Write(long sourceId, const std::string& best1, const std::string& bestn, - bool nbest); + bool nbest) override; void setPrintingStrategy(Ptr strategy) { printing_ = strategy; @@ -72,11 +78,15 @@ class OutputCollector { std::mutex mutex_; }; -class StringCollector { +class StringCollector : public CollectorBase { public: StringCollector(bool quiet = false); StringCollector(const StringCollector&) = delete; + void Write(long sourceId, + const std::string& best1, + const std::string& bestn, + bool nbest) override; void add(long sourceId, const std::string& best1, const std::string& bestn); std::vector collect(bool nbest); diff --git a/src/translator/self_adaptive.h b/src/translator/self_adaptive.h new file mode 100644 index 000000000..85b4f3041 --- /dev/null +++ b/src/translator/self_adaptive.h @@ -0,0 +1,202 @@ +#pragma once + +#include "common/config.h" +#include "common/file_stream.h" +#include "data/batch_generator.h" +#include "data/text_input.h" +#include "models/model_task.h" +#include "training/scheduler.h" +#include "training/validator.h" +#include "translator/swappable.h" +#include "data/adaptive_context.h" + +namespace marian { + +using namespace data; + +/** + * Implementation of the self-adaptive translation mode. + * Self-adaptive translation means optionally using a set of context sentences + * (e.g., provided by a translation memory), that are similar to the + * translatable sentence, to train the model for a few iterations to fine-tune + * it before translating the given sentence. + */ +class TrainSelfAdaptive : public ModelTask, public ModelServiceTask { +public: + TrainSelfAdaptive(Ptr options) : options_(options) { + options_->set("shuffle", "none"); + // Validation options are disabled for self-adaptive marian because + // typically training would happen for only a few iterations and it seems to + // not make much sense to run validation metrics on the validation dataset + // then (especially if you care about translation performance). However, we + // have to manually set the early-stopping option as disabled because the + // scheduler crashes if it's not present. + options_->set("early-stopping", 0); + // Set up translator options + optionsTrans_ = New(options_->clone()); + // We will only ever translate a single sentence at a time because dynamic + // adaptation happens per sentence + optionsTrans_->set("mini-batch", 1); + optionsTrans_->set("maxi-batch", 1); + auto maxTranslationInput = options_->get("max-length-translate"); + optionsTrans_->set("max-length", maxTranslationInput); + auto translationWorkspace = options_->get("workspace-translate"); + optionsTrans_->set("workspace", translationWorkspace); + optionsTrans_->set("shuffle", "none"); + + auto modelFilename = options_->get("model"); + // Training has a single "model", translation can have multiple "models" in + // the general case. Adaptive options also take only a single "model" so we + // have to adapt translation options manually. + optionsTrans_->set>("models", {modelFilename}); + + // We mask the alignment option for training so that the alignment loss + // nodes (self-attention heads) don't get added to the graph (for + // transformers). Adding the alignment loss nodes and not supplying guided + // alignments during training results in a crash with "There are more (n) + // than one top most nodes for the backward pass". In self-adaptive + // translation we don't support training the alignments because they are + // likely to remain good enough after the few self-adaptive updates. + // + // TODO: regarding the above, make the alignment heads non-trainable; afaik, + // they are treated like regular attantion heads currently which might + // decrease alignment precision. + options_->set("alignment", ""); + + auto vocabPaths = options_->get>("vocabs"); + std::vector srcVocabPaths(vocabPaths.begin(), vocabPaths.end() - 1); + cpuModel_ = New(options_, modelFilename, srcVocabPaths, vocabPaths.back()); + auto translateEngine = New(optionsTrans_, 0); + translateSlot_ = New(translateEngine); + auto trainEngine = New(options_, 0); + trainSlot_ = New(trainEngine); + } + + /** + * Implementation for self-adaptive translation where data come from a + * web request. + * + * @param json Input data in JSON. An "input" array of strings is expected to + * contain translatable sentences, each of which has a corresponding set of + * context sentences as a sub-array in the "context" array. + * + * @return JSON-encoded translations + */ + std::string run(const std::string& json) override { + // Check if input is in JSON + YAML::Node yaml = YAML::Load(json); + if(!yaml["input"]) { + LOG(warn, "No 'input' node found in the request"); + return ""; + } + + // Get input sentences + auto input = yaml["input"].as(); + auto testSet = New(std::vector({input}), cpuModel_->SrcVocabs(), optionsTrans_); + + // Prepare batches + auto testBatches = New>(testSet, optionsTrans_); + testBatches->prepare(); + + // Initialize output printing + auto collector = New(); + + // Get training sentences + std::vector> contexts; + if(yaml["context"]) + contexts = yaml["context"].as>>(); + + LOG(info, "Running..."); + + adaptAndTranslate(testBatches, contexts.begin(), contexts.end(), collector); + + auto translations = collector->collect(options_->get("n-best")); + YAML::Emitter output; + output << YAML::DoubleQuoted << YAML::Flow << utils::join(translations, "\\n"); + return "{\"output\":" + std::string(output.c_str()) + "}"; + } + + /** + * Implementation for self-adaptive translation where inputs and + * outputs are specified in CLI options. + */ + void run() override { + // Initialize input data + auto srcPaths = options_->get>("input"); + auto testSet = New(srcPaths, cpuModel_->SrcVocabs(), optionsTrans_); + + // Prepare batches + auto testBatches = New>(testSet, optionsTrans_); + testBatches->prepare(); + + // Initialize output printing + auto collector = New(options_->get("output")); + if(options_->get("quiet-translation")) + collector->setPrintingStrategy(New()); + + // Initialize train data + auto trainPaths = options_->get>("train-sets"); + auto trainSets = New(trainPaths); + + LOG(info, "Running..."); + + adaptAndTranslate(testBatches, trainSets->begin(), trainSets->end(), collector); + } + +private: + Ptr options_; // Options for training + Ptr optionsTrans_; // Options for translator + Ptr cpuModel_; // Holds model parameters and vocabularies + Ptr trainSlot_; // Performs model training + Ptr translateSlot_; // Performs translation with the model + bool needsSwitching_ = true; // Tracks whether translate slot's model needs to be reset + + template + void adaptAndTranslate( + Ptr> testBatches, + Iterator trainBegin, + Iterator trainEnd, + Ptr collector) { + auto printer = New(optionsTrans_, cpuModel_->TrgVocab()); + + for(auto testBatch : *testBatches) { + ABORT_IF(trainBegin == trainEnd, "Context batches ran out before test batches"); + + auto trainSet = *trainBegin; + ++trainBegin; + + if(!trainSet.empty()) { + LOG(info, "Got {} context sentences", trainSet.size()); + trainSlot_->SetModel(cpuModel_); + trainSlot_->Train(trainSet); + translateSlot_->PointToParams(*trainSlot_); + translate(testBatch, collector, printer); + needsSwitching_ = true; + } else { + LOG(info, "No context"); + if(needsSwitching_) { + translateSlot_->Load(*cpuModel_); + needsSwitching_ = false; + } + translate(testBatch, collector, printer); + } + } + } + + void translate(Ptr batch, + Ptr collector, + Ptr printer) { + auto histories = translateSlot_->Translate(batch); + + for(auto history : histories) { + std::stringstream best1; + std::stringstream bestn; + printer->print(history, best1, bestn); + collector->Write(history->getLineNum(), + best1.str(), + bestn.str(), + options_->get("n-best")); + } + } +}; +} diff --git a/src/translator/server_common.h b/src/translator/server_common.h new file mode 100644 index 000000000..94cec33f4 --- /dev/null +++ b/src/translator/server_common.h @@ -0,0 +1,60 @@ +#include "marian.h" +#include "translator/beam_search.h" +#include "translator/translator.h" +#include "common/timer.h" +#include "common/utils.h" + +#include "3rd_party/simple-websocket-server/server_ws.hpp" + +typedef SimpleWeb::SocketServer WSServer; + +namespace marian { + +int runServer(Ptr task, Ptr options) { + auto quiet = options->get("quiet-translation"); + + // Initialize web server + WSServer server; + server.config.port = (short)options->get("port", 8080); + + auto &translate = server.endpoint["^/translate/?$"]; + + translate.on_message = [&task, quiet](Ptr connection, + Ptr message) { + // Get input text + auto inputText = message->string(); + auto sendStream = std::make_shared(); + + // Translate + timer::Timer timer; + auto outputText = task->run(inputText); + *sendStream << outputText << std::endl; + if(!quiet) + LOG(info, "Translation took: {:.5f}s", timer.elapsed()); + + // Send translation back + connection->send(sendStream, [](const SimpleWeb::error_code &ec) { + if(ec) + LOG(error, "Error sending message: ({}) {}", ec.value(), ec.message()); + }); + }; + + // Error Codes for error code meanings + // http://www.boost.org/doc/libs/1_55_0/doc/html/boost_asio/reference.html + translate.on_error = [](Ptr /*connection*/, + const SimpleWeb::error_code &ec) { + LOG(error, "Connection error: ({}) {}", ec.value(), ec.message()); + }; + + // Start server thread + std::thread serverThread([&server]() { + server.start([](unsigned short port) { + LOG(info, "Server is listening on port {}", port); + }); + }); + + serverThread.join(); + + return 0; +} +} // namespace marian diff --git a/src/translator/swappable.cpp b/src/translator/swappable.cpp new file mode 100644 index 000000000..35c3cb3f1 --- /dev/null +++ b/src/translator/swappable.cpp @@ -0,0 +1,271 @@ +#include "translator/swappable.h" +#include +#include "common/io.h" +#include "common/logging.h" +#include "common/timer.h" +#include "data/corpus.h" +#include "data/text_input.h" +#include "marian.h" +#include "models/amun.h" +#include "models/nematus.h" +#include "tensors/gpu/swap.h" +#include "translator/beam_search.h" +#include "translator/translator.h" + +namespace marian { + +namespace { + DeviceId LookupGPU(const Ptr options, size_t deviceIdx) { + auto devices = Config::getDevices(options); + ABORT_IF(deviceIdx >= devices.size(), "GPU device index higher than configured."); + return devices[deviceIdx]; + } +} // namespace + +// For debugging memory +void get(std::vector &out, MemoryPiece::PtrType mem, Ptr backend) { + out.resize(mem->size()); +#ifdef CUDA_FOUND + gpu::copy(backend, mem->data(), mem->data() + mem->size(), out.data()); +#endif +} + +GPUEngineTrain::GPUEngineTrain(Ptr options, size_t deviceIdx) + : options_(options), myDeviceId_(LookupGPU(options, deviceIdx)) { + ABORT_IF(myDeviceId_.type == DeviceType::cpu, "Swappable slot only works for GPU devices."); + options_->set("inference", false); + options_->set("shuffle", "none"); + + // There is no need to initialize the graph or builder here because that's done before + // each Train() invokation +} + +void GPUEngineTrain::RecreateGraphAndBuilder() { + graph_ = New(); + auto prec = options_->get>("precision", {"float32"}); + graph_->setDefaultElementType(typeFromString(prec[0])); + graph_->setDevice(myDeviceId_); + graph_->reserveWorkspaceMB(options_->get("workspace")); + + builder_ = models::createCriterionFunctionFromOptions(options_, models::usage::training); +} + +GPUEngineTrain::~GPUEngineTrain() {} + +SwappableModelTrainer::SwappableModelTrainer(Ptr gpu) : engine_(gpu) { +} + +SwappableModelTrainer::~SwappableModelTrainer() { +} + +void SwappableModelTrainer::SetModel(Ptr from) { + srcVocabs_ = from->SrcVocabs(); + trgVocab_ = from->TrgVocab(); + cpuModel_ = from; +} + +std::vector SwappableModelTrainer::Parameters() const { + return engine_->graph_->params()->toMemoryPieces(); +} + +void SwappableModelTrainer::Train(const std::vector &input) { + ABORT_IF(!trgVocab_, "GPULoadedModelTrain needs to be overwritten by a CPU model first."); + + auto state = New(engine_->options_->get("learn-rate")); + auto scheduler = New(engine_->options_, state); + auto optimizer = Optimizer(engine_->options_); + scheduler->registerTrainingObserver(scheduler); + scheduler->registerTrainingObserver(optimizer); + + std::vector> allVocabs; + allVocabs.reserve(srcVocabs_.size() + 1); + allVocabs.insert(allVocabs.end(), srcVocabs_.begin(), srcVocabs_.end()); + allVocabs.emplace_back(trgVocab_); + auto corpus = New(input, allVocabs, engine_->options_); + data::BatchGenerator batchGenerator(corpus, engine_->options_, nullptr, false); + + // We reset the training graph to the original model parameters to prepare + // for adapting it to the new inputs + engine_->RecreateGraphAndBuilder(); + engine_->graph_->load(cpuModel_->Parameters(), true, true); + + scheduler->started(); + while(scheduler->keepGoing()) { + batchGenerator.prepare(); + + // LOG(info, "## NEW BATCHES"); + for(auto&& batch : batchGenerator) { + if(!scheduler->keepGoing()) + break; + + // LOG(info, "### NEW BATCH"); + // Make an update step on the copy of the model + auto lossNode = engine_->builder_->build(engine_->graph_, batch); + engine_->graph_->forward(); + StaticLoss loss = *lossNode; + engine_->graph_->backward(); + + // Notify optimizer and scheduler + optimizer->update(engine_->graph_, 1); + scheduler->update(loss, batch); + } + if(scheduler->keepGoing()) + scheduler->increaseEpoch(); + } + scheduler->finished(); +} + + + + + // ##### ^ above is stuff for runtime domain adaptation + + + + + +void GPUEngineTranslate::SwapPointers(std::vector &with) { + auto write_it = graph_->params()->begin(); + auto read_it = with.begin(); + for (; read_it != with.end(); ++write_it, ++read_it) { + std::swap(*(*write_it)->val()->memory(), **read_it); + } +} + +GPUEngineTranslate::GPUEngineTranslate(Ptr options, size_t deviceIdx) + : options_(options), graph_(New(true)), myDeviceId_(LookupGPU(options, deviceIdx)), allocator_(myDeviceId_, 0, 128 * 1048576) { + ABORT_IF(myDeviceId_.type == DeviceType::cpu, "Swappable slot only works for GPU devices."); + options_->set("inference", true); + options_->set("shuffle", "none"); + + // Create graph + auto prec = options_->get>("precision", {"float32"}); + graph_->setDefaultElementType(typeFromString(prec[0])); + graph_->setDevice(myDeviceId_); + graph_->reserveWorkspaceMB(options_->get("workspace")); + + scorers_ = createScorers(options_); + for (auto scorer : scorers_) { + scorer->init(graph_); + // TODO lexical shortlists are not supported yet. + } + graph_->forward(); + // TODO: reach into graph_->params() private members and free the parameter memory. +} + +GPUEngineTranslate::~GPUEngineTranslate() {} + +GPULoadedModel::GPULoadedModel(Ptr gpu) : engine_(gpu) { + for (auto ¶m : *engine_->graph_->params()) { + names_.push_back(param->name()); + parameters_.push_back(engine_->allocator_.alloc(param->val()->memory()->size())); + } +} + +GPULoadedModel::~GPULoadedModel() { + for (MemoryPiece::PtrType &p : parameters_) { + engine_->allocator_.free(p); + } +} + +void GPULoadedModel::PointToParams(const SwappableModelTrainer &from) { + ABORT_IF(engine_->myDeviceId_ != from.engine_->myDeviceId_, "TODO: copy across GPUs."); + srcVocabs_ = from.srcVocabs_; + trgVocab_ = from.trgVocab_; + parameters_ = from.Parameters(); +} + +void GPULoadedModel::Load(const CPULoadedModel &from) { + srcVocabs_ = from.SrcVocabs(); + trgVocab_ = from.TrgVocab(); + auto fromParams = from.Parameters(); + + auto printParamsAndExit = [&]() { + std::ostringstream paramNames; + for(size_t i = 0; i < parameters_.size(); ++i) { + paramNames << " TO (" << names_[i] << ") size: " << parameters_[i]->size() << "\n"; + } + for(size_t i = 0; i < fromParams.size(); ++i) { + paramNames << " FROM (" << fromParams[i].name << ") size: " << fromParams[i].size() << "\n"; + } + LOG(error, + "Attempting to load parameters with mismatched names or sizes:\n{}", + paramNames.str()); + ABORT("Attempting to load parameters with mismatched names or sizes."); + }; + + // Sanity check + if (parameters_.size() != fromParams.size()) + printParamsAndExit(); + + for(size_t i = 0; i < parameters_.size(); ++i) { + // Sanity check + // Not sure if that's ok, but we don't check for size equality because for + // some reason the target memory location sometimes can be bigger + if (names_[i] != fromParams[i].name || parameters_[i]->size() < fromParams[i].size()) + printParamsAndExit(); + + swapper::copyCpuToGpu(reinterpret_cast(parameters_[i]->data()), + fromParams[i].data(), + fromParams[i].size(), + engine_->myDeviceId_); + } +} + +Histories GPULoadedModel::Translate(const Ptr batch) { + ABORT_IF(!trgVocab_, "GPULoadedModel needs to be overwritten by a CPU model first."); + // std::vector outvec; + // get(outvec, parameters_[0], engine_->graph_->getBackend()); + engine_->SwapPointers(parameters_); + + BeamSearch search(engine_->options_, engine_->scorers_, trgVocab_); + Histories ret; + ret.reserve(batch->size()); // TODO: input.size() was here previously, this is likely wrong + + auto result = search.search(engine_->graph_, batch); + ret.insert(ret.end(), result.begin(), result.end()); + + std::sort(ret.begin(), ret.end(),[](marian::Ptr a, marian::Ptr b){return a->getLineNum() < b->getLineNum();}); + + engine_->SwapPointers(parameters_); + return ret; +} + +CPULoadedModel::CPULoadedModel(Ptr options, const std::string ¶meters, const std::vector &sourceVocabPaths, const std::string &targetVocabPath) + : parameters_(io::loadItems(parameters)) { + //Remap the parameter names if the model uses an older naming convention + if (options->get("type") == "amun") { + bool tied = options->get("tied-embeddings-src") || options->get("tied-embeddings-all"); + Amun::remapIoItems(parameters_, tied); + } else if (options->get("type") == "nematus") { + Nematus::remapIoItems(parameters_, options); + } + + // Find the special element and remove it: + auto pred = [](const io::Item &item) { return item.name == "special:model.yml"; }; + auto special_it = std::find_if(parameters_.begin(), parameters_.end(), pred); + if (special_it != parameters_.end()) { + parameters_.erase(special_it); + } + + // Prepare the name so that it matches the named map + for (auto&& item : parameters_) { + item.name = "F0::" + item.name; + } + // Sort by name to match params order. + std::sort(parameters_.begin(), parameters_.end(), [](const io::Item &a, const io::Item &b){return a.name < b.name;}); + + // Load source vocabs. + const std::vector &maxVocabs = options->get>("dim-vocabs"); + for(size_t i = 0; i < sourceVocabPaths.size(); ++i) { + Ptr vocab = New(options, i); + vocab->load(sourceVocabPaths[i], maxVocabs[i]); + srcVocabs_.emplace_back(vocab); + } + + // Load target vocab. + trgVocab_ = New(options, sourceVocabPaths.size()); + trgVocab_->load(targetVocabPath); +} + +} // namespace marian diff --git a/src/translator/swappable.h b/src/translator/swappable.h new file mode 100644 index 000000000..e6db24280 --- /dev/null +++ b/src/translator/swappable.h @@ -0,0 +1,201 @@ +#pragma once +/** + * Support for swapping and resetting models for the self-adaptive translation + * mode. The intended use case is to store a read-only copy of the model in + * `CPULoadedModel`, optionally train on a copy of the parameters using + * `SwappableModelTrainer` and then transfer either the trained or original + * model parameters into `GPULoadedModel` for translation. `GPUEngineTrain` and + * `GPUEngineTranslate` are used for storing the expression graphs for training + * and translation, respectively, and other related things. Translation on the + * CPU currently isn't supported. + * + * Originally this code was intended to allow multiple models to share a single + * GPU for translation and be swapped into GPU memory only when needed. However, + * parts of it, that weren't needed for self-adaptive translation, have been + * trimmed down since then. Look here + * https://github.com/kpu/marian-dev/blob/90e161fa9fcb3e3ba1467c76a10b1fc7f9390b6d/src/translator/swappable.h + * if you want to revive this functionality. + */ +#include "common/io.h" +#include "data/vocab.h" +#include "marian.h" +#include "training/scheduler.h" +#include "translator/history.h" + +#include +#include +namespace marian { + +class SwappableModelTrainer; + +class Scorer; + +class GPULoadedModel; +class CPULoadedModel; + + +/** + * The class wraps an expression graph and a model builder that are used by + * `SwappableModelTrainer` for training a model. + */ +class GPUEngineTrain { +private: + friend class SwappableModelTrainer; + friend class GPULoadedModel; + Ptr options_; + Ptr graph_; + Ptr builder_; + const DeviceId myDeviceId_; + + void RecreateGraphAndBuilder(); + +public: + /** + * @param options The marian options object + * @param deviceNum The index of the device you want to use for this slot. + * Note that this is not the deviceID but the index of the device in the + * array of supplied devices. Eg if you provide -d 0 3 5 and you want the + * Slot to run on GPU 3, you provide deviceNum=1. + */ + explicit GPUEngineTrain(Ptr options, size_t deviceNum); + + ~GPUEngineTrain(); +}; + +/** + * Wraps a `GPUEngineTrain` and a `CPULoadedModel` and performs model + * training. + * + * This class is created with self-adaptive translation in mind. Each invocation + * of Train() resets the model parameters at the start of training. + */ +class SwappableModelTrainer { + private: + friend class GPULoadedModel; + + Ptr engine_; + + Ptr cpuModel_; + std::vector> srcVocabs_; + Ptr trgVocab_; + + public: + SwappableModelTrainer(Ptr gpu); + + ~SwappableModelTrainer(); + + const std::vector> &SrcVocabs() const { return srcVocabs_; } + + Ptr TrgVocab() const { return trgVocab_; } + + /// Change the internal pointers to vocabularies and CPULoadedModel to + /// different ones + void SetModel(Ptr from); + + std::vector Parameters() const; + + /** + * Resets the training graph, reloads the model parameters and trains + * the model on the provided inputs. + * + * Intended to be used in the self-adaptive translation mode -- training is + * always performed on the original model parameters, each training + * invocation resets previous changes. + * + * @param input Training data. A vector representing a parallel corpus -- + * vector elements are the different sides of a parallel corpus, each is a + * newline separated set of sentences in a single language. + */ + void Train(const std::vector &input); +}; + +/** + * The class wraps an expression graph and scorers that are used by + * `GPULoadedModel` for translation. + */ +class GPUEngineTranslate { +private: + friend class GPULoadedModel; + Ptr options_; + Ptr graph_; + std::vector> scorers_; + const DeviceId myDeviceId_; + Allocator allocator_; + + void SwapPointers(std::vector &with); + +public: + /** + * @param options The marian options object + * @param deviceNum The index of the device you want to use for this slot. + * Note that this is not the deviceID but the index of the device in the + * array of supplied devices. Eg if you provide -d 0 3 5 and you want the + * Slot to run on GPU 3, you provide deviceNum=1. + */ + explicit GPUEngineTranslate(Ptr options, size_t deviceNum); + + ~GPUEngineTranslate(); +}; + +/** A model loaded on the GPU that can be overwritten from CPU. Facilitates + * translation with the model. + */ +class GPULoadedModel { + private: + Ptr engine_; + + std::vector names_; + std::vector parameters_; + std::vector> srcVocabs_; + Ptr trgVocab_; + + public: + GPULoadedModel(Ptr gpu); + + ~GPULoadedModel(); + + const std::vector> &SrcVocabs() const { return srcVocabs_; } + + Ptr TrgVocab() const { return trgVocab_; } + + /// Overwrite this model with parameters from a different one. + void Load(const CPULoadedModel &from); + /** + * Set the internal shared pointers to model parameters and + * vocabularies to different ones + * + * The effect is similar to `Load()` but nothing is copied in the process. + * + * @param from Swappable model trainer from which to take the shared + * pointers to model parameters and vocabularies. + */ + void PointToParams(const SwappableModelTrainer &from); + + Histories Translate(const Ptr batch); +}; + +/** + * A model loaded on the CPU. Holds model parameters and vocabularies. + */ +class CPULoadedModel { + private: + std::vector parameters_; + std::vector> srcVocabs_; + Ptr trgVocab_; + + public: + // The parts of Options that relate to model and vocab are ignored. The + // files provided will be loaded. + CPULoadedModel(Ptr options, + const std::string ¶meters, + const std::vector &sourceVocabPaths, + const std::string &targetVocabPath); + + const std::vector &Parameters() const { return parameters_; } + + const std::vector> &SrcVocabs() const { return srcVocabs_; } + + Ptr TrgVocab() const { return trgVocab_; } +}; + +} // namespace marian