Skip to content

Commit ef900c4

Browse files
authored
Merge pull request #470 from kikumoto/feature/update_aws_bedrock_claude_implementation
Feature/update aws bedrock claude implementation
2 parents 4197791 + 488c43e commit ef900c4

File tree

1 file changed

+151
-61
lines changed

1 file changed

+151
-61
lines changed

examples/pipelines/providers/aws_bedrock_claude_pipeline.py

Lines changed: 151 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import json
1313
import logging
1414
from io import BytesIO
15-
from typing import List, Union, Generator, Iterator
15+
from typing import List, Union, Generator, Iterator, Optional, Any
1616

1717
import boto3
1818

@@ -23,12 +23,23 @@
2323

2424
from utils.pipelines.main import pop_system_message
2525

26+
REASONING_EFFORT_BUDGET_TOKEN_MAP = {
27+
"none": None,
28+
"low": 1024,
29+
"medium": 4096,
30+
"high": 16384,
31+
"max": 32768,
32+
}
33+
34+
# Maximum combined token limit for Claude 3.7
35+
MAX_COMBINED_TOKENS = 64000
36+
2637

2738
class Pipeline:
2839
class Valves(BaseModel):
29-
AWS_ACCESS_KEY: str = ""
30-
AWS_SECRET_KEY: str = ""
31-
AWS_REGION_NAME: str = ""
40+
AWS_ACCESS_KEY: Optional[str] = None
41+
AWS_SECRET_KEY: Optional[str] = None
42+
AWS_REGION_NAME: Optional[str] = None
3243

3344
def __init__(self):
3445
self.type = "manifold"
@@ -47,21 +58,25 @@ def __init__(self):
4758
}
4859
)
4960

50-
self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
51-
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
52-
service_name="bedrock",
53-
region_name=self.valves.AWS_REGION_NAME)
54-
self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
55-
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
56-
service_name="bedrock-runtime",
57-
region_name=self.valves.AWS_REGION_NAME)
61+
self.valves = self.Valves(
62+
**{
63+
"AWS_ACCESS_KEY": os.getenv("AWS_ACCESS_KEY", ""),
64+
"AWS_SECRET_KEY": os.getenv("AWS_SECRET_KEY", ""),
65+
"AWS_REGION_NAME": os.getenv(
66+
"AWS_REGION_NAME", os.getenv(
67+
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "")
68+
)
69+
),
70+
}
71+
)
5872

59-
self.pipelines = self.get_models()
73+
self.update_pipelines()
6074

6175

6276
async def on_startup(self):
6377
# This function is called when the server is started.
6478
print(f"on_startup:{__name__}")
79+
self.update_pipelines()
6580
pass
6681

6782
async def on_shutdown(self):
@@ -72,40 +87,58 @@ async def on_shutdown(self):
7287
async def on_valves_updated(self):
7388
# This function is called when the valves are updated.
7489
print(f"on_valves_updated:{__name__}")
75-
self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
76-
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
77-
service_name="bedrock",
78-
region_name=self.valves.AWS_REGION_NAME)
79-
self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
80-
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
81-
service_name="bedrock-runtime",
82-
region_name=self.valves.AWS_REGION_NAME)
83-
self.pipelines = self.get_models()
84-
85-
def pipelines(self) -> List[dict]:
86-
return self.get_models()
90+
self.update_pipelines()
91+
92+
def update_pipelines(self) -> None:
93+
try:
94+
self.bedrock = boto3.client(service_name="bedrock",
95+
aws_access_key_id=self.valves.AWS_ACCESS_KEY,
96+
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
97+
region_name=self.valves.AWS_REGION_NAME)
98+
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime",
99+
aws_access_key_id=self.valves.AWS_ACCESS_KEY,
100+
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
101+
region_name=self.valves.AWS_REGION_NAME)
102+
self.pipelines = self.get_models()
103+
except Exception as e:
104+
print(f"Error: {e}")
105+
self.pipelines = [
106+
{
107+
"id": "error",
108+
"name": "Could not fetch models from Bedrock, please set up AWS Key/Secret or Instance/Task Role.",
109+
},
110+
]
87111

88112
def get_models(self):
89-
if self.valves.AWS_ACCESS_KEY and self.valves.AWS_SECRET_KEY:
90-
try:
91-
response = self.bedrock.list_foundation_models(byProvider='Anthropic', byInferenceType='ON_DEMAND')
92-
return [
93-
{
94-
"id": model["modelId"],
95-
"name": model["modelName"],
96-
}
97-
for model in response["modelSummaries"]
98-
]
99-
except Exception as e:
100-
print(f"Error: {e}")
101-
return [
102-
{
103-
"id": "error",
104-
"name": "Could not fetch models from Bedrock, please update the Access/Secret Key in the valves.",
105-
},
106-
]
107-
else:
108-
return []
113+
try:
114+
res = []
115+
response = self.bedrock.list_foundation_models(byProvider='Anthropic')
116+
for model in response['modelSummaries']:
117+
inference_types = model.get('inferenceTypesSupported', [])
118+
if "ON_DEMAND" in inference_types:
119+
res.append({'id': model['modelId'], 'name': model['modelName']})
120+
elif "INFERENCE_PROFILE" in inference_types:
121+
inferenceProfileId = self.getInferenceProfileId(model['modelArn'])
122+
if inferenceProfileId:
123+
res.append({'id': inferenceProfileId, 'name': model['modelName']})
124+
125+
return res
126+
except Exception as e:
127+
print(f"Error: {e}")
128+
return [
129+
{
130+
"id": "error",
131+
"name": "Could not fetch models from Bedrock, please check permissoin.",
132+
},
133+
]
134+
135+
def getInferenceProfileId(self, modelArn: str) -> str:
136+
response = self.bedrock.list_inference_profiles()
137+
for profile in response.get('inferenceProfileSummaries', []):
138+
for model in profile.get('models', []):
139+
if model.get('modelArn') == modelArn:
140+
return profile['inferenceProfileId']
141+
return None
109142

