diff --git a/go/plugins/compat_oai/anthropic/README.md b/go/plugins/compat_oai/anthropic/README.md index 1a66fd0406..8995dd4a7a 100644 --- a/go/plugins/compat_oai/anthropic/README.md +++ b/go/plugins/compat_oai/anthropic/README.md @@ -14,7 +14,8 @@ First, set your Anthropic API key as an environment variable: ```bash export ANTHROPIC_API_KEY= ``` -By default, `baseURL` is set to "https://api.anthropic.com/v1". However, if you + +By default, `baseURL` is set to "". However, if you want to use a custom value, you can set `ANTHROPIC_BASE_URL` environment variable: ```bash @@ -22,26 +23,32 @@ export ANTHROPIC_BASE_URL= ``` ### Running All Tests + To run all tests in the directory: + ```bash go test -v . ``` ### Running Tests from Specific Files + To run tests from a specific file: + ```bash # Run only generate_live_test.go tests go test -run "^TestGenerator" # Run only anthropic_live_test.go tests -go test -run "^TestPlugin" +go test -run "^TestAnthropicLive" ``` ### Running Individual Tests + To run a specific test case: + ```bash # Run only the streaming test from anthropic_live_test.go -go test -run "TestPlugin/streaming" +go test -run "TestAnthropicLive/streaming" # Run only the Complete test from generate_live_test.go go test -run "TestGenerator_Complete" @@ -51,9 +58,11 @@ go test -run "TestGenerator_Stream" ``` ### Test Output Verbosity + Add the `-v` flag for verbose output: + ```bash -go test -v -run "TestPlugin/streaming" +go test -v -run "TestAnthropicLive/streaming" ``` Note: All live tests require the ANTHROPIC_API_KEY environment variable to be set. Tests will be skipped if the API key is not provided. diff --git a/go/plugins/compat_oai/anthropic/anthropic_live_test.go b/go/plugins/compat_oai/anthropic/anthropic_live_test.go index adacf90804..605cfdc698 100644 --- a/go/plugins/compat_oai/anthropic/anthropic_live_test.go +++ b/go/plugins/compat_oai/anthropic/anthropic_live_test.go @@ -26,7 +26,7 @@ import ( "github.com/openai/openai-go/option" ) -func TestPlugin(t *testing.T) { +func TestAnthropicLive(t *testing.T) { apiKey := os.Getenv("ANTHROPIC_API_KEY") if apiKey == "" { t.Skip("Skipping test: ANTHROPIC_API_KEY environment variable not set") diff --git a/go/plugins/compat_oai/compat_oai.go b/go/plugins/compat_oai/compat_oai.go index 683172d477..2c3b821752 100644 --- a/go/plugins/compat_oai/compat_oai.go +++ b/go/plugins/compat_oai/compat_oai.go @@ -119,16 +119,7 @@ func (o *OpenAICompatible) DefineModel(provider, id string, opts ai.ModelOptions input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { - // Configure the response generator with input - generator := NewModelGenerator(o.client, id).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools) - - // Generate response - resp, err := generator.Generate(ctx, input, cb) - if err != nil { - return nil, err - } - - return resp, nil + return generate(ctx, o.client, id, input, cb) }) } diff --git a/go/plugins/compat_oai/generate.go b/go/plugins/compat_oai/generate.go index fd5050a291..c00fb2b97b 100644 --- a/go/plugins/compat_oai/generate.go +++ b/go/plugins/compat_oai/generate.go @@ -34,47 +34,82 @@ func mapToStruct(m map[string]any, v any) error { return json.Unmarshal(jsonData, v) } -// ModelGenerator handles OpenAI generation requests -type ModelGenerator struct { - client *openai.Client - modelName string - request *openai.ChatCompletionNewParams - messages []openai.ChatCompletionMessageParamUnion - tools []openai.ChatCompletionToolParam - toolChoice openai.ChatCompletionToolChoiceOptionUnionParam - // Store any errors that occur during building - err error -} +// generate executes the generation request using the new functional approach +func generate( + ctx context.Context, + client *openai.Client, + model string, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { + request, err := toOpenAIRequest(model, input) + if err != nil { + return nil, err + } -func (g *ModelGenerator) GetRequest() *openai.ChatCompletionNewParams { - return g.request + if cb != nil { + return generateStream(ctx, client, request, cb) + } + return generateComplete(ctx, client, request, input) } -// NewModelGenerator creates a new ModelGenerator instance -func NewModelGenerator(client *openai.Client, modelName string) *ModelGenerator { - return &ModelGenerator{ - client: client, - modelName: modelName, - request: &openai.ChatCompletionNewParams{ - Model: (modelName), - }, +func toOpenAIRequest(model string, input *ai.ModelRequest) (*openai.ChatCompletionNewParams, error) { + request, err := configFromRequest(input.Config) + if err != nil { + return nil, err + } + if request == nil { + request = &openai.ChatCompletionNewParams{} + } + + request.Model = model + + msgs, err := toOpenAIMessages(input.Messages) + if err != nil { + return nil, err + } + if len(msgs) == 0 { + return nil, fmt.Errorf("no messages provided") + } + request.Messages = msgs + + tools := toOpenAITools(input.Tools) + if len(tools) > 0 { + request.Tools = tools } + + return request, nil } -// WithMessages adds messages to the request -func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { - // Return early if we already have an error - if g.err != nil { - return g +func configFromRequest(config any) (*openai.ChatCompletionNewParams, error) { + if config == nil { + return nil, nil } + var openaiConfig openai.ChatCompletionNewParams + switch cfg := config.(type) { + case openai.ChatCompletionNewParams: + openaiConfig = cfg + case *openai.ChatCompletionNewParams: + openaiConfig = *cfg + case map[string]any: + if err := mapToStruct(cfg, &openaiConfig); err != nil { + return nil, fmt.Errorf("failed to convert config to openai.ChatCompletionNewParams: %w", err) + } + default: + return nil, fmt.Errorf("unexpected config type: %T", config) + } + return &openaiConfig, nil +} + +func toOpenAIMessages(messages []*ai.Message) ([]openai.ChatCompletionMessageParamUnion, error) { if messages == nil { - return g + return nil, nil } oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) for _, msg := range messages { - content := g.concatenateContent(msg.Content) + content := concatenateContent(msg.Content) switch msg.Role { case ai.RoleSystem: oaiMessages = append(oaiMessages, openai.SystemMessage(content)) @@ -83,8 +118,7 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { am.Content.OfString = param.NewOpt(content) toolCalls, err := convertToolCalls(msg.Content) if err != nil { - g.err = err - return g + return nil, err } if len(toolCalls) > 0 { am.ToolCalls = (toolCalls) @@ -105,8 +139,7 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { toolOutput, err := anyToJSONString(p.ToolResponse.Output) if err != nil { - g.err = err - return g + return nil, err } tm := openai.ToolMessage(toolOutput, toolCallID) oaiMessages = append(oaiMessages, tm) @@ -139,53 +172,12 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { } } - g.messages = oaiMessages - return g + return oaiMessages, nil } -// WithConfig adds configuration parameters from the model request -// see https://platform.openai.com/docs/api-reference/responses/create -// for more details on openai's request fields -func (g *ModelGenerator) WithConfig(config any) *ModelGenerator { - // Return early if we already have an error - if g.err != nil { - return g - } - - if config == nil { - return g - } - - var openaiConfig openai.ChatCompletionNewParams - switch cfg := config.(type) { - case openai.ChatCompletionNewParams: - openaiConfig = cfg - case *openai.ChatCompletionNewParams: - openaiConfig = *cfg - case map[string]any: - if err := mapToStruct(cfg, &openaiConfig); err != nil { - g.err = fmt.Errorf("failed to convert config to openai.ChatCompletionNewParams: %w", err) - return g - } - default: - g.err = fmt.Errorf("unexpected config type: %T", config) - return g - } - - // keep the original model in the updated config structure - openaiConfig.Model = g.request.Model - g.request = &openaiConfig - return g -} - -// WithTools adds tools to the request -func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator { - if g.err != nil { - return g - } - +func toOpenAITools(tools []*ai.ToolDefinition) []openai.ChatCompletionToolParam { if tools == nil { - return g + return nil } toolParams := make([]openai.ChatCompletionToolParam, 0, len(tools)) @@ -203,42 +195,11 @@ func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator { }), }) } - - // Set the tools in the request - // If no tools are provided, set it to nil - // This is important to avoid sending an empty array in the request - // which is not supported by some vendor APIs - if len(toolParams) > 0 { - g.tools = toolParams - } - - return g -} - -// Generate executes the generation request -func (g *ModelGenerator) Generate(ctx context.Context, req *ai.ModelRequest, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { - // Check for any errors that occurred during building - if g.err != nil { - return nil, g.err - } - - if len(g.messages) == 0 { - return nil, fmt.Errorf("no messages provided") - } - g.request.Messages = (g.messages) - - if len(g.tools) > 0 { - g.request.Tools = (g.tools) - } - - if handleChunk != nil { - return g.generateStream(ctx, handleChunk) - } - return g.generateComplete(ctx, req) + return toolParams } // concatenateContent concatenates text content into a single string -func (g *ModelGenerator) concatenateContent(parts []*ai.Part) string { +func concatenateContent(parts []*ai.Part) string { content := "" for _, part := range parts { content += part.Text @@ -247,8 +208,8 @@ func (g *ModelGenerator) concatenateContent(parts []*ai.Part) string { } // generateStream generates a streaming model response -func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { - stream := g.client.Chat.Completions.NewStreaming(ctx, *g.request) +func generateStream(ctx context.Context, client *openai.Client, request *openai.ChatCompletionNewParams, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { + stream := client.Chat.Completions.NewStreaming(ctx, *request) defer stream.Close() // Use openai-go's accumulator to collect the complete response @@ -405,8 +366,8 @@ func convertChatCompletionToModelResponse(completion *openai.ChatCompletion) (*a } // generateComplete generates a complete model response -func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) { - completion, err := g.client.Chat.Completions.New(ctx, *g.request) +func generateComplete(ctx context.Context, client *openai.Client, request *openai.ChatCompletionNewParams, req *ai.ModelRequest) (*ai.ModelResponse, error) { + completion, err := client.Chat.Completions.New(ctx, *request) if err != nil { return nil, fmt.Errorf("failed to create completion: %w", err) } diff --git a/go/plugins/compat_oai/generate_live_test.go b/go/plugins/compat_oai/generate_live_test.go index 3e2e340b19..f0e4bee718 100644 --- a/go/plugins/compat_oai/generate_live_test.go +++ b/go/plugins/compat_oai/generate_live_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package compat_oai_test +package compat_oai import ( "context" @@ -23,15 +23,13 @@ import ( "testing" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/plugins/compat_oai" "github.com/openai/openai-go" "github.com/openai/openai-go/option" - "github.com/stretchr/testify/assert" ) const defaultModel = "gpt-4o-mini" -func setupTestClient(t *testing.T) *compat_oai.ModelGenerator { +func setupTestClient(t *testing.T) *openai.Client { t.Helper() apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" { @@ -39,11 +37,11 @@ func setupTestClient(t *testing.T) *compat_oai.ModelGenerator { } client := openai.NewClient(option.WithAPIKey(apiKey)) - return compat_oai.NewModelGenerator(&client, defaultModel) + return &client } func TestGenerator_Complete(t *testing.T) { - g := setupTestClient(t) + client := setupTestClient(t) messages := []*ai.Message{ { Role: ai.RoleUser, @@ -68,7 +66,7 @@ func TestGenerator_Complete(t *testing.T) { Messages: messages, } - resp, err := g.WithMessages(messages).Generate(context.Background(), req, nil) + resp, err := generate(context.Background(), client, defaultModel, req, nil) if err != nil { t.Error(err) } @@ -81,7 +79,7 @@ func TestGenerator_Complete(t *testing.T) { } func TestGenerator_Stream(t *testing.T) { - g := setupTestClient(t) + client := setupTestClient(t) messages := []*ai.Message{ { Role: ai.RoleUser, @@ -102,7 +100,7 @@ func TestGenerator_Stream(t *testing.T) { return nil } - _, err := g.WithMessages(messages).Generate(context.Background(), req, handleChunk) + _, err := generate(context.Background(), client, defaultModel, req, handleChunk) if err != nil { t.Error(err) } @@ -234,28 +232,37 @@ func TestWithConfig(t *testing.T) { }, }, } - req := &ai.ModelRequest{ - Messages: messages, - } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - generator := setupTestClient(t) - result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), req, nil) + req := &ai.ModelRequest{ + Messages: messages, + Config: tt.config, + } + + oaiReq, err := toOpenAIRequest(defaultModel, req) if tt.err != nil { - assert.Error(t, err) - assert.Equal(t, tt.err.Error(), err.Error()) + if err == nil { + t.Fatal("expected error, got nil") + } + if err.Error() != tt.err.Error() { + t.Errorf("got error %q, want %q", err.Error(), tt.err.Error()) + } return } // validate that the response was successful - assert.NoError(t, err) - assert.NotNil(t, result) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if oaiReq == nil { + t.Fatal("expected oaiReq to be not nil") + } // validate the input request was transformed correctly if tt.validate != nil { - tt.validate(t, generator.GetRequest()) + tt.validate(t, oaiReq) } }) } diff --git a/go/plugins/compat_oai/openai/README.md b/go/plugins/compat_oai/openai/README.md index 0cc107423e..81515f0d5e 100644 --- a/go/plugins/compat_oai/openai/README.md +++ b/go/plugins/compat_oai/openai/README.md @@ -48,26 +48,32 @@ export OPENAI_API_KEY= ``` ### Running All Tests + To run all tests in the directory: + ```bash go test -v . ``` ### Running Tests from Specific Files + To run tests from a specific file: + ```bash # Run only generate_live_test.go tests go test -run "^TestGenerator" # Run only openai_live_test.go tests -go test -run "^TestPlugin" +go test -run "^TestOpenAILive" ``` ### Running Individual Tests + To run a specific test case: + ```bash # Run only the streaming test from openai_live_test.go -go test -run "TestPlugin/streaming" +go test -run "TestOpenAILive/streaming" # Run only the Complete test from generate_live_test.go go test -run "TestGenerator_Complete" @@ -77,9 +83,11 @@ go test -run "TestGenerator_Stream" ``` ### Test Output Verbosity + Add the `-v` flag for verbose output: + ```bash -go test -v -run "TestPlugin/streaming" +go test -v -run "TestOpenAILive/streaming" ``` Note: All live tests require the OPENAI_API_KEY environment variable to be set. Tests will be skipped if the API key is not provided. diff --git a/go/plugins/compat_oai/openai/openai_live_test.go b/go/plugins/compat_oai/openai/openai_live_test.go index 3f9b3cc327..2cb0ec7d1a 100644 --- a/go/plugins/compat_oai/openai/openai_live_test.go +++ b/go/plugins/compat_oai/openai/openai_live_test.go @@ -31,7 +31,7 @@ import ( "github.com/openai/openai-go" ) -func TestPlugin(t *testing.T) { +func TestOpenAILive(t *testing.T) { apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" { t.Skip("Skipping test: OPENAI_API_KEY environment variable not set")