Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 1f35037

Browse files
tsochapostrational
authored andcommitted
[ONNX] Enable deselected supported opset domain when needed. (#2350)
1 parent 676f8d3 commit 1f35037

File tree

11 files changed

+164
-35
lines changed

11 files changed

+164
-35
lines changed

src/ngraph/frontend/onnx_import/core/attribute.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace ngraph
2222
{
2323
namespace onnx_import
2424
{
25-
std::vector<Graph> Attribute::get_graph_array(const Model& model) const
25+
std::vector<Graph> Attribute::get_graph_array(Model& model) const
2626
{
2727
std::vector<Graph> result;
2828
for (const auto& graph : m_attribute_proto->graphs())
@@ -32,7 +32,7 @@ namespace ngraph
3232
return result;
3333
}
3434

35-
Graph Attribute::get_graph(const Model& model) const
35+
Graph Attribute::get_graph(Model& model) const
3636
{
3737
return Graph{m_attribute_proto->g(), model};
3838
}

src/ngraph/frontend/onnx_import/core/attribute.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ namespace ngraph
278278
float get_float() const { return m_attribute_proto->f(); }
279279
int64_t get_integer() const { return m_attribute_proto->i(); }
280280
const std::string& get_string() const { return m_attribute_proto->s(); }
281-
Graph get_graph(const Model&) const;
281+
Graph get_graph(Model&) const;
282282

283283
std::vector<Tensor> get_tensor_array() const
284284
{
@@ -303,7 +303,7 @@ namespace ngraph
303303
std::end(m_attribute_proto->strings())};
304304
}
305305

306-
std::vector<Graph> get_graph_array(const Model&) const;
306+
std::vector<Graph> get_graph_array(Model&) const;
307307

308308
/* explicit */ operator onnx::AttributeProto_AttributeType() const
309309
{

src/ngraph/frontend/onnx_import/core/graph.cpp

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
// limitations under the License.
1515
//*****************************************************************************
1616

17+
#include <functional>
1718
#include <set>
1819

1920
#include "graph.hpp"
@@ -25,26 +26,40 @@ namespace ngraph
2526
{
2627
namespace detail
2728
{
28-
std::string to_string(const std::set<std::string>& set)
29+
static std::string to_string(
30+
const std::map<std::string, std::reference_wrapper<const onnx::NodeProto>>& map)
2931
{
3032
std::string result;
31-
for (auto it = std::begin(set); it != std::end(set); ++it)
33+
for (auto it = std::begin(map); it != std::end(map); ++it)
3234
{
33-
result += (it != std::begin(set) ? ", " : "") + *it;
35+
result += (it != std::begin(map) ? ", " : "") + it->first;
3436
}
3537
return result;
3638
}
3739

38-
inline std::string to_string(const onnx::NodeProto& node_proto)
40+
static std::string get_node_domain(const onnx::NodeProto& node_proto)
3941
{
40-
return (node_proto.domain().empty() ? "" : node_proto.domain() + ".") +
41-
node_proto.op_type();
42+
return (node_proto.domain().empty() ? "" : node_proto.domain());
4243
}
43-
}
4444

45-
Graph::Graph(const onnx::GraphProto& graph_proto,
46-
const Model& model,
47-
const Weights& weights)
45+
/// \brief Gets the operator represented by provided node unique identificator.
46+
///
47+
/// \param[in] node_proto The node protobuf representation object.
48+
///
49+
/// \note The operator is uniquely identified by the tuple (domain, op_type,
50+
/// since_version). The first two elements are stored in NodeProto object,
51+
/// thus we use only them.
52+
///
53+
/// \return The unique identificator.
54+
///
55+
static std::string get_op_domain_and_name(const onnx::NodeProto& node_proto)
56+
{
57+
std::string domain = get_node_domain(node_proto);
58+
return (domain.empty() ? "" : domain + ".") + node_proto.op_type();
59+
}
60+
} // namespace detail
61+
62+
Graph::Graph(const onnx::GraphProto& graph_proto, Model& model, const Weights& weights)
4863
: m_graph_proto{&graph_proto}
4964
, m_model{&model}
5065
{
@@ -70,17 +85,34 @@ namespace ngraph
7085
}
7186

7287
// Verify that ONNX graph contains only nodes of available operator types
73-
std::set<std::string> unknown_operator_types;
88+
std::map<std::string, std::reference_wrapper<const onnx::NodeProto>> unknown_operators;
7489
for (const auto& node_proto : m_graph_proto->node())
7590
{
7691
if (!m_model->is_operator_available(node_proto))
7792
{
78-
unknown_operator_types.emplace(detail::to_string(node_proto));
93+
unknown_operators.emplace(detail::get_op_domain_and_name(node_proto),
94+
node_proto);
95+
// Try adding missing domain
96+
m_model->enable_opset_domain(detail::get_node_domain(node_proto));
97+
}
98+
}
99+
100+
// Reverify wheter we still have any unavailable operators.
101+
auto it = std::begin(unknown_operators);
102+
while (it != std::end(unknown_operators))
103+
{
104+
if (m_model->is_operator_available(it->second))
105+
{
106+
it = unknown_operators.erase(it);
107+
}
108+
else
109+
{
110+
it++;
79111
}
80112
}
81113

82-
NGRAPH_ASSERT(unknown_operator_types.empty())
83-
<< "unknown operations: " << detail::to_string(unknown_operator_types);
114+
NGRAPH_ASSERT(unknown_operators.empty()) << "unknown operations: "
115+
<< detail::to_string(unknown_operators);
84116

85117
// Process ONNX graph nodes, convert to nGraph nodes
86118
for (const auto& node_proto : m_graph_proto->node())

src/ngraph/frontend/onnx_import/core/graph.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace ngraph
3333
class Graph
3434
{
3535
public:
36-
Graph(const onnx::GraphProto& proto, const Model& model, const Weights& weights = {});
36+
Graph(const onnx::GraphProto& proto, Model& model, const Weights& weights = {});
3737

3838
const std::vector<Node>& get_nodes() const { return m_nodes; }
3939
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
@@ -59,7 +59,7 @@ namespace ngraph
5959
ParameterVector m_parameters;
6060
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
6161
std::map<std::string, Tensor> m_initializers;
62-
const Model* m_model;
62+
Model* m_model;
6363
};
6464

6565
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)

src/ngraph/frontend/onnx_import/core/model.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <onnx-ml.pb.h>
1818

1919
#include "model.hpp"
20+
#include "ngraph/log.hpp"
2021
#include "ops_bridge.hpp"
2122

2223
namespace ngraph
@@ -33,14 +34,14 @@ namespace ngraph
3334
{
3435
m_opset.emplace(id.domain(),
3536
OperatorsBridge::get_operator_set(
36-
id.version(), (id.domain() == "ai.onnx" ? "" : id.domain())));
37+
(id.domain() == "ai.onnx" ? "" : id.domain()), id.version()));
3738
}
3839
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
3940
// implies the operator set that is defined as part of the ONNX specification.
4041
const auto dm = m_opset.find("");
4142
if (dm == std::end(m_opset))
4243
{
43-
m_opset.emplace("", OperatorsBridge::get_operator_set(ONNX_OPSET_VERSION, ""));
44+
m_opset.emplace("", OperatorsBridge::get_operator_set("", ONNX_OPSET_VERSION));
4445
}
4546
}
4647