110143
def pipe(
111144
self, user_message: str, model_id: str, messages: List[dict], body: dict
@@ -139,11 +172,53 @@ def pipe(
139172

140173
payload = {"modelId": model_id,
141174
"messages": processed_messages,
142-
"system": [{'text': system_message if system_message else 'you are an intelligent ai assistant'}],
143-
"inferenceConfig": {"temperature": body.get("temperature", 0.5)},
144-
"additionalModelRequestFields": {"top_k": body.get("top_k", 200), "top_p": body.get("top_p", 0.9)}
175+
"system": [{'text': system_message["content"] if system_message else 'you are an intelligent ai assistant'}],
176+
"inferenceConfig": {
177+
"temperature": body.get("temperature", 0.5),
178+
"topP": body.get("top_p", 0.9),
179+
"maxTokens": body.get("max_tokens", 4096),
180+
"stopSequences": body.get("stop", []),
181+
},
182+
"additionalModelRequestFields": {"top_k": body.get("top_k", 200)}
145183
}
184+
146185
if body.get("stream", False):
186+
supports_thinking = "claude-3-7" in model_id
187+
reasoning_effort = body.get("reasoning_effort", "none")
188+
budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)
189+
190+
# Allow users to input an integer value representing budget tokens
191+
if (
192+
not budget_tokens
193+
and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
194+
):
195+
try:
196+
budget_tokens = int(reasoning_effort)
197+
except ValueError as e:
198+
print("Failed to convert reasoning effort to int", e)
199+
budget_tokens = None
200+
201+
if supports_thinking and budget_tokens:
202+
# Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
203+
max_tokens = payload.get("max_tokens", 4096)
204+
combined_tokens = budget_tokens + max_tokens
205+
206+
if combined_tokens > MAX_COMBINED_TOKENS:
207+
error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
208+
print(error_message)
209+
return error_message
210+
211+
payload["inferenceConfig"]["maxTokens"] = combined_tokens
212+
payload["additionalModelRequestFields"]["thinking"] = {
213+
"type": "enabled",
214+
"budget_tokens": budget_tokens,
215+
}
216+
# Thinking requires temperature 1.0 and does not support top_p, top_k
217+
payload["inferenceConfig"]["temperature"] = 1.0
218+
if "top_k" in payload["additionalModelRequestFields"]:
219+
del payload["additionalModelRequestFields"]["top_k"]
220+
if "topP" in payload["inferenceConfig"]:
221+
del payload["inferenceConfig"]["topP"]
147222
return self.stream_response(model_id, payload)
148223
else:
149224
return self.get_completion(model_id, payload)
@@ -152,30 +227,45 @@ def pipe(
152227

153228
def process_image(self, image: str):
154229
img_stream = None
230+
content_type = None
231+
155232
if image["url"].startswith("data:image"):
156-
if ',' in image["url"]:
157-
base64_string = image["url"].split(',')[1]
233+
mime_type, base64_string = image["url"].split(",", 1)
234+
content_type = mime_type.split(":")[1].split(";")[0]
158235
image_data = base64.b64decode(base64_string)
159-
160236
img_stream = BytesIO(image_data)
161237
else:
162-
img_stream = requests.get(image["url"]).content
238+
response = requests.get(image["url"])
239+
img_stream = BytesIO(response.content)
240+
content_type = response.headers.get('Content-Type', 'image/jpeg')
241+
242+
media_type = content_type.split('/')[-1] if '/' in content_type else content_type
163243
return {
164-
"image": {"format": "png" if image["url"].endswith(".png") else "jpeg",
165-
"source": {"bytes": img_stream.read()}}
244+
"image": {
245+
"format": media_type,
246+
"source": {"bytes": img_stream.read()}
247+
}
166248
}
167249

168250
def stream_response(self, model_id: str, payload: dict) -> Generator:
169-
if "system" in payload:
170-
del payload["system"]
171-
if "additionalModelRequestFields" in payload:
172-
del payload["additionalModelRequestFields"]
173251
streaming_response = self.bedrock_runtime.converse_stream(**payload)
252+
253+
in_resasoning_context = False
174254
for chunk in streaming_response["stream"]:
175-
if "contentBlockDelta" in chunk:
176-
yield chunk["contentBlockDelta"]["delta"]["text"]
255+
if in_resasoning_context and "contentBlockStop" in chunk:
256+
in_resasoning_context = False
257+
yield "\n </think> \n\n"
258+
elif "contentBlockDelta" in chunk and "delta" in chunk["contentBlockDelta"]:
259+
if "reasoningContent" in chunk["contentBlockDelta"]["delta"]:
260+
if not in_resasoning_context:
261+
yield "<think>"
262+
263+
in_resasoning_context = True
264+
if "text" in chunk["contentBlockDelta"]["delta"]["reasoningContent"]:
265+
yield chunk["contentBlockDelta"]["delta"]["reasoningContent"]["text"]
266+
elif "text" in chunk["contentBlockDelta"]["delta"]:
267+
yield chunk["contentBlockDelta"]["delta"]["text"]
177268

178269
def get_completion(self, model_id: str, payload: dict) -> str:
179270
response = self.bedrock_runtime.converse(**payload)
180271
return response['output']['message']['content'][0]['text']
181-

0 commit comments

Comments
 (0)