diff --git a/src/agents/query_engine/query_element/BUILD b/src/agents/query_engine/query_element/BUILD index 60b431de..c0cffdeb 100644 --- a/src/agents/query_engine/query_element/BUILD +++ b/src/agents/query_engine/query_element/BUILD @@ -7,6 +7,7 @@ cc_library( includes = ["."], deps = [ ":and", + ":chain", ":iterator", ":link_template", ":operator", @@ -44,6 +45,19 @@ cc_library( ], ) +cc_library( + name = "chain", + srcs = ["Chain.cc"], + hdrs = ["Chain.h"], + deps = [ + ":operator", + "//atomdb:atomdb_singleton", + "//commons:commons_lib", + "//commons/atoms:atoms_lib", + "//commons/processor:processor_lib", + ], +) + cc_library( name = "terminal", srcs = ["Terminal.cc"], diff --git a/src/agents/query_engine/query_element/Chain.cc b/src/agents/query_engine/query_element/Chain.cc new file mode 100644 index 00000000..9cca531a --- /dev/null +++ b/src/agents/query_engine/query_element/Chain.cc @@ -0,0 +1,414 @@ +#include "Chain.h" + +#include "AtomDBSingleton.h" +#include "Hasher.h" +#include "Logger.h" +#include "ThreadSafeHeap.h" + +using namespace query_element; +using namespace atomdb; +using namespace commons; + +string Chain::ORIGIN_VARIABLE_NAME = "origin"; +string Chain::DESTINY_VARIABLE_NAME = "destiny"; + +static string convert_handle(const string& handle) { +#if LOG_LEVEL >= DEBUG_LEVEL + shared_ptr node = + dynamic_pointer_cast(AtomDBSingleton::get_instance()->get_atom(handle)); + if (node != nullptr) { + return node->name; + } +#endif + return handle; +} + +// ------------------------------------------------------------------------------------------------- +// Public methods + +Chain::Chain(const array, 1>& clauses, + const string& source_handle, + const string& target_handle) + : Operator<1>(clauses), source_handle(source_handle), target_handle(target_handle) { + initialize(clauses); +} + +Chain::~Chain() { + LOG_DEBUG("Chain::~Chain() BEGIN"); + graceful_shutdown(); + delete this->forward_path_finder; + delete this->backward_path_finder; + LOG_DEBUG("~Chain::Chain() END"); +} + +shared_ptr Chain::get_source_index(const string& key) { + lock_guard semaphore(this->source_index_mutex); + auto it = this->source_index.find(key); + if (it == this->source_index.end()) { + return nullptr; + } else { + return it->second; + } +} + +shared_ptr Chain::get_target_index(const string& key) { + lock_guard semaphore(this->target_index_mutex); + auto it = this->target_index.find(key); + if (it == this->target_index.end()) { + return nullptr; + } else { + return it->second; + } +} + +// -------------------------------------------------------------------------------------------- +// QueryElement API + +void Chain::setup_buffers() { + LOG_DEBUG("Chain::setup_buffers() BEGIN"); + Operator<1>::setup_buffers(); + this->operator_thread = make_shared(this->id + ":main_thread", this); + this->operator_thread->setup(); + this->operator_thread->start(); + this->forward_thread = + make_shared(this->id + ":forward_thread", this->forward_path_finder); + this->forward_thread->setup(); + this->forward_thread->start(); + this->backward_thread = + make_shared(this->id + ":backward_thread", this->backward_path_finder); + this->backward_thread->setup(); + this->backward_thread->start(); + LOG_DEBUG("Chain::setup_buffers() END"); +} + +void Chain::graceful_shutdown() { + LOG_DEBUG("Chain::graceful_shutdown() BEGIN"); + if (!this->forward_thread->is_finished()) { + this->forward_thread->stop(); + } + if (!this->backward_thread->is_finished()) { + this->backward_thread->stop(); + } + if (!this->operator_thread->is_finished()) { + this->operator_thread->stop(); + } + Operator<1>::graceful_shutdown(); + LOG_DEBUG("Chain::graceful_shutdown() END"); +} + +// -------------------------------------------------------------------------------------------- +// ThreadMethod API + +bool Chain::PathFinder::conditional_refeed(Path& path, + shared_ptr& candidates_heap, + unsigned int count_cycles) { + if (this->chain_operator->all_input_acknowledged() && + (candidates_heap->empty() || (count_cycles == candidates_heap->size()))) { + LOG_DEBUG("[PATH_FINDER] " + << "All input is acknowledged. Discarding dead-end path: " << path.to_string()); + return false; + } else { + LOG_DEBUG("[PATH_FINDER] " + << "Still acknowledging input. Pushing " << path.to_string() + << " back to refeeding buffer."); + this->chain_operator->refeeding_buffer.push(path); + return true; + } +} + +bool Chain::PathFinder::thread_one_step() { +#if LOG_LEVEL >= DEBUG_LEVEL + lock_guard semaphore(this->chain_operator->thread_debug_mutex); +#endif + if (this->chain_operator->all_paths_explored()) { + return false; + } + LOG_DEBUG("[PATH_FINDER] " << (this->forward_flag ? "FORWARD" : "BACKWARD") << " PathFinder STEP"); + shared_ptr base_heap = this->forward_flag + ? this->chain_operator->get_source_index(this->origin) + : this->chain_operator->get_target_index(this->origin); + + if (base_heap->empty()) { + LOG_DEBUG("[PATH_FINDER] " + << "Empty base_heap. Trying to refeed paths."); + this->chain_operator->refeed_paths(); + if (base_heap->empty()) { + LOG_DEBUG("[PATH_FINDER] " + << "No paths to refeed."); + if (this->chain_operator->all_input_acknowledged()) { + // double check is required to avoid race condition + shared_ptr check_heap = + this->forward_flag ? this->chain_operator->get_source_index(this->origin) + : this->chain_operator->get_target_index(this->origin); + if (check_heap == nullptr || check_heap->empty()) { + this->chain_operator->set_all_paths_explored(true); + LOG_DEBUG("[PATH_FINDER] " + << "All paths has been explored"); + } + } + return false; + } else { + LOG_DEBUG("[PATH_FINDER] " + << "Paths has been refed."); + } + } + + Path previous_path = base_heap->top_and_pop(); + LOG_DEBUG("[PATH_FINDER] " + << "Popped: " + previous_path.to_string()); + if (previous_path.end_point() == this->destiny) { + LOG_DEBUG("[PATH_FINDER] " + << "Found complete path: " << previous_path.to_string()); + this->chain_operator->report_path(previous_path); + return true; + } + + LOG_DEBUG("[PATH_FINDER] " + << "Searching candidate paths " << (this->forward_flag ? "FROM " : "TO ") + << convert_handle(previous_path.end_point())); + shared_ptr candidates_heap = + this->forward_flag ? this->chain_operator->get_source_index(previous_path.end_point()) + : this->chain_operator->get_target_index(previous_path.end_point()); + if (candidates_heap->empty()) { + LOG_DEBUG("[PATH_FINDER] " + << "Found no candidates."); + return !conditional_refeed(previous_path, candidates_heap, 0); + } else { + LOG_DEBUG("[PATH_FINDER] " + << "Found " << candidates_heap->size()); + } + + vector candidates; + candidates_heap->snapshot(candidates); + Path new_path(this->forward_flag); + Path best_path(this->forward_flag); + double best_sti = -1; + unsigned int count_cycles = 0; + for (Path candidate : candidates) { + LOG_DEBUG("[PATH_FINDER] " + << "Candidate: " << candidate.to_string()); + if (previous_path.allow_concatenation(candidate)) { + new_path = previous_path; + new_path.concatenate(candidate); + if (candidate.path_sti > best_sti) { + LOG_DEBUG("[PATH_FINDER] " + << "Candidate is the best so far. Resulting path is: " + << new_path.to_string()); + best_sti = candidate.path_sti; + best_path = new_path; + } + LOG_DEBUG("[PATH_FINDER] " + << "Pushing new path: " << new_path.to_string()); + base_heap->push(new_path, new_path.path_sti); + } else { + count_cycles++; + } + } + if (best_sti >= 0) { + LOG_DEBUG("[PATH_FINDER] " + << "Best path: " << best_path.to_string()); + this->chain_operator->report_path(best_path); + return true; + } else { + LOG_DEBUG("[PATH_FINDER] " + << "No suitable candidate."); + return !conditional_refeed(previous_path, candidates_heap, count_cycles); + } +} + +void Chain::refeed_paths() { + while (!this->refeeding_buffer.empty()) { + Path path = refeeding_buffer.front_and_pop(); + if (path.forward_flag) { + this->source_index[this->source_handle]->push(path, path.path_sti); + } else { + this->target_index[this->target_handle]->push(path, path.path_sti); + } + } +} + +bool Chain::thread_one_step() { + QueryAnswer* answer; + + if (all_paths_explored()) { + if (!this->forward_thread->is_finished()) { + LOG_DEBUG("[CHAIN OPERATOR] " + << "All paths explored. Stopping path finders..."); + this->forward_thread->stop(); + this->backward_thread->stop(); + LOG_DEBUG("[CHAIN OPERATOR] " + << "All paths explored. Stopping path finders. DONE"); + LOG_DEBUG("[CHAIN OPERATOR] " + << "All paths explored. Notifying output buffer..."); + this->output_buffer->query_answers_finished(); + LOG_DEBUG("[CHAIN OPERATOR] " + << "All paths explored. Notifying output buffer. DONE"); + } + return false; + } + if (all_input_acknowledged()) { + return false; + } +#if LOG_LEVEL >= DEBUG_LEVEL + { + lock_guard semaphore(this->thread_debug_mutex); +#endif + if ((answer = dynamic_cast(this->input_buffer[0]->pop_query_answer())) != NULL) { + LOG_DEBUG("[CHAIN OPERATOR] " + << "New query answer: " << answer->to_string()); + for (string handle : answer->handles) { + auto iterator = this->known_links.find(handle); + if (iterator == this->known_links.end()) { + this->known_links.insert(iterator, handle); + shared_ptr link = + dynamic_pointer_cast(AtomDBSingleton::get_instance()->get_atom(handle)); + if (link == nullptr) { + Utils::error("Invalid query answer in Chain operator."); + } else { + LOG_DEBUG("[CHAIN OPERATOR] " + << "Valid link"); + } + LOG_DEBUG("[CHAIN OPERATOR] " + << "New link: " << link->to_string()); + if (link->arity() == 3) { + { + lock_guard semaphore(this->source_index_mutex); + for (unsigned int i = 1; i <= 2; i++) { + if (this->source_index.find(link->targets[i]) == + this->source_index.end()) { + this->source_index[link->targets[i]] = make_shared(); + } + } + this->source_index[link->targets[1]]->push(Path(link, answer, true), + answer->importance); + } + { + lock_guard semaphore(this->target_index_mutex); + for (unsigned int i = 1; i <= 2; i++) { + if (this->target_index.find(link->targets[i]) == + this->target_index.end()) { + this->target_index[link->targets[i]] = make_shared(); + } + } + this->target_index[link->targets[2]]->push( + Path(link, QueryAnswer::copy(answer), false), answer->importance); + } + } else { + Utils::error("Invalid Link " + link->to_string() + " with arity " + + std::to_string(link->arity()) + " in CHAIN operator."); + break; + } + } else { + LOG_DEBUG("[CHAIN OPERATOR] " + << "Discarding already inserted handle: " << convert_handle(handle)); + } + } + refeed_paths(); + return true; + } else { + if (this->input_buffer[0]->is_query_answers_finished() && + this->input_buffer[0]->is_query_answers_empty()) { + LOG_DEBUG("[CHAIN OPERATOR] " + << "All input has been acknowledged"); + this->set_all_input_acknowledged(true); + } + return false; + } +#if LOG_LEVEL >= DEBUG_LEVEL + } +#endif +} + +void Chain::report_path(Path& path) { + QueryAnswer* query_answer = new QueryAnswer(path.path_sti); + if (path.forward_flag) { + for (auto pair : path.links) { + query_answer->add_handle(pair.first->handle()); // TODO change to use handle in query_answer + if (!query_answer->merge(pair.second.get())) { + Utils::error("Incompatible assignments in Chain operator answer: " + + query_answer->to_string() + " + " + pair.second->to_string()); + } + } + } else { + for (auto pair = path.links.rbegin(); pair != path.links.rend(); ++pair) { + query_answer->add_handle(pair->first->handle()); + if (!query_answer->merge(pair->second.get())) { + Utils::error("Incompatible assignments in Chain operator answer: " + + query_answer->to_string() + " + " + pair->second->to_string()); + } + } + } + string answer_hash = Hasher::composite_handle(query_answer->handles); + if (this->reported_answers.find(answer_hash) == this->reported_answers.end()) { + this->reported_answers.insert(answer_hash); + query_answer->assignment.assign(ORIGIN_VARIABLE_NAME, path.start_point()); + query_answer->assignment.assign(DESTINY_VARIABLE_NAME, path.end_point()); + LOG_INFO("Reporting path: " << path.to_string()); + this->output_buffer->add_query_answer(query_answer); + } else { + delete query_answer; + } +} + +void Chain::set_all_input_acknowledged(bool flag) { + lock_guard semaphore(this->all_input_acknowledged_mutex); + this->all_input_acknowledged_flag = flag; +} + +bool Chain::all_input_acknowledged() { + lock_guard semaphore(this->all_input_acknowledged_mutex); + return this->all_input_acknowledged_flag; +} + +void Chain::set_all_paths_explored(bool flag) { + lock_guard semaphore(this->all_paths_explored_mutex); + this->all_paths_explored_flag = flag; +} + +bool Chain::all_paths_explored() { + lock_guard semaphore(this->all_paths_explored_mutex); + return this->all_paths_explored_flag; +} + +// -------------------------------------------------------------------------------------------- +// Private stuff + +void Chain::initialize(const array, 1>& clauses) { + if (clauses.size() != 1) { + Utils::error("Invalid Chain operator with " + std::to_string(clauses.size()) + " clauses."); + } + this->id = "CHAIN(" + clauses[0]->id + ", " + this->source_handle + ", " + this->target_handle + ")"; + this->all_input_acknowledged_flag = false; + this->all_paths_explored_flag = false; + this->forward_path_finder = new PathFinder(this, true); + this->backward_path_finder = new PathFinder(this, false); + this->source_index[this->source_handle] = make_shared(); + this->source_index[this->target_handle] = make_shared(); + this->target_index[this->source_handle] = make_shared(); + this->target_index[this->target_handle] = make_shared(); +} + +string Chain::Path::to_string() { + string answer = ""; + bool first = true; + string last_handle = ""; + string check_handle = ""; + for (auto pair : this->links) { + if (first) { + first = false; + last_handle = + convert_handle(this->forward_flag ? pair.first->targets[1] : pair.first->targets[2]); + answer = last_handle; + } + check_handle = + convert_handle(this->forward_flag ? pair.first->targets[1] : pair.first->targets[2]); + if (check_handle != last_handle) { + LOG_ERROR("Invalid Path"); + } + last_handle = + convert_handle(this->forward_flag ? pair.first->targets[2] : pair.first->targets[1]); + answer += this->forward_flag ? " -> " : " <- "; + answer += last_handle; + } + return answer; +} diff --git a/src/agents/query_engine/query_element/Chain.h b/src/agents/query_engine/query_element/Chain.h new file mode 100644 index 00000000..600ed038 --- /dev/null +++ b/src/agents/query_engine/query_element/Chain.h @@ -0,0 +1,327 @@ +#pragma once + +#include "DedicatedThread.h" +#include "Link.h" +#include "Operator.h" +#include "ThreadSafeHeap.h" +#include "ThreadSafeQueue.h" +#include "map" +#include "mutex" +#include "set" + +using namespace std; +using namespace atoms; +using namespace processor; + +namespace query_element { + +/** + * This operator takes as input a single query element and two handles (SOURCE and TARGET) and + * outputs QueryAnswers which represent paths between SOURCE and TARGET. + * + * Each QueryAnswer in the input is supposed to have exatcly 1 handle, + * i.e. query_answer->handles.size() == 1. In addition to this, the handle is supposed to be the + * handle of a ternary link such as + * + * LINK + * TARGET1 + * TARGET2 + * TARGET3 + * + * TARGET1 is disregarded. TARGET2 and TARGET3 are used to connect the paths. Just to make it + * easy to write an example, lets assume the input handles represent links like these: + * + * (Similarity H1 H2) + * + * Each QueryAnswer in the Chain Operator output will contain N handles, representing a path + * with N links connecting SOURCE and TARGET. Suppose we have a QueryAnswer with N=4, the handles + * in query_answer->handles will point to links like these: + * + * (Similarity SOURCE H1) + * (Similarity H1 H2) + * (Similarity H2 H3) + * (Similarity H3 TARGET) + * + * Chained to form a path between SOURCE and TARGET. Note that the first link target is + * disregarded so you may have something like: + * + * CHAIN SOURCE TARGET + * OR 2 + * LinkTemplate 3 + * Node Equivalence + * Variable v1 + * Variable v2 + * LinkTemplate 3 + * Node Similarity + * Variable v1 + * Variable v2 + * + * This cold produce QueryAnswers paths chaining Similarity and Equivalence links in the same path. + * E.g.: + * + * (Similarity SOURCE H1) + * (Similarity H1 H2) + * (Equivalence H2 H3) + * (Similarity H3 H4) + * (Equivalence H4 H5) + * (Similarrity H5 TARGET) + * + * Also note that, because input QueryAnswer are supposed to have exatcly 1 handle, the following + * query IS NOT valid: + * + * CHAIN SOURCE TARGET + * AND 2 + * LinkTemplate 3 + * Node Equivalence + * Variable v1 + * Variable v2 + * LinkTemplate 3 + * Node Equivalence + * Variable v2 + * Variable v3 + * + * Optionally, ALLOW_INCOMPLETE_CHAIN_PATH can be set true to determine that the Chain Operator + * should output incomplete paths as well as complete ones (the same prioritizartion by + * STI applies to incomplete * paths). All reported incomplete paths will contain either the + * SOURCE as its first element or the TARGET in its end. So, the QueryAnswers of a Chain operator + * may produce paths like these: + * + * (Similarity SOURCE H1) + * (Similarity H1 H2) + * + * (Similarity SOURCE H1) + * (Similarity H1 H2) + * (Similarity H2 H3) + * + * ... + * + * (Similarity H2 H1) + * (Similarity H1 TARGET) + * + * (Similarity H3 H2) + * (Similarity H2 H1) + * (Similarity H1 TARGET) + */ +class Chain : public Operator<1>, public ThreadMethod { + public: + // -------------------------------------------------------------------------------------------- + // Inner types + + class Path { + public: + vector, shared_ptr>> links; + double path_sti; + bool forward_flag; + Path(shared_ptr link, QueryAnswer* answer, bool forward_flag) { + if (link->targets[1] == link->targets[2]) { + Utils::error("Invalid cyclic link: " + link->to_string()); + } + this->links.push_back({link, shared_ptr(answer)}); + this->path_sti = answer->importance; + this->forward_flag = forward_flag; + } + Path(const Path& other) { + this->links = other.links; + this->path_sti = other.path_sti; + this->forward_flag = other.forward_flag; + } + Path(bool forward_flag) { + this->path_sti = 0; + this->forward_flag = forward_flag; + } + Path& operator=(const Path& other) { + this->links = other.links; + this->path_sti = other.path_sti; + this->forward_flag = other.forward_flag; + return *this; + } + inline bool empty() { return this->links.size() == 0; } + inline unsigned int size() { return this->links.size(); } + inline void clear() { + this->links.clear(); + this->path_sti = 0; + } + inline void concatenate(const Path& other) { + if (this->forward_flag != other.forward_flag) { + Utils::error("Invalid attempt to merge incompatible HeapElements"); + } + this->links.insert(this->links.end(), other.links.begin(), other.links.end()); + this->path_sti = max(this->path_sti, other.path_sti); + } + inline string end_point() { + if (this->forward_flag) { + return this->links.back().first->targets[2]; + } else { + return this->links.back().first->targets[1]; + } + } + inline string start_point() { + if (this->forward_flag) { + return this->links.front().first->targets[1]; + } else { + return this->links.front().first->targets[2]; + } + } + inline bool contains(string handle) { + for (auto pair : this->links) { + if ((pair.first->targets[1] == handle) || (pair.first->targets[2] == handle)) { + return true; + } + } + return false; + } + inline bool allow_concatenation(Path& other) { + if ((this->size() == 0) || (other.size() == 0)) { + return true; + } else if (this->end_point() != other.start_point()) { + return false; + } + unsigned int this_index = (this->forward_flag ? 1 : 2); + unsigned int other_index = (this->forward_flag ? 2 : 1); + for (auto pair_other : other.links) { + for (auto pair_this : this->links) { + if (pair_other.first->targets[other_index] == pair_this.first->targets[this_index]) { + return false; + } + } + } + return true; + } + string to_string(); + }; + + typedef ThreadSafeHeap HeapType; + + // -------------------------------------------------------------------------------------------- + // Static variables + + static string ORIGIN_VARIABLE_NAME; + static string DESTINY_VARIABLE_NAME; + + // -------------------------------------------------------------------------------------------- + // Public methods + + /** + * Constructor. + */ + Chain(const array, 1>& clauses, + const string& source_handle, + const string& target_handle); + + /** + * Destructor. + */ + ~Chain(); + + /** + * Thread-safe access to the source_index map. + */ + shared_ptr get_source_index(const string& key); + + /** + * Thread-safe access to the target_index map. + */ + shared_ptr get_target_index(const string& key); + + /** + * Chain Operator thread. + * Report a QueryAnswer to the next query element in the query tree. + */ + void report_path(Path& path); + + /** + * Chain Operator thread. + * Called after antecedent query element notifies that input has been ended and after + * all input has already been acknowledged and properly turned into elementary Paths. + */ + void set_all_input_acknowledged(bool flag); + + /** + * Chain Operator AND Path Finder threads. + * Check if all input has already been aknowledged (see set_all_input_acknowledged()). + */ + bool all_input_acknowledged(); + + /** + * Path Finder thread. + * Called after all possible paths between source and targetr has been explored. Basically, + * notifies that the Chain Operator has ended the job of trying to find paths + * (complete or incomplete). + */ + void set_all_paths_explored(bool flag); + + /** + * Chain Operator AND Path Finder threads. + * Check if the Chain Operator has ended the search for new paths + * (see set_all_paths_explored()). + */ + bool all_paths_explored(); + + /** + * Chain Operator AND Path Finder threads. + * Empties the refeed_buffer, a buffer that stores paths which are supposed to get back to be + * re-evaluated by Pathg Finder when new (so yet unseen) input is read in the Chain Operator. + */ + void refeed_paths(); + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + virtual void setup_buffers(); + virtual void graceful_shutdown(); + + // -------------------------------------------------------------------------------------------- + // ThreadMethod API + + virtual bool thread_one_step(); + + mutex thread_debug_mutex; + + private: + class PathFinder : public ThreadMethod { + public: + Chain* chain_operator; + bool forward_flag; + string origin; + string destiny; + PathFinder(Chain* chain_operator, bool forward_flag) { + this->chain_operator = chain_operator; + this->forward_flag = forward_flag; + if (forward_flag) { + origin = chain_operator->source_handle; + destiny = chain_operator->target_handle; + } else { + origin = chain_operator->target_handle; + destiny = chain_operator->source_handle; + } + } + ~PathFinder() {} + bool thread_one_step(); + bool conditional_refeed(Path& path, + shared_ptr& candidates_heap, + unsigned int count_cycles); + }; + + void initialize(const array, 1>& clauses); + + string source_handle; + string target_handle; + PathFinder* forward_path_finder; + PathFinder* backward_path_finder; + shared_ptr operator_thread; + shared_ptr forward_thread; + shared_ptr backward_thread; + ThreadSafeQueue refeeding_buffer; + set known_links; + set reported_answers; + map> source_index; + map> target_index; + bool all_input_acknowledged_flag; + bool all_paths_explored_flag; + mutex source_index_mutex; + mutex target_index_mutex; + mutex all_input_acknowledged_mutex; + mutex all_paths_explored_mutex; +}; + +} // namespace query_element diff --git a/src/commons/BUILD b/src/commons/BUILD index 584ba253..53a9403e 100644 --- a/src/commons/BUILD +++ b/src/commons/BUILD @@ -24,6 +24,8 @@ cc_library( "StoppableThread.h", "ThreadPool.h", "ThreadSafeHashmap.h", + "ThreadSafeHeap.h", + "ThreadSafeQueue.h", "Utils.h", ], includes = ["."], diff --git a/src/commons/ThreadSafeHeap.h b/src/commons/ThreadSafeHeap.h new file mode 100644 index 00000000..5eaf8f3d --- /dev/null +++ b/src/commons/ThreadSafeHeap.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include + +using namespace std; + +namespace commons { + +/** + * This class provides an abstraction for a thread-safe heap which uses STL priority-queue + * as container. The idea of a heap is to provide O(1) top() and O(log(N)) push() and + * pop(). + * + * Differently from standard STL priority_queue, the key used to sort elements are separated + * from the actual element being "heap-ed". So we have two types passed as template parameters, + * one for the "heap-ed" element and another one for the comparisson key. + * + * So when pushing something into the heapm one need to pass both, the element to be heap-ed + * and its key. When popping (actually top()'ing) just the heap-ed element is returned. + */ +template +class ThreadSafeHeap { + private: + class HeapElement { + public: + T element; + V value; + HeapElement(const HeapElement& other) : element(other.element), value(other.value) {} + HeapElement(const T& element, const V& value) : element(element), value(value) {} + bool operator<(const HeapElement& other) const { return this->value < other.value; } + HeapElement& operator=(const HeapElement& other) { + this->element = other.element; + this->value = other.value; + return *this; + } + }; + + public: + ThreadSafeHeap() {} + ~ThreadSafeHeap() {} + + const T& top() { + lock_guard semaphore(this->api_mutex); + return queue.top().element; + } + + const V& top_value() { + lock_guard semaphore(this->api_mutex); + return queue.top().value; + } + + T top_and_pop() { + lock_guard semaphore(this->api_mutex); + T element = queue.top().element; + queue.pop(); + return element; + } + + void push(const T& element, V value) { + lock_guard semaphore(this->api_mutex); + queue.push(HeapElement(element, value)); + } + + void pop() { + lock_guard semaphore(this->api_mutex); + queue.pop(); + } + + unsigned int size() { + lock_guard semaphore(this->api_mutex); + return queue.size(); + } + + bool empty() { + lock_guard semaphore(this->api_mutex); + return queue.size() == 0; + } + + void snapshot(vector& output) { + priority_queue> copy = this->queue; + while (!copy.empty()) { + output.push_back(copy.top().element); + copy.pop(); + } + } + + private: + mutex api_mutex; + priority_queue> queue; +}; + +} // namespace commons diff --git a/src/commons/ThreadSafeQueue.h b/src/commons/ThreadSafeQueue.h new file mode 100644 index 00000000..7fcef5f9 --- /dev/null +++ b/src/commons/ThreadSafeQueue.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +#include "Utils.h" + +namespace commons { + +template +/** + * This class is a wrapper around std::queue to provide thread-safe access. + */ +class ThreadSafeQueue { + private: + queue _queue; + mutex api_mutex; + + public: + ThreadSafeQueue() { lock_guard semaphore(this->api_mutex); } + + ~ThreadSafeQueue() { lock_guard semaphore(this->api_mutex); } + + void push(const T& element) { + lock_guard semaphore(this->api_mutex); + this->_queue.push(element); + } + + void pop() { + lock_guard semaphore(this->api_mutex); + this->_queue.pop(); + } + + const T& front() { + lock_guard semaphore(this->api_mutex); + return this->_queue.front(); + } + + T front_and_pop() { + lock_guard semaphore(this->api_mutex); + T element = this->_queue.front(); + this->_queue.pop(); + return element; + } + + bool empty() { + lock_guard semaphore(this->api_mutex); + return this->_queue.empty(); + } + + void clear() { + lock_guard semaphore(this->api_mutex); + this->_queue.clear(); + } +}; +} // namespace commons diff --git a/src/commons/processor/DedicatedThread.cc b/src/commons/processor/DedicatedThread.cc index ac7e636e..0cb0f313 100644 --- a/src/commons/processor/DedicatedThread.cc +++ b/src/commons/processor/DedicatedThread.cc @@ -16,6 +16,7 @@ DedicatedThread::DedicatedThread(const string& id, ThreadMethod* job) : Processo this->job = job; this->start_flag = false; this->stop_flag = false; + this->thread_object = NULL; } DedicatedThread::~DedicatedThread() {} @@ -39,10 +40,13 @@ void DedicatedThread::stop() { this->stop_flag = true; this->api_mutex.unlock(); Processor::stop(); - LOG_DEBUG("Joining DedicatedThread " + this->to_string() + "..."); - this->thread_object->join(); - LOG_DEBUG("Joined DedicatedThread " + this->to_string() ". Deleting thread object."); - delete this->thread_object; + if (this->thread_object != NULL) { + LOG_DEBUG("Joining DedicatedThread " + this->to_string() + "..."); + this->thread_object->join(); + LOG_DEBUG("Joined DedicatedThread " + this->to_string() ". Deleting thread object."); + delete this->thread_object; + this->thread_object = NULL; + } } // ------------------------------------------------------------------------------------------------- diff --git a/src/commons/processor/DedicatedThread.h b/src/commons/processor/DedicatedThread.h index d5cafd43..2c36ceb1 100644 --- a/src/commons/processor/DedicatedThread.h +++ b/src/commons/processor/DedicatedThread.h @@ -11,6 +11,7 @@ namespace processor { class ThreadMethod { public: + virtual ~ThreadMethod(){}; virtual bool thread_one_step() = 0; }; diff --git a/src/tests/cpp/BUILD b/src/tests/cpp/BUILD index caf8b328..af0791d5 100644 --- a/src/tests/cpp/BUILD +++ b/src/tests/cpp/BUILD @@ -304,6 +304,37 @@ cc_test( ], ) +cc_test( + name = "chain_operator_test", + size = "small", + srcs = [ + "chain_operator_test.cc", + "test_utils.cc", + "test_utils.h", + ], + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + linkopts = [ + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1, + deps = [ + "//agents/query_engine:query_engine_lib", + "//atomdb:atomdb_singleton", + "//atomdb/inmemorydb:inmemorydb_lib", + "//commons:commons_lib", + "//hasher:hasher_lib", + "@com_github_google_googletest//:gtest_main", + "@mbedtls", + ], +) + cc_test( name = "unique_assignment_filter_test", size = "medium", diff --git a/src/tests/cpp/chain_operator_test.cc b/src/tests/cpp/chain_operator_test.cc new file mode 100644 index 00000000..bbef0842 --- /dev/null +++ b/src/tests/cpp/chain_operator_test.cc @@ -0,0 +1,588 @@ +#include +#include + +#include "AtomDBSingleton.h" +#include "Chain.h" +#include "Hasher.h" +#include "InMemoryDB.h" +#include "Logger.h" +#include "QueryAnswer.h" +#include "Sink.h" +#include "Source.h" +#include "gtest/gtest.h" +#include "test_utils.h" + +using namespace commons; +using namespace query_engine; +using namespace query_element; +using namespace atomdb; +using namespace commons; + +#define SLEEP_DURATION ((unsigned int) 1000) +#define NODE_TYPE "Node" +#define LINK_TYPE "Link" +#define EVALUATION "EVALUATION" +#define UNKNOWN_NODE "UNKNOWN_NODE"; +#define UNKNOWN_LINK "UNKNOWN_LINK"; +#define NODE_COUNT ((unsigned int) 20) + +// Just to help in debugging +#define RUN_allow_concatenation ((bool) true) +#define RUN_allow_concatenation_reverse ((bool) true) +#define RUN_back_after_dead_end ((bool) true) +#define RUN_basics ((bool) true) + +static string EVALUATION_HANDLE = Hasher::node_handle(NODE_TYPE, EVALUATION); + +static set ALL_LINKS; + +static string node_name(unsigned int n) { + if (n == 0) { + return "S"; + } else if (n == (NODE_COUNT + 1)) { + return "T"; + } else { + return std::to_string(n); + } +} + +class ChainOperatorTestEnvironment : public ::testing::Environment { + public: + void load_data() { + auto db = AtomDBSingleton::get_instance(); + atoms::Node *node1, *node2; + atoms::Link* link; + node1 = new atoms::Node(NODE_TYPE, EVALUATION); + LOG_DEBUG("Add node: " + node1->handle() + " " + node1->to_string()); + db->add_node(node1, false); + for (unsigned int i = 0; i <= (NODE_COUNT + 1); i++) { + node1 = new atoms::Node(NODE_TYPE, node_name(i)); + db->add_node(node1, false); + LOG_DEBUG("Add node: " + node1->handle() + " " + node1->to_string()); + for (unsigned int j = 0; j <= (NODE_COUNT + 1); j++) { + node2 = new atoms::Node(NODE_TYPE, node_name(j)); + LOG_DEBUG("Add node: " + node2->handle() + " " + node2->to_string()); + db->add_node(node2, false); + link = new atoms::Link( + LINK_TYPE, {EVALUATION_HANDLE, node1->handle(), node2->handle()}, true); + LOG_DEBUG("Add link: " + link->handle() + " " + link->to_string()); + db->add_link(link, false); + } + } + } + + void SetUp() override { + auto atomdb = new InMemoryDB("chain_operator_test_"); + AtomDBSingleton::provide(shared_ptr(atomdb)); + this->load_data(); + } + + void TearDown() override {} +}; + +class ChainOperatorTest : public ::testing::Test { + protected: + void SetUp() override { + auto atomdb = AtomDBSingleton::get_instance(); + db = dynamic_pointer_cast(atomdb); + ASSERT_NE(db, nullptr) << "Failed to cast AtomDB to InMemoryDB"; + } + + void TearDown() override {} + shared_ptr db; +}; + +static string get_node_string(string handle) { + for (unsigned int i = 0; i <= (NODE_COUNT + 1); i++) { + string node_handle = Hasher::node_handle(NODE_TYPE, node_name(i)); + if (node_handle == handle) { + return node_name(i); + } + } + return UNKNOWN_NODE; +} + +static string get_link_string(string handle, unsigned int select = 0) { + for (unsigned int i = 0; i <= (NODE_COUNT + 1); i++) { + string node1_handle = Hasher::node_handle(NODE_TYPE, node_name(i)); + for (unsigned int j = 0; j <= (NODE_COUNT + 1); j++) { + string node2_handle = Hasher::node_handle(NODE_TYPE, node_name(j)); + string link_handle = + Hasher::link_handle(LINK_TYPE, {EVALUATION_HANDLE, node1_handle, node2_handle}); + if (link_handle == handle) { + if (select == 0) { + return get_node_string(node1_handle) + " -> " + get_node_string(node2_handle); + } else if (select == 1) { + return get_node_string(node1_handle); + } else { + return get_node_string(node2_handle); + } + } + } + } + return UNKNOWN_LINK; +} + +static string link(const string& node1_name, const string& node2_name) { + Node node1(NODE_TYPE, node1_name); + Node node2(NODE_TYPE, node2_name); + Link link(LINK_TYPE, {EVALUATION_HANDLE, node1.handle(), node2.handle()}); + ALL_LINKS.insert(link.handle()); + return link.handle(); +} + +static string link(unsigned int node1, unsigned int node2) { + return link(node_name(node1), node_name(node2)); +} + +class TestSource : public Source { + public: + TestSource(unsigned int count) { this->id = "TestSource_" + std::to_string(count); } + + ~TestSource() {} + + void add(const string& handle, + double importance, + const array& labels, + const array& values, + bool sleep_flag = true) { + QueryAnswer* query_answer = new QueryAnswer(handle, importance); + for (unsigned int i = 0; i < labels.size(); i++) { + query_answer->assignment.assign(labels[i], values[i]); + } + LOG_INFO("Feeding answer in the source element: " + get_link_string(handle)); + this->output_buffer->add_query_answer(query_answer); + if (sleep_flag) { + Utils::sleep(SLEEP_DURATION); + } + } + + void query_answers_finished() { return this->output_buffer->query_answers_finished(); } +}; + +class TestSink : public Sink { + public: + TestSink(shared_ptr precedent) : Sink(precedent, "TestSink(" + precedent->id + ")") {} + ~TestSink() {} + bool empty() { return this->input_buffer->is_query_answers_empty(); } + bool finished() { return this->input_buffer->is_query_answers_finished(); } + QueryAnswer* pop() { return this->input_buffer->pop_query_answer(); } +}; + +static bool check_answer(QueryAnswer* query_answer) { + string origin = get_node_string(query_answer->get(Chain::ORIGIN_VARIABLE_NAME)); + string destiny = get_node_string(query_answer->get(Chain::DESTINY_VARIABLE_NAME)); + EXPECT_TRUE((origin == "S") || (origin == "T")); + if (origin == "S") { + EXPECT_TRUE(get_link_string(query_answer->handles.front(), 1) == "S"); + } else { + EXPECT_TRUE(get_link_string(query_answer->handles.back(), 2) == "T"); + } + bool first = true; + string cursor; + for (string handle : query_answer->handles) { + EXPECT_TRUE(ALL_LINKS.find(handle) != ALL_LINKS.end()); + if (first) { + first = false; + cursor = get_link_string(handle, 1); + } + EXPECT_EQ(get_link_string(handle, 1), cursor); + cursor = get_link_string(handle, 2); + } + return (((origin == "S") && (destiny == "T")) || ((origin == "T") && (destiny == "S"))); +} + +static string answer_path_to_string(QueryAnswer* query_answer) { + bool first = true; + string answer = ""; + string cursor; + for (string handle : query_answer->handles) { + if (first) { + first = false; + answer = cursor = get_link_string(handle, 1); + } + if (get_link_string(handle, 1) != cursor) { + return "to_string() + " " + answer + " + " + + get_link_string(handle) + ">"; + } + answer += " -> "; + cursor = get_link_string(handle, 2); + answer += cursor; + } + return answer; +} + +TEST(ChainOperatorTest, allow_concatenation) { + if (!RUN_allow_concatenation) return; + + shared_ptr ab_link(new Link(LINK_TYPE, {" ", "a", "b"})); + shared_ptr ba_link(new Link(LINK_TYPE, {" ", "b", "a"})); + shared_ptr bc_link(new Link(LINK_TYPE, {" ", "b", "c"})); + shared_ptr ca_link(new Link(LINK_TYPE, {" ", "c", "a"})); + shared_ptr cd_link(new Link(LINK_TYPE, {" ", "c", "d"})); + shared_ptr da_link(new Link(LINK_TYPE, {" ", "d", "a"})); + shared_ptr db_link(new Link(LINK_TYPE, {" ", "d", "b"})); + shared_ptr dc_link(new Link(LINK_TYPE, {" ", "d", "c"})); + shared_ptr dd_link(new Link(LINK_TYPE, {" ", "d", "d"})); + shared_ptr xa_link(new Link(LINK_TYPE, {" ", "x", "a"})); + shared_ptr xb_link(new Link(LINK_TYPE, {" ", "x", "b"})); + shared_ptr xc_link(new Link(LINK_TYPE, {" ", "x", "c"})); + shared_ptr xe_link(new Link(LINK_TYPE, {" ", "x", "e"})); + shared_ptr dx_link(new Link(LINK_TYPE, {" ", "d", "x"})); + + Chain::Path base(true); + Chain::Path new_path(true); + Chain::Path ab(ab_link, new QueryAnswer(0), true); + Chain::Path ba(ba_link, new QueryAnswer(0), true); + Chain::Path bc(bc_link, new QueryAnswer(0), true); + Chain::Path ca(ca_link, new QueryAnswer(0), true); + Chain::Path cd(cd_link, new QueryAnswer(0), true); + Chain::Path da(da_link, new QueryAnswer(0), true); + Chain::Path db(db_link, new QueryAnswer(0), true); + Chain::Path dc(dc_link, new QueryAnswer(0), true); + EXPECT_THROW(Chain::Path dd(dd_link, new QueryAnswer(0), true), runtime_error); + Chain::Path xa(xa_link, new QueryAnswer(0), true); + Chain::Path xb(xb_link, new QueryAnswer(0), true); + Chain::Path xc(xc_link, new QueryAnswer(0), true); + Chain::Path xe(xe_link, new QueryAnswer(0), true); + Chain::Path dx(dx_link, new QueryAnswer(0), true); + + base.clear(); + EXPECT_TRUE(base.allow_concatenation(ab)); + EXPECT_TRUE(base.allow_concatenation(ba)); + base.concatenate(ab); + EXPECT_FALSE(base.allow_concatenation(ab)); + EXPECT_FALSE(base.allow_concatenation(ba)); + + base.clear(); + EXPECT_TRUE(base.allow_concatenation(ab)); + EXPECT_TRUE(base.allow_concatenation(bc)); + EXPECT_TRUE(base.allow_concatenation(ca)); + base.concatenate(ab); + EXPECT_FALSE(base.allow_concatenation(ab)); + EXPECT_TRUE(base.allow_concatenation(bc)); + EXPECT_FALSE(base.allow_concatenation(ca)); + base.concatenate(bc); + EXPECT_FALSE(base.allow_concatenation(ab)); + EXPECT_FALSE(base.allow_concatenation(bc)); + EXPECT_FALSE(base.allow_concatenation(ca)); + + base.clear(); + base.concatenate(ab); + base.concatenate(bc); + base.concatenate(cd); + EXPECT_FALSE(base.allow_concatenation(da)); + EXPECT_FALSE(base.allow_concatenation(db)); + EXPECT_FALSE(base.allow_concatenation(dc)); + + base.clear(); + base.concatenate(ab); + base.concatenate(bc); + base.concatenate(cd); + for (auto hop : {xa, xb, xc}) { + new_path.clear(); + new_path.concatenate(dx); + new_path.concatenate(hop); + EXPECT_FALSE(base.allow_concatenation(new_path)); + } + new_path.clear(); + new_path.concatenate(dx); + new_path.concatenate(xe); + EXPECT_TRUE(base.allow_concatenation(new_path)); + + EXPECT_TRUE(base.contains("a")); + EXPECT_TRUE(base.contains("b")); + EXPECT_TRUE(base.contains("c")); + EXPECT_TRUE(base.contains("d")); + EXPECT_FALSE(base.contains("x")); + EXPECT_FALSE(base.contains("e")); + base.concatenate(new_path); + EXPECT_TRUE(base.contains("a")); + EXPECT_TRUE(base.contains("b")); + EXPECT_TRUE(base.contains("c")); + EXPECT_TRUE(base.contains("d")); + EXPECT_TRUE(base.contains("x")); + EXPECT_TRUE(base.contains("e")); +} + +TEST(ChainOperatorTest, allow_concatenation_reverse) { + if (!RUN_allow_concatenation_reverse) return; + + return; + shared_ptr ab_link(new Link(LINK_TYPE, {" ", "a", "b"})); + shared_ptr bc_link(new Link(LINK_TYPE, {" ", "b", "c"})); + shared_ptr cd_link(new Link(LINK_TYPE, {" ", "c", "d"})); + shared_ptr dd_link(new Link(LINK_TYPE, {" ", "d", "d"})); + shared_ptr dx_link(new Link(LINK_TYPE, {" ", "d", "x"})); + shared_ptr cx_link(new Link(LINK_TYPE, {" ", "c", "x"})); + shared_ptr bx_link(new Link(LINK_TYPE, {" ", "b", "x"})); + shared_ptr ex_link(new Link(LINK_TYPE, {" ", "e", "x"})); + shared_ptr xa_link(new Link(LINK_TYPE, {" ", "x", "a"})); + + Chain::Path base(false); + Chain::Path new_path(false); + Chain::Path ba(ab_link, new QueryAnswer(0), false); + Chain::Path cb(bc_link, new QueryAnswer(0), false); + Chain::Path dc(cd_link, new QueryAnswer(0), false); + EXPECT_THROW(Chain::Path dd(dd_link, new QueryAnswer(0), false), runtime_error); + Chain::Path xd(dx_link, new QueryAnswer(0), false); + Chain::Path xc(cx_link, new QueryAnswer(0), false); + Chain::Path xb(bx_link, new QueryAnswer(0), false); + Chain::Path xe(ex_link, new QueryAnswer(0), false); + Chain::Path ax(xa_link, new QueryAnswer(0), false); + + base.clear(); + base.concatenate(dc); + base.concatenate(cb); + base.concatenate(ba); + for (auto hop : {xd, xc, xb}) { + new_path.clear(); + new_path.concatenate(ax); + new_path.concatenate(hop); + EXPECT_FALSE(base.allow_concatenation(new_path)); + } + new_path.clear(); + new_path.concatenate(ax); + new_path.concatenate(xe); + EXPECT_TRUE(base.allow_concatenation(new_path)); + + EXPECT_TRUE(base.contains("a")); + EXPECT_TRUE(base.contains("b")); + EXPECT_TRUE(base.contains("c")); + EXPECT_TRUE(base.contains("d")); + EXPECT_FALSE(base.contains("x")); + EXPECT_FALSE(base.contains("e")); + base.concatenate(new_path); + EXPECT_TRUE(base.contains("a")); + EXPECT_TRUE(base.contains("b")); + EXPECT_TRUE(base.contains("c")); + EXPECT_TRUE(base.contains("d")); + EXPECT_TRUE(base.contains("x")); + EXPECT_TRUE(base.contains("e")); +} + +TEST(ChainOperatorTest, back_after_dead_end) { + if (!RUN_back_after_dead_end) return; + auto source = make_shared(10); + auto chain_operator = make_shared(array, 1>({source}), + Hasher::node_handle(NODE_TYPE, "S"), + Hasher::node_handle(NODE_TYPE, "T")); + TestSink sink(chain_operator); + + EXPECT_TRUE(sink.empty()); + EXPECT_FALSE(sink.finished()); + + vector node; + unsigned int S = 0; + unsigned int node_count = 20; + unsigned int T = node_count + 1; + node.push_back("S"); + for (unsigned int cursor = S + 1; cursor <= node_count; cursor++) { + node.push_back(std::to_string(cursor)); + } + node.push_back("T"); + + // clang-format off + // + // +----- 1 -- 2 -- 3 --------+ + // | | | | + // +- 4 --+ | | + // | | | + // +-----------+ | + // | | + // +- 5 --+-- 7 | + // | | | + // + +-- 6 -- 8 | + // | | + // S -----+--------------------------+--+----- T + // | | | + // | | | + // +X 9 --+-- 10 -+- 13 | | + // | | | | | + // | | +- 14 X 19 -+ | + // | | | + // | +-- 11 -+- 15 ---------+ + // | | | | + // | | +- 16 | + // | | | + // | +-- 12 -+- 17 -> (9) | + // | | | + // | +- 18 -> (12) | + // | | + // +------------------------ 20 -+ + + Utils::sleep(1000); + source->add(link(S, 1), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 2), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 4), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 5), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 20), 0.5, {"v1"}, {"h1"}, false); + source->add(link(1, 2), 0.5, {"v1"}, {"h1"}, false); + source->add(link(2, 3), 0.5, {"v1"}, {"h1"}, false); + source->add(link(3, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(4, 1), 0.5, {"v1"}, {"h1"}, false); + source->add(link(5, 7), 0.5, {"v1"}, {"h1"}, false); + source->add(link(5, 6), 0.5, {"v1"}, {"h1"}, false); + source->add(link(6, 8), 0.5, {"v1"}, {"h1"}, false); + source->add(link(9, 10), 0.5, {"v1"}, {"h1"}, false); + source->add(link(9, 11), 0.5, {"v1"}, {"h1"}, false); + source->add(link(9, 12), 0.5, {"v1"}, {"h1"}, false); + source->add(link(10, 13), 0.5, {"v1"}, {"h1"}, false); + source->add(link(10, 14), 0.5, {"v1"}, {"h1"}, false); + source->add(link(11, 15), 0.5, {"v1"}, {"h1"}, false); + source->add(link(11, 16), 0.5, {"v1"}, {"h1"}, false); + source->add(link(12, 17), 0.5, {"v1"}, {"h1"}, false); + source->add(link(12, 18), 0.5, {"v1"}, {"h1"}, false); + source->add(link(15, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(17, 9), 0.5, {"v1"}, {"h1"}, false); + source->add(link(18, 12), 0.5, {"v1"}, {"h1"}, false); + source->add(link(19, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(20, T), 0.5, {"v1"}, {"h1"}, false); + Utils::sleep(3000); // TODO remove this + // clang-format on + QueryAnswer* answer; + unsigned int complete_path = 0; + while (complete_path < 5) { + while ((answer = sink.pop()) != NULL) { + LOG_INFO("[" + std::to_string(answer->importance) + "]: " + answer_path_to_string(answer)); + if (check_answer(answer)) { + complete_path++; + } + } + Utils::sleep(500); + } + EXPECT_FALSE(sink.finished()); + EXPECT_EQ(complete_path, 5); + source->add(link(S, 9), 0.5, {"v1"}, {"h1"}, false); + while (complete_path < 6) { + while ((answer = sink.pop()) != NULL) { + LOG_INFO("[" + std::to_string(answer->importance) + "]: " + answer_path_to_string(answer)); + if (check_answer(answer)) { + complete_path++; + } + } + Utils::sleep(500); + } + EXPECT_FALSE(sink.finished()); + EXPECT_EQ(complete_path, 6); + source->add(link(14, 19), 0.5, {"v1"}, {"h1"}, false); + source->query_answers_finished(); + while (!sink.empty() || !sink.finished()) { + while ((answer = sink.pop()) != NULL) { + LOG_INFO("[" + std::to_string(answer->importance) + "]: " + answer_path_to_string(answer)); + if (check_answer(answer)) { + complete_path++; + } + } + Utils::sleep(500); + } + EXPECT_EQ(complete_path, 7); + EXPECT_TRUE(sink.empty()); + EXPECT_TRUE(sink.finished()); +} + +TEST(ChainOperatorTest, basics) { + if (!RUN_basics) return; + auto source = make_shared(10); + auto chain_operator = make_shared(array, 1>({source}), + Hasher::node_handle(NODE_TYPE, "S"), + Hasher::node_handle(NODE_TYPE, "T")); + TestSink sink(chain_operator); + + EXPECT_TRUE(sink.empty()); + EXPECT_FALSE(sink.finished()); + + vector node; + unsigned int S = 0; + unsigned int node_count = 20; + unsigned int T = node_count + 1; + node.push_back("S"); + for (unsigned int cursor = S + 1; cursor <= node_count; cursor++) { + node.push_back(std::to_string(cursor)); + } + node.push_back("T"); + + // clang-format off + // + // +----- 1 -- 2 -- 3 --------+ + // | | | | + // +- 4 --+ | | + // | | | + // +-----------+ | + // | | + // +- 5 --+-- 7 | + // | | | + // + +-- 6 -- 8 | + // | | + // S -----+--------------------------+--+----- T + // | | | + // | | | + // +- 9 --+-- 10 -+- 13 | | + // | | | | | + // | | +- 14 - 19 -+ | + // | | | + // | +-- 11 -+- 15 ---------+ + // | | | | + // | | +- 16 | + // | | | + // | +-- 12 -+- 17 -> (9) | + // | | | + // | +- 18 -> (12) | + // | | + // +------------------------ 20 -+ + + Utils::sleep(1000); + source->add(link(S, 1), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 2), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 4), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 5), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 9), 0.5, {"v1"}, {"h1"}, false); + source->add(link(S, 20), 0.5, {"v1"}, {"h1"}, false); + source->add(link(1, 2), 0.5, {"v1"}, {"h1"}, false); + source->add(link(2, 3), 0.5, {"v1"}, {"h1"}, false); + source->add(link(3, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(4, 1), 0.5, {"v1"}, {"h1"}, false); + source->add(link(5, 7), 0.5, {"v1"}, {"h1"}, false); + source->add(link(5, 6), 0.5, {"v1"}, {"h1"}, false); + source->add(link(6, 8), 0.5, {"v1"}, {"h1"}, false); + source->add(link(9, 10), 0.5, {"v1"}, {"h1"}, false); + source->add(link(9, 11), 0.5, {"v1"}, {"h1"}, false); + source->add(link(9, 12), 0.5, {"v1"}, {"h1"}, false); + source->add(link(10, 13), 0.5, {"v1"}, {"h1"}, false); + source->add(link(10, 14), 0.5, {"v1"}, {"h1"}, false); + source->add(link(11, 15), 0.5, {"v1"}, {"h1"}, false); + source->add(link(11, 16), 0.5, {"v1"}, {"h1"}, false); + source->add(link(12, 17), 0.5, {"v1"}, {"h1"}, false); + source->add(link(12, 18), 0.5, {"v1"}, {"h1"}, false); + source->add(link(14, 19), 0.5, {"v1"}, {"h1"}, false); + source->add(link(15, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(17, 9), 0.5, {"v1"}, {"h1"}, false); + source->add(link(18, 12), 0.5, {"v1"}, {"h1"}, false); + source->add(link(19, T), 0.5, {"v1"}, {"h1"}, false); + source->add(link(20, T), 0.5, {"v1"}, {"h1"}, false); + Utils::sleep(3000); // TODO remove this + // clang-format on + source->query_answers_finished(); + QueryAnswer* answer; + unsigned int complete_path = 0; + while (!sink.empty() || !sink.finished()) { + while ((answer = sink.pop()) != NULL) { + LOG_INFO("[" + std::to_string(answer->importance) + "]: " + answer_path_to_string(answer)); + if (check_answer(answer)) { + complete_path++; + } + } + Utils::sleep(500); + } + EXPECT_EQ(complete_path, 7); + EXPECT_TRUE(sink.empty()); + EXPECT_TRUE(sink.finished()); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(new ChainOperatorTestEnvironment()); + return RUN_ALL_TESTS(); +}