Skip to content

Commit 5f7518b

Browse files
Introduce DatasetWithSeqId
1 parent 98615bc commit 5f7518b

File tree

6 files changed

+54
-20
lines changed

6 files changed

+54
-20
lines changed

src/nncf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nncf.common.strip import strip as strip
1919
from nncf.config import NNCFConfig as NNCFConfig
2020
from nncf.data import Dataset as Dataset
21+
from nncf.data import DatasetWithSeqId as DatasetWithSeqId
2122
from nncf.errors import BufferFullError as BufferFullError
2223
from nncf.errors import InstallationError as InstallationError
2324
from nncf.errors import InternalError as InternalError

src/nncf/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
# limitations under the License.
1111

1212
from nncf.data.dataset import Dataset as Dataset
13+
from nncf.data.dataset import DatasetWithSeqId as DatasetWithSeqId
1314
from nncf.data.generators import generate_text_data as generate_text_data

src/nncf/data/dataset.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,6 @@ class Dataset:
3838
will be passed into the model as-is.
3939
"""
4040

41-
# RESET_STATE_KEY is a special input key used by OpenVINO backend to control resetting of internal model state
42-
# between model inferences. This key can be added to a dataset sample input dictionary with either
43-
# `True` or `False` value. With `True` value, the model state will be reset before inference on the corresponding
44-
# sample, and with `False` the state will not be reset.
45-
RESET_STATE_KEY = "reset_state"
46-
4741
def __init__(self, data_source: Iterable[Any], transform_func: Optional[Callable[..., Any]] = None):
4842
self._data_source = data_source
4943
self._transform_func = transform_func
@@ -94,6 +88,41 @@ def get_batch_size(self) -> Optional[int]:
9488
return None
9589

9690

91+
@api(canonical_alias="nncf.DatasetWithSeqId")
92+
class DatasetWithSeqId(Dataset):
93+
def __init__(
94+
self,
95+
data_source: Iterable[Any],
96+
sequence_ids: Iterable[int],
97+
transform_func: Optional[Callable[..., Any]] = None,
98+
):
99+
"""
100+
A dataset wrapper that associates each data item with a sequence ID. Sequence IDs are used to reset state
101+
of stateful models between different sequences during inference.
102+
:param data_source: The iterable object serving as the source of data items.
103+
:param sequence_ids: The iterable of sequence IDs corresponding to each data item in the data source. Must
104+
have the same length as the data source and contain only integers.
105+
:param transform_func: The function that is used to extract the model's input
106+
from the data item. The data item here is the data item that is returned from
107+
the data source per iteration. This function should be passed when
108+
the data item cannot be directly used as model's input. If this is not specified, then the data item
109+
will be passed into the model as-is.
110+
"""
111+
if any(not isinstance(it, int) for it in sequence_ids):
112+
msg = "All sequence IDs must be integers."
113+
raise ValueError(msg)
114+
if transform_func is not None:
115+
transform_func = lambda x, seq_id: (transform_func(x), seq_id)
116+
super().__init__(data_source, transform_func)
117+
self._sequence_ids = sequence_ids
118+
119+
def get_data(self, indices: Optional[list[int]] = None) -> Iterable[Any]:
120+
return DataProvider(zip(self._data_source, self._sequence_ids), None, indices)
121+
122+
def get_inference_data(self, indices: Optional[list[int]] = None) -> Iterable[Any]:
123+
return DataProvider(zip(self._data_source, self._sequence_ids), self._transform_func, indices)
124+
125+
97126
class DataProvider:
98127
def __init__(
99128
self,

src/nncf/openvino/engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from openvino import Type
1717
from openvino.properties.hint import inference_precision
1818

19-
import nncf
2019
from nncf.common.engine import Engine
2120
from nncf.openvino.graph.model_utils import model_has_state
2221

@@ -35,7 +34,11 @@ def __init__(self, compiled_model: ov.CompiledModel, stateful: bool):
3534
self.reset_state = stateful and hasattr(self.infer_request, "reset_state")
3635

3736
def infer(
38-
self, input_data: Union[np.ndarray, list[np.ndarray], tuple[np.ndarray], dict[str, np.ndarray]]
37+
self,
38+
input_data: Union[
39+
Union[np.ndarray, list[np.ndarray], tuple[np.ndarray], dict[str, np.ndarray]],
40+
tuple[Union[np.ndarray, list[np.ndarray], tuple[np.ndarray], dict[str, np.ndarray]], int],
41+
],
3942
) -> dict[str, np.ndarray]:
4043
"""
4144
Runs model on the provided input via OpenVINO Runtime.
@@ -44,10 +47,11 @@ def infer(
4447
:param input_data: Inputs for the model.
4548
:return output_data: Model's output.
4649
"""
47-
if isinstance(input_data, dict) and nncf.Dataset.RESET_STATE_KEY in input_data:
48-
# In this case state resetting is controlled by the input data flag
49-
input_data = input_data.copy()
50-
if input_data.pop(nncf.Dataset.RESET_STATE_KEY):
50+
if isinstance(input_data, tuple) and len(input_data) == 2 and isinstance(input_data[1], int):
51+
# A dataset with sequence ids is provided
52+
input_data, seq_id = input_data
53+
if seq_id == 0:
54+
# Reset state only at the beginning of a new sequence
5155
if self.reset_state:
5256
self.infer_request.reset_state()
5357
else:

src/nncf/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
__version__ = "3.0.0"
12+
__version__ = "3.0.0.dev0+37b9d3650dirty"
1313

1414

1515
BKC_TORCH_SPEC = "==2.9.*"

tests/openvino/native/test_engine.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,31 +95,30 @@ def _reset_state():
9595

9696
model = StatefulModel(True).ov_model
9797
inp = model.get_parameters()[0]
98-
input_data = [{"input_data": np.ones(inp.shape), nncf.Dataset.RESET_STATE_KEY: False} for _ in range(10)]
99-
reset_ind = [2, 5, 7]
100-
for ind in reset_ind:
101-
input_data[ind][nncf.Dataset.RESET_STATE_KEY] = True
98+
input_data = [{"input_data": np.ones(inp.shape)} for _ in range(10)]
99+
# sequence lengths are [2, 4, 4]
100+
dataset = nncf.DatasetWithSeqId(input_data, [0, 1, 0, 1, 2, 3, 0, 1, 2, 3])
102101

103102
engine = OVNativeEngine(model)
104103
reset_order = []
105104
wrap_reset_state(engine.engine.infer_request)
106105

107-
for inp_data in input_data:
106+
for inp_data in dataset.get_inference_data():
108107
engine.infer(inp_data)
109108
reset_order.append("infer")
110109

111110
assert reset_order == [
111+
"reset",
112112
"infer",
113113
"infer",
114114
"reset",
115115
"infer",
116116
"infer",
117117
"infer",
118-
"reset",
119-
"infer",
120118
"infer",
121119
"reset",
122120
"infer",
123121
"infer",
124122
"infer",
123+
"infer",
125124
]

0 commit comments

Comments
 (0)