diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index c35be3d35bc..2af5c4613cf 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1359,11 +1359,7 @@ def deg2rad(x): def diag(x, k=0): x = convert_to_tensor(x) if len(x.shape) == 1: - return tf.cond( - tf.equal(tf.size(x), 0), - lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype), - lambda: tf.linalg.diag(x, k=k), - ) + return tf.linalg.diag(x, k=k) elif len(x.shape) == 2: return diagonal(x, offset=k) else: