Skip to content

Commit 98615bc

Browse files
Add test
1 parent 49261da commit 98615bc

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

src/nncf/data/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class Dataset:
3838
will be passed into the model as-is.
3939
"""
4040

41+
# RESET_STATE_KEY is a special input key used by OpenVINO backend to control resetting of internal model state
42+
# between model inferences. This key can be added to a dataset sample input dictionary with either
43+
# `True` or `False` value. With `True` value, the model state will be reset before inference on the corresponding
44+
# sample, and with `False` the state will not be reset.
4145
RESET_STATE_KEY = "reset_state"
4246

4347
def __init__(self, data_source: Iterable[Any], transform_func: Optional[Callable[..., Any]] = None):

src/nncf/openvino/engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,14 @@ def infer(
4545
:return output_data: Model's output.
4646
"""
4747
if isinstance(input_data, dict) and nncf.Dataset.RESET_STATE_KEY in input_data:
48+
# In this case state resetting is controlled by the input data flag
4849
input_data = input_data.copy()
4950
if input_data.pop(nncf.Dataset.RESET_STATE_KEY):
50-
self.infer_request.reset_state()
51+
if self.reset_state:
52+
self.infer_request.reset_state()
53+
else:
54+
msg = "Cannot reset state of a stateless model."
55+
raise RuntimeError(msg)
5156
elif self.reset_state:
5257
self.infer_request.reset_state()
5358

tests/openvino/native/test_engine.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
from functools import wraps
1112

1213
import numpy as np
1314
import pytest
1415

16+
import nncf
1517
from nncf.openvino.engine import OVNativeEngine
1618
from tests.openvino.native.models import ConvModel
1719
from tests.openvino.native.models import LinearModel
@@ -77,3 +79,47 @@ def test_compiled_model_engine_inference_stateful(stateful):
7779
out = out["Result"]
7880

7981
assert np.array_equal(out[0], input_data[0])
82+
83+
84+
def test_stateful_model_inference_with_controlled_resetting():
85+
def wrap_reset_state(infer_request):
86+
nonlocal reset_order
87+
original_reset_state = infer_request.reset_state
88+
89+
@wraps(infer_request.reset_state)
90+
def _reset_state():
91+
reset_order.append("reset")
92+
original_reset_state()
93+
94+
infer_request.reset_state = _reset_state
95+
96+
model = StatefulModel(True).ov_model
97+
inp = model.get_parameters()[0]
98+
input_data = [{"input_data": np.ones(inp.shape), nncf.Dataset.RESET_STATE_KEY: False} for _ in range(10)]
99+
reset_ind = [2, 5, 7]
100+
for ind in reset_ind:
101+
input_data[ind][nncf.Dataset.RESET_STATE_KEY] = True
102+
103+
engine = OVNativeEngine(model)
104+
reset_order = []
105+
wrap_reset_state(engine.engine.infer_request)
106+
107+
for inp_data in input_data:
108+
engine.infer(inp_data)
109+
reset_order.append("infer")
110+
111+
assert reset_order == [
112+
"infer",
113+
"infer",
114+
"reset",
115+
"infer",
116+
"infer",
117+
"infer",
118+
"reset",
119+
"infer",
120+
"infer",
121+
"reset",
122+
"infer",
123+
"infer",
124+
"infer",
125+
]

0 commit comments

Comments
 (0)