Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,19 +2742,30 @@ def round(x, decimals=0):

def tile(x, repeats):
x = convert_to_tensor(x)
repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
repeats_size = tf.size(repeats)
repeats = tf.pad(
repeats,
[[tf.maximum(x.shape.rank - repeats_size, 0), 0]],
constant_values=1,
)
x_shape = tf.pad(
tf.shape(x),
[[tf.maximum(repeats_size - x.shape.rank, 0), 0]],
constant_values=1,
)
x = tf.reshape(x, x_shape)

# Convert repeats to a list (works for both sequences and 1D tensors)
repeats = [v for v in repeats]

# Process list elements: convert concrete scalar tensors to Python ints
processed_repeats = []
for r in repeats:
if hasattr(r, "numpy") and r.shape == ():
processed_repeats.append(int(r.numpy()))
else:
processed_repeats.append(r)
repeats = processed_repeats

# Get x rank
x_rank = x.shape.rank

# Pad repeats if needed
if len(repeats) < x_rank:
repeats = [1] * (x_rank - len(repeats)) + repeats

# Add dimensions to x if needed using tf.expand_dims
while len(repeats) > x.shape.rank:
x = tf.expand_dims(x, 0)

return tf.tile(x, repeats)


Expand Down
9 changes: 6 additions & 3 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6411,17 +6411,20 @@ def compute_output_spec(self, x):
repeats = self.repeats
if isinstance(repeats, int):
repeats = [repeats]
else:
repeats = list(repeats)

if len(x_shape) > len(repeats):
repeats = [1] * (len(x_shape) - len(repeats)) + repeats
else:
x_shape = [1] * (len(repeats) - len(x_shape)) + x_shape

output_shape = []
for x_size, repeat in zip(x_shape, repeats):
if x_size is None:
output_shape.append(None)
else:
if isinstance(repeat, int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was incorrect in the previous review, this should be if isinstance(x_size, int):

output_shape.append(x_size * repeat)
else:
output_shape.append(None)
return KerasTensor(output_shape, dtype=x.dtype)


Expand Down
17 changes: 17 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,10 @@ def test_tile(self):
self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6))
self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6))

# Test with multi-dimensional input
x = KerasTensor((None, 3, 2, 2))
self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2))

def test_trace(self):
x = KerasTensor((None, 3, None, 5))
self.assertEqual(knp.trace(x).shape, (None, 5))
Expand Down Expand Up @@ -9507,3 +9511,16 @@ def call(self, x):
model.compile(jit_compile=jit_compile)

model.predict(np.random.randn(1, 8))


class TileTest(testing.TestCase):
def test_tile_shape_inference_in_layer(self):
class TileLayer(keras.layers.Layer):
def call(self, x):
repeats = [1, 2, 1, 1]
return knp.tile(x, repeats)

inputs = keras.Input(shape=(3, 2, 2))
output = TileLayer()(inputs)

self.assertEqual(output.shape, (None, 6, 2, 2))
Loading