1212import json
1313import logging
1414from io import BytesIO
15- from typing import List , Union , Generator , Iterator
15+ from typing import List , Union , Generator , Iterator , Optional , Any
1616
1717import boto3
1818
2323
2424from utils .pipelines .main import pop_system_message
2525
26+ REASONING_EFFORT_BUDGET_TOKEN_MAP = {
27+ "none" : None ,
28+ "low" : 1024 ,
29+ "medium" : 4096 ,
30+ "high" : 16384 ,
31+ "max" : 32768 ,
32+ }
33+
34+ # Maximum combined token limit for Claude 3.7
35+ MAX_COMBINED_TOKENS = 64000
36+
2637
2738class Pipeline :
2839 class Valves (BaseModel ):
29- AWS_ACCESS_KEY : str = ""
30- AWS_SECRET_KEY : str = ""
31- AWS_REGION_NAME : str = ""
40+ AWS_ACCESS_KEY : Optional [ str ] = None
41+ AWS_SECRET_KEY : Optional [ str ] = None
42+ AWS_REGION_NAME : Optional [ str ] = None
3243
3344 def __init__ (self ):
3445 self .type = "manifold"
@@ -47,21 +58,25 @@ def __init__(self):
4758 }
4859 )
4960
50- self .bedrock = boto3 .client (aws_access_key_id = self .valves .AWS_ACCESS_KEY ,
51- aws_secret_access_key = self .valves .AWS_SECRET_KEY ,
52- service_name = "bedrock" ,
53- region_name = self .valves .AWS_REGION_NAME )
54- self .bedrock_runtime = boto3 .client (aws_access_key_id = self .valves .AWS_ACCESS_KEY ,
55- aws_secret_access_key = self .valves .AWS_SECRET_KEY ,
56- service_name = "bedrock-runtime" ,
57- region_name = self .valves .AWS_REGION_NAME )
61+ self .valves = self .Valves (
62+ ** {
63+ "AWS_ACCESS_KEY" : os .getenv ("AWS_ACCESS_KEY" , "" ),
64+ "AWS_SECRET_KEY" : os .getenv ("AWS_SECRET_KEY" , "" ),
65+ "AWS_REGION_NAME" : os .getenv (
66+ "AWS_REGION_NAME" , os .getenv (
67+ "AWS_REGION" , os .getenv ("AWS_DEFAULT_REGION" , "" )
68+ )
69+ ),
70+ }
71+ )
5872
59- self .pipelines = self . get_models ()
73+ self .update_pipelines ()
6074
6175
6276 async def on_startup (self ):
6377 # This function is called when the server is started.
6478 print (f"on_startup:{ __name__ } " )
79+ self .update_pipelines ()
6580 pass
6681
6782 async def on_shutdown (self ):
@@ -72,40 +87,58 @@ async def on_shutdown(self):
7287 async def on_valves_updated (self ):
7388 # This function is called when the valves are updated.
7489 print (f"on_valves_updated:{ __name__ } " )
75- self .bedrock = boto3 .client (aws_access_key_id = self .valves .AWS_ACCESS_KEY ,
76- aws_secret_access_key = self .valves .AWS_SECRET_KEY ,
77- service_name = "bedrock" ,
78- region_name = self .valves .AWS_REGION_NAME )
79- self .bedrock_runtime = boto3 .client (aws_access_key_id = self .valves .AWS_ACCESS_KEY ,
80- aws_secret_access_key = self .valves .AWS_SECRET_KEY ,
81- service_name = "bedrock-runtime" ,
82- region_name = self .valves .AWS_REGION_NAME )
83- self .pipelines = self .get_models ()
84-
85- def pipelines (self ) -> List [dict ]:
86- return self .get_models ()
90+ self .update_pipelines ()
91+
92+ def update_pipelines (self ) -> None :
93+ try :
94+ self .bedrock = boto3 .client (service_name = "bedrock" ,
95+ aws_access_key_id = self .valves .AWS_ACCESS_KEY ,
96+ aws_secret_access_key = self .valves .AWS_SECRET_KEY ,
97+ region_name = self .valves .AWS_REGION_NAME )
98+ self .bedrock_runtime = boto3 .client (service_name = "bedrock-runtime" ,
99+ aws_access_key_id = self .valves .AWS_ACCESS_KEY ,
100+ aws_secret_access_key = self .valves .AWS_SECRET_KEY ,
101+ region_name = self .valves .AWS_REGION_NAME )
102+ self .pipelines = self .get_models ()
103+ except Exception as e :
104+ print (f"Error: { e } " )
105+ self .pipelines = [
106+ {
107+ "id" : "error" ,
108+ "name" : "Could not fetch models from Bedrock, please set up AWS Key/Secret or Instance/Task Role." ,
109+ },
110+ ]
87111
88112 def get_models (self ):
89- if self .valves .AWS_ACCESS_KEY and self .valves .AWS_SECRET_KEY :
90- try :
91- response = self .bedrock .list_foundation_models (byProvider = 'Anthropic' , byInferenceType = 'ON_DEMAND' )
92- return [
93- {
94- "id" : model ["modelId" ],
95- "name" : model ["modelName" ],
96- }
97- for model in response ["modelSummaries" ]
98- ]
99- except Exception as e :
100- print (f"Error: { e } " )
101- return [
102- {
103- "id" : "error" ,
104- "name" : "Could not fetch models from Bedrock, please update the Access/Secret Key in the valves." ,
105- },
106- ]
107- else :
108- return []
113+ try :
114+ res = []
115+ response = self .bedrock .list_foundation_models (byProvider = 'Anthropic' )
116+ for model in response ['modelSummaries' ]:
117+ inference_types = model .get ('inferenceTypesSupported' , [])
118+ if "ON_DEMAND" in inference_types :
119+ res .append ({'id' : model ['modelId' ], 'name' : model ['modelName' ]})
120+ elif "INFERENCE_PROFILE" in inference_types :
121+ inferenceProfileId = self .getInferenceProfileId (model ['modelArn' ])
122+ if inferenceProfileId :
123+ res .append ({'id' : inferenceProfileId , 'name' : model ['modelName' ]})
124+
125+ return res
126+ except Exception as e :
127+ print (f"Error: { e } " )
128+ return [
129+ {
130+ "id" : "error" ,
131+ "name" : "Could not fetch models from Bedrock, please check permissoin." ,
132+ },
133+ ]
134+
135+ def getInferenceProfileId (self , modelArn : str ) -> str :
136+ response = self .bedrock .list_inference_profiles ()
137+ for profile in response .get ('inferenceProfileSummaries' , []):
138+ for model in profile .get ('models' , []):
139+ if model .get ('modelArn' ) == modelArn :
140+ return profile ['inferenceProfileId' ]
141+ return None
109142
110143 def pipe (
111144 self , user_message : str , model_id : str , messages : List [dict ], body : dict
@@ -139,11 +172,53 @@ def pipe(
139172
140173 payload = {"modelId" : model_id ,
141174 "messages" : processed_messages ,
142- "system" : [{'text' : system_message if system_message else 'you are an intelligent ai assistant' }],
143- "inferenceConfig" : {"temperature" : body .get ("temperature" , 0.5 )},
144- "additionalModelRequestFields" : {"top_k" : body .get ("top_k" , 200 ), "top_p" : body .get ("top_p" , 0.9 )}
175+ "system" : [{'text' : system_message ["content" ] if system_message else 'you are an intelligent ai assistant' }],
176+ "inferenceConfig" : {
177+ "temperature" : body .get ("temperature" , 0.5 ),
178+ "topP" : body .get ("top_p" , 0.9 ),
179+ "maxTokens" : body .get ("max_tokens" , 4096 ),
180+ "stopSequences" : body .get ("stop" , []),
181+ },
182+ "additionalModelRequestFields" : {"top_k" : body .get ("top_k" , 200 )}
145183 }
184+
146185 if body .get ("stream" , False ):
186+ supports_thinking = "claude-3-7" in model_id
187+ reasoning_effort = body .get ("reasoning_effort" , "none" )
188+ budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP .get (reasoning_effort )
189+
190+ # Allow users to input an integer value representing budget tokens
191+ if (
192+ not budget_tokens
193+ and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP .keys ()
194+ ):
195+ try :
196+ budget_tokens = int (reasoning_effort )
197+ except ValueError as e :
198+ print ("Failed to convert reasoning effort to int" , e )
199+ budget_tokens = None
200+
201+ if supports_thinking and budget_tokens :
202+ # Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
203+ max_tokens = payload .get ("max_tokens" , 4096 )
204+ combined_tokens = budget_tokens + max_tokens
205+
206+ if combined_tokens > MAX_COMBINED_TOKENS :
207+ error_message = f"Error: Combined tokens (budget_tokens { budget_tokens } + max_tokens { max_tokens } = { combined_tokens } ) exceeds the maximum limit of { MAX_COMBINED_TOKENS } "
208+ print (error_message )
209+ return error_message
210+
211+ payload ["inferenceConfig" ]["maxTokens" ] = combined_tokens
212+ payload ["additionalModelRequestFields" ]["thinking" ] = {
213+ "type" : "enabled" ,
214+ "budget_tokens" : budget_tokens ,
215+ }
216+ # Thinking requires temperature 1.0 and does not support top_p, top_k
217+ payload ["inferenceConfig" ]["temperature" ] = 1.0
218+ if "top_k" in payload ["additionalModelRequestFields" ]:
219+ del payload ["additionalModelRequestFields" ]["top_k" ]
220+ if "topP" in payload ["inferenceConfig" ]:
221+ del payload ["inferenceConfig" ]["topP" ]
147222 return self .stream_response (model_id , payload )
148223 else :
149224 return self .get_completion (model_id , payload )
@@ -152,30 +227,45 @@ def pipe(
152227
153228 def process_image (self , image : str ):
154229 img_stream = None
230+ content_type = None
231+
155232 if image ["url" ].startswith ("data:image" ):
156- if ',' in image ["url" ]:
157- base64_string = image [ "url" ].split (',' )[ 1 ]
233+ mime_type , base64_string = image ["url" ]. split ( "," , 1 )
234+ content_type = mime_type . split ( ":" )[ 1 ].split (";" )[ 0 ]
158235 image_data = base64 .b64decode (base64_string )
159-
160236 img_stream = BytesIO (image_data )
161237 else :
162- img_stream = requests .get (image ["url" ]).content
238+ response = requests .get (image ["url" ])
239+ img_stream = BytesIO (response .content )
240+ content_type = response .headers .get ('Content-Type' , 'image/jpeg' )
241+
242+ media_type = content_type .split ('/' )[- 1 ] if '/' in content_type else content_type
163243 return {
164- "image" : {"format" : "png" if image ["url" ].endswith (".png" ) else "jpeg" ,
165- "source" : {"bytes" : img_stream .read ()}}
244+ "image" : {
245+ "format" : media_type ,
246+ "source" : {"bytes" : img_stream .read ()}
247+ }
166248 }
167249
168250 def stream_response (self , model_id : str , payload : dict ) -> Generator :
169- if "system" in payload :
170- del payload ["system" ]
171- if "additionalModelRequestFields" in payload :
172- del payload ["additionalModelRequestFields" ]
173251 streaming_response = self .bedrock_runtime .converse_stream (** payload )
252+
253+ in_resasoning_context = False
174254 for chunk in streaming_response ["stream" ]:
175- if "contentBlockDelta" in chunk :
176- yield chunk ["contentBlockDelta" ]["delta" ]["text" ]
255+ if in_resasoning_context and "contentBlockStop" in chunk :
256+ in_resasoning_context = False
257+ yield "\n </think> \n \n "
258+ elif "contentBlockDelta" in chunk and "delta" in chunk ["contentBlockDelta" ]:
259+ if "reasoningContent" in chunk ["contentBlockDelta" ]["delta" ]:
260+ if not in_resasoning_context :
261+ yield "<think>"
262+
263+ in_resasoning_context = True
264+ if "text" in chunk ["contentBlockDelta" ]["delta" ]["reasoningContent" ]:
265+ yield chunk ["contentBlockDelta" ]["delta" ]["reasoningContent" ]["text" ]
266+ elif "text" in chunk ["contentBlockDelta" ]["delta" ]:
267+ yield chunk ["contentBlockDelta" ]["delta" ]["text" ]
177268
178269 def get_completion (self , model_id : str , payload : dict ) -> str :
179270 response = self .bedrock_runtime .converse (** payload )
180271 return response ['output' ]['message' ]['content' ][0 ]['text' ]
181-
0 commit comments