diff --git a/go/ai/action_test.go b/go/ai/action_test.go index e41679bd5b..b24e7c8c87 100644 --- a/go/ai/action_test.go +++ b/go/ai/action_test.go @@ -136,13 +136,15 @@ func TestGenerateAction(t *testing.T) { t.Fatalf("action failed: %v", err) } - if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" { + if diff := cmp.Diff(tc.ExpectChunks, chunks, cmp.Options{ + cmpopts.IgnoreFields(ModelResponseChunk{}, "formatHandler"), + }); diff != "" { t.Errorf("chunks mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{ cmpopts.EquateEmpty(), - cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"), + cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"), cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"), }); diff != "" { t.Errorf("response mismatch (-want +got):\n%s", diff) @@ -155,7 +157,7 @@ func TestGenerateAction(t *testing.T) { if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{ cmpopts.EquateEmpty(), - cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"), + cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"), cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"), }); diff != "" { t.Errorf("response mismatch (-want +got):\n%s", diff) diff --git a/go/ai/format_array.go b/go/ai/format_array.go index a063231e45..8ca9db2780 100644 --- a/go/ai/format_array.go +++ b/go/ai/format_array.go @@ -16,9 +16,7 @@ package ai import ( "encoding/json" - "errors" "fmt" - "strings" "github.com/firebase/genkit/go/internal/base" ) @@ -45,6 +43,7 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) { handler := &arrayHandler{ instructions: instructions, config: ModelOutputConfig{ + Constrained: true, Format: OutputFormatArray, Schema: schema, ContentType: "application/json", @@ -55,58 +54,49 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) { } type arrayHandler struct { - instructions string - config ModelOutputConfig + instructions string + config ModelOutputConfig + accumulatedText string + currentIndex int + cursor int } // Instructions returns the instructions for the formatter. -func (a arrayHandler) Instructions() string { +func (a *arrayHandler) Instructions() string { return a.instructions } // Config returns the output config for the formatter. -func (a arrayHandler) Config() ModelOutputConfig { +func (a *arrayHandler) Config() ModelOutputConfig { return a.config } -// ParseMessage parses the message and returns the formatted message. -func (a arrayHandler) ParseMessage(m *Message) (*Message, error) { - if a.config.Format == OutputFormatArray { - if m == nil { - return nil, errors.New("message is empty") - } - if len(m.Content) == 0 { - return nil, errors.New("message has no content") - } - - var nonTextParts []*Part - accumulatedText := strings.Builder{} +// ParseOutput parses the final message and returns the parsed array. +func (a *arrayHandler) ParseOutput(m *Message) (any, error) { + result := base.ExtractItems(m.Text(), 0) + return result.Items, nil +} - for _, part := range m.Content { - if !part.IsText() { - nonTextParts = append(nonTextParts, part) - } else { - accumulatedText.WriteString(part.Text) - } - } +// ParseChunk processes a streaming chunk and returns parsed output. +func (a *arrayHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) { + if chunk.Index != a.currentIndex { + a.accumulatedText = "" + a.currentIndex = chunk.Index + a.cursor = 0 + } - var newParts []*Part - lines := base.GetJSONObjectLines(accumulatedText.String()) - for _, line := range lines { - var schemaBytes []byte - schemaBytes, err := json.Marshal(a.config.Schema["items"]) - if err != nil { - return nil, fmt.Errorf("expected schema is not valid: %w", err) - } - if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil { - return nil, err - } - - newParts = append(newParts, NewJSONPart(line)) + for _, part := range chunk.Content { + if part.IsText() { + a.accumulatedText += part.Text } - - m.Content = append(newParts, nonTextParts...) } + result := base.ExtractItems(a.accumulatedText, a.cursor) + a.cursor = result.Cursor + return result.Items, nil +} + +// ParseMessage parses the message and returns the formatted message. +func (a *arrayHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } diff --git a/go/ai/format_enum.go b/go/ai/format_enum.go index 8fa07f62dc..0d2aa390f9 100644 --- a/go/ai/format_enum.go +++ b/go/ai/format_enum.go @@ -20,6 +20,8 @@ import ( "regexp" "slices" "strings" + + "github.com/firebase/genkit/go/core" ) type enumFormatter struct{} @@ -33,7 +35,7 @@ func (e enumFormatter) Name() string { func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) { enums := objectEnums(schema) if schema == nil || len(enums) == 0 { - return nil, fmt.Errorf("schema is not valid JSON enum") + return nil, core.NewError(core.INVALID_ARGUMENT, "schema must be an object with an 'enum' property for enum format") } instructions := fmt.Sprintf("Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n```%s```", strings.Join(enums, "\n")) @@ -41,6 +43,7 @@ func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) { handler := &enumHandler{ instructions: instructions, config: ModelOutputConfig{ + Constrained: true, Format: OutputFormatEnum, Schema: schema, ContentType: "text/enum", @@ -52,23 +55,49 @@ func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) { } type enumHandler struct { - instructions string - config ModelOutputConfig - enums []string + instructions string + config ModelOutputConfig + enums []string + accumulatedText string + currentIndex int } // Instructions returns the instructions for the formatter. -func (e enumHandler) Instructions() string { +func (e *enumHandler) Instructions() string { return e.instructions } // Config returns the output config for the formatter. -func (e enumHandler) Config() ModelOutputConfig { +func (e *enumHandler) Config() ModelOutputConfig { return e.config } +// ParseOutput parses the final message and returns the enum value. +func (e *enumHandler) ParseOutput(m *Message) (any, error) { + return e.parseEnum(m.Text()) +} + +// ParseChunk processes a streaming chunk and returns parsed output. +func (e *enumHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) { + if chunk.Index != e.currentIndex { + e.accumulatedText = "" + e.currentIndex = chunk.Index + } + + for _, part := range chunk.Content { + if part.IsText() { + e.accumulatedText += part.Text + } + } + + // Ignore error since we are doing best effort parsing. + enum, _ := e.parseEnum(e.accumulatedText) + + return enum, nil +} + // ParseMessage parses the message and returns the formatted message. -func (e enumHandler) ParseMessage(m *Message) (*Message, error) { +func (e *enumHandler) ParseMessage(m *Message) (*Message, error) { if e.config.Format == OutputFormatEnum { if m == nil { return nil, errors.New("message is empty") @@ -127,3 +156,20 @@ func objectEnums(schema map[string]any) []string { return enums } + +// parseEnum is the shared parsing logic used by both ParseOutput and ParseChunk. +func (e *enumHandler) parseEnum(text string) (string, error) { + if text == "" { + return "", nil + } + + re := regexp.MustCompile(`['"]`) + clean := re.ReplaceAllString(text, "") + trimmed := strings.TrimSpace(clean) + + if !slices.Contains(e.enums, trimmed) { + return "", fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", ")) + } + + return trimmed, nil +} diff --git a/go/ai/format_json.go b/go/ai/format_json.go index b38e3b3805..b3f60c2635 100644 --- a/go/ai/format_json.go +++ b/go/ai/format_json.go @@ -23,10 +23,16 @@ import ( "github.com/firebase/genkit/go/internal/base" ) -type jsonFormatter struct{} +type jsonFormatter struct { + // v2 does not implement ParseMessage. + v2 bool +} // Name returns the name of the formatter. func (j jsonFormatter) Name() string { + if j.v2 { + return OutputFormatJSONV2 + } return OutputFormatJSON } @@ -43,8 +49,10 @@ func (j jsonFormatter) Handler(schema map[string]any) (FormatHandler, error) { } handler := &jsonHandler{ + v2: j.v2, instructions: instructions, config: ModelOutputConfig{ + Constrained: true, Format: OutputFormatJSON, Schema: schema, ContentType: "application/json", @@ -56,64 +64,121 @@ func (j jsonFormatter) Handler(schema map[string]any) (FormatHandler, error) { // jsonHandler is a handler for the JSON formatter. type jsonHandler struct { - instructions string - config ModelOutputConfig + v2 bool + instructions string + config ModelOutputConfig + accumulatedText string + currentIndex int } // Instructions returns the instructions for the formatter. -func (j jsonHandler) Instructions() string { +func (j *jsonHandler) Instructions() string { return j.instructions } // Config returns the output config for the formatter. -func (j jsonHandler) Config() ModelOutputConfig { +func (j *jsonHandler) Config() ModelOutputConfig { return j.config } -// ParseMessage parses the message and returns the formatted message. -func (j jsonHandler) ParseMessage(m *Message) (*Message, error) { - if j.config.Format == OutputFormatJSON { - if m == nil { - return nil, errors.New("message is empty") +// ParseOutput parses the final message and returns the parsed JSON value. +func (j *jsonHandler) ParseOutput(m *Message) (any, error) { + result, err := j.parseJSON(m.Text()) + if err != nil { + return nil, err + } + + if j.config.Schema != nil { + if err := base.ValidateValue(result, j.config.Schema); err != nil { + return nil, err } - if len(m.Content) == 0 { - return nil, errors.New("message has no content") + } + + return result, nil +} + +// ParseChunk processes a streaming chunk and returns parsed output. +func (j *jsonHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) { + if chunk.Index != j.currentIndex { + j.accumulatedText = "" + j.currentIndex = chunk.Index + } + + for _, part := range chunk.Content { + if part.IsText() { + j.accumulatedText += part.Text } + } + + return j.parseJSON(j.accumulatedText) +} - var nonTextParts []*Part - accumulatedText := strings.Builder{} +// ParseMessage parses the message and returns the formatted message. +func (j *jsonHandler) ParseMessage(m *Message) (*Message, error) { + if j.v2 { + return m, nil + } - for _, part := range m.Content { - if !part.IsText() { - nonTextParts = append(nonTextParts, part) - } else { - accumulatedText.WriteString(part.Text) - } + // Legacy behavior. + if m == nil { + return nil, errors.New("message is empty") + } + if len(m.Content) == 0 { + return nil, errors.New("message has no content") + } + + var nonTextParts []*Part + accumulatedText := strings.Builder{} + + for _, part := range m.Content { + if !part.IsText() { + nonTextParts = append(nonTextParts, part) + } else { + accumulatedText.WriteString(part.Text) } + } - newParts := []*Part{} - text := base.ExtractJSONFromMarkdown(accumulatedText.String()) - if text != "" { - if j.config.Schema != nil { - schemaBytes, err := json.Marshal(j.config.Schema) - if err != nil { - return nil, fmt.Errorf("expected schema is not valid: %w", err) - } - if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil { - return nil, err - } - } else { - if !base.ValidJSON(text) { - return nil, errors.New("message is not a valid JSON") - } + newParts := []*Part{} + text := base.ExtractJSONFromMarkdown(accumulatedText.String()) + if text != "" { + if j.config.Schema != nil { + schemaBytes, err := json.Marshal(j.config.Schema) + if err != nil { + return nil, fmt.Errorf("expected schema is not valid: %w", err) + } + if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil { + return nil, err + } + } else { + if !base.ValidJSON(text) { + return nil, errors.New("message is not a valid JSON") } - newParts = append(newParts, NewJSONPart(text)) } + newParts = append(newParts, NewJSONPart(text)) + } - newParts = append(newParts, nonTextParts...) + newParts = append(newParts, nonTextParts...) - m.Content = newParts - } + m.Content = newParts return m, nil } + +// parseJSON is the shared parsing logic used by both ParseOutput and ParseChunk. +func (j *jsonHandler) parseJSON(text string) (any, error) { + if text == "" { + return nil, nil + } + + extracted := base.ExtractJSONFromMarkdown(text) + if extracted == "" { + return nil, nil + } + + result, err := base.ExtractJSON(extracted) + if err != nil { + return nil, nil + } + + return result, nil +} diff --git a/go/ai/format_jsonl.go b/go/ai/format_jsonl.go index f45c180129..032b7d90f7 100644 --- a/go/ai/format_jsonl.go +++ b/go/ai/format_jsonl.go @@ -20,20 +20,27 @@ import ( "fmt" "strings" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/internal/base" ) -type jsonlFormatter struct{} +type jsonlFormatter struct { + // v2 does not implement ParseMessage. + v2 bool +} // Name returns the name of the formatter. func (j jsonlFormatter) Name() string { + if j.v2 { + return OutputFormatJSONLV2 + } return OutputFormatJSONL } // Handler returns a new formatter handler for the given schema. func (j jsonlFormatter) Handler(schema map[string]any) (FormatHandler, error) { if schema == nil || !base.ValidateIsJSONArray(schema) { - return nil, fmt.Errorf("schema is not valid JSONL") + return nil, core.NewError(core.INVALID_ARGUMENT, "schema must be an array of objects for JSONL format") } jsonBytes, err := json.Marshal(schema["items"]) @@ -44,6 +51,7 @@ func (j jsonlFormatter) Handler(schema map[string]any) (FormatHandler, error) { instructions := fmt.Sprintf("Output should be JSONL format, a sequence of JSON objects (one per line) separated by a newline '\\n' character. Each line should be a JSON object conforming to the following schema:\n\n```%s```", string(jsonBytes)) handler := &jsonlHandler{ + v2: j.v2, instructions: instructions, config: ModelOutputConfig{ Format: OutputFormatJSONL, @@ -56,60 +64,170 @@ func (j jsonlFormatter) Handler(schema map[string]any) (FormatHandler, error) { } type jsonlHandler struct { - instructions string - config ModelOutputConfig + v2 bool + instructions string + config ModelOutputConfig + accumulatedText string + currentIndex int + cursor int } // Instructions returns the instructions for the formatter. -func (j jsonlHandler) Instructions() string { +func (j *jsonlHandler) Instructions() string { return j.instructions } // Config returns the output config for the formatter. -func (j jsonlHandler) Config() ModelOutputConfig { +func (j *jsonlHandler) Config() ModelOutputConfig { return j.config } -// ParseMessage parses the message and returns the formatted message. -func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) { - if j.config.Format == OutputFormatJSONL { - if m == nil { - return nil, errors.New("message is empty") +// ParseOutput parses the final message and returns the parsed array of objects. +func (j *jsonlHandler) ParseOutput(m *Message) (any, error) { + // Handle legacy behavior where ParseMessage split out content into multiple JSON parts. + var jsonParts []string + for _, part := range m.Content { + if part.IsText() && part.ContentType == "application/json" { + jsonParts = append(jsonParts, part.Text) } - if len(m.Content) == 0 { - return nil, errors.New("message has no content") + } + + var text string + if len(jsonParts) > 0 { + text = strings.Join(jsonParts, "\n") + } else { + var sb strings.Builder + for _, part := range m.Content { + if part.IsText() { + sb.WriteString(part.Text) + } } + text = sb.String() + } - var nonTextParts []*Part - accumulatedText := strings.Builder{} + result, _, err := j.parseJSONL(text, 0, false) + if err != nil { + return nil, err + } - for _, part := range m.Content { - if !part.IsText() { - nonTextParts = append(nonTextParts, part) - } else { - accumulatedText.WriteString(part.Text) + if j.config.Schema != nil { + if err := base.ValidateValue(result, j.config.Schema); err != nil { + return nil, err + } + } + + return result, nil +} + +// ParseChunk processes a streaming chunk and returns parsed output. +func (j *jsonlHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) { + if chunk.Index != j.currentIndex { + j.accumulatedText = "" + j.currentIndex = chunk.Index + j.cursor = 0 + } + + for _, part := range chunk.Content { + if part.IsText() { + j.accumulatedText += part.Text + } + } + + items, newCursor, err := j.parseJSONL(j.accumulatedText, j.cursor, true) + if err != nil { + return nil, err + } + j.cursor = newCursor + return items, nil +} + +// ParseMessage parses the message and returns the formatted message. +func (j *jsonlHandler) ParseMessage(m *Message) (*Message, error) { + if j.v2 { + return m, nil + } + + // Legacy behavior. + if m == nil { + return nil, errors.New("message is empty") + } + if len(m.Content) == 0 { + return nil, errors.New("message has no content") + } + + var nonTextParts []*Part + accumulatedText := strings.Builder{} + + for _, part := range m.Content { + if !part.IsText() { + nonTextParts = append(nonTextParts, part) + } else { + accumulatedText.WriteString(part.Text) + } + } + + var newParts []*Part + lines := base.GetJSONObjectLines(accumulatedText.String()) + for _, line := range lines { + if j.config.Schema != nil { + var schemaBytes []byte + schemaBytes, err := json.Marshal(j.config.Schema["items"]) + if err != nil { + return nil, fmt.Errorf("expected schema is not valid: %w", err) + } + if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil { + return nil, err } } - var newParts []*Part - lines := base.GetJSONObjectLines(accumulatedText.String()) - for _, line := range lines { - if j.config.Schema != nil { - var schemaBytes []byte - schemaBytes, err := json.Marshal(j.config.Schema["items"]) - if err != nil { - return nil, fmt.Errorf("expected schema is not valid: %w", err) - } - if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil { - return nil, err + newParts = append(newParts, NewJSONPart(line)) + } + + m.Content = append(newParts, nonTextParts...) + + return m, nil +} + +// parseJSONL parses JSONL starting from the cursor position. +// Returns the parsed items, the new cursor position, and any error. +func (j *jsonlHandler) parseJSONL(text string, cursor int, allowPartial bool) ([]any, int, error) { + if text == "" || cursor >= len(text) { + return nil, cursor, nil + } + + results := []any{} + remaining := text[cursor:] + lines := strings.Split(remaining, "\n") + currentPos := cursor + + for i, line := range lines { + isLastLine := i == len(lines)-1 + lineLen := len(line) + trimmed := strings.TrimSpace(line) + + if strings.HasPrefix(trimmed, "{") { + var result any + err := json.Unmarshal([]byte(trimmed), &result) + if err != nil { + if allowPartial && isLastLine { + partialResult, partialErr := base.ParsePartialJSON(trimmed) + if partialErr == nil && partialResult != nil { + results = append(results, partialResult) + } + // Don't advance cursor for partial line. + break } + return nil, cursor, fmt.Errorf("invalid JSON on line %d: %w", i+1, err) + } + if result != nil { + results = append(results, result) } - - newParts = append(newParts, NewJSONPart(line)) } - m.Content = append(newParts, nonTextParts...) + if !isLastLine { + currentPos += lineLen + 1 // +1 for newline + } } - return m, nil + return results, currentPos, nil } diff --git a/go/ai/format_text.go b/go/ai/format_text.go index b9b85d9b8a..2ca3fefb6f 100644 --- a/go/ai/format_text.go +++ b/go/ai/format_text.go @@ -33,21 +33,44 @@ func (t textFormatter) Handler(schema map[string]any) (FormatHandler, error) { } type textHandler struct { - instructions string - config ModelOutputConfig + instructions string + config ModelOutputConfig + accumulatedText string + currentIndex int } // Config returns the output config for the formatter. -func (t textHandler) Config() ModelOutputConfig { +func (t *textHandler) Config() ModelOutputConfig { return t.config } // Instructions returns the instructions for the formatter. -func (t textHandler) Instructions() string { +func (t *textHandler) Instructions() string { return t.instructions } +// ParseOutput parses the final message and returns the text content. +func (t *textHandler) ParseOutput(m *Message) (any, error) { + return m.Text(), nil +} + +// ParseChunk processes a streaming chunk and returns parsed output. +func (t *textHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) { + if chunk.Index != t.currentIndex { + t.accumulatedText = "" + t.currentIndex = chunk.Index + } + + for _, part := range chunk.Content { + if part.IsText() { + t.accumulatedText += part.Text + } + } + + return t.accumulatedText, nil +} + // ParseMessage parses the message and returns the formatted message. -func (t textHandler) ParseMessage(m *Message) (*Message, error) { +func (t *textHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } diff --git a/go/ai/formatter.go b/go/ai/formatter.go index 8bf2a8fb18..f6d07b390a 100644 --- a/go/ai/formatter.go +++ b/go/ai/formatter.go @@ -21,19 +21,36 @@ import ( ) const ( - OutputFormatText string = "text" - OutputFormatJSON string = "json" + // OutputFormatText is the default format. + OutputFormatText string = "text" + // OutputFormatJSON is the legacy format for JSON content. It modifies the message content in place by stripping out non-JSON content. + // + // Deprecated: Use OutputFormatJSONV2 instead. + OutputFormatJSON string = "json" + // OutputFormatJSONV2 is the format for JSON content. + OutputFormatJSONV2 string = "json.v2" + // OutputFormatJSONL is the legacy format for JSONL content. It modifies the message content in place by splitting out a single text part into multiple JSON parts, one per JSON object line. + // It does not support [ModelResponse.Output] to an array; each JSON part must be unmarshaled manually. + // + // Deprecated: Use OutputFormatJSONLV2 instead. OutputFormatJSONL string = "jsonl" + // OutputFormatJSONLV2 is the format for JSONL content. + OutputFormatJSONLV2 string = "jsonl.v2" + // OutputFormatMedia is the format for media content. OutputFormatMedia string = "media" + // OutputFormatArray is the format for array content. OutputFormatArray string = "array" - OutputFormatEnum string = "enum" + // OutputFormatEnum is the format for enum content. + OutputFormatEnum string = "enum" ) // Default formats get automatically registered on registry init var DEFAULT_FORMATS = []Formatter{ + textFormatter{}, jsonFormatter{}, + jsonFormatter{v2: true}, jsonlFormatter{}, - textFormatter{}, + jsonlFormatter{v2: true}, arrayFormatter{}, enumFormatter{}, } @@ -46,9 +63,12 @@ type Formatter interface { Handler(schema map[string]any) (FormatHandler, error) } -// FormatHandler represents the handler part of the Formatter interface. +// FormatHandler is a handler for formatting messages. +// A new instance is created via [Formatter.Handler] for each request. type FormatHandler interface { // ParseMessage parses the message and returns a new formatted message. + // + // Legacy: New format handlers should implement this as a no-op passthrough and implement [StreamingFormatHandler] instead. ParseMessage(message *Message) (*Message, error) // Instructions returns the formatter instructions to embed in the prompt. Instructions() string @@ -56,6 +76,17 @@ type FormatHandler interface { Config() ModelOutputConfig } +// StreamingFormatHandler is a handler for formatting messages that supports streaming. +// This interface must be implemented to be able to use [ModelResponse.Output] and [ModelResponseChunk.Output]. +type StreamingFormatHandler interface { + // ParseOutput parses the final output and returns parsed output. + ParseOutput(message *Message) (any, error) + // ParseChunk processes a streaming chunk and returns parsed output. + // The handler maintains its own internal state. When the chunk's index changes, the state is reset for the new turn. + // Returns parsed output, or nil if nothing can be parsed yet. + ParseChunk(chunk *ModelResponseChunk) (any, error) +} + // ConfigureFormats registers default formats in the registry func ConfigureFormats(reg api.Registry) { for _, format := range DEFAULT_FORMATS { diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index 1b55f1bad7..8a96c0b0b4 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -1518,3 +1518,349 @@ func TestEnumParserStreaming(t *testing.T) { }) } } + +func TestJSONFormatterProcessChunk(t *testing.T) { + tests := []struct { + name string + text string + want any + wantNil bool + }{ + { + name: "complete JSON", + text: `{"name": "John", "age": 30}`, + want: map[string]any{"name": "John", "age": float64(30)}, + }, + { + name: "partial JSON", + text: `{"name": "John", "age": 3`, + want: map[string]any{"name": "John", "age": float64(3)}, + }, + { + name: "JSON in markdown", + text: "```json\n{\"name\": \"Jane\"}\n```", + want: map[string]any{"name": "Jane"}, + }, + { + name: "empty text", + text: "", + wantNil: true, + }, + { + name: "incremental chunks simulated", + text: `{"name": "John", "age": 30}`, + want: map[string]any{"name": "John", "age": float64(30)}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler, err := jsonFormatter{}.Handler(map[string]any{"type": "object"}) + if err != nil { + t.Fatal(err) + } + + sfh, ok := handler.(StreamingFormatHandler) + if !ok { + t.Fatal("handler does not implement StreamingFormatHandler") + } + + chunk := &ModelResponseChunk{ + Content: []*Part{NewTextPart(tt.text)}, + Index: 0, + } + got, err := sfh.ParseChunk(chunk) + if err != nil { + t.Errorf("ProcessChunk() error = %v", err) + return + } + if tt.wantNil { + if got != nil { + t.Errorf("ProcessChunk() = %v, want nil", got) + } + return + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("ProcessChunk() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestJSONLFormatterProcessChunk(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + }, + } + + tests := []struct { + name string + text string + want []any + wantNil bool + }{ + { + name: "two complete lines", + text: "{\"name\": \"John\"}\n{\"name\": \"Jane\"}", + want: []any{ + map[string]any{"name": "John"}, + map[string]any{"name": "Jane"}, + }, + }, + { + name: "one complete, one partial", + text: "{\"name\": \"John\"}\n{\"name\": \"J", + want: []any{ + map[string]any{"name": "John"}, + map[string]any{"name": "J"}, + }, + }, + { + name: "incremental chunks simulated", + text: "{\"name\": \"John\"}\n{\"name\": \"Jane\"}", + want: []any{ + map[string]any{"name": "John"}, + map[string]any{"name": "Jane"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler, err := jsonlFormatter{}.Handler(schema) + if err != nil { + t.Fatal(err) + } + + sfh, ok := handler.(StreamingFormatHandler) + if !ok { + t.Fatal("handler does not implement StreamingFormatHandler") + } + + chunk := &ModelResponseChunk{ + Content: []*Part{NewTextPart(tt.text)}, + Index: 0, + } + got, err := sfh.ParseChunk(chunk) + if err != nil { + t.Errorf("ProcessChunk() error = %v", err) + return + } + if tt.wantNil { + if got != nil { + t.Errorf("ProcessChunk() = %v, want nil", got) + } + return + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("ProcessChunk() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestArrayFormatterProcessChunk(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + }, + } + + tests := []struct { + name string + text string + want []any + wantNil bool + }{ + { + name: "complete array", + text: `[{"name": "John"}, {"name": "Jane"}]`, + want: []any{ + map[string]any{"name": "John"}, + map[string]any{"name": "Jane"}, + }, + }, + { + name: "partial array", + text: `[{"name": "John"}, {"name": "J`, + want: []any{ + map[string]any{"name": "John"}, + }, + }, + { + name: "incremental chunks simulated", + text: `[{"name": "John"}, {"name": "Jane"}]`, + want: []any{ + map[string]any{"name": "John"}, + map[string]any{"name": "Jane"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler, err := arrayFormatter{}.Handler(schema) + if err != nil { + t.Fatal(err) + } + + sfh, ok := handler.(StreamingFormatHandler) + if !ok { + t.Fatal("handler does not implement StreamingFormatHandler") + } + + chunk := &ModelResponseChunk{ + Content: []*Part{NewTextPart(tt.text)}, + Index: 0, + } + got, err := sfh.ParseChunk(chunk) + if err != nil { + t.Errorf("ProcessChunk() error = %v", err) + return + } + if tt.wantNil { + if got != nil { + t.Errorf("ProcessChunk() = %v, want nil", got) + } + return + } + + gotSlice, ok := got.([]any) + if !ok { + t.Errorf("ProcessChunk() returned %T, want []any", got) + return + } + + if diff := cmp.Diff(tt.want, gotSlice); diff != "" { + t.Errorf("ProcessChunk() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestTextFormatterProcessChunk(t *testing.T) { + tests := []struct { + name string + text string + want string + }{ + { + name: "simple text", + text: "Hello, world!", + want: "Hello, world!", + }, + { + name: "incremental text simulated", + text: "Hello, world!", + want: "Hello, world!", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler, err := textFormatter{}.Handler(nil) + if err != nil { + t.Fatal(err) + } + + sfh, ok := handler.(StreamingFormatHandler) + if !ok { + t.Fatal("handler does not implement StreamingFormatHandler") + } + + chunk := &ModelResponseChunk{ + Content: []*Part{NewTextPart(tt.text)}, + Index: 0, + } + got, err := sfh.ParseChunk(chunk) + if err != nil { + t.Errorf("ProcessChunk() error = %v", err) + return + } + if got != tt.want { + t.Errorf("ProcessChunk() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEnumFormatterProcessChunk(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + "enum": []any{"option1", "option2", "option3"}, + }, + }, + } + + tests := []struct { + name string + text string + want string + wantNil bool + }{ + { + name: "valid enum", + text: "option1", + want: "option1", + }, + { + name: "valid enum with quotes", + text: "\"option2\"", + want: "option2", + }, + { + name: "invalid enum", + text: "invalid", + wantNil: true, + }, + { + name: "partial match", + text: "opt", + wantNil: true, + }, + { + name: "incremental chunks simulated", + text: "option1", + want: "option1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler, err := enumFormatter{}.Handler(schema) + if err != nil { + t.Fatal(err) + } + + sfh, ok := handler.(StreamingFormatHandler) + if !ok { + t.Fatal("handler does not implement StreamingFormatHandler") + } + + chunk := &ModelResponseChunk{ + Content: []*Part{NewTextPart(tt.text)}, + Index: 0, + } + got, err := sfh.ParseChunk(chunk) + if err != nil { + t.Errorf("ProcessChunk() error = %v", err) + return + } + if tt.wantNil { + if got != nil { + t.Errorf("ProcessChunk() = %v, want nil", got) + } + return + } + if got != tt.want { + t.Errorf("ProcessChunk() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go/ai/gen.go b/go/ai/gen.go index 5d24d51bf0..0aba7f4ca9 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -262,17 +262,19 @@ type ModelResponse struct { // Request is the [ModelRequest] struct used to trigger this response. Request *ModelRequest `json:"request,omitempty"` // Usage describes how many resources were used by this generation request. - Usage *GenerationUsage `json:"usage,omitempty"` + Usage *GenerationUsage `json:"usage,omitempty"` + formatHandler StreamingFormatHandler } // A ModelResponseChunk is the portion of the [ModelResponse] // that is passed to a streaming callback. type ModelResponseChunk struct { - Aggregated bool `json:"aggregated,omitempty"` - Content []*Part `json:"content,omitempty"` - Custom any `json:"custom,omitempty"` - Index int `json:"index,omitempty"` - Role Role `json:"role,omitempty"` + Aggregated bool `json:"aggregated,omitempty"` + Content []*Part `json:"content,omitempty"` + Custom any `json:"custom,omitempty"` + Index int `json:"index,omitempty"` + Role Role `json:"role,omitempty"` + formatHandler StreamingFormatHandler } // OutputConfig describes the structure that the model's output diff --git a/go/ai/generate.go b/go/ai/generate.go index fe8470d301..4db116eaa6 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -288,7 +288,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi // Native constrained output is enabled only when the user has // requested it, the model supports it, and there's a JSON schema. outputCfg.Constrained = opts.Output.JsonSchema != nil && - opts.Output.Constrained && m.(*model).supportsConstrained(len(toolDefs) > 0) + opts.Output.Constrained && outputCfg.Constrained && m.(*model).supportsConstrained(len(toolDefs) > 0) // Add schema instructions to prompt when not using native constraints. // This is a no-op for unstructured output requests. @@ -334,6 +334,11 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi currentRole := RoleModel currentIndex := messageIndex + var streamingHandler StreamingFormatHandler + if sfh, ok := formatHandler.(StreamingFormatHandler); ok { + streamingHandler = sfh + } + if cb != nil { wrappedCb = func(ctx context.Context, chunk *ModelResponseChunk) error { if chunk.Role != currentRole && chunk.Role != "" { @@ -344,6 +349,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi if chunk.Role == "" { chunk.Role = RoleModel } + chunk.formatHandler = streamingHandler return cb(ctx, chunk) } } @@ -354,6 +360,8 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } if formatHandler != nil { + resp.formatHandler = streamingHandler + // This is legacy behavior. New format handlers should implement ParseMessage as a passthrough. resp.Message, err = formatHandler.ParseMessage(resp.Message) if err != nil { logger.FromContext(ctx).Debug("model failed to generate output matching expected schema", "error", err.Error()) @@ -526,7 +534,6 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) ( } // Generate run generate request for this model. Returns ModelResponse struct. -// TODO: Stream GenerateData with partial JSON func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) { var value Out opts = append(opts, WithOutputType(value)) @@ -758,14 +765,29 @@ func (mr *ModelResponse) Reasoning() string { return sb.String() } -// Output unmarshals structured JSON output into the provided -// struct pointer. +// Output parses the structured output from the response and unmarshals it into v. +// If a format handler is set, it uses the handler's ParseOutput method. +// Otherwise, it falls back to parsing the response text as JSON. func (mr *ModelResponse) Output(v any) error { - j := base.ExtractJSONFromMarkdown(mr.Text()) - if j == "" { - return errors.New("unable to parse JSON from response text") + if mr.Message == nil || len(mr.Message.Content) == 0 { + return errors.New("no content in response") + } + + if mr.formatHandler != nil { + output, err := mr.formatHandler.ParseOutput(mr.Message) + if err != nil { + return err + } + + b, err := json.Marshal(output) + if err != nil { + return fmt.Errorf("failed to marshal output: %w", err) + } + return json.Unmarshal(b, v) } - return json.Unmarshal([]byte(j), v) + + // For backward compatibility, extract JSON from the response text. + return json.Unmarshal([]byte(base.ExtractJSONFromMarkdown(mr.Message.Text())), v) } // ToolRequests returns the tool requests from the response. @@ -809,9 +831,9 @@ func (mr *ModelResponse) Media() string { return "" } -// Text returns the text content of the [ModelResponseChunk] -// as a string. It returns an error if there is no Content -// in the response chunk. +// Text returns the text content of the ModelResponseChunk as a string. +// It returns an empty string if there is no Content in the response chunk. +// For the parsed structured output, use [ModelResponseChunk.Output] instead. func (c *ModelResponseChunk) Text() string { if len(c.Content) == 0 { return "" @@ -828,6 +850,40 @@ func (c *ModelResponseChunk) Text() string { return sb.String() } +// Output parses the chunk using the format handler and unmarshals the result into v. +// Returns an error if the format handler is not set or does not support parsing chunks. +func (c *ModelResponseChunk) Output(v any) error { + if c.formatHandler == nil { + return errors.New("output format chosen does not support parsing chunks") + } + + output, err := c.formatHandler.ParseChunk(c) + if err != nil { + return err + } + + b, err := json.Marshal(output) + if err != nil { + return fmt.Errorf("failed to marshal chunk output: %w", err) + } + return json.Unmarshal(b, v) +} + +// outputer is an interface for types that can unmarshal structured output. +type outputer interface { + Output(v any) error +} + +// OutputFrom is a convenience function that parses structured output from a +// [ModelResponse] or [ModelResponseChunk] and returns it as a typed value. +// This is equivalent to calling Output() but returns the value directly instead +// of requiring a pointer argument. If you need to handle the error, use Output() instead. +func OutputFrom[T any](src outputer) T { + var v T + src.Output(&v) + return v +} + // Text returns the contents of a [Message] as a string. It // returns an empty string if the message has no content. // If you want to get reasoning from the message, use Reasoning() instead. diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index 95299f1852..1996fdb7fb 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -1269,3 +1269,235 @@ func TestResourceProcessingError(t *testing.T) { t.Fatalf("wrong error: %v", err) } } + +func TestModelResponseOutput(t *testing.T) { + t.Run("single JSON part (json format)", func(t *testing.T) { + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"name":"Alice","age":30}`), + }, + }, + } + + var result struct { + Name string `json:"name"` + Age int `json:"age"` + } + err := mr.Output(&result) + if err != nil { + t.Fatalf("Output() error = %v", err) + } + if result.Name != "Alice" || result.Age != 30 { + t.Errorf("Output() = %+v, want {Alice 30}", result) + } + }) + + t.Run("JSON array without format handler", func(t *testing.T) { + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart(`[{"id":1},{"id":2},{"id":3}]`), + }, + }, + } + + var result []struct { + ID int `json:"id"` + } + err := mr.Output(&result) + if err != nil { + t.Fatalf("Output() error = %v", err) + } + if len(result) != 3 { + t.Fatalf("Output() got %d items, want 3", len(result)) + } + for i, item := range result { + if item.ID != i+1 { + t.Errorf("Output()[%d].ID = %d, want %d", i, item.ID, i+1) + } + } + }) + + t.Run("plain JSON text without format handler", func(t *testing.T) { + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart(`{"value":42}`), + }, + }, + } + + var result struct { + Value int `json:"value"` + } + err := mr.Output(&result) + if err != nil { + t.Fatalf("Output() error = %v", err) + } + if result.Value != 42 { + t.Errorf("Output().Value = %d, want 42", result.Value) + } + }) + + t.Run("no content error", func(t *testing.T) { + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{}, + }, + } + + var result any + err := mr.Output(&result) + if err == nil { + t.Error("Output() expected error for empty content") + } + }) + + t.Run("nil message error", func(t *testing.T) { + mr := &ModelResponse{ + Message: nil, + } + + var result any + err := mr.Output(&result) + if err == nil { + t.Error("Output() expected error for nil message") + } + }) + + t.Run("no JSON found error", func(t *testing.T) { + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart("Just plain text with no JSON"), + }, + }, + } + + var result any + err := mr.Output(&result) + if err == nil { + t.Error("Output() expected error when no JSON found") + } + }) + + t.Run("format-aware: jsonl format with handler", func(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "line": map[string]any{"type": "integer"}, + }, + }, + } + formatter := jsonlFormatter{} + handler, err := formatter.Handler(schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + streamingHandler := handler.(StreamingFormatHandler) + + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart("{\"line\":1}\n{\"line\":2}"), + }, + }, + formatHandler: streamingHandler, + } + + var result []struct { + Line int `json:"line"` + } + err = mr.Output(&result) + if err != nil { + t.Fatalf("Output() error = %v", err) + } + if len(result) != 2 || result[0].Line != 1 || result[1].Line != 2 { + t.Errorf("Output() = %+v, want [{1} {2}]", result) + } + }) + + t.Run("format-aware: array format with handler", func(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "item": map[string]any{"type": "string"}, + }, + }, + } + formatter := arrayFormatter{} + handler, err := formatter.Handler(schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + streamingHandler := handler.(StreamingFormatHandler) + + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart(`[{"item":"a"},{"item":"b"}]`), + }, + }, + formatHandler: streamingHandler, + } + + var result []struct { + Item string `json:"item"` + } + err = mr.Output(&result) + if err != nil { + t.Fatalf("Output() error = %v", err) + } + if len(result) != 2 || result[0].Item != "a" || result[1].Item != "b" { + t.Errorf("Output() = %+v, want [{a} {b}]", result) + } + }) + + t.Run("format-aware: json format with handler", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "key": map[string]any{"type": "string"}, + }, + } + formatter := jsonFormatter{} + handler, err := formatter.Handler(schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + streamingHandler := handler.(StreamingFormatHandler) + + mr := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart(`{"key":"value"}`), + }, + }, + formatHandler: streamingHandler, + } + + var result struct { + Key string `json:"key"` + } + err = mr.Output(&result) + if err != nil { + t.Fatalf("Output() error = %v", err) + } + if result.Key != "value" { + t.Errorf("Output().Key = %q, want %q", result.Key, "value") + } + }) +} diff --git a/go/core/schemas.config b/go/core/schemas.config index fba66996b5..30b3998b73 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -261,6 +261,7 @@ ModelResponse.request type *ModelRequest ModelResponse.usage type *GenerationUsage ModelResponse.raw omit ModelResponse.operation omit +ModelResponse field formatHandler StreamingFormatHandler # ModelResponseChunk ModelResponseChunk pkg ai @@ -269,6 +270,7 @@ ModelResponseChunk.content type []*Part ModelResponseChunk.custom type any ModelResponseChunk.index type int ModelResponseChunk.role type Role +ModelResponseChunk field formatHandler StreamingFormatHandler GenerationCommonConfig doc GenerationCommonConfig holds configuration for generation. diff --git a/go/internal/base/extract.go b/go/internal/base/extract.go new file mode 100644 index 0000000000..6e6917bbf0 --- /dev/null +++ b/go/internal/base/extract.go @@ -0,0 +1,253 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "encoding/json" + "strings" +) + +// ExtractJSON extracts JSON from string with lenient parsing rules. +// It handles both complete and partial JSON structures. +func ExtractJSON(text string) (any, error) { + var openingChar, closingChar rune + var startPos int = -1 + nestingCount := 0 + inString := false + escapeNext := false + + for i, char := range text { + // Replace non-breaking space with regular space + if char == '\u00A0' { + char = ' ' + } + + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + if openingChar == 0 && (char == '{' || char == '[') { + // Look for opening character + openingChar = char + if char == '{' { + closingChar = '}' + } else { + closingChar = ']' + } + startPos = i + nestingCount++ + } else if char == openingChar { + // Increment nesting for matching opening character + nestingCount++ + } else if char == closingChar { + // Decrement nesting for matching closing character + nestingCount-- + if nestingCount == 0 { + // Reached end of target element + jsonStr := text[startPos : i+1] + var result any + err := json.Unmarshal([]byte(jsonStr), &result) + if err != nil { + return nil, err + } + return result, nil + } + } + } + + if startPos != -1 && nestingCount > 0 { + // If an incomplete JSON structure is detected, try to parse it partially + jsonStr := text[startPos:] + result, err := ParsePartialJSON(jsonStr) + if err != nil { + return nil, err + } + return result, nil + } + + return nil, nil +} + +// ParsePartialJSON attempts to parse incomplete JSON by completing it. +func ParsePartialJSON(jsonStr string) (any, error) { + // Try to parse as-is first + var result any + err := json.Unmarshal([]byte(jsonStr), &result) + if err == nil { + return result, nil + } + + // If it fails, try to complete the JSON structure + completed := CompleteJSON(jsonStr) + err = json.Unmarshal([]byte(completed), &result) + return result, err +} + +// CompleteJSON attempts to complete an incomplete JSON string. +func CompleteJSON(jsonStr string) string { + jsonStr = strings.TrimSpace(jsonStr) + if jsonStr == "" { + return "{}" + } + + // Count unclosed structures + var openBraces, openBrackets int + inString := false + escapeNext := false + + for _, char := range jsonStr { + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + switch char { + case '{': + openBraces++ + case '}': + openBraces-- + case '[': + openBrackets++ + case ']': + openBrackets-- + } + } + + // Close any unclosed string + if inString { + jsonStr += "\"" + } + + // Remove trailing comma if present (before closing) + jsonStr = strings.TrimRight(jsonStr, " \t\n\r") + jsonStr = strings.TrimSuffix(jsonStr, ",") + + // Close open structures + for i := 0; i < openBrackets; i++ { + jsonStr += "]" + } + for i := 0; i < openBraces; i++ { + jsonStr += "}" + } + + return jsonStr +} + +// ExtractItemsResult contains the result of extracting items from an array. +type ExtractItemsResult struct { + Items []any + Cursor int +} + +// ExtractItems extracts complete objects from the first array found in the text. +// Processes text from the cursor position and returns both complete items +// and the new cursor position. +func ExtractItems(text string, cursor int) ExtractItemsResult { + items := []any{} + currentCursor := cursor + + // Find the first array start if we haven't already processed any text + if cursor == 0 { + arrayStart := strings.Index(text, "[") + if arrayStart == -1 { + return ExtractItemsResult{Items: items, Cursor: len(text)} + } + currentCursor = arrayStart + 1 + } + + objectStart := -1 + braceCount := 0 + inString := false + escapeNext := false + + // Process the text from the cursor position + for i := currentCursor; i < len(text); i++ { + char := rune(text[i]) + + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + if char == '{' { + if braceCount == 0 { + objectStart = i + } + braceCount++ + } else if char == '}' { + braceCount-- + if braceCount == 0 && objectStart != -1 { + var obj any + err := json.Unmarshal([]byte(text[objectStart:i+1]), &obj) + if err == nil { + items = append(items, obj) + currentCursor = i + 1 + objectStart = -1 + } + } + } else if char == ']' && braceCount == 0 { + // End of array + break + } + } + + return ExtractItemsResult{ + Items: items, + Cursor: currentCursor, + } +} diff --git a/go/internal/base/extract_test.go b/go/internal/base/extract_test.go new file mode 100644 index 0000000000..4142f82de3 --- /dev/null +++ b/go/internal/base/extract_test.go @@ -0,0 +1,194 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestExtractJSON(t *testing.T) { + tests := []struct { + name string + input string + want any + wantErr bool + }{ + { + name: "complete object", + input: `{"name": "John", "age": 30}`, + want: map[string]any{"name": "John", "age": float64(30)}, + }, + { + name: "complete array", + input: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + name: "object with prefix text", + input: `Some text before {"name": "Jane"}`, + want: map[string]any{"name": "Jane"}, + }, + { + name: "incomplete object", + input: `{"name": "John", "age": 3`, + want: map[string]any{"name": "John", "age": float64(3)}, + }, + { + name: "incomplete object with partial string", + input: `{"name": "Jo`, + want: map[string]any{"name": "Jo"}, + }, + { + name: "incomplete nested object", + input: `{"person": {"name": "John"`, + want: map[string]any{"person": map[string]any{"name": "John"}}, + }, + { + name: "object with trailing comma", + input: `{"name": "John",`, + want: map[string]any{"name": "John"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ExtractJSON(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ExtractJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("ExtractJSON() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestExtractItems(t *testing.T) { + tests := []struct { + name string + input string + cursor int + wantItems []any + wantCursor int + }{ + { + name: "complete array", + input: `[{"name": "John"}, {"name": "Jane"}]`, + cursor: 0, + wantItems: []any{map[string]any{"name": "John"}, map[string]any{"name": "Jane"}}, + wantCursor: 35, + }, + { + name: "partial array - first item", + input: `[{"name": "John"}`, + cursor: 0, + wantItems: []any{map[string]any{"name": "John"}}, + wantCursor: 17, + }, + { + name: "partial array - incomplete second item", + input: `[{"name": "John"}, {"name": "J`, + cursor: 0, + wantItems: []any{map[string]any{"name": "John"}}, + wantCursor: 17, + }, + { + name: "incremental parsing from cursor", + input: `[{"name": "John"}, {"name": "Jane"}]`, + cursor: 18, + wantItems: []any{map[string]any{"name": "Jane"}}, + wantCursor: 35, + }, + { + name: "array with prefix text", + input: `Some text [{"name": "John"}]`, + cursor: 0, + wantItems: []any{map[string]any{"name": "John"}}, + wantCursor: 27, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractItems(tt.input, tt.cursor) + if diff := cmp.Diff(tt.wantItems, result.Items); diff != "" { + t.Errorf("ExtractItems() items mismatch (-want +got):\n%s", diff) + } + if result.Cursor != tt.wantCursor { + t.Errorf("ExtractItems() cursor = %v, want %v", result.Cursor, tt.wantCursor) + } + }) + } +} + +func TestCompleteJSON(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "unclosed object", + input: `{"name": "John"`, + want: `{"name": "John"}`, + }, + { + name: "unclosed array", + input: `[1, 2, 3`, + want: `[1, 2, 3]`, + }, + { + name: "unclosed string", + input: `{"name": "John`, + want: `{"name": "John"}`, + }, + { + name: "nested unclosed", + input: `{"person": {"name": "John"`, + want: `{"person": {"name": "John"}}`, + }, + { + name: "trailing comma", + input: `{"name": "John",`, + want: `{"name": "John"}`, + }, + { + name: "empty string", + input: "", + want: "{}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CompleteJSON(tt.input) + if got != tt.want { + t.Errorf("CompleteJSON() = %v, want %v", got, tt.want) + } + + // Verify result is valid JSON + var result any + err := json.Unmarshal([]byte(got), &result) + if err != nil { + t.Errorf("CompleteJSON() produced invalid JSON: %v", err) + } + }) + } +} + diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 6714c0dce8..d692b40109 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -410,6 +410,9 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err jsonTag := fmt.Sprintf(`json:"%s,omitempty"`, field) g.pr(fmt.Sprintf(" %s %s `%s`\n", adjustIdentifier(field), typeExpr, jsonTag)) } + for _, f := range tcfg.fields { + g.pr(fmt.Sprintf(" %s %s\n", f.name, f.typeExpr)) + } g.pr("}\n\n") return nil } @@ -580,6 +583,13 @@ type itemConfig struct { pkgPath string typeExpr string docLines []string + fields []extraField +} + +// extraField represents an additional unexported field to add to a struct. +type extraField struct { + name string + typeExpr string } // parseConfigFile parses the config file. @@ -602,6 +612,8 @@ type itemConfig struct { // package path, relative to outdir (last component is package name) // import // path of package to import (for packages only) +// field NAME TYPE +// add an unexported field to the struct (for types only) func parseConfigFile(filename string) (config, error) { c := config{ itemConfigs: map[string]*itemConfig{}, @@ -667,6 +679,11 @@ func parseConfigFile(filename string) (config, error) { return errf("need NAME import PATH") } ic.pkgPath = words[2] + case "field": + if len(words) < 4 { + return errf("need NAME field FIELDNAME TYPE") + } + ic.fields = append(ic.fields, extraField{name: words[2], typeExpr: words[3]}) default: return errf("unknown directive %q", words[1]) }