Skip to content

[Bug]: Running multiple inferences with the same compiled model and input produces different outputs #32045

@Mohamed-Ashraf273

Description

@Mohamed-Ashraf273

OpenVINO Version

openvino version: 2025.3.0

Operating System

Ubuntu 18.04 (LTS)

Device used for inference

CPU

Framework

Keras

Model used

Gemma

Issue description

Running multiple inferences with the same compiled model and input produces different outputs.
The issue has been discovered from this PR: keras-team/keras-hub#2389.

using keras from source: https://github.com/keras-team/keras
modify keras/src/backend/openvino/trainer.py
function: predict_step
to

def predict_step(self, data):
    x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
    ov_compiled_model = self._get_compiled_model(x)
    flatten_x = tree.flatten(x)
    y_pred = ov_compiled_model(flatten_x)
    for i in range(5):
        y_new = ov_compiled_model(flatten_x)
        print("max_diff:", np.max(np.abs(np.array(y_pred.to_tuple()[0] - np.array(y_new.to_tuple()[0])))))
        y_pred = y_new
    # recover structure of the model output
    y_pred = self._unpack_singleton(
        tree.pack_sequence_as(self.struct_outputs, y_pred.to_tuple())
    )
    return y_pred

and keras-hub-nightly
run this test script:

import keras

import tensorflow as tf
import numpy as np

from keras import tree
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
    GemmaCausalLMPreprocessor,
)
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
tokenizer = GemmaTokenizer(
    proto="/home/mohamed-ashraf/Desktop/projects/keras-hub/keras_hub/src/tests/test_data/gemma_test_vocab.spm",
)
preprocessor = GemmaCausalLMPreprocessor(
    tokenizer,
    sequence_length=8,
)
# Test Gemma 2 like config, as it's the more complicated case.
backbone = GemmaBackbone(
    vocabulary_size=preprocessor.tokenizer.vocabulary_size(),
    num_layers=2,
    num_query_heads=4,
    num_key_value_heads=2,
    hidden_dim=8,
    intermediate_dim=16,
    head_dim=2,
    sliding_window_size=3,
    use_sliding_window_attention=True,
    attention_logit_soft_cap=50,
    final_logit_soft_cap=30,
    query_head_dim_normalize=False,
    use_post_ffw_norm=True,
    use_post_attention_norm=True,
)
init_kwargs = {
    "preprocessor": preprocessor,
    "backbone": backbone,
}
train_data = (["the quick brown fox", "the quick brown fox"],)
input_data = preprocessor(*train_data)[0]


task = GemmaCausalLM(**init_kwargs)
preprocessor = task.preprocessor
ds = tf.data.Dataset.from_tensor_slices(train_data).batch(2)
x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data)

# Test: the tree struct output by the
# preprocessor must match what model expects.
preprocessed_data = preprocessor(*train_data)[0]
tree.assert_same_structure(
    preprocessed_data,
    task._inputs_struct,
    check_types=False,
)

# Test predict.
output = task.predict(x)
output_shape = tree.map_structure(lambda x: x.shape, output)
if not np.allclose(output_shape, (2, 8, 11)):
    raise AssertionError("Output shape does not match expected output shape.")

output_ds = task.predict(ds)
if not np.allclose(output, output_ds):
    max_diff = np.max(np.abs(np.array(output) - np.array(output_ds)))
    raise AssertionError(f"Output and output_ds are not close. Max absolute difference: {max_diff}")

it gives:

max_diff: 2.2334716
max_diff: 2.666542
max_diff: 2.1285663
max_diff: 3.2441604
max_diff: 1.9508281
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 314ms/step
max_diff: 2.831503
max_diff: 2.5028913
max_diff: 2.7453814
max_diff: 3.031059
max_diff: 2.1415253
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 81ms/step
Traceback (most recent call last):
  File "/home/mohamed-ashraf/Desktop/projects/test.py", line 68, in <module>
    raise AssertionError(f"Output and output_ds are not close. Max absolute difference: {max_diff}")
AssertionError: Output and output_ds are not close. Max absolute difference: 2.4950499534606934

This means that for the same input, the inference with the same compiled model gives a different output.

Step-by-step reproduction

import openvino as ov
import numpy as np

core = ov.Core()
ov_model = core.read_model("model.xml")
ov_compiled_model = core.compile_model(ov_model, "CPU")

arr = np.random.rand(2, 8).astype(np.float32)

result1 = ov_compiled_model([arr])[0]
result2 = ov_compiled_model([arr])[0]

are_consistent = np.allclose(result1, result2)
max_difference = np.max(np.abs(result1 - result2))

print(f"Inference results are consistent: {are_consistent}")
print(f"Maximum difference between results: {max_difference}")
print(f"Result 1 shape: {result1.shape}, Result 2 shape: {result2.shape}")

output:

Inference results are consistent: False
Maximum difference between results: 3.2739830017089844
Result 1 shape: (2, 8, 11), Result 2 shape: (2, 8, 11)

IR link: https://drive.google.com/drive/folders/16iGXbcUVe5FikiWh2YlGbuNTgUASyMoE?usp=sharing

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions