Skip to content

feat: Watsonx client #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ee.carlrobert.llm.client.watsonx;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public class IBMAuthBearerToken {

@JsonProperty("token")
String token;
@JsonProperty("access_token")
String accessToken;
@JsonProperty("expiration")
Integer expiration;

String getToken() {
return this.token;
}

public void setToken(String token) {
this.token = token;
}

String getAccessToken() {
return this.accessToken;
}

public void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}

Integer getExpiration() {
return this.expiration;
}

public void setExpiration(int expiration) {
this.expiration = expiration;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package ee.carlrobert.llm.client.watsonx;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public class IBMAuthTokenExpiry {

@JsonProperty("exp")
Integer expiry;

Integer getExpiry() {
return this.expiry;
}

public void setExpiry(int expiry) {
this.expiry = expiry;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package ee.carlrobert.llm.client.watsonx;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.Base64;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Map;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;

public class WatsonxAuthenticator {

IBMAuthBearerToken bearerToken;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields could be private if they're not accessed outside this class

OkHttpClient client;
Request request;
Request expiryRequest;
Boolean isZenApiKey = false;

// Watsonx SaaS
public WatsonxAuthenticator(String apiKey) {
this.client = new OkHttpClient().newBuilder()
.build();
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType,
"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + apiKey);
this.request = new Request.Builder()
.url("https://iam.cloud.ibm.com/identity/token")
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.build();
try {
Response response = client.newCall(request).execute();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't consider creating new bearer tokens on object initiation a good practise, please reconsider using a lazy initialization

this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

// Zen API Key
public WatsonxAuthenticator(String username, String zenApiKey) {
IBMAuthBearerToken token = new IBMAuthBearerToken();
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes())
.toString();
token.setAccessToken(tokenStr);
this.bearerToken = token;
this.isZenApiKey = true;
}

// Watsonx API Key
public WatsonxAuthenticator(String username, String apiKey,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor is too complex. Can we break it down into smaller units for better readability?

String host) { // TODO add support for password
this.client = new OkHttpClient().newBuilder()
.build();
ObjectMapper mapper = new ObjectMapper();
Map<String, String> authParams = new LinkedHashMap<>();
authParams.put("username", username);
authParams.put("api_key", apiKey);

String authParamsStr = "";
try {
authParamsStr = mapper.writeValueAsString(authParams);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, authParamsStr);
// TODO add support for IAM endpoint v1/auth/identitytoken
this.request = new Request.Builder()
.url(host
+ "/icp4d-api/v1/authorize")
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();

this.expiryRequest = new Request.Builder()
.url(host + "/usermgmt/v1/user/tokenExpiry")
.addHeader("Accept", "application/json")
.addHeader("Authorization", "Bearer " + this.bearerToken.getAccessToken())
.build();

try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
IBMAuthBearerToken.class);

Response expiry = client.newCall(request).execute();
this.bearerToken.setExpiration(
OBJECT_MAPPER.readValue(expiry.body().string(), IBMAuthTokenExpiry.class)
.getExpiry());

} catch (IOException e) {
System.out.println(e);
}
}

private void generateNewBearerToken() {
try {
Response response = client.newCall(request).execute();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resources need to be cleaned up - try (Response response = client.newCall(request).execute())

this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
IBMAuthBearerToken.class);
if (this.bearerToken.getExpiration() == null) {
Response expiry = client.newCall(expiryRequest).execute();
this.bearerToken.setExpiration(
OBJECT_MAPPER.readValue(expiry.body().string(), IBMAuthTokenExpiry.class)
.getExpiry());
}
} catch (IOException e) {
System.out.println(e);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please re-throw the exception with meaningful message

}
}

public String getBearerTokenValue() {
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000)
< (new Date().getTime() + 60000))) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 60000 could be named constant for clarity

generateNewBearerToken();
}
return this.bearerToken.getAccessToken();
}
}
175 changes: 175 additions & 0 deletions src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package ee.carlrobert.llm.client.watsonx;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import okhttp3.Headers;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

public class WatsonxClient {

private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");
private final OkHttpClient httpClient;
private final String host;
private final String apiVersion;
private final WatsonxAuthenticator authenticator;

private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
this.httpClient = httpClientBuilder.build();
this.apiVersion = builder.apiVersion;
this.host = builder.host;
if (builder.isOnPrem) {
if (builder.isZenApiKey) {
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey);
} else {
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey,
builder.host);
}
} else {
this.authenticator = new WatsonxAuthenticator(builder.apiKey);
}
}

public EventSource getCompletionAsync(
WatsonxCompletionRequest request,
CompletionEventListener<String> eventListener) {
return EventSources.createFactory(httpClient).newEventSource(
buildCompletionRequest(request),
getCompletionEventSourceListener(eventListener));
}

public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) {
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) {
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

protected Request buildCompletionRequest(WatsonxCompletionRequest request) {
var headers = new HashMap<>(getRequiredHeaders());
if (request.getStream()) {
headers.put("Accept", "text/event-stream");
}
try {
String path =
(request.getDeploymentId() == null || request.getDeploymentId().isEmpty()) ? "text/"
: "deployments/" + request.getDeploymentId() + "/";
String generation = request.getStream() ? "generation_stream" : "generation";
return new Request.Builder()
.url(host + "/ml/v1/" + path + generation + "?version=" + apiVersion)
.headers(Headers.of(headers))
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to process request", e);
}
}

private Map<String, String> getRequiredHeaders() {
return new HashMap<>(Map.of("Authorization",
(this.authenticator.isZenApiKey ? "ZenApiKey " : "Bearer ")
+ authenticator.getBearerTokenValue()));
}

private CompletionEventSourceListener<String> getCompletionEventSourceListener(
CompletionEventListener<String> eventListener) {
return new CompletionEventSourceListener<>(eventListener) {
@Override
protected String getMessage(String data) {
try {
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class)
.getResults().get(0).getGeneratedText();
} catch (Exception e) {
try {
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class)
.getError()
.getMessage();
return message == null ? "" : message;
} catch (Exception ex) {
System.out.println(ex);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use proper logging

return "";
}
}
}

@Override
protected ErrorDetails getErrorDetails(String error) {
try {
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
};
}

public static class Builder {

private final String apiKey;
private String host = PropertiesLoader.getValue("watsonx.baseUrl");
private String apiVersion = "2024-03-14";
private Boolean isOnPrem;
private Boolean isZenApiKey;
private String username;

public Builder(String apiKey) {
this.apiKey = apiKey;
}

public Builder setApiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
}

public Builder setHost(String host) {
this.host = host;
return this;
}

public Builder setIsZenApiKey(Boolean isZenApiKey) {
this.isZenApiKey = isZenApiKey;
return this;
}

public Builder setIsOnPrem(Boolean isOnPrem) {
this.isOnPrem = isOnPrem;
return this;
}

public Builder setUsername(String username) {
this.username = username;
return this;
}

public WatsonxClient build(OkHttpClient.Builder builder) {
return new WatsonxClient(this, builder);
}

public WatsonxClient build() {
return build(new OkHttpClient.Builder());
}
}
}






Loading