Skip to content

Commit 1f9c8d3

Browse files
committed
feat: add support for image editing
1 parent 70ba9f1 commit 1f9c8d3

File tree

7 files changed

+361
-124
lines changed

7 files changed

+361
-124
lines changed

api/client/client.go

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"golang.org/x/text/language"
1818
"io"
1919
"mime/multipart"
20+
"net/textproto"
2021
"net/url"
2122
"os"
2223
"path/filepath"
@@ -399,22 +400,115 @@ func (c *Client) SynthesizeSpeech(inputText, outputPath string) error {
399400
// Returns:
400401
// - An error if any part of the request, decoding, or file writing fails.
401402
func (c *Client) GenerateImage(inputText, outputPath string) error {
402-
reqBody := api.Draw{
403+
req := api.Draw{
403404
Model: c.Config.Model,
404405
Prompt: inputText,
405406
}
406407

407-
body, err := json.Marshal(reqBody)
408+
return c.postAndWriteBinaryOutput(
409+
c.getEndpoint(c.Config.ImageGenerationsPath),
410+
req,
411+
outputPath,
412+
"image",
413+
func(respBytes []byte) ([]byte, error) {
414+
var response struct {
415+
Data []struct {
416+
B64 string `json:"b64_json"`
417+
} `json:"data"`
418+
}
419+
if err := json.Unmarshal(respBytes, &response); err != nil {
420+
return nil, fmt.Errorf("failed to decode response: %w", err)
421+
}
422+
if len(response.Data) == 0 {
423+
return nil, fmt.Errorf("no image data returned")
424+
}
425+
decoded, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
426+
if err != nil {
427+
return nil, fmt.Errorf("failed to decode base64 image: %w", err)
428+
}
429+
return decoded, nil
430+
},
431+
)
432+
}
433+
434+
// EditImage edits an input image using a text prompt and writes the modified image to the specified output path.
435+
//
436+
// This method sends a multipart/form-data POST request to the image editing endpoint
437+
// (typically OpenAI's /v1/images/edits). The request includes:
438+
// - The image file to edit.
439+
// - A text prompt describing how the image should be modified.
440+
// - The model ID (e.g., gpt-image-1).
441+
//
442+
// The response is expected to contain a base64-encoded image, which is decoded and written to the outputPath.
443+
//
444+
// Parameters:
445+
// - inputText: A text prompt describing the desired modifications to the image.
446+
// - inputPath: The file path to the source image (must be a supported format: PNG, JPEG, or WebP).
447+
// - outputPath: The file path where the edited image will be saved.
448+
//
449+
// Returns:
450+
// - An error if any step of the process fails: reading the file, building the request, sending it,
451+
// decoding the response, or writing the output image.
452+
//
453+
// Example:
454+
//
455+
// err := client.EditImage("Add a rainbow in the sky", "input.png", "output.png")
456+
// if err != nil {
457+
// log.Fatal(err)
458+
// }
459+
func (c *Client) EditImage(inputText, inputPath, outputPath string) error {
460+
endpoint := c.getEndpoint(c.Config.ImageEditsPath)
461+
462+
file, err := c.reader.Open(inputPath)
408463
if err != nil {
409-
return fmt.Errorf("failed to marshal request: %w", err)
464+
return fmt.Errorf("failed to open input image: %w", err)
410465
}
466+
defer file.Close()
411467

412-
endpoint := c.getEndpoint(c.Config.DrawPath)
413-
c.printRequestDebugInfo(endpoint, body, nil)
468+
var buf bytes.Buffer
469+
writer := multipart.NewWriter(&buf)
414470

415-
respBytes, err := c.caller.Post(endpoint, body, false)
471+
mimeType, err := c.getMimeTypeFromFileContent(inputPath)
472+
if err != nil {
473+
return fmt.Errorf("failed to detect MIME type: %w", err)
474+
}
475+
if !strings.HasPrefix(mimeType, "image/") {
476+
return fmt.Errorf("unsupported MIME type: %s", mimeType)
477+
}
478+
479+
header := make(textproto.MIMEHeader)
480+
header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="image"; filename="%s"`, filepath.Base(inputPath)))
481+
header.Set("Content-Type", mimeType)
482+
483+
part, err := writer.CreatePart(header)
484+
if err != nil {
485+
return fmt.Errorf("failed to create image part: %w", err)
486+
}
487+
if _, err := io.Copy(part, file); err != nil {
488+
return fmt.Errorf("failed to copy image data: %w", err)
489+
}
490+
491+
if err := writer.WriteField("prompt", inputText); err != nil {
492+
return fmt.Errorf("failed to add prompt: %w", err)
493+
}
494+
if err := writer.WriteField("model", c.Config.Model); err != nil {
495+
return fmt.Errorf("failed to add model: %w", err)
496+
}
497+
498+
if err := writer.Close(); err != nil {
499+
return fmt.Errorf("failed to close multipart writer: %w", err)
500+
}
501+
502+
c.printRequestDebugInfo(endpoint, buf.Bytes(), map[string]string{
503+
"Content-Type": writer.FormDataContentType(),
504+
})
505+
506+
respBytes, err := c.caller.PostWithHeaders(endpoint, buf.Bytes(), map[string]string{
507+
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
508+
"Content-Type": writer.FormDataContentType(),
509+
})
416510
if err != nil {
417-
return fmt.Errorf("failed to generate image: %w", err)
511+
return fmt.Errorf("failed to edit image: %w", err)
418512
}
419513

420514
// Parse the JSON and extract b64_json

api/client/client_test.go

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package client_test
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/base64"
67
"encoding/json"
@@ -16,6 +17,7 @@ import (
1617
"github.com/kardolus/chatgpt-cli/history"
1718
"github.com/kardolus/chatgpt-cli/internal"
1819
"github.com/kardolus/chatgpt-cli/test"
20+
"io"
1921
"os"
2022
"strings"
2123
"testing"
@@ -1060,7 +1062,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
10601062
})
10611063
it("throws an error when the http call fails", func() {
10621064
mockCaller.EXPECT().
1063-
Post(subject.Config.URL+subject.Config.DrawPath, body, false).
1065+
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
10641066
Return(nil, errors.New(errorText))
10651067

10661068
err := subject.GenerateImage(inputText, outputFile)
@@ -1069,7 +1071,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
10691071
})
10701072
it("throws an error when no image data is returned", func() {
10711073
mockCaller.EXPECT().
1072-
Post(subject.Config.URL+subject.Config.DrawPath, body, false).
1074+
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
10731075
Return([]byte(`{"data":[]}`), nil)
10741076

10751077
err := subject.GenerateImage(inputText, outputFile)
@@ -1078,7 +1080,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
10781080
})
10791081
it("throws an error when base64 is invalid", func() {
10801082
mockCaller.EXPECT().
1081-
Post(subject.Config.URL+subject.Config.DrawPath, body, false).
1083+
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
10821084
Return([]byte(`{"data":[{"b64_json":"!!notbase64!!"}]}`), nil)
10831085

10841086
err := subject.GenerateImage(inputText, outputFile)
@@ -1089,7 +1091,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
10891091
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
10901092

10911093
mockCaller.EXPECT().
1092-
Post(subject.Config.URL+subject.Config.DrawPath, body, false).
1094+
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
10931095
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
10941096

10951097
mockWriter.EXPECT().Create(outputFile).Return(nil, errors.New(errorText))
@@ -1105,7 +1107,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
11051107
defer file.Close()
11061108

11071109
mockCaller.EXPECT().
1108-
Post(subject.Config.URL+subject.Config.DrawPath, body, false).
1110+
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
11091111
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
11101112

11111113
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
@@ -1122,7 +1124,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
11221124
defer file.Close()
11231125

11241126
mockCaller.EXPECT().
1125-
Post(subject.Config.URL+subject.Config.DrawPath, body, false).
1127+
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
11261128
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
11271129

11281130
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
@@ -1132,6 +1134,100 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
11321134
Expect(err).NotTo(HaveOccurred())
11331135
})
11341136
})
1137+
when("EditImage()", func() {
1138+
const (
1139+
inputText = "give the dog sunglasses"
1140+
inputFile = "dog.png"
1141+
outputFile = "dog_cool.png"
1142+
errorText = "mock error occurred"
1143+
)
1144+
1145+
var (
1146+
subject *client.Client
1147+
validB64 string
1148+
imageBytes = []byte("image-bytes")
1149+
respBytes []byte
1150+
)
1151+
1152+
it.Before(func() {
1153+
subject = factory.buildClientWithoutConfig()
1154+
validB64 = base64.StdEncoding.EncodeToString(imageBytes)
1155+
respBytes = []byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, validB64))
1156+
})
1157+
1158+
it("returns error when input file can't be opened", func() {
1159+
mockReader.EXPECT().Open(inputFile).Return(nil, errors.New(errorText))
1160+
1161+
err := subject.EditImage(inputText, inputFile, outputFile)
1162+
Expect(err).To(HaveOccurred())
1163+
Expect(err.Error()).To(ContainSubstring("failed to open input image"))
1164+
})
1165+
it("returns error on invalid mime type", func() {
1166+
file := openDummy()
1167+
mockReader.EXPECT().Open(inputFile).Return(file, nil).Times(2)
1168+
mockReader.EXPECT().ReadBufferFromFile(file).Return([]byte("not an image"), nil)
1169+
1170+
err := subject.EditImage(inputText, inputFile, outputFile)
1171+
Expect(err).To(HaveOccurred())
1172+
Expect(err.Error()).To(ContainSubstring("unsupported MIME type"))
1173+
})
1174+
it("returns error when HTTP call fails", func() {
1175+
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
1176+
return openDummy(), nil
1177+
}).Times(2)
1178+
1179+
mockReader.EXPECT().
1180+
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
1181+
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
1182+
1183+
mockCaller.EXPECT().
1184+
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
1185+
Return(nil, errors.New(errorText))
1186+
1187+
err := subject.EditImage(inputText, inputFile, outputFile)
1188+
Expect(err).To(HaveOccurred())
1189+
Expect(err.Error()).To(ContainSubstring("failed to edit image"))
1190+
})
1191+
it("returns error when base64 is invalid", func() {
1192+
invalidResp := []byte(`{"data":[{"b64_json":"!notbase64"}]}`)
1193+
1194+
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
1195+
return openDummy(), nil
1196+
}).Times(2)
1197+
1198+
mockReader.EXPECT().
1199+
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
1200+
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
1201+
1202+
mockCaller.EXPECT().
1203+
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
1204+
Return(invalidResp, nil)
1205+
1206+
err := subject.EditImage(inputText, inputFile, outputFile)
1207+
Expect(err).To(HaveOccurred())
1208+
Expect(err.Error()).To(ContainSubstring("failed to decode base64 image"))
1209+
})
1210+
it("writes image when all steps succeed", func() {
1211+
file := openDummy()
1212+
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
1213+
return openDummy(), nil
1214+
}).Times(2)
1215+
1216+
mockReader.EXPECT().
1217+
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
1218+
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
1219+
1220+
mockCaller.EXPECT().
1221+
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
1222+
Return(respBytes, nil)
1223+
1224+
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
1225+
mockWriter.EXPECT().Write(file, imageBytes).Return(nil)
1226+
1227+
err := subject.EditImage(inputText, inputFile, outputFile)
1228+
Expect(err).NotTo(HaveOccurred())
1229+
})
1230+
})
11351231
when("Transcribe()", func() {
11361232
const audioPath = "path/to/audio.wav"
11371233
const transcribedText = "Hello, this is a test."
@@ -1535,6 +1631,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
15351631
})
15361632
}
15371633

1634+
func openDummy() *os.File {
1635+
// Use os.Pipe to get an *os.File without needing a real disk file.
1636+
r, w, _ := os.Pipe()
1637+
go func() {
1638+
_, _ = io.Copy(w, bytes.NewBuffer([]byte("\x89PNG\r\n\x1a\n")))
1639+
_ = w.Close()
1640+
}()
1641+
return r
1642+
}
1643+
15381644
func createBody(messages []api.Message, stream bool) ([]byte, error) {
15391645
req := api.CompletionsRequest{
15401646
Model: config.Model,

0 commit comments

Comments
 (0)