Skip to content

Commit 148c3f1

Browse files
[OpenVINO] Introduce a way to control resetting of the state of stateful OV models (#3714)
### Changes Added `nncf.definitions.NNCF_DATASET_RESET_STATE_KEY` constant to specify when to reset model state. This constant is used by OpenVINO backend to control resetting of internal model state between model inferences. This key can be added to a dataset sample input dictionary with either `True` or `False` value. With `True` value, the model state will be reset before inference on the corresponding sample, and with `False` the state will not be reset. For an example of usage please see huggingface/optimum-intel#1505. ### Reason for changes Without this logic static quantization quality of stateful Whisper models is poor because a state of a stateful model must be cleared with the same schedule as it is done during calibration input data collection. ### Related tickets 172705 ### Tests Added `tests/openvino/native/test_engine.py::test_stateful_model_inference_with_controlled_resetting`.
1 parent 141caed commit 148c3f1

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed

src/nncf/data/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class Dataset:
3030
usually contains both examples and labels. So transformation function should extract
3131
the examples from the data item.
3232
33+
A special input key nncf.definitions.NNCF_DATASET_RESET_STATE_KEY can be used by OpenVINO backend to control
34+
resetting of internal model state between model inferences. This key can be added to a dataset sample input
35+
dictionary with either `True` or `False` value. With `True` value, the model state will be reset before inference
36+
on the corresponding sample, and with `False` the state will not be reset.
37+
3338
:param data_source: The iterable object serving as the source of data items.
3439
:param transform_func: The function that is used to extract the model's input
3540
from the data item. The data item here is the data item that is returned from

src/nncf/definitions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@
2323
# debug dumps, are performed or not performed
2424
NNCF_CI_ENV_VAR_NAME = "NNCF_CI" # Must be set in CI environments
2525
NNCF_DEV_ENV_VAR_NAME = "NNCF_DEV" # Must be set in environments of the NNCF dev team machines
26+
27+
# This is a special input key used by OpenVINO backend to control resetting of internal model state
28+
# between model inferences. This key can be added to a dataset sample input dictionary with either
29+
# `True` or `False` value. With `True` value, the model state will be reset before inference on the corresponding
30+
# sample, and with `False` the state will not be reset.
31+
NNCF_DATASET_RESET_STATE_KEY = "nncf_reset_state"

src/nncf/openvino/engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openvino.properties.hint import inference_precision
1818

1919
from nncf.common.engine import Engine
20+
from nncf.definitions import NNCF_DATASET_RESET_STATE_KEY
2021
from nncf.openvino.graph.model_utils import model_has_state
2122

2223

@@ -43,7 +44,16 @@ def infer(
4344
:param input_data: Inputs for the model.
4445
:return output_data: Model's output.
4546
"""
46-
if self.reset_state:
47+
if isinstance(input_data, dict) and NNCF_DATASET_RESET_STATE_KEY in input_data:
48+
# In this case state resetting is controlled by the input data flag
49+
input_data = input_data.copy()
50+
if input_data.pop(NNCF_DATASET_RESET_STATE_KEY):
51+
if self.reset_state:
52+
self.infer_request.reset_state()
53+
else:
54+
msg = "Cannot reset state of a stateless model."
55+
raise RuntimeError(msg)
56+
elif self.reset_state:
4757
self.infer_request.reset_state()
4858

4959
model_outputs = self.infer_request.infer(input_data, share_inputs=True)

tests/openvino/native/test_engine.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
from functools import wraps
1112

1213
import numpy as np
1314
import pytest
1415

16+
from nncf.definitions import NNCF_DATASET_RESET_STATE_KEY
1517
from nncf.openvino.engine import OVNativeEngine
1618
from tests.openvino.native.models import ConvModel
1719
from tests.openvino.native.models import LinearModel
@@ -77,3 +79,47 @@ def test_compiled_model_engine_inference_stateful(stateful):
7779
out = out["Result"]
7880

7981
assert np.array_equal(out[0], input_data[0])
82+
83+
84+
def test_stateful_model_inference_with_controlled_resetting():
85+
def wrap_reset_state(infer_request):
86+
nonlocal reset_order
87+
original_reset_state = infer_request.reset_state
88+
89+
@wraps(infer_request.reset_state)
90+
def _reset_state():
91+
reset_order.append("reset")
92+
original_reset_state()
93+
94+
infer_request.reset_state = _reset_state
95+
96+
model = StatefulModel(True).ov_model
97+
inp = model.get_parameters()[0]
98+
input_data = [{"input_data": np.ones(inp.shape), NNCF_DATASET_RESET_STATE_KEY: False} for _ in range(10)]
99+
reset_ind = [2, 5, 7]
100+
for ind in reset_ind:
101+
input_data[ind][NNCF_DATASET_RESET_STATE_KEY] = True
102+
103+
engine = OVNativeEngine(model)
104+
reset_order = []
105+
wrap_reset_state(engine.engine.infer_request)
106+
107+
for inp_data in input_data:
108+
engine.infer(inp_data)
109+
reset_order.append("infer")
110+
111+
assert reset_order == [
112+
"infer",
113+
"infer",
114+
"reset",
115+
"infer",
116+
"infer",
117+
"infer",
118+
"reset",
119+
"infer",
120+
"infer",
121+
"reset",
122+
"infer",
123+
"infer",
124+
"infer",
125+
]

0 commit comments

Comments
 (0)