@@ -71,6 +72,26 @@ namespace ngraph
7172
return (op != std::end(dm->second));
7273
}
7374

75+
void Model::enable_opset_domain(const std::string& domain)
76+
{
77+
// There is no need to 'update' already enabled domain.
78+
// Since this function may be called only during model import,
79+
// (maybe multiple times) the registered domain opset won't differ
80+
// between subsequent calls.
81+
if (m_opset.find(domain) == std::end(m_opset))
82+
{
83+
OperatorSet opset{OperatorsBridge::get_operator_set(domain)};
84+
if (opset.empty())
85+
{
86+
NGRAPH_WARN << "Couldn't enable domain: " << domain
87+
<< " since it hasn't any registered operators.";
88+
89+
return;
90+
}
91+
m_opset.emplace(domain, opset);
92+
}
93+
}
94+
7495
} // namespace onnx_import
7596

7697
} // namespace ngraph

src/ngraph/frontend/onnx_import/core/model.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ namespace ngraph
6161
/// \return `true` if the operator is available, otherwise it returns `false`.
6262
bool is_operator_available(const onnx::NodeProto& node_proto) const;
6363

64+
/// \brief Enable operators from provided domain to use by this model.
65+
///
66+
/// \note This function makes visible all currently registered in provided domain
67+
/// operators for use in this model.
68+
///
69+
/// \param[in] domain The domain name.
70+
///
71+
void enable_opset_domain(const std::string& domain);
72+
6473
private:
6574
const onnx::ModelProto* m_model_proto;
6675
std::unordered_map<std::string, OperatorSet> m_opset;

