Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ func TestGenerateAction(t *testing.T) {
t.Fatalf("action failed: %v", err)
}

if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" {
if diff := cmp.Diff(tc.ExpectChunks, chunks, cmp.Options{
cmpopts.IgnoreFields(ModelResponseChunk{}, "formatHandler"),
}); diff != "" {
t.Errorf("chunks mismatch (-want +got):\n%s", diff)
}

if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{
cmpopts.EquateEmpty(),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"),
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
Expand All @@ -155,7 +157,7 @@ func TestGenerateAction(t *testing.T) {

if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{
cmpopts.EquateEmpty(),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"),
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
Expand Down
70 changes: 30 additions & 40 deletions go/ai/format_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ package ai

import (
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/firebase/genkit/go/internal/base"
)
Expand All @@ -45,6 +43,7 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
handler := &arrayHandler{
instructions: instructions,
config: ModelOutputConfig{
Constrained: true,
Format: OutputFormatArray,
Schema: schema,
ContentType: "application/json",
Expand All @@ -55,58 +54,49 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
}

type arrayHandler struct {
instructions string
config ModelOutputConfig
instructions string
config ModelOutputConfig
accumulatedText string
currentIndex int
cursor int
}

// Instructions returns the instructions for the formatter.
func (a arrayHandler) Instructions() string {
func (a *arrayHandler) Instructions() string {
return a.instructions
}

// Config returns the output config for the formatter.
func (a arrayHandler) Config() ModelOutputConfig {
func (a *arrayHandler) Config() ModelOutputConfig {
return a.config
}

// ParseMessage parses the message and returns the formatted message.
func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
if a.config.Format == OutputFormatArray {
if m == nil {
return nil, errors.New("message is empty")
}
if len(m.Content) == 0 {
return nil, errors.New("message has no content")
}

var nonTextParts []*Part
accumulatedText := strings.Builder{}
// ParseOutput parses the final message and returns the parsed array.
func (a *arrayHandler) ParseOutput(m *Message) (any, error) {
result := base.ExtractItems(m.Text(), 0)
return result.Items, nil
}

for _, part := range m.Content {
if !part.IsText() {
nonTextParts = append(nonTextParts, part)
} else {
accumulatedText.WriteString(part.Text)
}
}
// ParseChunk processes a streaming chunk and returns parsed output.
func (a *arrayHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) {
if chunk.Index != a.currentIndex {
a.accumulatedText = ""
a.currentIndex = chunk.Index
a.cursor = 0
}

var newParts []*Part
lines := base.GetJSONObjectLines(accumulatedText.String())
for _, line := range lines {
var schemaBytes []byte
schemaBytes, err := json.Marshal(a.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}

newParts = append(newParts, NewJSONPart(line))
for _, part := range chunk.Content {
if part.IsText() {
a.accumulatedText += part.Text
}

m.Content = append(newParts, nonTextParts...)
}

result := base.ExtractItems(a.accumulatedText, a.cursor)
a.cursor = result.Cursor
return result.Items, nil
}

// ParseMessage parses the message and returns the formatted message.
func (a *arrayHandler) ParseMessage(m *Message) (*Message, error) {
return m, nil
}
60 changes: 53 additions & 7 deletions go/ai/format_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"regexp"
"slices"
"strings"

"github.com/firebase/genkit/go/core"
)

type enumFormatter struct{}
Expand All @@ -33,14 +35,15 @@ func (e enumFormatter) Name() string {
func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
enums := objectEnums(schema)
if schema == nil || len(enums) == 0 {
return nil, fmt.Errorf("schema is not valid JSON enum")
return nil, core.NewError(core.INVALID_ARGUMENT, "schema must be an object with an 'enum' property for enum format")
}

instructions := fmt.Sprintf("Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n```%s```", strings.Join(enums, "\n"))

handler := &enumHandler{
instructions: instructions,
config: ModelOutputConfig{
Constrained: true,
Format: OutputFormatEnum,
Schema: schema,
ContentType: "text/enum",
Expand All @@ -52,23 +55,49 @@ func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
}

type enumHandler struct {
instructions string
config ModelOutputConfig
enums []string
instructions string
config ModelOutputConfig
enums []string
accumulatedText string
currentIndex int
}

// Instructions returns the instructions for the formatter.
func (e enumHandler) Instructions() string {
func (e *enumHandler) Instructions() string {
return e.instructions
}

// Config returns the output config for the formatter.
func (e enumHandler) Config() ModelOutputConfig {
func (e *enumHandler) Config() ModelOutputConfig {
return e.config
}

// ParseOutput parses the final message and returns the enum value.
func (e *enumHandler) ParseOutput(m *Message) (any, error) {
return e.parseEnum(m.Text())
}

// ParseChunk processes a streaming chunk and returns parsed output.
func (e *enumHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) {
if chunk.Index != e.currentIndex {
e.accumulatedText = ""
e.currentIndex = chunk.Index
}

for _, part := range chunk.Content {
if part.IsText() {
e.accumulatedText += part.Text
}
}

// Ignore error since we are doing best effort parsing.
enum, _ := e.parseEnum(e.accumulatedText)

return enum, nil
}

// ParseMessage parses the message and returns the formatted message.
func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
func (e *enumHandler) ParseMessage(m *Message) (*Message, error) {
if e.config.Format == OutputFormatEnum {
if m == nil {
return nil, errors.New("message is empty")
Expand Down Expand Up @@ -127,3 +156,20 @@ func objectEnums(schema map[string]any) []string {

return enums
}

// parseEnum is the shared parsing logic used by both ParseOutput and ParseChunk.
func (e *enumHandler) parseEnum(text string) (string, error) {
if text == "" {
return "", nil
}

re := regexp.MustCompile(`['"]`)
clean := re.ReplaceAllString(text, "")
trimmed := strings.TrimSpace(clean)

if !slices.Contains(e.enums, trimmed) {
return "", fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
}

return trimmed, nil
}
Loading