Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions go/plugins/compat_oai/anthropic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,41 @@ First, set your Anthropic API key as an environment variable:
```bash
export ANTHROPIC_API_KEY=<your-api-key>
```
By default, `baseURL` is set to "https://api.anthropic.com/v1". However, if you

By default, `baseURL` is set to "<https://api.anthropic.com/v1>". However, if you
want to use a custom value, you can set `ANTHROPIC_BASE_URL` environment variable:

```bash
export ANTHROPIC_BASE_URL=<your-custom-base-url>
```

### Running All Tests

To run all tests in the directory:

```bash
go test -v .
```

### Running Tests from Specific Files

To run tests from a specific file:

```bash
# Run only generate_live_test.go tests
go test -run "^TestGenerator"

# Run only anthropic_live_test.go tests
go test -run "^TestPlugin"
go test -run "^TestAnthropicLive"
```

### Running Individual Tests

To run a specific test case:

```bash
# Run only the streaming test from anthropic_live_test.go
go test -run "TestPlugin/streaming"
go test -run "TestAnthropicLive/streaming"

# Run only the Complete test from generate_live_test.go
go test -run "TestGenerator_Complete"
Expand All @@ -51,9 +58,11 @@ go test -run "TestGenerator_Stream"
```

### Test Output Verbosity

Add the `-v` flag for verbose output:

```bash
go test -v -run "TestPlugin/streaming"
go test -v -run "TestAnthropicLive/streaming"
```

Note: All live tests require the ANTHROPIC_API_KEY environment variable to be set. Tests will be skipped if the API key is not provided.
2 changes: 1 addition & 1 deletion go/plugins/compat_oai/anthropic/anthropic_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/openai/openai-go/option"
)

func TestPlugin(t *testing.T) {
func TestAnthropicLive(t *testing.T) {
apiKey := os.Getenv("ANTHROPIC_API_KEY")
if apiKey == "" {
t.Skip("Skipping test: ANTHROPIC_API_KEY environment variable not set")
Expand Down
11 changes: 1 addition & 10 deletions go/plugins/compat_oai/compat_oai.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,7 @@ func (o *OpenAICompatible) DefineModel(provider, id string, opts ai.ModelOptions
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
// Configure the response generator with input
generator := NewModelGenerator(o.client, id).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools)

// Generate response
resp, err := generator.Generate(ctx, input, cb)
if err != nil {
return nil, err
}

return resp, nil
return generate(ctx, o.client, id, input, cb)
})
}

Expand Down
187 changes: 74 additions & 113 deletions go/plugins/compat_oai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,47 +34,82 @@ func mapToStruct(m map[string]any, v any) error {
return json.Unmarshal(jsonData, v)
}

