Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 2e6d6a6

Browse files
ashokeidiyessi
authored andcommitted
add serialize_types and deserialize_types API (#3795)
* add serialize_attrs and deserialize_attrs API * add doc comments * update function signature
1 parent 42a3a0a commit 2e6d6a6

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

src/ngraph/serializer.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,33 @@ std::string ngraph::serialize(std::shared_ptr<ngraph::Function> func, size_t ind
458458
return ::serialize(func, indent, false);
459459
}
460460

461+
std::string
462+
ngraph::serialize_types(const std::vector<std::pair<PartialShape, element::Type>>& types)
463+
{
464+
json attrs = json::array();
465+
for (const auto& n : types)
466+
{
467+
json j;
468+
j["shape"] = write_partial_shape(n.first);
469+
j["type"] = write_element_type(n.second);
470+
attrs.push_back(j);
471+
}
472+
return attrs.dump();
473+
}
474+
475+
std::vector<std::pair<PartialShape, element::Type>>
476+
ngraph::deserialize_types(const std::string& str)
477+
{
478+
std::vector<std::pair<PartialShape, element::Type>> outs;
479+
json js = json::parse(str);
480+
for (auto& j : js)
481+
{
482+
auto s = read_partial_shape(j["shape"]);
483+
auto t = read_element_type(j["type"]);
484+
outs.emplace_back(s, t);
485+
}
486+
return outs;
487+
}
461488
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
462489
{
463490
shared_ptr<Function> rc;

src/ngraph/serializer.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ namespace ngraph
3030
/// indent level specified.
3131
std::string serialize(std::shared_ptr<ngraph::Function> func, size_t indent = 0);
3232

33+
/// \brief Serialize given vector of shapes/types
34+
/// \param types The vector of shape/types to serialize
35+
std::string serialize_types(const std::vector<std::pair<PartialShape, element::Type>>& types);
36+
/// \brief Deerialize a string into vector of shapes/types
37+
/// \param str The serialized string to deseriailze
38+
std::vector<std::pair<PartialShape, element::Type>> deserialize_types(const std::string& str);
39+
3340
/// \brief Serialize a Function to a json file
3441
/// \param path The path to the output file
3542
/// \param func The Function to serialize

test/serialize.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,42 @@ TEST(serialize, main)
9696
EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
9797
}
9898

99+
TEST(serialize, main_attrs)
100+
{
101+
// First create "f(A,B,C) = (A+B)*C".
102+
Shape shape{2, 2};
103+
auto A = make_shared<op::Parameter>(element::f32, shape);
104+
auto B = make_shared<op::Parameter>(element::f32, shape);
105+
auto C = make_shared<op::Parameter>(element::f32, shape);
106+
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C}, "f");
107+
108+
std::vector<std::pair<PartialShape, element::Type>> types;
109+
110+
auto results = f->get_results();
111+
for (auto& n : results)
112+
types.emplace_back(n->get_output_partial_shape(0), n->output(0).get_element_type());
113+
auto s_types = serialize_types(types);
114+
for (auto& attrs : deserialize_types(s_types))
115+
{
116+
EXPECT_EQ(size_t(attrs.first.rank()), shape.size());
117+
EXPECT_EQ(size_t(attrs.first[0]), 2);
118+
EXPECT_EQ(size_t(attrs.first[1]), 2);
119+
EXPECT_EQ(element::f32, attrs.second);
120+
}
121+
auto params = f->get_parameters();
122+
types.clear();
123+
for (auto& n : params)
124+
types.emplace_back(n->get_output_partial_shape(0), n->output(0).get_element_type());
125+
s_types = serialize_types(types);
126+
for (auto& attrs : deserialize_types(s_types))
127+
{
128+
EXPECT_EQ(size_t(attrs.first.rank()), shape.size());
129+
EXPECT_EQ(size_t(attrs.first[0]), 2);
130+
EXPECT_EQ(size_t(attrs.first[1]), 2);
131+
EXPECT_EQ(element::f32, attrs.second);
132+
}
133+
}
134+
99135
TEST(serialize, friendly_name)
100136
{
101137
// First create "f(A,B,C) = (A+B)*C".

0 commit comments

Comments
 (0)