Skip to content

StableDiffusion example doesn't currently work without adjustment #1890

Open
@mdvthu

Description

@mdvthu

Issue Type

Documentation Bug

Source

binary

Keras Version

keras_cv 0.9.0

Custom Code

No

OS Platform and Distribution

macOS, Windows, Ubuntu

Python version

3.12.4

GPU model and memory

CPU, Nvidia 3080, and M1 (METAL)

Current Behavior?

Following the example verbatim from Keras documentation pages including https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/ produces shape mismatch errors:

ValueError: Exception encountered when calling DiffusionModelV2.call().

Invalid input shape for input Tensor("data_2:0", shape=(1, 77, 1024), dtype=float32). Expected shape (None, 96, 96, 4), but input has incompatible shape (1, 77, 1024)

Arguments received by DiffusionModelV2.call():
  • inputs={'latent': 'tf.Tensor(shape=(1, 96, 96, 4), dtype=float32)', 'timestep_embedding': 'tf.Tensor(shape=(1, 320), dtype=float32)', 'context': 'tf.Tensor(shape=(1, 77, 1024), dtype=float32)'}
  • training=False
  • mask={'latent': 'None', 'timestep_embedding': 'None', 'context': 'None'}

Standalone code to reproduce the issue or tutorial link

Follow the documentation on https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/


python3 -m venv venv
. ./venv/bin/activate
python3 -m pip install tensorflow keras_cv IPython
python3 -m IPython


```python
import keras_cv
model = keras_cv.models.StableDiffusion(
    img_width=512, img_height=512, jit_compile=False
)
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)


### Relevant log output

```shell
ValueError: Exception encountered when calling DiffusionModelV2.call().

Invalid input shape for input Tensor("data_2:0", shape=(1, 77, 1024), dtype=float32). Expected shape (None, 96, 96, 4), but input has incompatible shape (1, 77, 1024)

Arguments received by DiffusionModelV2.call():
  • inputs={'latent': 'tf.Tensor(shape=(1, 96, 96, 4), dtype=float32)', 'timestep_embedding': 'tf.Tensor(shape=(1, 320), dtype=float32)', 'context': 'tf.Tensor(shape=(1, 77, 1024), dtype=float32)'}
  • training=False
  • mask={'latent': 'None', 'timestep_embedding': 'None', 'context': 'None'}

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions