Skip to content

fix(generate_image): updating generate_image tool to support additional models in Amazon Bedrock #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 26, 2025
Merged
210 changes: 125 additions & 85 deletions src/strands_tools/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
Key Features:

1. Image Generation:
• Text-to-image conversion using Stable Diffusion
• Support for multiple model variants (primarily stable-diffusion-xl-v1)
• Customizable generation parameters (seed, steps, cfg_scale)
• Style preset selection for consistent aesthetics
• Text-to-image conversion using Stable Diffusion models
• Support for the following models:
• stability.sd3-5-large-v1:0
• stability.stable-image-core-v1:1
• stability.stable-image-ultra-v1:1
• Customizable generation parameters (seed, aspect_ratio, output_format, negative_prompt)

2. Output Management:
• Automatic local saving with intelligent filename generation
Expand All @@ -36,14 +38,22 @@
# Basic usage with default parameters
agent.tool.generate_image(prompt="A steampunk robot playing chess")

# Advanced usage with custom parameters
# Advanced usage with Stable Diffusion
agent.tool.generate_image(
prompt="A futuristic city with flying cars",
model_id="stability.stable-diffusion-xl-v1",
seed=42,
steps=50,
cfg_scale=12,
style_preset="cinematic"
model_id="stability.sd3-5-large-v1:0",
aspect_ratio="5:4",
output_format="jpeg",
negative_prompt="bad lighting, harsh lighting, abstract, surreal, twisted, multiple levels",
)

# Using another Stable Diffusion model
agent.tool.generate_image(
prompt="A photograph of a cup of coffee from the side",
model_id="stability.stable-image-ultra-v1:1",
aspect_ratio="1:1",
output_format="png",
negative_prompt="blurry, distorted",
)
```

Expand All @@ -60,9 +70,16 @@
import boto3
from strands.types.tools import ToolResult, ToolUse

STABLE_DIFFUSION_MODEL_ID = [
"stability.sd3-5-large-v1:0",
"stability.stable-image-core-v1:1",
"stability.stable-image-ultra-v1:1",
]


TOOL_SPEC = {
"name": "generate_image",
"description": "Generates an image using Stable Diffusion based on a given prompt",
"description": "Generates an image using Stable Diffusion models based on a given prompt",
"inputSchema": {
"json": {
"type": "object",
Expand All @@ -73,23 +90,32 @@
},
"model_id": {
"type": "string",
"description": "Model id for image model, stability.stable-diffusion-xl-v1.",
"description": "Model id for image model, stability.sd3-5-large-v1:0, \
stability.stable-image-core-v1:1, or stability.stable-image-ultra-v1:1",
},
"region": {
"type": "string",
"description": "AWS region for the image generation model (default: us-west-2)",
},
"seed": {
"type": "integer",
"description": "Optional: Seed for random number generation (default: random)",
},
"steps": {
"type": "integer",
"description": "Optional: Number of steps for image generation (default: 30)",
"aspect_ratio": {
"type": "string",
"description": "Optional: Controls the aspect ratio of the generated image for \
Stable Diffusion models. Default 1:1. Enum: 16:9, 1:1, 21:9, 2:3, 3:2, 4:5, 5:4, 9:16, 9:21",
},
"cfg_scale": {
"type": "number",
"description": "Optional: CFG scale for image generation (default: 10)",
"output_format": {
"type": "string",
"description": "Optional: Specifies the format of the output image for Stable Diffusion models. \
Supported formats: JPEG, PNG.",
},
"style_preset": {
"negative_prompt": {
"type": "string",
"description": "Optional: Style preset for image generation (default: 'photographic')",
"description": "Optional: Keywords of what you do not wish to see in the output image. \
Default: bad lighting, harsh lighting. \
Max: 10.000 characters.",
},
},
"required": ["prompt"],
Expand All @@ -98,19 +124,28 @@
}


# Create a filename based on the prompt
def create_filename(prompt: str) -> str:
"""Generate a filename from the prompt text."""
words = re.findall(r"\w+", prompt.lower())[:5]
filename = "_".join(words)
filename = re.sub(r"[^\w\-_\.]", "_", filename)
return filename[:100] # Limit filename length


def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
"""
Generate images from text prompts using Stable Diffusion via Amazon Bedrock.
Generate images from text prompts using Stable Diffusion models via Amazon Bedrock.

This function transforms textual descriptions into high-quality images using
Stable Diffusion models available through Amazon Bedrock. It provides extensive
image generation models available through Amazon Bedrock. It provides extensive
customization options and handles the complete process from API interaction to
image storage and result formatting.

How It Works:
------------
1. Extracts and validates parameters from the tool input
2. Configures the request payload with appropriate parameters
2. Configures the request payload with appropriate parameters based on model type
3. Invokes the Bedrock image generation model through AWS SDK
4. Processes the response to extract the base64-encoded image
5. Creates an appropriate filename based on the prompt content
Expand All @@ -120,11 +155,13 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
Generation Parameters:
--------------------
- prompt: The textual description of the desired image
- model_id: Specific model to use (defaults to stable-diffusion-xl-v1)
- model_id: Specific model to use (defaults to stability.stable-image-core-v1:1)
- seed: Controls randomness for reproducible results
- style_preset: Artistic style to apply (e.g., photographic, cinematic)
- cfg_scale: Controls how closely the image follows the prompt
- steps: Number of diffusion steps (higher = more refined but slower)
- aspect_ratio: Controls the aspect ratio of the generated image
- output_format: Specifies the format of the output image (e.g., png or jpeg)
- negative_prompt: Keywords of what you do not wish to see in the output image



Common Usage Scenarios:
---------------------
Expand All @@ -137,11 +174,8 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
Args:
tool: ToolUse object containing the parameters for image generation.
- prompt: The text prompt describing the desired image.
- model_id: Optional model identifier (default: "stability.stable-diffusion-xl-v1").
- seed: Optional random seed (default: random integer).
- style_preset: Optional style preset name (default: "photographic").
- cfg_scale: Optional CFG scale value (default: 10).
- steps: Optional number of diffusion steps (default: 30).
- model_id: Optional model identifier.
- Additional parameters specific to the chosen model type.
**kwargs: Additional keyword arguments (unused).

Returns:
Expand All @@ -161,78 +195,84 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
tool_use_id = tool["toolUseId"]
tool_input = tool["input"]

# Extract input parameters
# Extract common and Stable Diffusion input parameters
aspect_ratio = tool_input.get("aspect_ratio", "1:1")
output_format = tool_input.get("output_format", "jpeg")
prompt = tool_input.get("prompt", "A stylized picture of a cute old steampunk robot.")
model_id = tool_input.get("model_id", "stability.stable-diffusion-xl-v1")
model_id = tool_input.get("model_id", "stability.stable-image-core-v1:1")
region = tool_input.get("region", "us-west-2")
seed = tool_input.get("seed", random.randint(0, 4294967295))
style_preset = tool_input.get("style_preset", "photographic")
cfg_scale = tool_input.get("cfg_scale", 10)
steps = tool_input.get("steps", 30)
negative_prompt = tool_input.get("negative_prompt", "bad lighting, harsh lighting")

# Create a Bedrock Runtime client
client = boto3.client("bedrock-runtime", region_name="us-west-2")
client = boto3.client("bedrock-runtime", region_name=region)

# Initialize variables for later use
base64_image_data = None

# Format the request payload
# create the request body
native_request = {
"text_prompts": [{"text": prompt}],
"style_preset": style_preset,
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"seed": seed,
"cfg_scale": cfg_scale,
"steps": steps,
"output_format": output_format,
"negative_prompt": negative_prompt,
}
request = json.dumps(native_request)

# Invoke the model
response = client.invoke_model(modelId=model_id, body=request)

# Decode the response body
model_response = json.loads(response["body"].read())
model_response = json.loads(response["body"].read().decode("utf-8"))

# Extract the image data
base64_image_data = model_response["artifacts"][0]["base64"]

# Create a filename based on the prompt
def create_filename(prompt: str) -> str:
"""Generate a filename from the prompt text."""
words = re.findall(r"\w+", prompt.lower())[:5]
filename = "_".join(words)
filename = re.sub(r"[^\w\-_\.]", "_", filename)
return filename[:100] # Limit filename length

filename = create_filename(prompt)

# Save the generated image to a local folder
output_dir = "output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)

i = 1
base_image_path = os.path.join(output_dir, f"{filename}.png")
image_path = base_image_path
while os.path.exists(image_path):
image_path = os.path.join(output_dir, f"{filename}_{i}.png")
i += 1

with open(image_path, "wb") as file:
file.write(base64.b64decode(base64_image_data))

base64_image_data = model_response["images"][0]

# If we have image data, process and save it
if base64_image_data:
filename = create_filename(prompt)

# Save the generated image to a local folder
output_dir = "output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)

i = 1
base_image_path = os.path.join(output_dir, f"{filename}.png")
image_path = base_image_path
while os.path.exists(image_path):
image_path = os.path.join(output_dir, f"{filename}_{i}.png")
i += 1

with open(image_path, "wb") as file:
file.write(base64.b64decode(base64_image_data))

return {
"toolUseId": tool_use_id,
"status": "success",
"content": [
{"text": f"The generated image has been saved locally to {image_path}. "},
{
"image": {
"format": output_format,
"source": {"bytes": base64.b64decode(base64_image_data)},
}
},
],
}
else:
raise Exception("No image data found in the response.")
except Exception as e:
return {
"toolUseId": tool_use_id,
"status": "success",
"status": "error",
"content": [
{"text": f"The generated image has been saved locally to {image_path}. "},
{
"image": {
"format": "png",
"source": {"bytes": base64.b64decode(base64_image_data)},
}
},
"text": f"Error generating image: {str(e)} \n Try other supported models for this tool are: \n \
1. stability.sd3-5-large-v1:0 \n \
2. stability.stable-image-core-v1:1 \n \
3. stability.stable-image-ultra-v1:1"
}
],
}

except Exception as e:
return {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error generating image: {str(e)}"}],
}
2 changes: 1 addition & 1 deletion src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
"description": "Optional metadata to store with the memory",
},
},
"required": ["action"]
"required": ["action"],
}
},
}
Expand Down
22 changes: 11 additions & 11 deletions tests/test_generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def mock_boto3_client():
# Set up mock response
mock_body = MagicMock()
mock_body.read.return_value = json.dumps(
{"artifacts": [{"base64": base64.b64encode(b"mock_image_data").decode("utf-8")}]}
{"images": [base64.b64encode(b"mock_image_data").decode("utf-8")]}
).encode("utf-8")

mock_client_instance = MagicMock()
Expand Down Expand Up @@ -76,9 +76,9 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m
"input": {
"prompt": "A cute robot",
"seed": 123,
"steps": 30,
"cfg_scale": 10,
"style_preset": "photographic",
"aspect_ratio": "5:4",
"output_format": "png",
"negative_prompt": "blurry, low resolution, pixelated, grainy, unrealistic",
},
}

Expand All @@ -94,11 +94,11 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m
args, kwargs = mock_client_instance.invoke_model.call_args
request_body = json.loads(kwargs["body"])

assert request_body["text_prompts"][0]["text"] == "A cute robot"
assert request_body["prompt"] == "A cute robot"
assert request_body["seed"] == 123
assert request_body["steps"] == 30
assert request_body["cfg_scale"] == 10
assert request_body["style_preset"] == "photographic"
assert request_body["aspect_ratio"] == "5:4"
assert request_body["output_format"] == "png"
assert request_body["negative_prompt"] == "blurry, low resolution, pixelated, grainy, unrealistic"

# Verify directory creation
mock_os_makedirs.assert_called_once()
Expand Down Expand Up @@ -128,9 +128,9 @@ def test_generate_image_default_params(mock_boto3_client, mock_os_path_exists, m
request_body = json.loads(kwargs["body"])

assert request_body["seed"] == 42 # From our mocked random.randint
assert request_body["steps"] == 30
assert request_body["cfg_scale"] == 10
assert request_body["style_preset"] == "photographic"
assert request_body["aspect_ratio"] == "1:1"
assert request_body["output_format"] == "jpeg"
assert request_body["negative_prompt"] == "bad lighting, harsh lighting"

assert result["status"] == "success"

Expand Down
Loading