// ModelGenerator handles OpenAI generation requests
type ModelGenerator struct {
client *openai.Client
modelName string
request *openai.ChatCompletionNewParams
messages []openai.ChatCompletionMessageParamUnion
tools []openai.ChatCompletionToolParam
toolChoice openai.ChatCompletionToolChoiceOptionUnionParam
// Store any errors that occur during building
err error
}
// generate executes the generation request using the new functional approach
func generate(
ctx context.Context,
client *openai.Client,
model string,
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
request, err := toOpenAIRequest(model, input)
if err != nil {
return nil, err
}

func (g *ModelGenerator) GetRequest() *openai.ChatCompletionNewParams {
return g.request
if cb != nil {
return generateStream(ctx, client, request, cb)
}
return generateComplete(ctx, client, request, input)
}

// NewModelGenerator creates a new ModelGenerator instance
func NewModelGenerator(client *openai.Client, modelName string) *ModelGenerator {
return &ModelGenerator{
client: client,
modelName: modelName,
request: &openai.ChatCompletionNewParams{
Model: (modelName),
},
func toOpenAIRequest(model string, input *ai.ModelRequest) (*openai.ChatCompletionNewParams, error) {
request, err := configFromRequest(input.Config)
if err != nil {
return nil, err
}
if request == nil {
request = &openai.ChatCompletionNewParams{}
}

request.Model = model

msgs, err := toOpenAIMessages(input.Messages)
if err != nil {
return nil, err
}
if len(msgs) == 0 {
return nil, fmt.Errorf("no messages provided")
}
request.Messages = msgs

tools := toOpenAITools(input.Tools)
if len(tools) > 0 {
request.Tools = tools
}

return request, nil
}

// WithMessages adds messages to the request
func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
// Return early if we already have an error
if g.err != nil {
return g
func configFromRequest(config any) (*openai.ChatCompletionNewParams, error) {
if config == nil {
return nil, nil
}

var openaiConfig openai.ChatCompletionNewParams
switch cfg := config.(type) {
case openai.ChatCompletionNewParams:
openaiConfig = cfg
case *openai.ChatCompletionNewParams:
openaiConfig = *cfg
case map[string]any:
if err := mapToStruct(cfg, &openaiConfig); err != nil {
return nil, fmt.Errorf("failed to convert config to openai.ChatCompletionNewParams: %w", err)
}
default:
return nil, fmt.Errorf("unexpected config type: %T", config)
}
return &openaiConfig, nil
}

func toOpenAIMessages(messages []*ai.Message) ([]openai.ChatCompletionMessageParamUnion, error) {
if messages == nil {
return g
return nil, nil
}

oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages))
for _, msg := range messages {
content := g.concatenateContent(msg.Content)
content := concatenateContent(msg.Content)
switch msg.Role {
case ai.RoleSystem:
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
Expand All @@ -83,8 +118,7 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
am.Content.OfString = param.NewOpt(content)
toolCalls, err := convertToolCalls(msg.Content)
if err != nil {
g.err = err
return g
return nil, err
}
if len(toolCalls) > 0 {
am.ToolCalls = (toolCalls)
Expand All @@ -105,8 +139,7 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {

toolOutput, err := anyToJSONString(p.ToolResponse.Output)
if err != nil {
g.err = err
return g
return nil, err
}
tm := openai.ToolMessage(toolOutput, toolCallID)
oaiMessages = append(oaiMessages, tm)
Expand Down Expand Up @@ -139,53 +172,12 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
}

}
g.messages = oaiMessages
return g
return oaiMessages, nil
}

// WithConfig adds configuration parameters from the model request
// see https://platform.openai.com/docs/api-reference/responses/create
// for more details on openai's request fields
func (g *ModelGenerator) WithConfig(config any) *ModelGenerator {
// Return early if we already have an error
if g.err != nil {
return g
}

if config == nil {
return g
}

var openaiConfig openai.ChatCompletionNewParams
switch cfg := config.(type) {
case openai.ChatCompletionNewParams:
openaiConfig = cfg
case *openai.ChatCompletionNewParams:
openaiConfig = *cfg
case map[string]any:
if err := mapToStruct(cfg, &openaiConfig); err != nil {
g.err = fmt.Errorf("failed to convert config to openai.ChatCompletionNewParams: %w", err)
return g
}
default:
g.err = fmt.Errorf("unexpected config type: %T", config)
return g
}

// keep the original model in the updated config structure
openaiConfig.Model = g.request.Model
g.request = &openaiConfig
return g
}

// WithTools adds tools to the request
func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator {
if g.err != nil {
return g
}

func toOpenAITools(tools []*ai.ToolDefinition) []openai.ChatCompletionToolParam {
if tools == nil {
return g
return nil
}

toolParams := make([]openai.ChatCompletionToolParam, 0, len(tools))
Expand All @@ -203,42 +195,11 @@ func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator {
}),
})
}

// Set the tools in the request
// If no tools are provided, set it to nil
// This is important to avoid sending an empty array in the request
// which is not supported by some vendor APIs
if len(toolParams) > 0 {
g.tools = toolParams
}

return g
}

// Generate executes the generation request
func (g *ModelGenerator) Generate(ctx context.Context, req *ai.ModelRequest, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// Check for any errors that occurred during building
if g.err != nil {
return nil, g.err
}

if len(g.messages) == 0 {
return nil, fmt.Errorf("no messages provided")
}
g.request.Messages = (g.messages)

if len(g.tools) > 0 {
g.request.Tools = (g.tools)
}

if handleChunk != nil {
return g.generateStream(ctx, handleChunk)
}
return g.generateComplete(ctx, req)
return toolParams
}

// concatenateContent concatenates text content into a single string
func (g *ModelGenerator) concatenateContent(parts []*ai.Part) string {
func concatenateContent(parts []*ai.Part) string {
content := ""
for _, part := range parts {
content += part.Text
Expand All @@ -247,8 +208,8 @@ func (g *ModelGenerator) concatenateContent(parts []*ai.Part) string {
}

// generateStream generates a streaming model response
func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
stream := g.client.Chat.Completions.NewStreaming(ctx, *g.request)
func generateStream(ctx context.Context, client *openai.Client, request *openai.ChatCompletionNewParams, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
stream := client.Chat.Completions.NewStreaming(ctx, *request)
defer stream.Close()

// Use openai-go's accumulator to collect the complete response
Expand Down Expand Up @@ -405,8 +366,8 @@ func convertChatCompletionToModelResponse(completion *openai.ChatCompletion) (*a
}

// generateComplete generates a complete model response
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
func generateComplete(ctx context.Context, client *openai.Client, request *openai.ChatCompletionNewParams, req *ai.ModelRequest) (*ai.ModelResponse, error) {
completion, err := client.Chat.Completions.New(ctx, *request)
if err != nil {
return nil, fmt.Errorf("failed to create completion: %w", err)
}
Expand Down
Loading
Loading