diff --git a/src/nncf/data/dataset.py b/src/nncf/data/dataset.py index 9e607fd1fbb..24d7bcf5d31 100644 --- a/src/nncf/data/dataset.py +++ b/src/nncf/data/dataset.py @@ -38,6 +38,12 @@ class Dataset: will be passed into the model as-is. """ + # RESET_STATE_KEY is a special input key used by OpenVINO backend to control resetting of internal model state + # between model inferences. This key can be added to a dataset sample input dictionary with either + # `True` or `False` value. With `True` value, the model state will be reset before inference on the corresponding + # sample, and with `False` the state will not be reset. + RESET_STATE_KEY = "nncf_reset_state" + def __init__(self, data_source: Iterable[Any], transform_func: Optional[Callable[..., Any]] = None): self._data_source = data_source self._transform_func = transform_func diff --git a/src/nncf/openvino/engine.py b/src/nncf/openvino/engine.py index e1c5e04a4aa..a60df924785 100644 --- a/src/nncf/openvino/engine.py +++ b/src/nncf/openvino/engine.py @@ -16,6 +16,7 @@ from openvino import Type from openvino.properties.hint import inference_precision +import nncf from nncf.common.engine import Engine from nncf.openvino.graph.model_utils import model_has_state @@ -43,7 +44,16 @@ def infer( :param input_data: Inputs for the model. :return output_data: Model's output. """ - if self.reset_state: + if isinstance(input_data, dict) and nncf.Dataset.RESET_STATE_KEY in input_data: + # In this case state resetting is controlled by the input data flag + input_data = input_data.copy() + if input_data.pop(nncf.Dataset.RESET_STATE_KEY): + if self.reset_state: + self.infer_request.reset_state() + else: + msg = "Cannot reset state of a stateless model." + raise RuntimeError(msg) + elif self.reset_state: self.infer_request.reset_state() model_outputs = self.infer_request.infer(input_data, share_inputs=True) diff --git a/tests/openvino/native/test_engine.py b/tests/openvino/native/test_engine.py index b58608c5b63..df3f4d50516 100644 --- a/tests/openvino/native/test_engine.py +++ b/tests/openvino/native/test_engine.py @@ -8,10 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import wraps import numpy as np import pytest +import nncf from nncf.openvino.engine import OVNativeEngine from tests.openvino.native.models import ConvModel from tests.openvino.native.models import LinearModel @@ -77,3 +79,47 @@ def test_compiled_model_engine_inference_stateful(stateful): out = out["Result"] assert np.array_equal(out[0], input_data[0]) + + +def test_stateful_model_inference_with_controlled_resetting(): + def wrap_reset_state(infer_request): + nonlocal reset_order + original_reset_state = infer_request.reset_state + + @wraps(infer_request.reset_state) + def _reset_state(): + reset_order.append("reset") + original_reset_state() + + infer_request.reset_state = _reset_state + + model = StatefulModel(True).ov_model + inp = model.get_parameters()[0] + input_data = [{"input_data": np.ones(inp.shape), nncf.Dataset.RESET_STATE_KEY: False} for _ in range(10)] + reset_ind = [2, 5, 7] + for ind in reset_ind: + input_data[ind][nncf.Dataset.RESET_STATE_KEY] = True + + engine = OVNativeEngine(model) + reset_order = [] + wrap_reset_state(engine.engine.infer_request) + + for inp_data in input_data: + engine.infer(inp_data) + reset_order.append("infer") + + assert reset_order == [ + "infer", + "infer", + "reset", + "infer", + "infer", + "infer", + "reset", + "infer", + "infer", + "reset", + "infer", + "infer", + "infer", + ]