Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class Dataset:
will be passed into the model as-is.
"""

# RESET_STATE_KEY is a special input key used by OpenVINO backend to control resetting of internal model state
# between model inferences. This key can be added to a dataset sample input dictionary with either
# `True` or `False` value. With `True` value, the model state will be reset before inference on the corresponding
# sample, and with `False` the state will not be reset.
RESET_STATE_KEY = "nncf_reset_state"

def __init__(self, data_source: Iterable[Any], transform_func: Optional[Callable[..., Any]] = None):
self._data_source = data_source
self._transform_func = transform_func
Expand Down
12 changes: 11 additions & 1 deletion src/nncf/openvino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from openvino import Type
from openvino.properties.hint import inference_precision

import nncf
from nncf.common.engine import Engine
from nncf.openvino.graph.model_utils import model_has_state

Expand Down Expand Up @@ -43,7 +44,16 @@ def infer(
:param input_data: Inputs for the model.
:return output_data: Model's output.
"""
if self.reset_state:
if isinstance(input_data, dict) and nncf.Dataset.RESET_STATE_KEY in input_data:
# In this case state resetting is controlled by the input data flag
input_data = input_data.copy()
if input_data.pop(nncf.Dataset.RESET_STATE_KEY):
if self.reset_state:
self.infer_request.reset_state()
else:
msg = "Cannot reset state of a stateless model."
raise RuntimeError(msg)
elif self.reset_state:
self.infer_request.reset_state()

model_outputs = self.infer_request.infer(input_data, share_inputs=True)
Expand Down
46 changes: 46 additions & 0 deletions tests/openvino/native/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps

import numpy as np
import pytest

import nncf
from nncf.openvino.engine import OVNativeEngine
from tests.openvino.native.models import ConvModel
from tests.openvino.native.models import LinearModel
Expand Down Expand Up @@ -77,3 +79,47 @@ def test_compiled_model_engine_inference_stateful(stateful):
out = out["Result"]

assert np.array_equal(out[0], input_data[0])


def test_stateful_model_inference_with_controlled_resetting():
def wrap_reset_state(infer_request):
nonlocal reset_order
original_reset_state = infer_request.reset_state

@wraps(infer_request.reset_state)
def _reset_state():
reset_order.append("reset")
original_reset_state()

infer_request.reset_state = _reset_state

model = StatefulModel(True).ov_model
inp = model.get_parameters()[0]
input_data = [{"input_data": np.ones(inp.shape), nncf.Dataset.RESET_STATE_KEY: False} for _ in range(10)]
reset_ind = [2, 5, 7]
for ind in reset_ind:
input_data[ind][nncf.Dataset.RESET_STATE_KEY] = True

engine = OVNativeEngine(model)
reset_order = []
wrap_reset_state(engine.engine.infer_request)

for inp_data in input_data:
engine.infer(inp_data)
reset_order.append("infer")

assert reset_order == [
"infer",
"infer",
"reset",
"infer",
"infer",
"infer",
"reset",
"infer",
"infer",
"reset",
"infer",
"infer",
"infer",
]