Skip to content

Commit 228420a

Browse files
committed
add Watsonx client
1 parent 03648e9 commit 228420a

11 files changed

+732
-1
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package ee.carlrobert.llm.client.watsonx;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
6+
@JsonIgnoreProperties(ignoreUnknown = true)
7+
public class IBMAuthBearerToken {
8+
@JsonProperty("access_token")
9+
String accessToken;
10+
@JsonProperty("expiration")
11+
int expiration;
12+
13+
String getAccessToken() {
14+
return this.accessToken;
15+
}
16+
17+
public void setAccessToken(String accessToken) {
18+
this.accessToken = accessToken;
19+
}
20+
21+
int getExpiration() {
22+
return this.expiration;
23+
}
24+
25+
public void setExpiration(int expiration) {
26+
this.expiration = expiration;
27+
}
28+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package ee.carlrobert.llm.client.watsonx;
2+
3+
import okhttp3.*;
4+
5+
import java.io.IOException;
6+
import java.util.Base64;
7+
import java.util.Date;
8+
9+
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;
10+
11+
public class WatsonxAuthenticator {
12+
13+
IBMAuthBearerToken bearerToken;
14+
OkHttpClient client;
15+
Request request;
16+
Boolean isZenApiKey=false;
17+
18+
// On Cloud
19+
public WatsonxAuthenticator(String apiKey) {
20+
this.client = new OkHttpClient().newBuilder()
21+
.build();
22+
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
23+
RequestBody body = RequestBody.create(mediaType, "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey="+apiKey);
24+
this.request = new Request.Builder()
25+
.url("https://iam.cloud.ibm.com/identity/token")
26+
.method("POST", body)
27+
.addHeader("Content-Type", "application/x-www-form-urlencoded")
28+
.build();
29+
try {
30+
Response response = client.newCall(request).execute();
31+
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class);
32+
} catch (IOException e) {
33+
System.out.println(e);
34+
}
35+
}
36+
37+
// Zen API Key
38+
public WatsonxAuthenticator(String username, String zenApiKey){
39+
IBMAuthBearerToken token = new IBMAuthBearerToken();
40+
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes()).toString();
41+
token.setAccessToken(tokenStr);
42+
this.bearerToken = token;
43+
this.isZenApiKey = true;
44+
}
45+
46+
// Watsonx API Key
47+
public WatsonxAuthenticator(String username, String apiKey, String host){//TODO add support for password
48+
this.client = new OkHttpClient().newBuilder()
49+
.build();
50+
MediaType mediaType = MediaType.parse("application/json");
51+
RequestBody body = RequestBody.create(mediaType, "{\"username\":\""+username+"\",\"api_key\":\""+apiKey+"\"}");
52+
this.request = new Request.Builder()
53+
.url(host + "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken
54+
.method("POST", body)
55+
.addHeader("Content-Type", "application/json")
56+
.build();
57+
try {
58+
Response response = client.newCall(request).execute();
59+
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class);
60+
} catch (IOException e) {
61+
System.out.println(e);
62+
}
63+
}
64+
65+
private void generateNewBearerToken() {
66+
try {
67+
Response response = client.newCall(request).execute();
68+
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class);
69+
} catch (IOException e) {
70+
System.out.println(e);
71+
}
72+
}
73+
74+
public String getBearerTokenValue() {
75+
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000) < new Date().getTime() + 1000000)) {//TODO add correct number of seconds
76+
generateNewBearerToken();
77+
}
78+
return this.bearerToken.getAccessToken();
79+
}
80+
}
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package ee.carlrobert.llm.client.watsonx;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import ee.carlrobert.llm.PropertiesLoader;
5+
import ee.carlrobert.llm.client.DeserializationUtil;
6+
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
7+
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest;
8+
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse;
9+
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError;
10+
import ee.carlrobert.llm.completion.CompletionEventListener;
11+
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
12+
import okhttp3.*;
13+
import okhttp3.sse.EventSource;
14+
import okhttp3.sse.EventSources;
15+
16+
import java.io.IOException;
17+
import java.util.HashMap;
18+
import java.util.Map;
19+
20+
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;
21+
22+
public class WatsonxClient {
23+
24+
private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");
25+
private final OkHttpClient httpClient;
26+
private final String host;
27+
private final String apiVersion;
28+
private final WatsonxAuthenticator authenticator;
29+
30+
private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
31+
this.httpClient = httpClientBuilder.build();
32+
this.apiVersion = builder.apiVersion;
33+
this.host = builder.host;
34+
if (builder.isOnPrem) {
35+
if (builder.isZenApiKey)
36+
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey);
37+
else
38+
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey, builder.host);
39+
} else {
40+
this.authenticator = new WatsonxAuthenticator(builder.apiKey);
41+
}
42+
}
43+
44+
public EventSource getCompletionAsync(
45+
WatsonxCompletionRequest request,
46+
CompletionEventListener<String> eventListener) {
47+
return EventSources.createFactory(httpClient).newEventSource(
48+
buildCompletionRequest(request),
49+
getCompletionEventSourceListener(eventListener));
50+
}
51+
52+
public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) {
53+
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) {
54+
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class);
55+
} catch (IOException e) {
56+
throw new RuntimeException(e);
57+
}
58+
}
59+
60+
protected Request buildCompletionRequest(WatsonxCompletionRequest request) {
61+
var headers = new HashMap<>(getRequiredHeaders());
62+
if (request.getStream()) {
63+
headers.put("Accept", "text/event-stream");
64+
}
65+
try {
66+
return new Request.Builder()
67+
.url(host + "/ml/v1/text/" + (request.getStream() ? "generation_stream" : "generation") + "?version=" + apiVersion)
68+
.headers(Headers.of(headers))
69+
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON))
70+
.build();
71+
} catch (JsonProcessingException e) {
72+
throw new RuntimeException("Unable to process request", e);
73+
}
74+
}
75+
76+
private Map<String, String> getRequiredHeaders() {
77+
return new HashMap<>(Map.of("Authorization", "Bearer " + authenticator.getBearerTokenValue()));
78+
}
79+
80+
private CompletionEventSourceListener<String> getCompletionEventSourceListener(
81+
CompletionEventListener<String> eventListener) {
82+
return new CompletionEventSourceListener<>(eventListener) {
83+
@Override
84+
protected String getMessage(String data) {
85+
try {
86+
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class)
87+
.getResults().get(0).getGeneratedText();
88+
} catch (Exception e) {
89+
try {
90+
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class)
91+
.getError()
92+
.getMessage();
93+
if (message != null) return message;
94+
return "";
95+
} catch (Exception ex) {
96+
return "";
97+
}
98+
}
99+
}
100+
101+
@Override
102+
protected ErrorDetails getErrorDetails(String error) {
103+
try {
104+
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError();
105+
} catch (JsonProcessingException e) {
106+
throw new RuntimeException(e);
107+
}
108+
}
109+
};
110+
}
111+
112+
public static class Builder {
113+
114+
private final String apiKey;
115+
private String host = PropertiesLoader.getValue("watsonx.baseUrl");
116+
private String apiVersion = "2024-03-14";
117+
private Boolean isOnPrem;
118+
private Boolean isZenApiKey;
119+
private String username;
120+
121+
public Builder(String apiKey){
122+
this.apiKey = apiKey;
123+
}
124+
public Builder setApiVersion(String apiVersion) {
125+
this.apiVersion = apiVersion;
126+
return this;
127+
}
128+
129+
public Builder setHost(String host) {
130+
this.host = host;
131+
return this;
132+
}
133+
134+
public Builder setIsZenApiKey(Boolean isZenApiKey) {
135+
this.isZenApiKey = isZenApiKey;
136+
return this;
137+
}
138+
139+
public Builder setIsOnPrem(Boolean isOnPrem) {
140+
this.isOnPrem = isOnPrem;
141+
return this;
142+
}
143+
144+
public Builder setUsername(String username) {
145+
this.username = username;
146+
return this;
147+
}
148+
149+
public WatsonxClient build(OkHttpClient.Builder builder) {
150+
return new WatsonxClient(this, builder);
151+
}
152+
153+
public WatsonxClient build() {
154+
return build(new OkHttpClient.Builder());
155+
}
156+
}
157+
}
158+
159+
160+
161+
162+
163+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package ee.carlrobert.llm.client.watsonx.completion;
2+
3+
import com.fasterxml.jackson.annotation.JsonCreator;
4+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
5+
import com.fasterxml.jackson.annotation.JsonProperty;
6+
import ee.carlrobert.llm.client.BaseError;
7+
8+
@JsonIgnoreProperties(ignoreUnknown = true)
9+
public class WatsonxCompletionErrorDetails extends BaseError {
10+
11+
private static final String DEFAULT_ERROR_MSG = "Something went wrong. Please try again later.";
12+
13+
String code;
14+
String message;
15+
16+
public WatsonxCompletionErrorDetails(String message) {
17+
this(message, null);
18+
}
19+
20+
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
21+
public WatsonxCompletionErrorDetails(
22+
@JsonProperty("message") String message,
23+
@JsonProperty("code") String code) {
24+
this.message = message;
25+
this.code = code;
26+
}
27+
28+
public static WatsonxCompletionErrorDetails DEFAULT_ERROR = new WatsonxCompletionErrorDetails(DEFAULT_ERROR_MSG,null);
29+
30+
public String getMessage() {
31+
return message;
32+
}
33+
34+
public String getCode() {
35+
return code;
36+
}
37+
38+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package ee.carlrobert.llm.client.watsonx.completion;
2+
3+
import ee.carlrobert.llm.completion.CompletionModel;
4+
5+
import java.util.Arrays;
6+
7+
public enum WatsonxCompletionModel implements CompletionModel {
8+
9+
GRANITE_3B_CODE_INSTRUCT("ibm/granite-3b-code-instruct","IBM Granite 3B Code Instruct", 8192),
10+
GRANITE_8B_CODE_INSTRUCT("ibm/granite-8b-code-instruct","IBM Granite 8B Code Instruct", 8192),
11+
GRANITE_20B_CODE_INSTRUCT( "ibm/granite-20b-code-instruct","IBM Granite 20B Code Instruct",8192),
12+
GRANITE_34B_CODE_INSTRUCT( "ibm/granite-34b-code-instruct","IBM Granite 34B Code Instruct",8192),
13+
CODELLAMA_34_B_INSTRUCT("codellama/codellama-34b-instruct-hf","Code Llama 34B Instruct", 8192),
14+
MIXTRAL_8_7B("mistralai/mixtral-8x7b-instruct-v01","Mixtral (8x7B)",32768),
15+
MIXTRAL_LARGE("mistralai/mistral-large","Mistral Large",128000),
16+
LLAMA_3_1_70B( "meta-llama/llama-3-1-70b-instruct","Llama 3.1 Instruct (70B)", 128000),
17+
LLAMA_3_1_8B( "meta-llama/llama-3-1-8b-instruct", "Llama 3.1 Instruct (8B)", 128000),
18+
LLAMA_2_7B("meta-llama/llama-2-70b-chat","Llama 2 Chat (70B)",4096),
19+
LLAMA_2_13B("meta-llama/llama-2-13b-chat","Llama 2 Chat (13B)",4096),
20+
GRANITE_13B_INSTRUCT_V2("ibm/granite-13b-instruct-v2","IBM Granite 13B Instruct V2",8192),
21+
GRANITE_13B_CHAT_V2("ibm/granite-13b-chat-v2","IBM Granite 13B Chat V2",8192),
22+
GRANITE_20B_MULTILINGUAL("ibm/granite-20b-multilingual","IBM Granite 20B Multilingual",8192);
23+
24+
private final String code;
25+
private final String description;
26+
private final int maxTokens;
27+
28+
WatsonxCompletionModel(String code, String description, int maxTokens) {
29+
this.code = code;
30+
this.description = description;
31+
this.maxTokens = maxTokens;
32+
}
33+
34+
public String getCode() {
35+
return code;
36+
}
37+
38+
public String getDescription() {
39+
return description;
40+
}
41+
42+
public int getMaxTokens() {
43+
return maxTokens;
44+
}
45+
46+
@Override
47+
public String toString() {
48+
return description;
49+
}
50+
51+
public static WatsonxCompletionModel findByCode(String code) {
52+
return Arrays.stream(WatsonxCompletionModel.values())
53+
.filter(item -> item.getCode().equals(code))
54+
.findFirst().orElseThrow();
55+
}
56+
}
57+

0 commit comments

Comments
 (0)