Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 8bab36f

Browse files
authored
Merge descriptor::TensorView into descriptor::Tensor (#1536)
* Merge descriptor::TensorView into descriptor::Tensot * fix GPU build
1 parent 62e470b commit 8bab36f

16 files changed

+92
-234
lines changed

src/ngraph/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ set (SRC
2929
descriptor/layout/tensor_view_layout.cpp
3030
descriptor/output.cpp
3131
descriptor/tensor.cpp
32-
descriptor/tensor_view.cpp
3332
file_util.cpp
3433
function.cpp
3534
log.cpp

src/ngraph/descriptor/layout/dense_tensor_view_layout.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace ngraph
2525
{
2626
namespace descriptor
2727
{
28-
class TensorView;
28+
class Tensor;
2929

3030
namespace layout
3131
{
@@ -36,7 +36,7 @@ namespace ngraph
3636
{
3737
public:
3838
~DenseTensorViewLayout() override {}
39-
DenseTensorViewLayout(const TensorView& tensor_view);
39+
DenseTensorViewLayout(const Tensor& tensor);
4040

4141
virtual size_t get_size() override { return m_size; }
4242
size_t get_offset() const { return m_offset; }

src/ngraph/descriptor/layout/tensor_view_layout.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
//*****************************************************************************
1616

1717
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
18-
#include "ngraph/descriptor/tensor_view.hpp"
18+
#include "ngraph/descriptor/tensor.hpp"
1919
#include "ngraph/type/element_type.hpp"
2020

2121
using namespace ngraph;

src/ngraph/descriptor/layout/tensor_view_layout.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <memory>
2020
#include <vector>
2121

22-
#include "ngraph/descriptor/tensor_view.hpp"
22+
#include "ngraph/descriptor/tensor.hpp"
2323

2424
namespace ngraph
2525
{

src/ngraph/descriptor/output.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
using namespace std;
2222
using namespace ngraph;
2323

24-
descriptor::Output::Output(Node* node, size_t index, const shared_ptr<TensorView>& tensor_view)
24+
descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& tensor)
2525
: m_node(node)
2626
, m_index(index)
27-
, m_tensor_view(tensor_view)
27+
, m_tensor(tensor)
2828
{
2929
}
3030

@@ -46,15 +46,15 @@ shared_ptr<Node> descriptor::Output::get_node() const
4646

4747
descriptor::Tensor& descriptor::Output::get_tensor() const
4848
{
49-
return m_tensor_view->get_tensor();
49+
return *m_tensor;
5050
}
5151

5252
const Shape& descriptor::Output::get_shape() const
5353
{
54-
return m_tensor_view->get_shape();
54+
return m_tensor->get_shape();
5555
}
5656

5757
const element::Type& descriptor::Output::get_element_type() const
5858
{
59-
return m_tensor_view->get_element_type();
59+
return m_tensor->get_element_type();
6060
}

src/ngraph/descriptor/output.hpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include <set>
2121

2222
#include "ngraph/descriptor/input.hpp"
23-
#include "ngraph/descriptor/tensor_view.hpp"
23+
#include "ngraph/descriptor/tensor.hpp"
2424

2525
namespace ngraph
2626
{
@@ -39,16 +39,13 @@ namespace ngraph
3939
public:
4040
/// \param node Node that owns this output.
4141
/// \param index Position of the output tensor in all output tensors
42-
/// \param tensor_view The view of this tensor; where the value will be written
43-
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
42+
/// \param tensor The view of this tensor; where the value will be written
43+
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
4444

4545
std::shared_ptr<Node> get_node() const;
4646
size_t get_index() const { return m_index; }
47-
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
48-
void set_tensor_view(const std::shared_ptr<TensorView>& tensor_view)
49-
{
50-
m_tensor_view = tensor_view;
51-
}
47+
std::shared_ptr<Tensor> get_tensor_view() const { return m_tensor; }
48+
void set_tensor_view(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
5249
void add_input(Input* input);
5350
void remove_input(Input* input);
5451
const std::set<Input*>& get_inputs() const { return m_inputs; }
@@ -62,7 +59,7 @@ namespace ngraph
6259
protected:
6360
Node* m_node;
6461
size_t m_index;
65-
std::shared_ptr<TensorView> m_tensor_view;
62+
std::shared_ptr<Tensor> m_tensor;
6663
std::set<Input*> m_inputs;
6764

6865
private:

src/ngraph/descriptor/tensor.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,22 @@ using namespace ngraph;
2222
using namespace std;
2323

2424
descriptor::Tensor::Tensor(const element::Type& element_type,
25-
TensorView* tensor_view,
26-
const string& name)
25+
const Shape& shape,
26+
const std::string& name)
2727
: m_element_type(element_type)
28-
, m_tensor_view(tensor_view)
29-
, m_name{name}
30-
, m_next_view_id{0}
28+
, m_shape(shape)
29+
, m_name(name)
3130
{
3231
}
3332

34-
string descriptor::Tensor::make_tensor_name(const Node* node, size_t value_index)
33+
void descriptor::Tensor::set_tensor_view_type(const element::Type& element_type, const Shape& shape)
3534
{
36-
return node->get_name() + "_" + to_string(value_index);
37-
}
38-
39-
string descriptor::Tensor::get_next_view_name()
40-
{
41-
return m_name + "_TV" + to_string(m_next_view_id++);
35+
m_shape = shape;
36+
m_element_type = element_type;
37+
if (nullptr != m_tensor_view_layout)
38+
{
39+
m_tensor_view_layout->set_tensor_view_type(element_type, shape);
40+
}
4241
}
4342

4443
void descriptor::Tensor::set_pool_offset(size_t offset)
@@ -53,21 +52,16 @@ size_t descriptor::Tensor::get_pool_offset() const
5352

5453
size_t descriptor::Tensor::size() const
5554
{
56-
if (auto tvl = m_tensor_view->get_tensor_view_layout())
55+
if (auto tvl = get_tensor_view_layout())
5756
{
5857
return tvl->get_allocated_size();
5958
}
6059
else
6160
{
62-
return shape_size(m_tensor_view->get_shape()) * m_element_type.size();
61+
return shape_size(get_shape()) * m_element_type.size();
6362
}
6463
}
6564

66-
void descriptor::Tensor::set_element_type(const element::Type& element_type)
67-
{
68-
m_element_type = element_type;
69-
}
70-
7165
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
7266
{
7367
out << "Tensor(" << tensor.get_name() << ")";

src/ngraph/descriptor/tensor.hpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,55 +16,69 @@
1616

1717
#pragma once
1818

19-
#include <iostream>
2019
#include <memory>
21-
#include <vector>
20+
#include <string>
2221

22+
#include "ngraph/descriptor/tensor.hpp"
2323
#include "ngraph/shape.hpp"
2424
#include "ngraph/type/element_type.hpp"
2525

2626
namespace ngraph
2727
{
2828
class Node;
2929

30-
namespace element
31-
{
32-
class Type;
33-
}
34-
3530
namespace descriptor
3631
{
37-
class TensorView;
38-
class Tensor;
39-
}
40-
}
32+
namespace layout
33+
{
34+
class TensorViewLayout;
35+
}
4136

42-
class ngraph::descriptor::Tensor
43-
{
44-
friend class TensorView;
37+
/// \brief Compile-time descriptor of a first-class value that is a view of a tensor.
38+
class Tensor
39+
{
40+
Tensor(const Tensor&) = delete;
41+
Tensor& operator=(const Tensor&) = delete;
4542

46-
private:
47-
Tensor(const Tensor&) = delete;
48-
Tensor& operator=(const Tensor&) = delete;
43+
public:
44+
Tensor(const element::Type& element_type, const Shape& shape, const std::string& name);
4945

50-
Tensor(const element::Type& element_type, TensorView* tensor_view, const std::string& name);
51-
std::string get_next_view_name();
46+
const std::string& get_name() const { return m_name; }
47+
void set_tensor_view_type(const element::Type& element_type, const Shape& shape);
5248

53-
public:
54-
const std::string& get_name() const { return m_name; }
55-
void set_pool_offset(size_t);
56-
size_t size() const;
57-
size_t get_pool_offset() const;
58-
const element::Type& get_element_type() const { return m_element_type; }
59-
void set_element_type(const element::Type& element_type);
60-
static std::string make_tensor_name(const Node* node, size_t value_index);
49+
const element::Type& get_element_type() const { return m_element_type; }
50+
const Shape& get_shape() const { return m_shape; }
51+
const std::shared_ptr<layout::TensorViewLayout>& get_tensor_view_layout() const
52+
{
53+
return m_tensor_view_layout;
54+
}
6155

62-
protected:
63-
element::Type m_element_type;
64-
TensorView* m_tensor_view;
65-
std::string m_name;
66-
size_t m_next_view_id;
67-
size_t m_pool_offset;
68-
};
56+
void set_tensor_view_layout(
57+
const std::shared_ptr<layout::TensorViewLayout>& tensor_view_layout)
58+
{
59+
m_tensor_view_layout = tensor_view_layout;
60+
}
6961

70-
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
62+
void set_pool_offset(size_t);
63+
size_t get_pool_offset() const;
64+
65+
size_t size() const;
66+
67+
const Tensor& get_tensor() const { return *this; }
68+
Tensor& get_tensor() { return *this; }
69+
const Tensor& get_tensor_view() const { return *this; }
70+
Tensor& get_tensor_view() { return *this; }
71+
protected:
72+
element::Type m_element_type;
73+
Shape m_shape;
74+
std::string m_name;
75+
std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout;
76+
size_t m_pool_offset;
77+
};
78+
79+
using TensorView = Tensor;
80+
81+
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
82+
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
83+
}
84+
}

src/ngraph/descriptor/tensor_view.cpp

Lines changed: 0 additions & 65 deletions
This file was deleted.

0 commit comments

Comments
 (0)