-
Notifications
You must be signed in to change notification settings - Fork 35
feat: Watsonx client #45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
public class IBMAuthBearerToken { | ||
|
||
@JsonProperty("token") | ||
String token; | ||
@JsonProperty("access_token") | ||
String accessToken; | ||
@JsonProperty("expiration") | ||
Integer expiration; | ||
|
||
String getToken() { | ||
return this.token; | ||
} | ||
|
||
public void setToken(String token) { | ||
this.token = token; | ||
} | ||
|
||
String getAccessToken() { | ||
return this.accessToken; | ||
} | ||
|
||
public void setAccessToken(String accessToken) { | ||
this.accessToken = accessToken; | ||
} | ||
|
||
Integer getExpiration() { | ||
return this.expiration; | ||
} | ||
|
||
public void setExpiration(int expiration) { | ||
this.expiration = expiration; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
public class IBMAuthTokenExpiry { | ||
|
||
@JsonProperty("exp") | ||
Integer expiry; | ||
|
||
Integer getExpiry() { | ||
return this.expiry; | ||
} | ||
|
||
public void setExpiry(int expiry) { | ||
this.expiry = expiry; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import java.io.IOException; | ||
import java.util.Base64; | ||
import java.util.Date; | ||
import java.util.LinkedHashMap; | ||
import java.util.Map; | ||
import okhttp3.MediaType; | ||
import okhttp3.OkHttpClient; | ||
import okhttp3.Request; | ||
import okhttp3.RequestBody; | ||
import okhttp3.Response; | ||
|
||
public class WatsonxAuthenticator { | ||
|
||
IBMAuthBearerToken bearerToken; | ||
OkHttpClient client; | ||
Request request; | ||
Request expiryRequest; | ||
Boolean isZenApiKey = false; | ||
|
||
// Watsonx SaaS | ||
public WatsonxAuthenticator(String apiKey) { | ||
this.client = new OkHttpClient().newBuilder() | ||
.build(); | ||
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); | ||
RequestBody body = RequestBody.create(mediaType, | ||
"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + apiKey); | ||
this.request = new Request.Builder() | ||
.url("https://iam.cloud.ibm.com/identity/token") | ||
.method("POST", body) | ||
.addHeader("Content-Type", "application/x-www-form-urlencoded") | ||
.build(); | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't consider creating new bearer tokens on object initiation a good practise, please reconsider using a lazy initialization |
||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), | ||
IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
// Zen API Key | ||
public WatsonxAuthenticator(String username, String zenApiKey) { | ||
IBMAuthBearerToken token = new IBMAuthBearerToken(); | ||
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes()) | ||
.toString(); | ||
token.setAccessToken(tokenStr); | ||
this.bearerToken = token; | ||
this.isZenApiKey = true; | ||
} | ||
|
||
// Watsonx API Key | ||
public WatsonxAuthenticator(String username, String apiKey, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The constructor is too complex. Can we break it down into smaller units for better readability? |
||
String host) { // TODO add support for password | ||
this.client = new OkHttpClient().newBuilder() | ||
.build(); | ||
ObjectMapper mapper = new ObjectMapper(); | ||
Map<String, String> authParams = new LinkedHashMap<>(); | ||
authParams.put("username", username); | ||
authParams.put("api_key", apiKey); | ||
|
||
String authParamsStr = ""; | ||
try { | ||
authParamsStr = mapper.writeValueAsString(authParams); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); | ||
} | ||
|
||
MediaType mediaType = MediaType.parse("application/json"); | ||
RequestBody body = RequestBody.create(mediaType, authParamsStr); | ||
// TODO add support for IAM endpoint v1/auth/identitytoken | ||
this.request = new Request.Builder() | ||
.url(host | ||
+ "/icp4d-api/v1/authorize") | ||
.method("POST", body) | ||
.addHeader("Content-Type", "application/json") | ||
.build(); | ||
|
||
this.expiryRequest = new Request.Builder() | ||
.url(host + "/usermgmt/v1/user/tokenExpiry") | ||
.addHeader("Accept", "application/json") | ||
.addHeader("Authorization", "Bearer " + this.bearerToken.getAccessToken()) | ||
.build(); | ||
|
||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), | ||
IBMAuthBearerToken.class); | ||
|
||
Response expiry = client.newCall(request).execute(); | ||
this.bearerToken.setExpiration( | ||
OBJECT_MAPPER.readValue(expiry.body().string(), IBMAuthTokenExpiry.class) | ||
.getExpiry()); | ||
|
||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
private void generateNewBearerToken() { | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resources need to be cleaned up - |
||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), | ||
IBMAuthBearerToken.class); | ||
if (this.bearerToken.getExpiration() == null) { | ||
Response expiry = client.newCall(expiryRequest).execute(); | ||
this.bearerToken.setExpiration( | ||
OBJECT_MAPPER.readValue(expiry.body().string(), IBMAuthTokenExpiry.class) | ||
.getExpiry()); | ||
} | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please re-throw the exception with meaningful message |
||
} | ||
} | ||
|
||
public String getBearerTokenValue() { | ||
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000) | ||
< (new Date().getTime() + 60000))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The magic number |
||
generateNewBearerToken(); | ||
} | ||
return this.bearerToken.getAccessToken(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import ee.carlrobert.llm.PropertiesLoader; | ||
import ee.carlrobert.llm.client.DeserializationUtil; | ||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError; | ||
import ee.carlrobert.llm.completion.CompletionEventListener; | ||
import ee.carlrobert.llm.completion.CompletionEventSourceListener; | ||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import okhttp3.Headers; | ||
import okhttp3.MediaType; | ||
import okhttp3.OkHttpClient; | ||
import okhttp3.Request; | ||
import okhttp3.RequestBody; | ||
import okhttp3.sse.EventSource; | ||
import okhttp3.sse.EventSources; | ||
|
||
public class WatsonxClient { | ||
|
||
private static final MediaType APPLICATION_JSON = MediaType.parse("application/json"); | ||
private final OkHttpClient httpClient; | ||
private final String host; | ||
private final String apiVersion; | ||
private final WatsonxAuthenticator authenticator; | ||
|
||
private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) { | ||
this.httpClient = httpClientBuilder.build(); | ||
this.apiVersion = builder.apiVersion; | ||
this.host = builder.host; | ||
if (builder.isOnPrem) { | ||
if (builder.isZenApiKey) { | ||
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey); | ||
} else { | ||
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey, | ||
builder.host); | ||
} | ||
} else { | ||
this.authenticator = new WatsonxAuthenticator(builder.apiKey); | ||
} | ||
} | ||
|
||
public EventSource getCompletionAsync( | ||
WatsonxCompletionRequest request, | ||
CompletionEventListener<String> eventListener) { | ||
return EventSources.createFactory(httpClient).newEventSource( | ||
buildCompletionRequest(request), | ||
getCompletionEventSourceListener(eventListener)); | ||
} | ||
|
||
public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) { | ||
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) { | ||
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
protected Request buildCompletionRequest(WatsonxCompletionRequest request) { | ||
var headers = new HashMap<>(getRequiredHeaders()); | ||
if (request.getStream()) { | ||
headers.put("Accept", "text/event-stream"); | ||
} | ||
try { | ||
String path = | ||
(request.getDeploymentId() == null || request.getDeploymentId().isEmpty()) ? "text/" | ||
: "deployments/" + request.getDeploymentId() + "/"; | ||
String generation = request.getStream() ? "generation_stream" : "generation"; | ||
return new Request.Builder() | ||
.url(host + "/ml/v1/" + path + generation + "?version=" + apiVersion) | ||
.headers(Headers.of(headers)) | ||
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON)) | ||
.build(); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException("Unable to process request", e); | ||
} | ||
} | ||
|
||
private Map<String, String> getRequiredHeaders() { | ||
return new HashMap<>(Map.of("Authorization", | ||
(this.authenticator.isZenApiKey ? "ZenApiKey " : "Bearer ") | ||
+ authenticator.getBearerTokenValue())); | ||
} | ||
|
||
private CompletionEventSourceListener<String> getCompletionEventSourceListener( | ||
CompletionEventListener<String> eventListener) { | ||
return new CompletionEventSourceListener<>(eventListener) { | ||
@Override | ||
protected String getMessage(String data) { | ||
try { | ||
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class) | ||
.getResults().get(0).getGeneratedText(); | ||
} catch (Exception e) { | ||
try { | ||
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class) | ||
.getError() | ||
.getMessage(); | ||
return message == null ? "" : message; | ||
} catch (Exception ex) { | ||
System.out.println(ex); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use proper logging |
||
return ""; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
protected ErrorDetails getErrorDetails(String error) { | ||
try { | ||
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError(); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
public static class Builder { | ||
|
||
private final String apiKey; | ||
private String host = PropertiesLoader.getValue("watsonx.baseUrl"); | ||
private String apiVersion = "2024-03-14"; | ||
private Boolean isOnPrem; | ||
private Boolean isZenApiKey; | ||
private String username; | ||
|
||
public Builder(String apiKey) { | ||
this.apiKey = apiKey; | ||
} | ||
|
||
public Builder setApiVersion(String apiVersion) { | ||
this.apiVersion = apiVersion; | ||
return this; | ||
} | ||
|
||
public Builder setHost(String host) { | ||
this.host = host; | ||
return this; | ||
} | ||
|
||
public Builder setIsZenApiKey(Boolean isZenApiKey) { | ||
this.isZenApiKey = isZenApiKey; | ||
return this; | ||
} | ||
|
||
public Builder setIsOnPrem(Boolean isOnPrem) { | ||
this.isOnPrem = isOnPrem; | ||
return this; | ||
} | ||
|
||
public Builder setUsername(String username) { | ||
this.username = username; | ||
return this; | ||
} | ||
|
||
public WatsonxClient build(OkHttpClient.Builder builder) { | ||
return new WatsonxClient(this, builder); | ||
} | ||
|
||
public WatsonxClient build() { | ||
return build(new OkHttpClient.Builder()); | ||
} | ||
} | ||
} | ||
|
||
|
||
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These fields could be private if they're not accessed outside this class