diff --git a/CMakeLists.txt b/CMakeLists.txt index d065fb4..aa4712a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,12 @@ project(lstm-parser) cmake_minimum_required(VERSION 2.8 FATAL_ERROR) +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE RelWithDebInfo) +endif(NOT CMAKE_BUILD_TYPE) + set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) -set(CMAKE_CXX_FLAGS "-Wall -std=c++11 -O3 -g") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -std=c++14") enable_testing() @@ -16,7 +20,7 @@ if(DEFINED ENV{BOOST_ROOT}) set(Boost_NO_SYSTEM_PATHS ON) endif() set(Boost_REALPATH ON) -find_package(Boost COMPONENTS program_options serialization iostreams REQUIRED) +find_package(Boost COMPONENTS program_options serialization iostreams regex filesystem REQUIRED) include_directories(${Boost_INCLUDE_DIR}) set(LIBS ${LIBS} ${Boost_LIBRARIES}) @@ -26,6 +30,6 @@ include_directories(${EIGEN3_INCLUDE_DIR}) #configure_file(${CMAKE_CURRENT_SOURCE_DIR}/config.h.cmake ${CMAKE_CURRENT_BINARY_DIR}/config.h) -add_subdirectory(cnn/cnn) +add_subdirectory(cnn) # add_subdirectory(cnn/examples) -add_subdirectory(parser) +add_subdirectory(parser) \ No newline at end of file diff --git a/README.md b/README.md index 1635877..a82d88e 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Given a `training.conll` file and a `development.conll` formatted according to t java -jar ParserOracleArcStdWithSwap.jar -t -1 -l 1 -c training.conll > trainingOracle.txt java -jar ParserOracleArcStdWithSwap.jar -t -1 -l 1 -c development.conll > devOracle.txt - parser/lstm-parse -P -t trainingOracle.txt -d devOracle.txt --hidden_dim 100 --lstm_input_dim 100 -w sskip.100.vectors --rel_dim 20 --action_dim 20 + parser/lstm-parse --train -t trainingOracle.txt -d devOracle.txt --hidden_dim 100 --lstm_input_dim 100 --words sskip.100.vectors --rel_dim 20 --action_dim 20 --use_pos_tags Link to the word vectors used in the ACL 2015 paper for English: [sskip.100.vectors](https://drive.google.com/file/d/0B8nESzOdPhLsdWF2S1Ayb1RkTXc/view?usp=sharing). @@ -43,7 +43,7 @@ There is a pretrained model for English [here](http://www.cs.cmu.edu/~jdunietz/h Given a `test.conll` file formatted according to the [CoNLL data format](http://ilk.uvt.nl/conll/#dataformat): - parser/lstm-parse -m english_pos_2_32_100_20_100_12_20.params -t test.conll + parser/lstm-parse -m english_pos_2_32_100_20_100_12_20.params -T test.conll -s If you are not using the pretrained model, you will need to replace the `.params` argument with the name of your own trained model file. diff --git a/cnn/CMakeLists.txt b/cnn/CMakeLists.txt index e8408b4..17fc1ec 100644 --- a/cnn/CMakeLists.txt +++ b/cnn/CMakeLists.txt @@ -1,6 +1,10 @@ project(cnn) cmake_minimum_required(VERSION 2.8 FATAL_ERROR) +if(NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_BUILD_TYPE RelWithDebInfo) +endif(NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) # CNN uses Eigen which exploits modern CPU architectures. To get the @@ -10,7 +14,7 @@ set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) # 3. try compiler options like -march=native or other architecture # flags (the compiler does not always make the best configuration # decisions without help) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -funroll-loops -Wall -std=c++11 -Ofast -g -DEIGEN_FAST_MATH -march=native") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -funroll-loops -Wall -std=c++14 -Ofast -g -DEIGEN_FAST_MATH -march=native") enable_testing() @@ -63,9 +67,11 @@ else() endif() if(BACKEND MATCHES "^eigen$") - set(WITH_EIGEN_BACKEND 1) + set(WITH_CUDA_BACKEND 0 CACHE INTERNAL "" FORCE) + set(WITH_EIGEN_BACKEND 1 CACHE INTERNAL "" FORCE) elseif(BACKEND MATCHES "^cuda$") - set(WITH_CUDA_BACKEND 1) + set(WITH_CUDA_BACKEND 1 CACHE INTERNAL "" FORCE) + set(WITH_EIGEN_BACKEND 0 CACHE INTERNAL "" FORCE) else() message(SEND_ERROR "BACKEND must be eigen or cuda") endif() @@ -93,8 +99,12 @@ set(LIBS ${LIBS} ${CMAKE_THREAD_LIBS_INIT}) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/config.h.cmake ${CMAKE_CURRENT_BINARY_DIR}/config.h) include_directories(${CMAKE_CURRENT_BINARY_DIR}) +option(CNN_CORE_ONLY "If off, won't build extra dirs like tests and examples" ON) + add_subdirectory(cnn) -add_subdirectory(tests) -add_subdirectory(examples) -add_subdirectory(rnnlm) -enable_testing() +if(NOT CNN_CORE_ONLY) + add_subdirectory(tests) + add_subdirectory(examples) + add_subdirectory(rnnlm) + enable_testing() +endif(NOT CNN_CORE_ONLY) \ No newline at end of file diff --git a/cnn/cnn/CMakeLists.txt b/cnn/cnn/CMakeLists.txt index bfa85d0..6f66321 100644 --- a/cnn/cnn/CMakeLists.txt +++ b/cnn/cnn/CMakeLists.txt @@ -69,6 +69,8 @@ set(cnn_library_HDRS training.h ) +option(CNN_SHARED "Whether to build CNN shared libs" OFF) + if(WITH_CUDA_BACKEND) list(APPEND cnn_library_SRCS cuda.cc) @@ -99,20 +101,24 @@ file(GLOB TEST_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} tests/*.cc) # actual target: add_library(cnn STATIC ${cnn_library_SRCS} ${cnn_library_HDRS}) target_link_libraries(cnn ${LIBS}) -if(WITH_CUDA_BACKEND) +if(CNN_SHARED) + if(WITH_CUDA_BACKEND) add_library(gcnn_shared SHARED ${cnn_library_SRCS} ${cnn_library_HDRS}) target_link_libraries(gcnn_shared ${LIBS}) -else() + else() add_library(cnn_shared SHARED ${cnn_library_SRCS} ${cnn_library_HDRS}) target_link_libraries(cnn_shared ${LIBS}) -endif(WITH_CUDA_BACKEND) + endif(WITH_CUDA_BACKEND) +endif(CNN_SHARED) #add_library(cnn ${cnn_library_SRCS} ${cnn_library_HDRS} ${LIBS}) if(WITH_CUDA_BACKEND) set(CUDA_SEPARABLE_COMPILATION ON) list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_20,code=sm_20;-gencode;arch=compute_30,code=sm_30;-gencode;arch=compute_35,code=sm_35;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_52,code=sm_52;-gencode;arch=compute_52,code=compute_52;-std=c++11;-O2;-DVERBOSE;-Xcompiler;-fpic") SET(CUDA_PROPAGATE_HOST_FLAGS OFF) cuda_add_library(cnncuda STATIC gpu-ops.cu) - cuda_add_library(cnncuda_shared SHARED gpu-ops.cu) + if(CNN_SHARED) + cuda_add_library(cnncuda_shared SHARED gpu-ops.cu) + endif(CNN_SHARED) endif(WITH_CUDA_BACKEND) install(FILES ${cnn_library_HDRS} DESTINATION include/cnn) diff --git a/cnn/cnn/aligned-mem-pool.h b/cnn/cnn/aligned-mem-pool.h index 9a087b0..fa616f0 100644 --- a/cnn/cnn/aligned-mem-pool.h +++ b/cnn/cnn/aligned-mem-pool.h @@ -8,6 +8,8 @@ namespace cnn { class AlignedMemoryPool { public: + typedef size_t PoolState; + explicit AlignedMemoryPool(size_t cap, MemAllocator* a) : a(a) { sys_alloc(cap); zero_all(); @@ -36,6 +38,14 @@ class AlignedMemoryPool { bool is_shared() { return shared; } + + PoolState get_state() const { + return used; + } + + void restore_state(const PoolState& state) { + used = state; + } private: void sys_alloc(size_t cap) { capacity = a->round_up_align(cap); diff --git a/cnn/cnn/exec.cc b/cnn/cnn/exec.cc index bc8b799..4005ad4 100644 --- a/cnn/cnn/exec.cc +++ b/cnn/cnn/exec.cc @@ -10,6 +10,7 @@ ExecutionEngine::~ExecutionEngine() {} void SimpleExecutionEngine::invalidate() { num_nodes_evaluated = 0; + fxs->free(); } const Tensor& SimpleExecutionEngine::forward() { diff --git a/cnn/cnn/init.cc b/cnn/cnn/init.cc index 4e0a1a3..915a246 100644 --- a/cnn/cnn/init.cc +++ b/cnn/cnn/init.cc @@ -30,7 +30,7 @@ static void RemoveArgs(int& argc, char**& argv, int& argi, int n) { assert(argc >= 0); } -void Initialize(int& argc, char**& argv, unsigned random_seed, bool shared_parameters) { +unsigned Initialize(int& argc, char**& argv, unsigned random_seed, bool shared_parameters) { vector gpudevices; #if HAVE_CUDA cerr << "[cnn] initializing CUDA\n"; @@ -88,6 +88,8 @@ void Initialize(int& argc, char**& argv, unsigned random_seed, bool shared_param kSCALAR_ONE = default_device->kSCALAR_ONE; kSCALAR_ZERO = default_device->kSCALAR_ZERO; cerr << "[cnn] memory allocation done.\n"; + + return random_seed; } void Cleanup() { diff --git a/cnn/cnn/init.h b/cnn/cnn/init.h index e9e8fef..80a4b28 100644 --- a/cnn/cnn/init.h +++ b/cnn/cnn/init.h @@ -3,7 +3,7 @@ namespace cnn { -void Initialize(int& argc, char**& argv, unsigned random_seed = 0, bool shared_parameters = false); +unsigned Initialize(int& argc, char**& argv, unsigned random_seed = 0, bool shared_parameters = false); void Cleanup(); } // namespace cnn diff --git a/cnn/cnn/model.cc b/cnn/cnn/model.cc index 4bd35d4..7179e4d 100644 --- a/cnn/cnn/model.cc +++ b/cnn/cnn/model.cc @@ -160,6 +160,7 @@ void LookupParameters::clear() { Model::~Model() { for (auto p : all_params) delete p; + default_device->mem->free(gradient_norm_scratch); } void Model::project_weights(float radius) { diff --git a/cnn/cnn/model.h b/cnn/cnn/model.h index 2e76194..b27dec9 100644 --- a/cnn/cnn/model.h +++ b/cnn/cnn/model.h @@ -61,6 +61,7 @@ struct LookupParameters : public ParametersBase { void squared_l2norm(float* sqnorm) const override; void g_squared_l2norm(float* sqnorm) const override; size_t size() const override; + size_t num_values() const { return values.size(); } void Initialize(unsigned index, const std::vector& val); void copy(const LookupParameters & val); @@ -103,6 +104,15 @@ struct LookupParameters : public ParametersBase { class Model { public: Model() : gradient_norm_scratch() {} + Model(const Model&) = delete; + Model(Model&& m) { + all_params = std::move(m.all_params); + lookup_params = std::move(m.lookup_params); + params = std::move(m.params); + // Free our scratch memory before claiming the other model's. + default_device->mem->free(gradient_norm_scratch); + gradient_norm_scratch = m.gradient_norm_scratch; + } ~Model(); float gradient_l2_norm() const; void reset_gradient(); diff --git a/cnn/cnn/tensor.h b/cnn/cnn/tensor.h index 0516fe7..599b97a 100644 --- a/cnn/cnn/tensor.h +++ b/cnn/cnn/tensor.h @@ -7,6 +7,7 @@ #include "cnn/dim.h" #include "cnn/random.h" #include "cnn/aligned-mem-pool.h" +#include "devices.h" #if HAVE_CUDA #include @@ -26,6 +27,7 @@ namespace cnn { #define EIGEN_BACKEND 1 typedef float real; +extern Device* default_device; // for allocating memory on a load struct Tensor { Tensor() = default; @@ -160,8 +162,12 @@ struct Tensor { float* vc = static_cast(std::malloc(d.size() * sizeof(float))); ar & boost::serialization::make_array(vc, d.size()); CUDA_CHECK(cudaMemcpyAsync(v, vc, d.size() * sizeof(float), cudaMemcpyHostToDevice)); + free(vc); #else - v = static_cast(_mm_malloc(d.size() * sizeof(float), 32)); + // UGLY HACK to avoid memory leak: node values and gradients don't get + // stored to disk; only parameters. So allocate memory for loading from the + // parameters pool. + v = static_cast(default_device->ps->allocate(d.size() * sizeof(float))); ar & boost::serialization::make_array(v, d.size()); #endif } diff --git a/parser/CMakeLists.txt b/parser/CMakeLists.txt index 3ac3352..80fee68 100644 --- a/parser/CMakeLists.txt +++ b/parser/CMakeLists.txt @@ -1,8 +1,19 @@ PROJECT(lstm-parser:parser) CMAKE_MINIMUM_REQUIRED(VERSION 2.8) -ADD_LIBRARY(lstm-parser-core lstm-parser.cc corpus.cc) -target_link_libraries(lstm-parser-core cnn ${Boost_LIBRARIES}) - +add_library(lstm-parser-core STATIC lstm-parser.cc corpus.cc + neural-transition-tagger.cpp) ADD_EXECUTABLE(lstm-parse lstm-parser-driver.cc) -target_link_libraries(lstm-parse lstm-parser-core ${Boost_LIBRARIES}) + +if(WITH_CUDA_BACKEND) + add_dependencies(lstm-parser-core cnncuda) + target_link_libraries(lstm-parser-core cnncuda) + CUDA_ADD_CUBLAS_TO_TARGET(lstm-parser-core) + + add_dependencies(lstm-parse cnncuda) + target_link_libraries(lstm-parse cnncuda) + CUDA_ADD_CUBLAS_TO_TARGET(lstm-parse) +endif(WITH_CUDA_BACKEND) + +target_link_libraries(lstm-parser-core cnn ${Boost_LIBRARIES}) +target_link_libraries(lstm-parse lstm-parser-core ${Boost_LIBRARIES}) \ No newline at end of file diff --git a/parser/corpus.cc b/parser/corpus.cc index 1928cfa..639963e 100644 --- a/parser/corpus.cc +++ b/parser/corpus.cc @@ -16,17 +16,28 @@ constexpr unsigned Corpus::ROOT_TOKEN_ID; const string CorpusVocabulary::BAD0 = ""; const string CorpusVocabulary::UNK = ""; const string CorpusVocabulary::ROOT = ""; +// We assume that actions with arcs will be of the form +// "action-name(arc-label)". Allow any non-paren characters, followed by the +// label name in parens. (Group 1 is the label name.) +const boost::regex CorpusVocabulary::ARC_ACTION_REGEX( + {"[^\\(\\)]+\\(([^\\(\\)]+)\\)"}); const string ORACLE_ROOT_POS = "ROOT"; void ConllUCorpusReader::ReadSentences(const string& file, Corpus* corpus) const { string next_line; - map current_sentence_unk_surface_forms; - map current_sentence; - map current_sentence_pos; + // TODO: Replace this code with simpler Sentence-based code. + Sentence::SentenceUnkMap current_sentence_unk_surface_forms; + Sentence::SentenceMap current_sentence; + Sentence::SentenceMap current_sentence_pos; ifstream conll_file(file); + if (!conll_file) { + cerr << "Unable to open corpus file " << file << "; aborting" << endl; + abort(); + } + unsigned unk_word_symbol = corpus->vocab->GetWord(CorpusVocabulary::UNK); unsigned root_symbol = corpus->vocab->GetWord(CorpusVocabulary::ROOT); unsigned root_pos_symbol = corpus->vocab->GetPOS(CorpusVocabulary::ROOT); @@ -38,15 +49,11 @@ void ConllUCorpusReader::ReadSentences(const string& file, current_sentence_pos[Corpus::ROOT_TOKEN_ID] = root_pos_symbol; current_sentence_unk_surface_forms[Corpus::ROOT_TOKEN_ID] = ""; - corpus->sentences.push_back(move(current_sentence)); - current_sentence.clear(); - - corpus->sentences_pos.push_back(move(current_sentence_pos)); - current_sentence_pos.clear(); - - corpus->sentences_unk_surface_forms.push_back( - move(current_sentence_unk_surface_forms)); - current_sentence_unk_surface_forms.clear(); + corpus->sentences.emplace_back(*corpus->vocab); + corpus->sentences.back().words.swap(current_sentence); + corpus->sentences.back().poses.swap(current_sentence_pos); + corpus->sentences.back().unk_surface_forms.swap( + current_sentence_unk_surface_forms); } continue; } else if (next_line[0] == '#') { @@ -75,15 +82,16 @@ void ConllUCorpusReader::ReadSentences(const string& file, current_sentence[token_index] = word_id; current_sentence_pos[token_index] = corpus->vocab->GetPOS(pos); } -} + corpus->sentences.shrink_to_fit(); +} -void TrainingCorpus::CountSingletons() { +void ParserTrainingCorpus::CountSingletons() { // compute the singletons in the parser's training data map counts; for (const auto& sent : sentences) { - for (const auto& index_and_word_id : sent) { + for (const auto& index_and_word_id : sent.words) { counts[index_and_word_id.second]++; } } @@ -94,22 +102,130 @@ void TrainingCorpus::CountSingletons() { } -void TrainingCorpus::OracleTransitionsCorpusReader::LoadCorrectActions( - const string& file, TrainingCorpus* corpus) const { - // TODO: break up this function? - cerr << "Loading " << (is_training ? "training" : "dev") +void TrainingCorpus::OracleTransitionsCorpusReader::RecordWord( + const string& word, const string& pos, unsigned next_token_index, + TrainingCorpus* corpus, Sentence::SentenceMap* sentence, + Sentence::SentenceMap* sentence_pos, + Sentence::SentenceUnkMap* sentence_unk_surface_forms) const { + // We assume that we'll have seen all POS tags in training, so don't + // worry about OOV tags. + CorpusVocabulary* vocab = corpus->vocab; + unsigned pos_id = vocab->GetOrAddEntry(pos, &vocab->pos_to_int, + &vocab->int_to_pos); + + unsigned word_id; + if (is_training) { + unsigned num_words = vocab->CountWords(); // store for later check + word_id = vocab->GetOrAddWord(word, true); + if (vocab->CountWords() > num_words) { + // A new word was added; add its chars, too. + unsigned j = 0; + while (j < word.length()) { + unsigned char_utf8_len = UTF8Len(word[j]); + string next_utf8_char = word.substr(j, char_utf8_len); + vocab->GetOrAddEntry(next_utf8_char, &vocab->chars_to_int, + &vocab->int_to_chars); + j += char_utf8_len; + } + } else { + // It's an old word. Make sure it's marked as present in training. + vocab->int_to_training_word[word_id] = true; + } + } else { + // add an empty string for any token except OOVs (it is easy to + // recover the surface form of non-OOV using intToWords(id)). + // OOV word + if (corpus->USE_SPELLING) { + word_id = vocab->GetOrAddWord(word); // don't record as training + (*sentence_unk_surface_forms)[next_token_index] = ""; + } else { + auto word_iter = vocab->words_to_int.find(word); + if (word_iter == vocab->words_to_int.end()) { + // Save the surface form of this OOV. + (*sentence_unk_surface_forms)[next_token_index] = word; + word_id = vocab->words_to_int[vocab->UNK]; + } else { + (*sentence_unk_surface_forms)[next_token_index] = ""; + word_id = word_iter->second; + } + } + } + + (*sentence)[next_token_index] = word_id; + (*sentence_pos)[next_token_index] = pos_id; +} + + +void TrainingCorpus::OracleTransitionsCorpusReader::RecordAction( + const string& action, TrainingCorpus* corpus, + vector* correct_actions) const { + CorpusVocabulary* vocab = corpus->vocab; + auto action_iter = find(vocab->action_names.begin(), vocab->action_names.end(), action); + if (action_iter != vocab->action_names.end()) { + unsigned action_index = distance(vocab->action_names.begin(), action_iter); + correct_actions->push_back(action_index); + } else { // A not-previously-seen action + if (is_training) { + vocab->action_names.push_back(action); + unsigned action_index = vocab->action_names.size() - 1; + correct_actions->push_back(action_index); + vocab->actions_to_arc_labels.push_back(vocab->GetLabelForAction(action)); + } else { + // TODO: right now, new actions which haven't been observed in + // training are not added to correct_act_sent. In dev, this may + // be a problem if there is little training data. + cerr << "WARNING: encountered unknown transition in dev corpus: " + << action << endl; + } + } +} + + +void TrainingCorpus::OracleTransitionsCorpusReader::RecordSentence( + TrainingCorpus* corpus, Sentence::SentenceMap* words, + Sentence::SentenceMap* sentence_pos, + Sentence::SentenceUnkMap* sentence_unk_surface_forms, + vector* correct_actions, + Sentence::SentenceMetadata* metadata) const { + // Store the sentence variables and clear them for the next sentence. + corpus->sentences.emplace_back(*corpus->vocab); + Sentence* sentence = &corpus->sentences.back(); + sentence->words.swap(*words); + sentence->poses.swap(*sentence_pos); + sentence->metadata.reset(metadata); + corpus->correct_act_sent.push_back({}); + corpus->correct_act_sent.back().swap(*correct_actions); + + if (!is_training) { + sentence->unk_surface_forms.swap(*sentence_unk_surface_forms); + } + + assert(corpus->correct_act_sent.size() == corpus->sentences.size()); +} + + +void ParserTrainingCorpus::OracleParseTransitionsReader::LoadCorrectActions( + const string& file, ParserTrainingCorpus* corpus) const { + cerr << "Loading " << (is_training ? "training" : "dev/test") << " corpus from " << file << "..." << endl; - ifstream actionsFile(file); - string lineS; + ifstream actions_file(file); + if (!actions_file) { + cerr << "Unable to open actions file " << file << "; aborting" << endl; + abort(); + } + + string line; CorpusVocabulary* vocab = corpus->vocab; bool next_is_action_line = false; bool start_of_sentence = false; bool first = true; - map sentence; - map sentence_pos; - map sentence_unk_surface_forms; + // TODO: replace this code with simpler Sentence-based code. + Sentence::SentenceMap sentence; + Sentence::SentenceMap sentence_pos; + Sentence::SentenceUnkMap sentence_unk_surface_forms; + vector correct_actions; // We'll need to make sure ROOT token has a consistent ID. // (Should get inlined; defined here for DRY purposes.) @@ -133,24 +249,16 @@ void TrainingCorpus::OracleTransitionsCorpusReader::LoadCorrectActions( } }; - while (getline(actionsFile, lineS)) { - ReplaceStringInPlace(lineS, "-RRB-", "_RRB_"); - ReplaceStringInPlace(lineS, "-LRB-", "_LRB_"); + while (getline(actions_file, line)) { + ReplaceStringInPlace(&line, "-RRB-", "_RRB_"); + ReplaceStringInPlace(&line, "-LRB-", "_LRB_"); // An empty line marks the end of a sentence. - if (lineS.empty()) { + if (line.empty()) { next_is_action_line = false; if (!first) { // if first, first line is blank, but no sentence yet FixRootID(); - // Store the sentence variables and clear them for the next sentence. - corpus->sentences.push_back({}); - corpus->sentences.back().swap(sentence); - corpus->sentences_pos.push_back({}); - corpus->sentences_pos.back().swap(sentence_pos); - if (!is_training) { - corpus->sentences_unk_surface_forms.push_back({}); - corpus->sentences_unk_surface_forms.back().swap( - sentence_unk_surface_forms); - } + RecordSentence(corpus, &sentence, &sentence_pos, + &sentence_unk_surface_forms, &correct_actions); } start_of_sentence = true; continue; // don't update next_is_action_line @@ -163,9 +271,9 @@ void TrainingCorpus::OracleTransitionsCorpusReader::LoadCorrectActions( // the initial line in each sentence should look like: // [][the-det, cat-noun, is-verb, on-adp, the-det, mat-noun, ,-punct, ROOT-ROOT] // first, get rid of the square brackets. - lineS = lineS.substr(3, lineS.size() - 4); + line = line.substr(3, line.size() - 4); // read the initial line, token by token "the-det," "cat-noun," ... - istringstream iss(lineS); + istringstream iss(line); do { string word; iss >> word; @@ -177,14 +285,14 @@ void TrainingCorpus::OracleTransitionsCorpusReader::LoadCorrectActions( word = word.substr(0, word.size() - 1); } // split the string (at '-') into word and POS tag. - size_t posIndex = word.rfind('-'); - if (posIndex == string::npos) { + size_t pos_index = word.rfind('-'); + if (pos_index == string::npos) { cerr << "can't find the dash in '" << word << "'" << endl; } - assert(posIndex != string::npos); - string pos = word.substr(posIndex + 1); - word = word.substr(0, posIndex); + assert(pos_index != string::npos); + string pos = word.substr(pos_index + 1); + word = word.substr(0, pos_index); if (pos == ORACLE_ROOT_POS) { // Prevent any confusion with the actual word "ROOT". @@ -192,84 +300,14 @@ void TrainingCorpus::OracleTransitionsCorpusReader::LoadCorrectActions( pos = CorpusVocabulary::ROOT; } - // We assume that we'll have seen all POS tags in training, so don't - // worry about OOV tags. - unsigned pos_id = vocab->GetOrAddEntry(pos, &vocab->pos_to_int, - &vocab->int_to_pos); // Use 1-indexed token IDs to leave room for ROOT in position 0. unsigned next_token_index = sentence.size() + 1; - unsigned word_id; - if (is_training) { - unsigned num_words = vocab->CountWords(); // store for later check - word_id = vocab->GetOrAddWord(word, true); - if (vocab->CountWords() > num_words) { - // A new word was added; add its chars, too. - unsigned j = 0; - while (j < word.length()) { - unsigned char_utf8_len = UTF8Len(word[j]); - string next_utf8_char = word.substr(j, char_utf8_len); - vocab->GetOrAddEntry(next_utf8_char, &vocab->chars_to_int, - &vocab->int_to_chars); - j += char_utf8_len; - } - } else { - // It's an old word. Make sure it's marked as present in training. - vocab->int_to_training_word[word_id] = true; - } - } else { - // add an empty string for any token except OOVs (it is easy to - // recover the surface form of non-OOV using intToWords(id)). - // OOV word - if (corpus->USE_SPELLING) { - word_id = vocab->GetOrAddWord(word); // don't record as training - sentence_unk_surface_forms[next_token_index] = ""; - } else { - auto word_iter = vocab->words_to_int.find(word); - if (word_iter == vocab->words_to_int.end()) { - // Save the surface form of this OOV. - sentence_unk_surface_forms[next_token_index] = word; - word_id = vocab->words_to_int[vocab->UNK]; - } else { - sentence_unk_surface_forms[next_token_index] = ""; - word_id = word_iter->second; - } - } - } - - sentence[next_token_index] = word_id; - sentence_pos[next_token_index] = pos_id; + RecordWord(word, pos, next_token_index, corpus, &sentence, + &sentence_pos, &sentence_unk_surface_forms); } while (iss); } - } else if (next_is_action_line) { - auto action_iter = find(vocab->actions.begin(), vocab->actions.end(), - lineS); - if (action_iter != vocab->actions.end()) { - unsigned action_index = distance(vocab->actions.begin(), action_iter); - if (start_of_sentence) - corpus->correct_act_sent.push_back({action_index}); - else - corpus->correct_act_sent.back().push_back(action_index); - } else { // A not-previously-seen action - if (is_training) { - vocab->actions.push_back(lineS); - vocab->actions_to_arc_labels.push_back( - vocab->GetLabelForAction(lineS)); - - unsigned action_index = vocab->actions.size() - 1; - if (start_of_sentence) - corpus->correct_act_sent.push_back({action_index}); - else - corpus->correct_act_sent.back().push_back(action_index); - } else { - // TODO: right now, new actions which haven't been observed in - // training are not added to correct_act_sent. In dev, this may - // be a problem if there is little training data. - cerr << "WARNING: encountered unknown transition in dev corpus: " - << lineS << endl; - if (start_of_sentence) - corpus->correct_act_sent.push_back({}); - } - } + } else { // next_is_action_line + RecordAction(line, corpus, &correct_actions); start_of_sentence = false; } @@ -279,19 +317,16 @@ void TrainingCorpus::OracleTransitionsCorpusReader::LoadCorrectActions( // Add the last sentence. if (sentence.size() > 0) { FixRootID(); - corpus->sentences.push_back(move(sentence)); - corpus->sentences_pos.push_back(move(sentence_pos)); - if (!is_training) { - corpus->sentences_unk_surface_forms.push_back( - move(sentence_unk_surface_forms)); - } + RecordSentence(corpus, &sentence, &sentence_pos, + &sentence_unk_surface_forms, &correct_actions); } - actionsFile.close(); + actions_file.close(); cerr << "done." << "\n"; if (is_training) { - for (auto a : vocab->actions) { + for (auto a : vocab->action_names) { + vocab->actions_to_arc_labels.push_back(vocab->GetLabelForAction(a)); cerr << a << "\n"; } } diff --git a/parser/corpus.h b/parser/corpus.h index 9be1a9d..b6aa496 100644 --- a/parser/corpus.h +++ b/parser/corpus.h @@ -2,6 +2,7 @@ #define CORPUS_H #include +#include #include #include #include @@ -34,24 +35,26 @@ class CorpusVocabulary { StrToIntMap chars_to_int; std::vector int_to_chars; - std::vector actions; + std::vector action_names; std::vector actions_to_arc_labels; + unsigned kUNK; + CorpusVocabulary() : int_to_training_word({true, true}) { AddEntry(BAD0, &words_to_int, &int_to_words); - AddEntry(UNK, &words_to_int, &int_to_words); + kUNK = AddEntry(UNK, &words_to_int, &int_to_words); AddEntry(BAD0, &chars_to_int, &int_to_chars); } inline unsigned CountPOS() { return pos_to_int.size(); } inline unsigned CountWords() { return words_to_int.size(); } inline unsigned CountChars() { return chars_to_int.size(); } - inline unsigned CountActions() { return actions.size(); } + inline unsigned CountActions() { return action_names.size(); } inline unsigned GetWord(const std::string& word) const { auto word_iter = words_to_int.find(word); if (word_iter == words_to_int.end()) { - return words_to_int.find(CorpusVocabulary::UNK)->second; + return kUNK; } else { return word_iter->second; } @@ -92,12 +95,9 @@ class CorpusVocabulary { } static inline std::string GetLabelForAction(const std::string& action) { - if (boost::starts_with(action, "RIGHT-ARC") || - boost::starts_with(action, "LEFT-ARC")) { - size_t first_char_in_rel = action.find('(') + 1; - size_t last_char_in_rel = action.rfind(')') - 1; - return action.substr( - first_char_in_rel, last_char_in_rel - first_char_in_rel + 1); + boost::smatch match; + if (boost::regex_search(action, match, ARC_ACTION_REGEX)) { + return match[1]; } else { return "NONE"; } @@ -106,6 +106,8 @@ class CorpusVocabulary { private: friend class boost::serialization::access; + static const boost::regex ARC_ACTION_REGEX; + template // Shared code: serialize the number-to-string mappings, from which the // reverse mappings can be reconstructed. @@ -115,7 +117,7 @@ class CorpusVocabulary { ar & vocab->int_to_pos; ar & vocab->int_to_chars; ar & vocab->int_to_training_word; - ar & vocab->actions; + ar & vocab->action_names; } template @@ -149,7 +151,7 @@ class CorpusVocabulary { chars_to_int[int_to_chars[i]] = i; // ...and the arc labels. - for (const std::string& action : actions) { + for (const std::string& action : action_names) { actions_to_arc_labels.push_back(GetLabelForAction(action)); } } @@ -186,15 +188,69 @@ class ConllUCorpusReader : public CorpusReader { }; +class Sentence; +inline std::ostream& operator<<(std::ostream& os, const Sentence& sent); + +class ParseTree; // forward declaration + +class Sentence { +public: + typedef std::map SentenceMap; + typedef std::map SentenceUnkMap; + + // TODO: move correct_act_sent from corpus-level to here + struct SentenceMetadata {}; + + Sentence(const CorpusVocabulary& vocab) : vocab(&vocab), tree(nullptr) {} + + SentenceMap words; + SentenceMap poses; + SentenceUnkMap unk_surface_forms; + const CorpusVocabulary* vocab; + ParseTree* tree; + std::unique_ptr metadata; + + size_t Size() const { + return words.size(); + } + + const std::string& WordForToken(unsigned token_id) const { + return WordForToken(words.find(token_id), token_id); + } + + const std::string& WordForToken(SentenceMap::const_iterator words_iter, + unsigned token_id) const { + unsigned word_id = words_iter->second; + return word_id == vocab->kUNK ? unk_surface_forms.at(token_id) + : vocab->int_to_words[word_id]; + } +}; + +inline std::ostream& operator<<(std::ostream& os, const Sentence& sent) { + for (auto &index_and_word_id : sent.words) { + unsigned index = index_and_word_id.first; + unsigned word_id = index_and_word_id.second; + unsigned pos_id = sent.poses.at(index); + auto unk_iter = sent.unk_surface_forms.find(index); + os << (unk_iter == sent.unk_surface_forms.end() || unk_iter->second == "" + ? sent.vocab->int_to_words.at(word_id) + : unk_iter->second) + << '/' << sent.vocab->int_to_pos.at(pos_id); + if (index != sent.words.rend()->first) { + os << ' '; + } + } + return os; +} + + class Corpus { public: // Store root tokens with unsigned ID -1 internally to make root come last // when iterating over a list of tokens in order of IDs. static constexpr unsigned ROOT_TOKEN_ID = -1; - std::vector> sentences; - std::vector> sentences_pos; - std::vector> sentences_unk_surface_forms; + std::vector sentences; CorpusVocabulary* vocab; Corpus(CorpusVocabulary* vocab, const CorpusReader& reader, @@ -207,38 +263,49 @@ class Corpus { // Corpus for subclasses to inherit and use. Subclasses are then responsible // for doing any corpus-reading or setup. Corpus(CorpusVocabulary* vocab) : vocab(vocab) {} - }; class TrainingCorpus : public Corpus { public: - friend class OracleTransitionsCorpusReader; - - bool USE_SPELLING = false; - std::vector> correct_act_sent; - std::set singletons; - - TrainingCorpus(CorpusVocabulary* vocab, const std::string& file, - bool is_training) : - Corpus(vocab) { - OracleTransitionsCorpusReader reader(is_training); - reader.ReadSentences(file, this); - } + bool USE_SPELLING = false; -private: +protected: class OracleTransitionsCorpusReader : public CorpusReader { public: OracleTransitionsCorpusReader(bool is_training) : - is_training(is_training) {} + is_training(is_training) { + } - virtual void ReadSentences(const std::string& file, Corpus* corpus) const { - TrainingCorpus* training_corpus = static_cast(corpus); - LoadCorrectActions(file, training_corpus); + static inline void ReplaceStringInPlace(std::string* subject, + const std::string& search, + const std::string& replace) { + size_t pos = 0; + while ((pos = subject->find(search, pos)) != std::string::npos) { + subject->replace(pos, search.length(), replace); + pos += replace.length(); + } } - virtual ~OracleTransitionsCorpusReader() {}; + protected: + bool is_training; // can be dev rather than actual training + + void RecordWord( + const std::string& word, const std::string& pos, + unsigned next_token_index, TrainingCorpus* corpus, + Sentence::SentenceMap* sentence, + Sentence::SentenceMap* sentence_pos, + Sentence::SentenceUnkMap* sentence_unk_surface_forms) const; + + void RecordAction(const std::string& action, TrainingCorpus* corpus, + std::vector* correct_actions) const; + + void RecordSentence(TrainingCorpus* corpus, Sentence::SentenceMap* words, + Sentence::SentenceMap* sentence_pos, + Sentence::SentenceUnkMap* sentence_unk_surface_forms, + std::vector* correct_actions, + Sentence::SentenceMetadata* metadata = nullptr) const; static inline unsigned UTF8Len(unsigned char x) { if (x < 0x80) return 1; @@ -249,25 +316,57 @@ class TrainingCorpus : public Corpus { else if ((x >> 1) == 0x7e) return 6; else return 0; } - private: - bool is_training; - void LoadCorrectActions(const std::string& file, - TrainingCorpus* corpus) const; }; - static inline void ReplaceStringInPlace(std::string& subject, - const std::string& search, - const std::string& replace) { - size_t pos = 0; - while ((pos = subject.find(search, pos)) != std::string::npos) { - subject.replace(pos, search.length(), replace); - pos += replace.length(); - } + // Don't provide access to reader constructor -- object won't be fully + // constructed yet, so it would segfault. + TrainingCorpus(CorpusVocabulary* vocab) : Corpus(vocab) {} +}; + + +class ParserTrainingCorpus : public TrainingCorpus { +public: + friend class OracleTransitionsCorpusReader; + + std::set singletons; + + ParserTrainingCorpus(CorpusVocabulary* vocab, const std::string& file, + bool is_training) : + TrainingCorpus(vocab) { + OracleParseTransitionsReader(is_training).ReadSentences(file, this); } +private: + class OracleParseTransitionsReader : public OracleTransitionsCorpusReader{ + public: + OracleParseTransitionsReader(bool is_training) : + OracleTransitionsCorpusReader(is_training) {} + + virtual void ReadSentences(const std::string& file, Corpus* corpus) const { + ParserTrainingCorpus* training_corpus = + static_cast(corpus); + LoadCorrectActions(file, training_corpus); + training_corpus->sentences.shrink_to_fit(); + training_corpus->correct_act_sent.shrink_to_fit(); + } + + virtual ~OracleParseTransitionsReader() {}; + + private: + void LoadCorrectActions(const std::string& file, + ParserTrainingCorpus* corpus) const; + }; + void CountSingletons(); }; } // namespace lstm_parser + +inline void swap(lstm_parser::Sentence& s1, lstm_parser::Sentence& s2) { + lstm_parser::Sentence tmp = std::move(s1); + s1 = std::move(s2); + s2 = std::move(tmp); +} + #endif diff --git a/parser/lstm-parser-driver.cc b/parser/lstm-parser-driver.cc index 8d6b235..09a6fb3 100644 --- a/parser/lstm-parser-driver.cc +++ b/parser/lstm-parser-driver.cc @@ -54,7 +54,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("help")) { + if (conf->count("help") || argc == 1) { cerr << dcmdline_options << endl; exit(0); } @@ -119,8 +119,16 @@ int main(int argc, char** argv) { cerr << "No model specified for testing!" << endl; abort(); } + if (train && !load_model) { + if (!conf.count("training_data")) { + cerr << "Can't train without training data! Please provide" + " --training_data" << endl; + abort(); + } + } - const string words = load_model ? "" : conf["words"].as(); + const string words = + load_model || !conf.count("words") ? "" : conf["words"].as(); unique_ptr parser; if (load_model) { parser.reset(new LSTMParser(conf["model"].as())); @@ -135,7 +143,7 @@ int main(int argc, char** argv) { parser.reset(new LSTMParser(cmd_options, words, false)); } - unique_ptr dev_corpus; // shared by train/evaluate + unique_ptr dev_corpus; // shared by train/evaluate if (train) { if (!conf.count("training_data") || !conf.count("dev_data")) { @@ -145,14 +153,16 @@ int main(int argc, char** argv) { } signal(SIGINT, signal_callback_handler); - TrainingCorpus training_corpus(&parser->vocab, - conf["training_data"].as(), true); + ParserTrainingCorpus training_corpus(parser->GetVocab(), + conf["training_data"].as(), + true); parser->FinalizeVocab(); cerr << "Total number of words: " << training_corpus.vocab->CountWords() << endl; // OOV words will be replaced by UNK tokens - dev_corpus.reset(new TrainingCorpus(&parser->vocab, - conf["dev_data"].as(), false)); + dev_corpus.reset( + new ParserTrainingCorpus(parser->GetVocab(), + conf["dev_data"].as(), false)); ostringstream os; os << "parser_" << (parser->options.use_pos ? "pos" : "nopos") @@ -180,8 +190,8 @@ int main(int argc, char** argv) { cerr << "Evaluating model on " << conf["dev_data"].as() << endl; if (!train) { // Didn't already load dev corpus for training dev_corpus.reset( - new TrainingCorpus(&parser->vocab, conf["dev_data"].as(), - false)); + new ParserTrainingCorpus(parser->GetVocab(), + conf["dev_data"].as(), false)); } parser->Evaluate(*dev_corpus); } @@ -203,7 +213,8 @@ int main(int argc, char** argv) { << endl; abort(); } - Corpus test_corpus(&parser->vocab, *reader, conf["test_data"].as()); + Corpus test_corpus(parser->GetVocab(), *reader, + conf["test_data"].as()); parser->Test(test_corpus); } } diff --git a/parser/lstm-parser.cc b/parser/lstm-parser.cc index 790aba0..e33ed09 100644 --- a/parser/lstm-parser.cc +++ b/parser/lstm-parser.cc @@ -14,7 +14,6 @@ #include "cnn/model.h" #include "cnn/tensor.h" -#include "eos/portable_archive.hpp" using namespace cnn::expr; @@ -24,7 +23,7 @@ using namespace std; namespace lstm_parser { -string ParseTree::NO_LABEL = "ERROR"; +const string ParseTree::NO_LABEL("ERROR"); void LSTMParser::LoadPretrainedWords(const string& words_path) { @@ -65,10 +64,7 @@ void LSTMParser::LoadPretrainedWords(const string& words_path) { } -void LSTMParser::FinalizeVocab() { - if (finalized) - return; - +void LSTMParser::InitializeNetworkParameters() { // Now that the vocab is ready to be finalized, we can set all the network // parameters. unsigned action_size = vocab.CountActions() + 1; @@ -78,57 +74,54 @@ void LSTMParser::FinalizeVocab() { if (!pretrained.empty()) { unsigned pretrained_dim = pretrained.begin()->second.size(); - p_t = model.add_lookup_parameters(vocab_size, {pretrained_dim}); + p_t = model->add_lookup_parameters(vocab_size, {pretrained_dim}); for (const auto& it : pretrained) p_t->Initialize(it.first, it.second); - p_t2l = model.add_parameters({options.lstm_input_dim, pretrained_dim}); + p_t2l = model->add_parameters({options.lstm_input_dim, pretrained_dim}); } else { p_t = nullptr; p_t2l = nullptr; } - p_w = model.add_lookup_parameters(vocab_size, {options.input_dim}); - p_a = model.add_lookup_parameters(action_size, {options.action_dim}); - p_r = model.add_lookup_parameters(action_size, {options.rel_dim}); - p_pbias = model.add_parameters({options.hidden_dim}); - p_A = model.add_parameters({options.hidden_dim, options.hidden_dim}); - p_B = model.add_parameters({options.hidden_dim, options.hidden_dim}); - p_S = model.add_parameters({options.hidden_dim, options.hidden_dim}); - p_H = model.add_parameters({options.lstm_input_dim, options.lstm_input_dim}); - p_D = model.add_parameters({options.lstm_input_dim, options.lstm_input_dim}); - p_R = model.add_parameters({options.lstm_input_dim, options.rel_dim}); - p_w2l = model.add_parameters({options.lstm_input_dim, options.input_dim}); - p_ib = model.add_parameters({options.lstm_input_dim}); - p_cbias = model.add_parameters({options.lstm_input_dim}); - p_p2a = model.add_parameters({action_size, options.hidden_dim}); - p_action_start = model.add_parameters({options.action_dim}); - p_abias = model.add_parameters({action_size}); - p_buffer_guard = model.add_parameters({options.lstm_input_dim}); - p_stack_guard = model.add_parameters({options.lstm_input_dim}); + p_w = model->add_lookup_parameters(vocab_size, {options.input_dim}); + p_a = model->add_lookup_parameters(action_size, {options.action_dim}); + p_r = model->add_lookup_parameters(action_size, {options.rel_dim}); + p_pbias = model->add_parameters({options.hidden_dim}); + p_A = model->add_parameters({options.hidden_dim, options.hidden_dim}); + p_B = model->add_parameters({options.hidden_dim, options.hidden_dim}); + p_S = model->add_parameters({options.hidden_dim, options.hidden_dim}); + p_H = model->add_parameters({options.lstm_input_dim, options.lstm_input_dim}); + p_D = model->add_parameters({options.lstm_input_dim, options.lstm_input_dim}); + p_R = model->add_parameters({options.lstm_input_dim, options.rel_dim}); + p_w2l = model->add_parameters({options.lstm_input_dim, options.input_dim}); + p_ib = model->add_parameters({options.lstm_input_dim}); + p_cbias = model->add_parameters({options.lstm_input_dim}); + p_p2a = model->add_parameters({action_size, options.hidden_dim}); + p_action_start = model->add_parameters({options.action_dim}); + p_abias = model->add_parameters({action_size}); + p_buffer_guard = model->add_parameters({options.lstm_input_dim}); + p_stack_guard = model->add_parameters({options.lstm_input_dim}); if (options.use_pos) { - p_p = model.add_lookup_parameters(pos_size, {options.pos_dim}); - p_p2l = model.add_parameters({options.lstm_input_dim, options.pos_dim}); + p_p = model->add_lookup_parameters(pos_size, {options.pos_dim}); + p_p2l = model->add_parameters({options.lstm_input_dim, options.pos_dim}); } else { p_p = nullptr; p_p2l = nullptr; } - - finalized = true; } LSTMParser::LSTMParser(const ParserOptions& poptions, const string& pretrained_words_path, bool finalize) : options(poptions), - kUNK(vocab.GetOrAddWord(vocab.UNK)), kROOT_SYMBOL(vocab.GetOrAddWord(vocab.ROOT)), stack_lstm(options.layers, options.lstm_input_dim, options.hidden_dim, - &model), + model.get()), buffer_lstm(options.layers, options.lstm_input_dim, options.hidden_dim, - &model), + model.get()), action_lstm(options.layers, options.action_dim, options.hidden_dim, - &model) { + model.get()) { // First load words if needed before creating network parameters. // That will ensure that the vocab has the final number of words. if (!pretrained_words_path.empty()) { @@ -143,18 +136,23 @@ LSTMParser::LSTMParser(const ParserOptions& poptions, } -bool LSTMParser::IsActionForbidden(const string& a, unsigned bsize, - unsigned ssize, const vector& stacki) { - if (a[1] == 'W' && ssize < 3) +bool LSTMParser::IsActionForbidden(const unsigned action, + TaggerState* state) const { + const string& action_name = vocab.action_names[action]; + const ParserState& real_state = static_cast(*state); + unsigned ssize = real_state.stack.size(); + unsigned bsize = real_state.buffer.size(); + + if (action_name[1] == 'W' && ssize < 3) return true; - if (a[1] == 'W') { - int top = stacki[stacki.size() - 1]; - int sec = stacki[stacki.size() - 2]; + if (action_name[1] == 'W') { + int top = real_state.stacki[real_state.stacki.size() - 1]; + int sec = real_state.stacki[real_state.stacki.size() - 2]; if (sec > top) return true; } - bool is_shift = (a[0] == 'S' && a[1] == 'H'); + bool is_shift = (action_name[0] == 'S' && action_name[1] == 'H'); bool is_reduce = !is_shift; if (is_shift && bsize == 1) return true; @@ -165,27 +163,27 @@ bool LSTMParser::IsActionForbidden(const string& a, unsigned bsize, is_shift) return true; // only attach left to ROOT - if (bsize == 1 && ssize == 3 && a[0] == 'R') + if (bsize == 1 && ssize == 3 && action_name[0] == 'R') return true; return false; } ParseTree LSTMParser::RecoverParseTree( - const map& sentence, const vector& actions, - const vector& action_names, - const vector& actions_to_arc_labels, bool labeled) { + const Sentence& sentence, const vector& actions, double logprob, + bool labeled) const { ParseTree tree(sentence, labeled); - vector bufferi(sentence.size() + 1); + vector bufferi(sentence.Size() + 1); bufferi[0] = -999; vector stacki(1, -999); unsigned added_to_buffer = 0; - for (const auto& index_and_word_id : sentence) { + for (const auto& index_and_word_id : sentence.words) { // ROOT is set to -1, so it'll come last in a sequence of unsigned ints. - bufferi[sentence.size() - added_to_buffer++] = index_and_word_id.first; + bufferi[sentence.Size() - added_to_buffer++] = + index_and_word_id.first; } for (auto action : actions) { // loop over transitions for sentence - const string& action_string = action_names[action]; + const string& action_string = vocab.action_names[action]; const char ac = action_string[0]; const char ac2 = action_string[1]; if (ac == 'S' && ac2 == 'H') { // SHIFT @@ -212,250 +210,185 @@ ParseTree LSTMParser::RecoverParseTree( (ac == 'R' ? headi : depi) = stacki.back(); stacki.pop_back(); stacki.push_back(headi); - tree.SetParent(depi, headi, actions_to_arc_labels[action]); + tree.SetParent(depi, headi, vocab.actions_to_arc_labels[action]); } } assert(bufferi.size() == 1); //assert(stacki.size() == 2); + + tree.logprob = logprob; return tree; } -vector LSTMParser::LogProbParser( - ComputationGraph* hg, - const map& raw_sent, // raw sentence - const map& sent, // sentence with OOVs replaced - const map& sent_pos, - const vector& correct_actions, const vector& action_names, - const vector& int_to_words, double* correct) { - // TODO: break up this function? - assert(finalized); - vector results; - const bool build_training_graph = correct_actions.size() > 0; - - stack_lstm.new_graph(*hg); - buffer_lstm.new_graph(*hg); - action_lstm.new_graph(*hg); +Expression LSTMParser::GetActionProbabilities(TaggerState* state) { + // p_t = pbias + S * slstm + B * blstm + A * alstm + Expression p_t = affine_transform( + {GetParamExpr(p_pbias), GetParamExpr(p_S), stack_lstm.back(), + GetParamExpr(p_B), buffer_lstm.back(), GetParamExpr(p_A), + action_lstm.back()}); + Expression nlp_t = rectify(p_t); + // r_t = abias + p2a * nlp + Expression r_t = affine_transform( + {GetParamExpr(p_abias), GetParamExpr(p_p2a), nlp_t}); + return r_t; +} + + +void LSTMParser::DoAction(unsigned action, TaggerState* state, + ComputationGraph* cg, + map* states_to_expose) { + ParserState* real_state = static_cast(state); + // add current action to action LSTM + Expression action_e = lookup(*cg, p_a, action); + action_lstm.add_input(action_e); + + // get relation embedding from action (TODO: convert to rel from action?) + Expression relation = lookup(*cg, p_r, action); + + // do action + const string& action_string = vocab.action_names[action]; + const char ac = action_string[0]; + const char ac2 = action_string[1]; + + if (ac == 'S' && ac2 == 'H') { // SHIFT + assert(real_state->buffer.size() > 1); // dummy symbol means > 1 (not >= 1) + real_state->stack.push_back(real_state->buffer.back()); + stack_lstm.add_input(real_state->buffer.back()); + real_state->buffer.pop_back(); + buffer_lstm.rewind_one_step(); + real_state->stacki.push_back(real_state->bufferi.back()); + real_state->bufferi.pop_back(); + } else if (ac == 'S' && ac2 == 'W') { //SWAP --- Miguel + assert(real_state->stack.size() > 2); // dummy symbol means > 2 (not >= 2) + + Expression toki, tokj; + unsigned ii = 0, jj = 0; + tokj = real_state->stack.back(); + jj = real_state->stacki.back(); + real_state->stack.pop_back(); + real_state->stacki.pop_back(); + + toki = real_state->stack.back(); + ii = real_state->stacki.back(); + real_state->stack.pop_back(); + real_state->stacki.pop_back(); + + real_state->buffer.push_back(toki); + real_state->bufferi.push_back(ii); + + stack_lstm.rewind_one_step(); + stack_lstm.rewind_one_step(); + + buffer_lstm.add_input(real_state->buffer.back()); + + real_state->stack.push_back(tokj); + real_state->stacki.push_back(jj); + + stack_lstm.add_input(real_state->stack.back()); + } else { // LEFT or RIGHT + assert(real_state->stack.size() > 2); // dummy symbol means > 2 (not >= 2) + assert(ac == 'L' || ac == 'R'); + Expression dep, head; + unsigned depi = 0, headi = 0; + (ac == 'R' ? dep : head) = real_state->stack.back(); + (ac == 'R' ? depi : headi) = real_state->stacki.back(); + real_state->stack.pop_back(); + real_state->stacki.pop_back(); + (ac == 'R' ? head : dep) = real_state->stack.back(); + (ac == 'R' ? headi : depi) = real_state->stacki.back(); + real_state->stack.pop_back(); + real_state->stacki.pop_back(); + // composed = cbias + H * head + D * dep + R * relation + Expression composed = affine_transform( + {GetParamExpr(p_cbias), + GetParamExpr(p_H), head, + GetParamExpr(p_D), dep, + GetParamExpr(p_R), relation}); + Expression nlcomposed = tanh(composed); + stack_lstm.rewind_one_step(); + stack_lstm.rewind_one_step(); + stack_lstm.add_input(nlcomposed); + real_state->stack.push_back(nlcomposed); + real_state->stacki.push_back(headi); + if (states_to_expose) { + // Once something is attached as a dependent, it will never again be + // modified, so cache its expression. + (*states_to_expose)[to_string(depi)] = dep; + } + } + + // After the last action, record the final tree state, if requested. + if (states_to_expose && ShouldTerminate(real_state)) { + (*states_to_expose)["Tree"] = real_state->stack.back(); + } +} + + +NeuralTransitionTagger::TaggerState* LSTMParser::InitializeParserState( + ComputationGraph* cg, + const Sentence& raw_sent, + const Sentence::SentenceMap& sent, // sentence with OOVs replaced + const vector& correct_actions) { + stack_lstm.new_graph(*cg); + buffer_lstm.new_graph(*cg); + action_lstm.new_graph(*cg); stack_lstm.start_new_sequence(); buffer_lstm.start_new_sequence(); action_lstm.start_new_sequence(); - // variables in the computation graph representing the parameters - Expression pbias = parameter(*hg, p_pbias); - Expression H = parameter(*hg, p_H); - Expression D = parameter(*hg, p_D); - Expression R = parameter(*hg, p_R); - Expression cbias = parameter(*hg, p_cbias); - Expression S = parameter(*hg, p_S); - Expression B = parameter(*hg, p_B); - Expression A = parameter(*hg, p_A); - Expression ib = parameter(*hg, p_ib); - Expression w2l = parameter(*hg, p_w2l); - Expression p2l; - if (options.use_pos) - p2l = parameter(*hg, p_p2l); - Expression t2l; - if (p_t2l) - t2l = parameter(*hg, p_t2l); - Expression p2a = parameter(*hg, p_p2a); - Expression abias = parameter(*hg, p_abias); - Expression action_start = parameter(*hg, p_action_start); - - action_lstm.add_input(action_start); - - // variables representing word embeddings (possibly including POS info) - vector buffer(sent.size() + 1); - vector bufferi(sent.size() + 1); // position of the words in the sentence - // precompute buffer representation from left to right + Expression stack_guard = GetParamExpr(p_stack_guard); + ParserState* state = new ParserState(raw_sent, sent, stack_guard); + action_lstm.add_input(GetParamExpr(p_action_start)); + stack_lstm.add_input(stack_guard); + + // precompute buffer representation from left to right unsigned added_to_buffer = 0; for (const auto& index_and_word_id : sent) { unsigned token_index = index_and_word_id.first; unsigned word_id = index_and_word_id.second; assert(word_id < vocab.CountWords()); - Expression w = lookup(*hg, p_w, word_id); - - vector args = {ib, w2l, w}; // learn embeddings - if (options.use_pos) { // learn POS tag? - unsigned pos_id = sent_pos.find(token_index)->second; - Expression p = lookup(*hg, p_p, pos_id); - args.push_back(p2l); + Expression w = lookup(*cg, p_w, word_id); + + vector args = {GetParamExpr(p_ib), GetParamExpr(p_w2l), + w}; // learn embeddings + if (options.use_pos) { // learn POS tag? + unsigned pos_id = raw_sent.poses.at(token_index); + Expression p = lookup(*cg, p_p, pos_id); + args.push_back(GetParamExpr(p_p2l)); args.push_back(p); } - unsigned raw_word_id = raw_sent.find(token_index)->second; - if (p_t && pretrained.count(raw_word_id)) { // include pretrained vectors? - Expression t = const_lookup(*hg, p_t, raw_word_id); - args.push_back(t2l); + unsigned raw_word_id = raw_sent.words.at(token_index); + if (p_t && pretrained.count(raw_word_id)) { // include pretrained vectors? + Expression t = const_lookup(*cg, p_t, raw_word_id); + args.push_back(GetParamExpr(p_t2l)); args.push_back(t); } - buffer[sent.size() - added_to_buffer] = rectify(affine_transform(args)); - bufferi[sent.size() - added_to_buffer] = token_index; + state->buffer[sent.size() - added_to_buffer] = rectify( + affine_transform(args)); + state->bufferi[sent.size() - added_to_buffer] = token_index; added_to_buffer++; } // dummy symbol to represent the empty buffer - buffer[0] = parameter(*hg, p_buffer_guard); - bufferi[0] = -999; - for (auto& b : buffer) + state->buffer[0] = parameter(*cg, p_buffer_guard); + state->bufferi[0] = -999; + for (auto& b : state->buffer) buffer_lstm.add_input(b); - vector stack; // variables representing subtree embeddings - vector stacki; // position of words in the sentence of head of subtree - stack.push_back(parameter(*hg, p_stack_guard)); - stacki.push_back(-999); // not used for anything - // drive dummy symbol on stack through LSTM - stack_lstm.add_input(stack.back()); - vector log_probs; - unsigned action_count = 0; // incremented at each prediction - while (stack.size() > 2 || buffer.size() > 1) { - // get list of possible actions for the current parser state - vector current_valid_actions; - for (unsigned action = 0; action < n_possible_actions; ++action) { - if (IsActionForbidden(action_names[action], buffer.size(), stack.size(), - stacki)) - continue; - current_valid_actions.push_back(action); - } - - // p_t = pbias + S * slstm + B * blstm + A * almst - Expression p_t = affine_transform( - {pbias, S, stack_lstm.back(), B, buffer_lstm.back(), A, - action_lstm.back()}); - Expression nlp_t = rectify(p_t); - // r_t = abias + p2a * nlp - Expression r_t = affine_transform({abias, p2a, nlp_t}); - - // adist = log_softmax(r_t, current_valid_actions) - Expression adiste = log_softmax(r_t, current_valid_actions); - vector adist = as_vector(hg->incremental_forward()); - double best_score = adist[current_valid_actions[0]]; - unsigned best_a = current_valid_actions[0]; - for (unsigned i = 1; i < current_valid_actions.size(); ++i) { - if (adist[current_valid_actions[i]] > best_score) { - best_score = adist[current_valid_actions[i]]; - best_a = current_valid_actions[i]; - } - } - unsigned action = best_a; - // If we have reference actions (for training), use the reference action. - if (build_training_graph) { - action = correct_actions[action_count]; - if (best_a == action) { - (*correct)++; - } - } - ++action_count; - log_probs.push_back(pick(adiste, action)); - results.push_back(action); - - // add current action to action LSTM - Expression actione = lookup(*hg, p_a, action); - action_lstm.add_input(actione); - - // get relation embedding from action (TODO: convert to rel from action?) - Expression relation = lookup(*hg, p_r, action); - - // do action - const string& actionString = action_names[action]; - const char ac = actionString[0]; - const char ac2 = actionString[1]; - - if (ac == 'S' && ac2 == 'H') { // SHIFT - assert(buffer.size() > 1); // dummy symbol means > 1 (not >= 1) - stack.push_back(buffer.back()); - stack_lstm.add_input(buffer.back()); - buffer.pop_back(); - buffer_lstm.rewind_one_step(); - stacki.push_back(bufferi.back()); - bufferi.pop_back(); - } else if (ac == 'S' && ac2 == 'W') { //SWAP --- Miguel - assert(stack.size() > 2); // dummy symbol means > 2 (not >= 2) - - Expression toki, tokj; - unsigned ii = 0, jj = 0; - tokj = stack.back(); - jj = stacki.back(); - stack.pop_back(); - stacki.pop_back(); - - toki = stack.back(); - ii = stacki.back(); - stack.pop_back(); - stacki.pop_back(); - - buffer.push_back(toki); - bufferi.push_back(ii); - - stack_lstm.rewind_one_step(); - stack_lstm.rewind_one_step(); - - buffer_lstm.add_input(buffer.back()); - - stack.push_back(tokj); - stacki.push_back(jj); - - stack_lstm.add_input(stack.back()); - } else { // LEFT or RIGHT - assert(stack.size() > 2); // dummy symbol means > 2 (not >= 2) - assert(ac == 'L' || ac == 'R'); - Expression dep, head; - unsigned depi = 0, headi = 0; - (ac == 'R' ? dep : head) = stack.back(); - (ac == 'R' ? depi : headi) = stacki.back(); - stack.pop_back(); - stacki.pop_back(); - (ac == 'R' ? head : dep) = stack.back(); - (ac == 'R' ? headi : depi) = stacki.back(); - stack.pop_back(); - stacki.pop_back(); - // composed = cbias + H * head + D * dep + R * relation - Expression composed = affine_transform({cbias, H, head, D, dep, R, - relation}); - Expression nlcomposed = tanh(composed); - stack_lstm.rewind_one_step(); - stack_lstm.rewind_one_step(); - stack_lstm.add_input(nlcomposed); - stack.push_back(nlcomposed); - stacki.push_back(headi); - } - } - assert(stack.size() == 2); // guard symbol, root - assert(stacki.size() == 2); - assert(buffer.size() == 1); // guard symbol - assert(bufferi.size() == 1); - Expression tot_neglogprob = -sum(log_probs); - assert(tot_neglogprob.pg != nullptr); - return results; + return state; } -void LSTMParser::SaveModel(const string& model_fname, bool softlink_created) { - ofstream out_file(model_fname); - eos::portable_oarchive archive(out_file); - archive << *this; - cerr << "Model saved." << endl; - // Create a soft link to the most recent model in order to make it - // easier to refer to it in a shell script. - if (!softlink_created) { - string softlink = "latest_model.params"; - - if (system((string("rm -f ") + softlink).c_str()) == 0 - && system(("ln -s " + model_fname + " " + softlink).c_str()) == 0) { - cerr << "Created " << softlink << " as a soft link to " << model_fname - << " for convenience." << endl; - } - } -} - - -void LSTMParser::Train(const TrainingCorpus& corpus, - const TrainingCorpus& dev_corpus, const double unk_prob, - const string& model_fname, +void LSTMParser::Train(const ParserTrainingCorpus& corpus, + const ParserTrainingCorpus& dev_corpus, + const double unk_prob, const string& model_fname, const volatile bool* requested_stop) { bool softlink_created = false; int best_correct_heads = 0; unsigned status_every_i_iterations = 100; - SimpleSGDTrainer sgd(&model); - //MomentumSGDTrainer sgd(model); + SimpleSGDTrainer sgd(model.get()); + //MomentumSGDTrainer sgd(model.get()); sgd.eta_decay = 0.08; //sgd.eta_decay = 0.05; unsigned num_sentences = corpus.sentences.size(); @@ -490,30 +423,26 @@ void LSTMParser::Train(const TrainingCorpus& corpus, random_shuffle(order.begin(), order.end()); } tot_seen += 1; - const map& sentence = corpus.sentences[order[si]]; - map tsentence(sentence); + const Sentence& sentence = corpus.sentences[order[si]]; + Sentence::SentenceMap tsentence(sentence.words); if (options.unk_strategy == 1) { for (auto& index_and_id : tsentence) { // use reference to overwrite if (corpus.singletons.count(index_and_id.second) && cnn::rand01() < unk_prob) { - index_and_id.second = kUNK; + index_and_id.second = vocab.kUNK; } } } - const map& sentence_pos = - corpus.sentences_pos[order[si]]; const vector& actions = corpus.correct_act_sent[order[si]]; - ComputationGraph hg; - LogProbParser(&hg, sentence, tsentence, sentence_pos, actions, - corpus.vocab->actions, corpus.vocab->int_to_words, - &correct); - double lp = as_scalar(hg.incremental_forward()); + ComputationGraph cg; + LogProbTagger(&cg, sentence, tsentence, true, actions, &correct); + double lp = as_scalar(cg.incremental_forward()); if (lp < 0) { cerr << "Log prob < 0 on sentence " << order[si] << ": lp=" << lp << endl; assert(lp >= 0.0); } - hg.backward(); + cg.backward(); sgd.update(1.0); llh += lp; ++si; @@ -535,34 +464,31 @@ void LSTMParser::Train(const TrainingCorpus& corpus, // dev_size = 100; double llh = 0; double trs = 0; - double correct = 0; double correct_heads = 0; double total_heads = 0; auto t_start = chrono::high_resolution_clock::now(); for (unsigned sii = 0; sii < dev_size; ++sii) { - const map& sentence = dev_corpus.sentences[sii]; - const map& sentence_pos = - dev_corpus.sentences_pos[sii]; - ParseTree hyp = Parse(sentence, sentence_pos, vocab, false, &correct); + const Sentence& sentence = dev_corpus.sentences[sii]; + + ParseTree hyp = Parse(sentence, vocab, false); + llh += hyp.logprob; - double lp = 0; - llh -= lp; const vector& actions = dev_corpus.correct_act_sent[sii]; - ParseTree ref = RecoverParseTree( - sentence, actions, dev_corpus.vocab->actions, - dev_corpus.vocab->actions_to_arc_labels); + ParseTree ref = RecoverParseTree(sentence, actions); trs += actions.size(); correct_heads += ComputeCorrect(ref, hyp); - total_heads += sentence.size() - 1; // -1 to account for ROOT + total_heads += sentence.Size() - 1; // -1 to account for ROOT } + auto t_end = chrono::high_resolution_clock::now(); + auto ms = chrono::duration(t_end - t_start).count(); cerr << " **dev (iter=" << iter << " epoch=" - << (tot_seen / num_sentences) << ")\tllh=" << llh << " ppl: " - << exp(llh / trs) << " err: " << (trs - correct) / trs << " uas: " - << (correct_heads / total_heads) << "\t[" << dev_size << " sents in " - << chrono::duration(t_end - t_start).count() << " ms]" - << endl; + << (tot_seen / num_sentences) << ")\tllh=" << llh + << " ppl: " << exp(llh / trs) + << " uas: " << (correct_heads / total_heads) + << "\t[" << dev_size << " sents in " << ms << " ms]" << endl; + if (correct_heads > best_correct_heads) { best_correct_heads = correct_heads; SaveModel(model_fname, softlink_created); @@ -573,31 +499,12 @@ void LSTMParser::Train(const TrainingCorpus& corpus, } -vector LSTMParser::LogProbParser( - const map& sentence, - const map& sentence_pos, const CorpusVocabulary& vocab, - ComputationGraph *cg, double* correct) { - map tsentence(sentence); // sentence with OOVs replaced - for (auto& index_and_id : tsentence) { // use reference to overwrite - if (!vocab.int_to_training_word[index_and_id.second]) { - index_and_id.second = kUNK; - } - } - return LogProbParser(cg, sentence, tsentence, sentence_pos, - vector(), vocab.actions, - vocab.int_to_words, correct); -} - - -ParseTree LSTMParser::Parse(const map& sentence, - const map& sentence_pos, - const CorpusVocabulary& vocab, - bool labeled, double* correct) { +ParseTree LSTMParser::Parse(const Sentence& sentence, + const CorpusVocabulary& vocab, bool labeled) { ComputationGraph cg; - vector pred = LogProbParser(sentence, sentence_pos, vocab, &cg, - correct); - return RecoverParseTree(sentence, pred, vocab.actions, - vocab.actions_to_arc_labels, labeled); + vector pred = LogProbTagger(&cg, sentence); + double lp = as_scalar(cg.incremental_forward()); + return RecoverParseTree(sentence, pred, labeled, lp); } @@ -609,42 +516,38 @@ void LSTMParser::DoTest(const Corpus& corpus, bool evaluate, } double llh = 0; double trs = 0; - double correct = 0; double correct_heads = 0; double total_heads = 0; auto t_start = chrono::high_resolution_clock::now(); unsigned corpus_size = corpus.sentences.size(); for (unsigned sii = 0; sii < corpus_size; ++sii) { - const map& sentence = corpus.sentences[sii]; - const map& sentence_pos = corpus.sentences_pos[sii]; - const map& sentence_unk_str = - corpus.sentences_unk_surface_forms[sii]; - ParseTree hyp = Parse(sentence, sentence_pos, vocab, true, &correct); + const Sentence& sentence = corpus.sentences[sii]; + ParseTree hyp = Parse(sentence, vocab, true); if (output_parses) { - OutputConll(sentence, sentence_pos, sentence_unk_str, - corpus.vocab->int_to_words, corpus.vocab->int_to_pos, - corpus.vocab->words_to_int, hyp); + OutputConll(sentence, corpus.vocab->int_to_words, + corpus.vocab->int_to_pos, corpus.vocab->words_to_int, hyp); } if (evaluate) { - // Downcast to TrainingCorpus to get gold-standard data. We can only get - // here if this function was called by Evaluate, which statically checks - // that the corpus is in fact a TrainingCorpus, so this cast is safe. - const TrainingCorpus& training_corpus = - static_cast(corpus); + // Downcast to ParserTrainingCorpus to get gold-standard data. We can only + // get here if this function was called by Evaluate, which statically + // checks that the corpus is in fact a ParserTrainingCorpus, so this cast + // is safe. + const ParserTrainingCorpus& training_corpus = + static_cast(corpus); const vector& actions = training_corpus.correct_act_sent[sii]; - ParseTree ref = RecoverParseTree(sentence, actions, corpus.vocab->actions, - corpus.vocab->actions_to_arc_labels, - true); + ParseTree ref = RecoverParseTree(sentence, actions, true); trs += actions.size(); + llh += hyp.logprob; correct_heads += ComputeCorrect(ref, hyp); - total_heads += sentence.size() - 1; // -1 to account for ROOT + total_heads += sentence.Size() - 1; // -1 to account for ROOT } } + auto t_end = chrono::high_resolution_clock::now(); if (evaluate) { - cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " - << (trs - correct) / trs << " uas: " << (correct_heads / total_heads) + cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) + << " uas: " << (correct_heads / total_heads) << "\t[" << corpus_size << " sents in " << chrono::duration(t_end - t_start).count() << " ms]" << endl; @@ -656,29 +559,26 @@ void LSTMParser::DoTest(const Corpus& corpus, bool evaluate, } -void LSTMParser::OutputConll(const map& sentence, - const map& pos, - const map& sentence_unk_strings, +void LSTMParser::OutputConll(const Sentence& sentence, const vector& int_to_words, const vector& int_to_pos, const map& words_to_int, const ParseTree& tree) { - const unsigned int unk_word = - words_to_int.find(CorpusVocabulary::UNK)->second; - for (const auto& token_index_and_word : sentence) { + const unsigned int unk_word = words_to_int.at(CorpusVocabulary::UNK); + for (const auto& token_index_and_word : sentence.words) { unsigned token_index = token_index_and_word.first; unsigned word_id = token_index_and_word.second; if (token_index == Corpus::ROOT_TOKEN_ID) // don't output anything for ROOT continue; - auto unk_strs_iter = sentence_unk_strings.find(token_index); - assert(unk_strs_iter != sentence_unk_strings.end() && + auto unk_strs_iter = sentence.unk_surface_forms.find(token_index); + assert(unk_strs_iter != sentence.unk_surface_forms.end() && ((word_id == unk_word && unk_strs_iter->second.size() > 0) || (word_id != unk_word && unk_strs_iter->second.size() == 0 && int_to_words.size() > word_id))); string wit = (unk_strs_iter->second.size() > 0) ? unk_strs_iter->second : int_to_words[word_id]; - const string& pos_tag = int_to_pos[pos.find(token_index)->second]; + const string& pos_tag = int_to_pos[sentence.poses.at(token_index)]; unsigned parent = tree.GetParent(token_index); if (parent == Corpus::ROOT_TOKEN_ID) parent = 0; diff --git a/parser/lstm-parser.h b/parser/lstm-parser.h index 5d89922..e6a7e09 100644 --- a/parser/lstm-parser.h +++ b/parser/lstm-parser.h @@ -20,6 +20,7 @@ #include "cnn/rnn.h" #include "corpus.h" #include "eos/portable_archive.hpp" +#include "neural-transition-tagger.h" namespace lstm_parser { @@ -66,26 +67,44 @@ struct ParserOptions { }; +// Barebones representation of a parse tree. class ParseTree { public: - static std::string NO_LABEL; - // Barebones representation of a parse tree. - const std::map& sentence; + static const std::string NO_LABEL; - ParseTree(const std::map& sentence, bool labeled = true) : - sentence(sentence), - arc_labels( labeled ? new std::map : nullptr) { - } + double logprob; + + ParseTree(const Sentence& sentence, bool labeled = true) : + logprob(0), + arc_labels(labeled ? new std::map : nullptr), + sentence(sentence), root_child(-1) {} + + ParseTree(const ParseTree& other) + : logprob(other.logprob), parents(other.parents), + arc_labels(other.IsLabeled() ? + new std::map(*other.arc_labels) : nullptr), + sentence(other.sentence), root_child(-1) {} + + ParseTree(ParseTree&& other) = default; - inline void SetParent(unsigned child_index, unsigned parent_index, + ParseTree& operator=(ParseTree&& other) = default; + + void SetParent(unsigned child_index, unsigned parent_index, const std::string& arc_label="") { parents[child_index] = parent_index; - if (arc_labels) { + if (IsLabeled()) { (*arc_labels)[child_index] = arc_label; } + if (parent_index == Corpus::ROOT_TOKEN_ID) { + root_child = child_index; + } + } + + const Sentence& GetSentence() const { + return sentence.get(); } - const inline unsigned GetParent(unsigned child) const { + const unsigned GetParent(unsigned child) const { auto parent_iter = parents.find(child); if (parent_iter == parents.end()) { return Corpus::ROOT_TOKEN_ID; // This is the best guess we've got. @@ -94,8 +113,8 @@ class ParseTree { } } - const inline std::string& GetArcLabel(unsigned child) const { - if (!arc_labels) + const std::string& GetArcLabel(unsigned child) const { + if (!IsLabeled()) return NO_LABEL; auto arc_label_iter = arc_labels->find(child); if (arc_label_iter == arc_labels->end()) { @@ -105,23 +124,24 @@ class ParseTree { } } -private: + const unsigned GetRootChild() const { return root_child; } + + bool IsLabeled() const { return arc_labels.get(); } + +protected: std::map parents; std::unique_ptr> arc_labels; + std::reference_wrapper sentence; + unsigned root_child; }; -class LSTMParser { +class LSTMParser : public NeuralTransitionTagger { public: - // TODO: make some of these members non-public ParserOptions options; - CorpusVocabulary vocab; - cnn::Model model; - bool finalized; std::unordered_map> pretrained; unsigned n_possible_actions; - const unsigned kUNK; const unsigned kROOT_SYMBOL; cnn::LSTMBuilder stack_lstm; // (layers, input, hidden, trainer) @@ -155,89 +175,112 @@ class LSTMParser { bool finalize=true); explicit LSTMParser(const std::string& model_path) : - kUNK(vocab.GetOrAddWord(vocab.UNK)), kROOT_SYMBOL(vocab.GetOrAddWord(vocab.ROOT)) { - std::cerr << "Loading model from " << model_path << "..."; + std::cerr << "Loading parser model from " << model_path << "..."; auto t_start = std::chrono::high_resolution_clock::now(); std::ifstream model_file(model_path.c_str(), std::ios::binary); + if (!model_file) { + std::cerr << "Unable to open model file; aborting" << std::endl; + abort(); + } eos::portable_iarchive archive(model_file); archive >> *this; auto t_end = std::chrono::high_resolution_clock::now(); auto ms_passed = std::chrono::duration(t_end - t_start).count(); - std::cerr << "done. (Loading took " << ms_passed << " milliseconds.)" << std::endl; + std::cerr << "done. (Loading took " << ms_passed << " milliseconds.)" + << std::endl; } template explicit LSTMParser(Archive* archive) : - kUNK(vocab.GetOrAddWord(vocab.UNK)), kROOT_SYMBOL(vocab.GetOrAddWord(vocab.ROOT)) { *archive >> *this; } - static bool IsActionForbidden(const std::string& a, unsigned bsize, - unsigned ssize, const std::vector& stacki); - - ParseTree Parse(const std::map& sentence, - const std::map& sentence_pos, - const CorpusVocabulary& vocab, bool labeled, double* correct); + ParseTree Parse(const Sentence& sentence, + const CorpusVocabulary& vocab, bool labeled); // take a vector of actions and return a parse tree ParseTree RecoverParseTree( - const std::map& sentence, - const std::vector& actions, - const std::vector& action_names, - const std::vector& actions_to_arc_labels, - bool labeled = false); - - void Train(const TrainingCorpus& corpus, const TrainingCorpus& dev_corpus, - const double unk_prob, const std::string& model_fname, + const Sentence& sentence, const std::vector& actions, + double logprob = 0, bool labeled = false) const; + + void Train(const ParserTrainingCorpus& corpus, + const ParserTrainingCorpus& dev_corpus, const double unk_prob, + const std::string& model_fname, const volatile bool* requested_stop = nullptr); void Test(const Corpus& corpus) { DoTest(corpus, false, true); } - void Evaluate(const TrainingCorpus& corpus, bool output_parses=false) { + void Evaluate(const ParserTrainingCorpus& corpus, bool output_parses=false) { DoTest(corpus, true, output_parses); } - // Used for testing. Replaces OOV with UNK. - std::vector LogProbParser( - const std::map& sentence, - const std::map& sentence_pos, - const CorpusVocabulary& vocab, cnn::ComputationGraph *cg, - double* correct); - void LoadPretrainedWords(const std::string& words_path); - void FinalizeVocab(); - protected: - // *** if correct_actions is empty, this runs greedy decoding *** - // returns parse actions for input sentence (in training just returns the - // reference) - // OOV handling: raw_sent will have the actual words - // sent will have words replaced by appropriate UNK tokens - // this lets us use pretrained embeddings, when available, for words that were - // OOV in the parser training data. - std::vector LogProbParser( - cnn::ComputationGraph* hg, - const std::map& raw_sent, // raw sentence - const std::map& sent, // sentence with OOVs replaced - const std::map& sentPos, - const std::vector& correct_actions, - const std::vector& action_names, - const std::vector& int_to_words, double* right); - - void SaveModel(const std::string& model_fname, bool softlink_created); + struct ParserState : public TaggerState { + std::vector buffer; + std::vector bufferi; // position of the words in the sentence + std::vector stack; // subtree embeddings + std::vector stacki; // word position in sentence of head of subtree + + ParserState(const Sentence& raw_sentence, + const Sentence::SentenceMap& sentence, Expression stack_guard) + : TaggerState(raw_sentence, sentence), buffer(raw_sentence.Size() + 1), + bufferi(raw_sentence.Size() + 1), stack({stack_guard}), + stacki({-999}) {} + + ~ParserState() { + assert(stack.size() == 2); // guard symbol, root + assert(stacki.size() == 2); + assert(buffer.size() == 1); // guard symbol + assert(bufferi.size() == 1); + } + }; + + virtual std::vector GetParameters() override { + std::vector all_params {p_pbias, p_H, p_D, p_R, p_cbias, + p_S, p_B, p_A, p_ib, p_w2l, p_p2a, p_abias, p_action_start, + p_stack_guard}; + if (options.use_pos) + all_params.push_back(p_p2l); + if (p_t2l) + all_params.push_back(p_t2l); + return all_params; + } + + virtual TaggerState* InitializeParserState( + cnn::ComputationGraph* cg, const Sentence& raw_sent, + const Sentence::SentenceMap& sent, // sentence with OOVs replaced + const std::vector& correct_actions) override; + + virtual void InitializeNetworkParameters() override; + + virtual bool ShouldTerminate(TaggerState* state) const override { + const ParserState& real_state = static_cast(*state); + return real_state.stack.size() <= 2 && real_state.buffer.size() <= 1; + } + + virtual bool IsActionForbidden(const unsigned action, + TaggerState* state) const override; + + virtual cnn::expr::Expression GetActionProbabilities(TaggerState* state) + override; + + virtual void DoAction( + unsigned action, TaggerState* state, cnn::ComputationGraph* cg, + std::map* states_to_expose) override; inline unsigned ComputeCorrect(const ParseTree& ref, const ParseTree& hyp) const { - assert(ref.sentence.size() == hyp.sentence.size()); + assert(ref.GetSentence().Size() == hyp.GetSentence().Size()); unsigned correct_count = 0; - for (const auto& token_index_and_word : ref.sentence) { + for (const auto& token_index_and_word : ref.GetSentence().words) { unsigned i = token_index_and_word.first; if (i != Corpus::ROOT_TOKEN_ID && ref.GetParent(i) == hyp.GetParent(i)) ++correct_count; @@ -245,6 +288,10 @@ class LSTMParser { return correct_count; } + virtual void DoSave(eos::portable_oarchive& archive) override { + archive << *this; + } + private: friend class boost::serialization::access; @@ -253,7 +300,7 @@ class LSTMParser { ar & options; ar & vocab; ar & pretrained; - ar & model; + ar & *model; } template @@ -266,27 +313,25 @@ class LSTMParser { ar & pretrained; // Don't finalize yet...we want to finalize once our model is initialized. - model = cnn::Model(); + model.reset(new cnn::Model); // Reset the LSTMs *before* reading in the network model, to make sure the // model knows how big it's supposed to be. stack_lstm = cnn::LSTMBuilder(options.layers, options.lstm_input_dim, - options.hidden_dim, &model); + options.hidden_dim, model.get()); buffer_lstm = cnn::LSTMBuilder(options.layers, options.lstm_input_dim, - options.hidden_dim, &model); + options.hidden_dim, model.get()); action_lstm = cnn::LSTMBuilder(options.layers, options.action_dim, - options.hidden_dim, &model); + options.hidden_dim, model.get()); - FinalizeVocab(); // OK, now finalize. :) + FinalizeVocab(); // OK, now finalize. :) (Also initializes network params.) - ar & model; + ar & *model; } BOOST_SERIALIZATION_SPLIT_MEMBER(); void DoTest(const Corpus& corpus, bool evaluate, bool output_parses); - static void OutputConll(const std::map& sentence, - const std::map& pos, - const std::map& sentence_unk_strings, + static void OutputConll(const Sentence& sentence, const std::vector& int_to_words, const std::vector& int_to_pos, const std::map& words_to_int, diff --git a/parser/neural-transition-tagger.cpp b/parser/neural-transition-tagger.cpp new file mode 100644 index 0000000..11f8ec1 --- /dev/null +++ b/parser/neural-transition-tagger.cpp @@ -0,0 +1,157 @@ +#include "neural-transition-tagger.h" + +#include +#include +#include +#include + +#include "cnn/expr.h" +#include "cnn/model.h" +#include "eos/portable_archive.hpp" + +using namespace std; +using namespace cnn; +using namespace cnn::expr; + +namespace lstm_parser { + +const cnn::expr::Expression NeuralTransitionTagger::USE_ORACLE( + nullptr, cnn::VariableIndex(static_cast(-1))); + + +void NeuralTransitionTagger::SaveModel(const string& model_fname, + bool softlink_created) { + boost::filesystem::path model_dir_path(model_fname); + model_dir_path.remove_filename(); + if (boost::filesystem::create_directories(model_dir_path)) { + cerr << "Created directory " << model_dir_path << endl; + } + + ofstream out_file(model_fname); + eos::portable_oarchive archive(out_file); + DoSave(archive); + cerr << "Model saved." << endl; + // Create a soft link to the most recent model in order to make it + // easier to refer to it in a shell script. + if (false) { + string softlink = "latest_model.params"; + + if (system((string("rm -f ") + softlink).c_str()) == 0 + && system(("ln -s " + model_fname + " " + softlink).c_str()) == 0) { + cerr << "Created " << softlink << " as a soft link to " << model_fname + << " for convenience." << endl; + } + } +} + + +void NeuralTransitionTagger::FinalizeVocab() { + if (finalized) + return; + if (!model.get()) + model.reset(new Model); + InitializeNetworkParameters(); + // Give up memory we don't need. + vocab.action_names.shrink_to_fit(); + vocab.actions_to_arc_labels.shrink_to_fit(); + vocab.int_to_chars.shrink_to_fit(); + vocab.int_to_pos.shrink_to_fit(); + vocab.int_to_training_word.shrink_to_fit(); + vocab.int_to_words.shrink_to_fit(); + finalized = true; +} + + +Sentence::SentenceMap NeuralTransitionTagger::ReplaceUnknowns( + const Sentence& sentence) { + Sentence::SentenceMap tsentence(sentence.words); // sentence w/ OOVs replaced + for (auto& index_and_id : tsentence) { + // use reference to overwrite + if (index_and_id.second >= vocab.int_to_training_word.size() + || !vocab.int_to_training_word[index_and_id.second]) { + index_and_id.second = vocab.kUNK; + } + } + return tsentence; +} + + +vector NeuralTransitionTagger::LogProbTagger( + ComputationGraph* cg, + const Sentence& raw_sent, // raw sentence + const Sentence::SentenceMap& sent, // sentence with OOVs replaced + bool training, + const vector& correct_actions, double* correct, + map* states_to_expose) { + in_training = training; + if (training) + assert(!correct_actions.empty()); + assert(finalized); + vector results; + + // variables in the computation graph representing the parameters + for (Parameters *params : GetParameters()) { + param_expressions[params] = parameter(*cg, params); + } + + unique_ptr state( + InitializeParserState(cg, raw_sent, sent, correct_actions)); + + vector log_probs; + unsigned action_count = 0; // incremented at each prediction + while (!ShouldTerminate(state.get())) { + // Get list of possible actions for the current parser state. + vector current_valid_actions; + for (unsigned action = 0; action < vocab.action_names.size(); ++action) { + if (IsActionForbidden(action, state.get())) + continue; + current_valid_actions.push_back(action); + } + + Expression r_t = GetActionProbabilities(state.get()); + unsigned action; + if (r_t.pg == USE_ORACLE.pg && r_t.i == USE_ORACLE.i) { + assert(!correct_actions.empty() && action_count < correct_actions.size()); + action = correct_actions[action_count]; + // cerr << "Using oracle action: " << vocab.action_names[action] << endl; + } else { + // adist = log_softmax(r_t, current_valid_actions) + Expression adiste = log_softmax(r_t, current_valid_actions); + vector adist = as_vector(cg->incremental_forward()); + double best_score = adist[current_valid_actions[0]]; + unsigned best_a = current_valid_actions[0]; + for (unsigned i = 1; i < current_valid_actions.size(); ++i) { + if (adist[current_valid_actions[i]] > best_score) { + best_score = adist[current_valid_actions[i]]; + best_a = current_valid_actions[i]; + } + } + action = best_a; + + if (!correct_actions.empty()) { + assert(action_count < correct_actions.size() || !training); + unsigned correct_action = correct_actions[action_count]; + if (correct && best_a == correct_action) { + (*correct)++; + } + // If we're training, use the reference action. + if (training) + action = correct_action; + } + log_probs.push_back(pick(adiste, action)); + } + ++action_count; + results.push_back(action); + + DoAction(action, state.get(), cg, states_to_expose); + } + + Expression tot_neglogprob = -sum(log_probs); + assert(tot_neglogprob.pg != nullptr); + + param_expressions.clear(); + return results; +} + + +} /* namespace lstm_parser */ diff --git a/parser/neural-transition-tagger.h b/parser/neural-transition-tagger.h new file mode 100644 index 0000000..b4afa5c --- /dev/null +++ b/parser/neural-transition-tagger.h @@ -0,0 +1,115 @@ +#ifndef LSTM_PARSER_PARSER_NEURAL_TRANSITION_TAGGER_H_ +#define LSTM_PARSER_PARSER_NEURAL_TRANSITION_TAGGER_H_ + +#include +#include +#include + +#include "cnn/expr.h" +#include "cnn/model.h" +#include "corpus.h" + +namespace eos { +class portable_oarchive; +} + +namespace lstm_parser { + +class NeuralTransitionTagger { +public: + NeuralTransitionTagger() : finalized(false), in_training(false), + model(new cnn::Model) {} + virtual ~NeuralTransitionTagger() {} + + void FinalizeVocab(); + + // Used for testing. Replaces OOV with UNK. + std::vector LogProbTagger( + cnn::ComputationGraph *cg, + const Sentence& sentence, + bool replace_unknowns = true, + std::map* states_to_expose = + nullptr) { + return LogProbTagger( + cg, sentence, + replace_unknowns ? ReplaceUnknowns(sentence) : sentence.words, + false, std::vector(), nullptr, states_to_expose); + } + + // *** if correct_actions is empty, this runs greedy decoding *** + // returns actions for input sentence (in training just returns the reference) + // OOV handling: raw_sent will have the actual words + // sent will have words replaced by appropriate UNK tokens + // this lets us use pretrained embeddings, when available, for words that were + // OOV in the training data. + std::vector LogProbTagger( + cnn::ComputationGraph* cg, + const Sentence& sentence, // raw sentence + const Sentence::SentenceMap& sent, // sentence with OOVs replaced + bool training = false, + const std::vector& correct_actions = std::vector(), + double* correct = nullptr, + std::map* states_to_expose = nullptr); + + const CorpusVocabulary& GetVocab() const { return vocab; } + + // TODO: arrange things such that we don't need to expose this? + CorpusVocabulary* GetVocab() { return &vocab; } + +protected: + struct TaggerState { + TaggerState(const Sentence& raw_sentence, + const Sentence::SentenceMap& sentence) + : raw_sentence(raw_sentence), sentence(sentence) {} + const Sentence& raw_sentence; + const Sentence::SentenceMap& sentence; + virtual ~TaggerState() {} + }; + + // Special network pseudo-node for signaling that an oracle action should + // be used. + static const cnn::expr::Expression USE_ORACLE; + + bool finalized; + bool in_training; // expose to virtual fns whether we're doing training + std::map param_expressions; + + // Store the model as a smart ptr so we can call its destructor when needed. + std::unique_ptr model; + CorpusVocabulary vocab; + + inline cnn::expr::Expression GetParamExpr(cnn::Parameters* params) { + return param_expressions.at(params); + } + + virtual std::vector GetParameters() = 0; + + virtual TaggerState* InitializeParserState( + cnn::ComputationGraph* hg, const Sentence& raw_sent, + const Sentence::SentenceMap& sent, // sentence with OOVs replaced + const std::vector& correct_actions) = 0; + + virtual cnn::expr::Expression GetActionProbabilities( + TaggerState* state) = 0; + + virtual bool ShouldTerminate(TaggerState* state) const = 0; + + virtual bool IsActionForbidden(const unsigned action, + TaggerState* state) const = 0; + + virtual void DoAction( + unsigned action, TaggerState* state, cnn::ComputationGraph* cg, + std::map* states_to_expose) = 0; + + virtual void DoSave(eos::portable_oarchive& archive) = 0; + + virtual void InitializeNetworkParameters() = 0; + + void SaveModel(const std::string& model_fname, bool softlink_created); + + Sentence::SentenceMap ReplaceUnknowns(const Sentence& sentence); +}; + +} /* namespace lstm_parser */ + +#endif /* LSTM_PARSER_PARSER_NEURAL_TRANSITION_TAGGER_H_ */