|
1 | 1 | package openai |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
4 | 5 | "context" |
5 | 6 | "encoding/json" |
6 | 7 | "errors" |
@@ -86,12 +87,34 @@ type ChatMessagePartType string |
86 | 87 | const ( |
87 | 88 | ChatMessagePartTypeText ChatMessagePartType = "text" |
88 | 89 | ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" |
| 90 | + ChatMessagePartTypeAudio ChatMessagePartType = "input_audio" |
| 91 | + ChatMessagePartTypeVideo ChatMessagePartType = "video" |
| 92 | + ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url" |
89 | 93 | ) |
90 | 94 |
|
| 95 | +/* reference: |
| 96 | + * https://bailian.console.aliyun.com/?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 |
| 97 | + * https://help.aliyun.com/zh/model-studio/qwen-omni#423736d367a7x |
| 98 | + */ |
| 99 | +type InputAudio struct { |
| 100 | + Data string `json:"data"` |
| 101 | + Format string `json:"format"` |
| 102 | +} |
| 103 | + |
| 104 | +type CacheControl struct { |
| 105 | + Type string `json:"type"` // must be "ephemeral" |
| 106 | +} |
| 107 | + |
91 | 108 | type ChatMessagePart struct { |
92 | | - Type ChatMessagePartType `json:"type,omitempty"` |
93 | | - Text string `json:"text,omitempty"` |
94 | | - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` |
| 109 | + Type ChatMessagePartType `json:"type,omitempty"` |
| 110 | + Text string `json:"text,omitempty"` |
| 111 | + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` |
| 112 | + Audio *InputAudio `json:"input_audio,omitempty"` // required when Type is "input_audio" |
| 113 | + VideoURL *ChatMessageImageURL `json:"video_url,omitempty"` // required when Type is "video_url" |
| 114 | + Video []string `json:"video,omitempty"` // required when Type is "video", array of image URLs |
| 115 | + MinPixels int `json:"min_pixels,omitempty"` |
| 116 | + MaxPixels int `json:"max_pixels,omitempty"` |
| 117 | + *CacheControl `json:"cache_control,omitempty"` |
95 | 118 | } |
96 | 119 |
|
97 | 120 | type ChatCompletionMessage struct { |
@@ -333,6 +356,33 @@ type ChatCompletionRequest struct { |
333 | 356 | SafetyIdentifier string `json:"safety_identifier,omitempty"` |
334 | 357 | // Embedded struct for non-OpenAI extensions |
335 | 358 | ChatCompletionRequestExtensions |
| 359 | + // non-OpenAI extensions |
| 360 | + Extensions map[string]interface{} `json:"-"` |
| 361 | +} |
| 362 | + |
| 363 | +type customChatCompletionRequest ChatCompletionRequest |
| 364 | + |
| 365 | +func (r *ChatCompletionRequest) MarshalJSON() ([]byte, error) { |
| 366 | + if len(r.Extensions) == 0 { |
| 367 | + return json.Marshal((*customChatCompletionRequest)(r)) |
| 368 | + } |
| 369 | + buf := bytes.NewBuffer(nil) |
| 370 | + encoder := json.NewEncoder(buf) |
| 371 | + if err := encoder.Encode((*customChatCompletionRequest)(r)); err != nil { |
| 372 | + return nil, err |
| 373 | + } |
| 374 | + // remove the trailing "}\n" |
| 375 | + buf.Truncate(buf.Len() - 2) |
| 376 | + // record the current position |
| 377 | + pos := buf.Len() |
| 378 | + // append extensions |
| 379 | + if err := encoder.Encode(r.Extensions); err != nil { |
| 380 | + return nil, err |
| 381 | + } |
| 382 | + data := buf.Bytes() |
| 383 | + // change the leading '{' of extensions to ',' |
| 384 | + data[pos] = ',' |
| 385 | + return data, nil |
336 | 386 | } |
337 | 387 |
|
338 | 388 | type StreamOptions struct { |
|
0 commit comments