src/ngraph/frontend/onnx_import/onnx.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ namespace ngraph
9090
std::set<std::string> get_supported_operators(std::int64_t version,
9191
const std::string& domain)
9292
{
93-
OperatorSet op_set{OperatorsBridge::get_operator_set(version, domain)};
93+
OperatorSet op_set{
94+
OperatorsBridge::get_operator_set(domain == "ai.onnx" ? "" : domain, version)};
9495
std::set<std::string> op_list{};
9596
for (const auto& op : op_set)
9697
{

src/ngraph/frontend/onnx_import/ops_bridge.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ namespace ngraph
110110
find(std::int64_t version, const std::map<std::int64_t, Operator>& map)
111111
{
112112
std::map<std::int64_t, Operator>::const_iterator it{};
113+
// Get the latest version.
114+
if (version == -1)
115+
{
116+
return map.empty() ? std::end(map) : --std::end(map);
117+
}
113118
while (version > 0)
114119
{
115120
it = map.find(version--);
@@ -127,23 +132,29 @@ namespace ngraph
127132
const std::string& domain,
128133
Operator fn)
129134
{
130-
m_map[domain][name].emplace(version, std::move(fn));
135+
auto result = m_map[domain][name].emplace(version, std::move(fn));
136+
if (result.second)
137+
{
138+
NGRAPH_WARN << "Overwriting existing operator: "
139+
<< domain + "." + name + ":" + std::to_string(version);
140+
}
131141
}
132142

133-
OperatorSet OperatorsBridge::_get_operator_set(std::int64_t version,
134-
const std::string& domain)
143+
OperatorSet OperatorsBridge::_get_operator_set(const std::string& domain,
144+
std::int64_t version)
135145
{
136146
OperatorSet result;
147+
137148
auto dm = m_map.find(domain);
138149
if (dm == std::end(m_map))
139150
{
140151
throw error::UnknownDomain{domain};
141152
}
142-
if (version > OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION)
153+
if (domain == "" && version > OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION)
143154
{
144-
NGRAPH_WARN << "Currently operator set version: " << version << " is unsupported."
145-
<< " Falling back to: "
146-
<< OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION;
155+
NGRAPH_WARN << "Currently ONNX operator set version: " << version
156+
<< " is unsupported. Falling back to: "
157+
<< OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION;
147158
}
148159
for (const auto& op : dm->second)
149160
{

src/ngraph/frontend/onnx_import/ops_bridge.hpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,17 @@ namespace ngraph
6262
class OperatorsBridge
6363
{
6464
public:
65-
static constexpr const int LATEST_SUPPORTED_OPSET_VERSION = ONNX_OPSET_VERSION;
65+
static constexpr const int LATEST_SUPPORTED_ONNX_OPSET_VERSION = ONNX_OPSET_VERSION;
6666

6767
OperatorsBridge(const OperatorsBridge&) = delete;
6868
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
6969
OperatorsBridge(OperatorsBridge&&) = delete;
7070
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
7171

72-
static OperatorSet get_operator_set(std::int64_t version, const std::string& domain)
72+
static OperatorSet get_operator_set(const std::string& domain,
73+
std::int64_t version = -1)
7374
{
74-
return instance()._get_operator_set(version, domain);
75+
return instance()._get_operator_set(domain, version);
7576
}
7677

7778
static void register_operator(const std::string& name,
@@ -90,6 +91,20 @@ namespace ngraph
9091
}
9192

9293
private:
94+
// Registered operators structure
95+
// {
96+
// domain_1: {
97+
// op_type_1: {
98+
// version_1: {func_handle},
99+
// version_2: {func_handle},
100+
// ...
101+
// },
102+
// op_type_2: { ... }
103+
// ...
104+
// },
105+
// domain_2: { ... },
106+
// ...
107+
// }
93108
std::unordered_map<std::string,
94109
std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
95110
m_map;
@@ -106,7 +121,8 @@ namespace ngraph
106121
std::int64_t version,
107122
const std::string& domain,
108123
Operator fn);
109-
OperatorSet _get_operator_set(std::int64_t version, const std::string& domain);
124+
OperatorSet _get_operator_set(const std::string& domain, std::int64_t version);
125+
110126
bool _is_operator_registered(const std::string& name,
111127
std::int64_t version,
112128
const std::string& domain);

test/models/onnx/missing_op_domain.onnx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
ONNXnGraphImporter:o
2+

3+
A
4+
BC" CustomAdd: custom.opcompute_graphZ
5+
A
6+

7+

8+
Z
9+
B
10+

11+

12+
b
13+
C
14+

15+

16+
B

0 commit comments

Comments
 (0)