Skip to content

Commit 1c073b1

Browse files
committed
Make the role configurable
1 parent 85b85d8 commit 1c073b1

File tree

7 files changed

+91
-20
lines changed

7 files changed

+91
-20
lines changed

README.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,18 @@ values, the `config.yaml` file, and environment variables, in that respective or
162162

163163
Configuration variables:
164164

165-
| Variable | Description | Default |
166-
|--------------------|-----------------------------------------------------------------------------------|--------------------------|
167-
| `name` | The prefix for environment variable overrides. | 'openai' |
168-
| `api_key` | Your OpenAI API key. | (none for security) |
169-
| `model` | The GPT model used by the application. | 'gpt-3.5-turbo' |
170-
| `max_tokens` | The maximum number of tokens that can be used in a single API call. | 4096 |
171-
| `thread` | The name of the current chat thread. Each unique thread name has its own context. | 'default' |
172-
| `omit_history` | If true, the chat history will not be used to provide context for the GPT model. | false |
173-
| `url` | The base URL for the OpenAI API. | 'https://api.openai.com' |
174-
| `completions_path` | The API endpoint for completions. | '/v1/chat/completions' |
175-
| `models_path` | The API endpoint for accessing model information. | '/v1/models' |
165+
| Variable | Description | Default |
166+
|--------------------|-----------------------------------------------------------------------------------|--------------------------------|
167+
| `name` | The prefix for environment variable overrides. | 'openai' |
168+
| `api_key` | Your OpenAI API key. | (none for security) |
169+
| `model` | The GPT model used by the application. | 'gpt-3.5-turbo' |
170+
| `max_tokens` | The maximum number of tokens that can be used in a single API call. | 4096 |
171+
| `role` | The system role | 'You are a helpful assistant.' |
172+
| `thread` | The name of the current chat thread. Each unique thread name has its own context. | 'default' |
173+
| `omit_history` | If true, the chat history will not be used to provide context for the GPT model. | false |
174+
| `url` | The base URL for the OpenAI API. | 'https://api.openai.com' |
175+
| `completions_path` | The API endpoint for completions. | '/v1/chat/completions' |
176+
| `models_path` | The API endpoint for accessing model information. | '/v1/models' |
176177

177178
The defaults can be overridden by providing your own values in the user configuration file,
178179
named `.chatgpt-cli/config.yaml`, located in your home directory.

client/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
)
1515

1616
const (
17-
AssistantContent = "You are a helpful assistant."
1817
AssistantRole = "assistant"
1918
ErrEmptyResponse = "empty response"
2019
MaxTokenBufferPercentage = 20
@@ -180,10 +179,11 @@ func (c *Client) initHistory() {
180179

181180
if len(c.History) == 0 {
182181
c.History = []types.Message{{
183-
Role: SystemRole,
184-
Content: AssistantContent,
182+
Role: SystemRole,
185183
}}
186184
}
185+
186+
c.History[0].Content = c.Config.Role
187187
}
188188

189189
func (c *Client) addQuery(query string) {

client/client_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ const (
2929
defaultCompletionsPath = "/default/completions"
3030
defaultModelsPath = "/default/models"
3131
defaultThread = "default-thread"
32+
defaultRole = "You are a great default-role"
3233
envApiKey = "api-key"
3334
)
3435

@@ -209,7 +210,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
209210
history := []types.Message{
210211
{
211212
Role: client.SystemRole,
212-
Content: client.AssistantContent,
213+
Content: defaultRole,
213214
},
214215
{
215216
Role: client.UserRole,
@@ -252,7 +253,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
252253
history := []types.Message{
253254
{
254255
Role: client.SystemRole,
255-
Content: client.AssistantContent,
256+
Content: defaultRole,
256257
},
257258
{
258259
Role: client.UserRole,
@@ -352,7 +353,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
352353
history := []types.Message{
353354
{
354355
Role: client.SystemRole,
355-
Content: client.AssistantContent,
356+
Content: defaultRole,
356357
},
357358
{
358359
Role: client.UserRole,
@@ -433,7 +434,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
433434

434435
systemMessage := subject.History[0]
435436
Expect(systemMessage.Role).To(Equal(client.SystemRole))
436-
Expect(systemMessage.Content).To(Equal("You are a helpful assistant."))
437+
Expect(systemMessage.Content).To(Equal(defaultRole))
437438

438439
contextMessage := subject.History[1]
439440
Expect(contextMessage.Role).To(Equal(client.UserRole))
@@ -458,7 +459,7 @@ func createMessages(history []types.Message, query string) []types.Message {
458459
if len(history) == 0 {
459460
messages = append(messages, types.Message{
460461
Role: client.SystemRole,
461-
Content: client.AssistantContent,
462+
Content: defaultRole,
462463
})
463464
} else {
464465
messages = history
@@ -486,6 +487,7 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor
486487
URL: defaultURL,
487488
CompletionsPath: defaultCompletionsPath,
488489
ModelsPath: defaultModelsPath,
490+
Role: defaultRole,
489491
Thread: defaultThread,
490492
}).Times(1)
491493

config/store.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const (
1515
openAIURL = "https://api.openai.com"
1616
openAICompletionsPath = "/v1/chat/completions"
1717
openAIModelsPath = "/v1/models"
18+
openAIRole = "You are a helpful assistant."
1819
openAIThread = "default"
1920
)
2021

@@ -51,6 +52,7 @@ func (f *FileIO) ReadDefaults() types.Config {
5152
return types.Config{
5253
Name: openAIName,
5354
Model: openAIModel,
55+
Role: openAIRole,
5456
MaxTokens: openAIModelMaxTokens,
5557
URL: openAIURL,
5658
CompletionsPath: openAICompletionsPath,

configmanager/configmanager_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
2727
defaultName = "default-name"
2828
defaultURL = "default-url"
2929
defaultModel = "default-model"
30+
defaultRole = "default-role"
3031
defaultApiKey = "default-api-key"
3132
defaultThread = "default-thread"
3233
defaultCompletionsPath = "default-completions-path"
@@ -55,6 +56,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
5556
CompletionsPath: defaultCompletionsPath,
5657
ModelsPath: defaultModelsPath,
5758
OmitHistory: defaultOmitHistory,
59+
Role: defaultRole,
5860
Thread: defaultThread,
5961
}
6062

@@ -87,6 +89,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
8789
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
8890
Expect(subject.Config.Name).To(Equal(defaultName))
8991
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
92+
Expect(subject.Config.Role).To(Equal(defaultRole))
9093
Expect(subject.Config.Thread).To(Equal(defaultThread))
9194
})
9295
it("gives precedence to the user provided model", func() {
@@ -105,6 +108,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
105108
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
106109
Expect(subject.Config.Name).To(Equal(defaultName))
107110
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
111+
Expect(subject.Config.Role).To(Equal(defaultRole))
108112
Expect(subject.Config.Thread).To(Equal(defaultThread))
109113
})
110114
it("gives precedence to the user provided name", func() {
@@ -123,6 +127,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
123127
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
124128
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
125129
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
130+
Expect(subject.Config.Role).To(Equal(defaultRole))
126131
Expect(subject.Config.Thread).To(Equal(defaultThread))
127132
})
128133
it("gives precedence to the user provided max-tokens", func() {
@@ -141,6 +146,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
141146
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
142147
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
143148
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
149+
Expect(subject.Config.Role).To(Equal(defaultRole))
144150
Expect(subject.Config.Thread).To(Equal(defaultThread))
145151

146152
})
@@ -160,6 +166,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
160166
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
161167
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
162168
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
169+
Expect(subject.Config.Role).To(Equal(defaultRole))
163170
Expect(subject.Config.Thread).To(Equal(defaultThread))
164171
})
165172
it("gives precedence to the user provided completions-path", func() {
@@ -178,6 +185,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
178185
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
179186
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
180187
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
188+
Expect(subject.Config.Role).To(Equal(defaultRole))
181189
Expect(subject.Config.Thread).To(Equal(defaultThread))
182190
})
183191
it("gives precedence to the user provided models-path", func() {
@@ -196,6 +204,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
196204
Expect(subject.Config.ModelsPath).To(Equal(modelsPath))
197205
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
198206
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
207+
Expect(subject.Config.Role).To(Equal(defaultRole))
199208
Expect(subject.Config.Thread).To(Equal(defaultThread))
200209
})
201210
it("gives precedence to the user provided api-key", func() {
@@ -214,6 +223,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
214223
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
215224
Expect(subject.Config.APIKey).To(Equal(apiKey))
216225
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
226+
Expect(subject.Config.Role).To(Equal(defaultRole))
217227
Expect(subject.Config.Thread).To(Equal(defaultThread))
218228
})
219229
it("gives precedence to the user provided omit-history", func() {
@@ -231,6 +241,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
231241
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
232242
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
233243
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
244+
Expect(subject.Config.Role).To(Equal(defaultRole))
234245
Expect(subject.Config.Thread).To(Equal(defaultThread))
235246
Expect(subject.Config.OmitHistory).To(Equal(omitHistory))
236247
})
@@ -250,8 +261,28 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
250261
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
251262
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
252263
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
264+
Expect(subject.Config.Role).To(Equal(defaultRole))
253265
Expect(subject.Config.Thread).To(Equal(userThread))
254266
})
267+
it("gives precedence to the user provided role", func() {
268+
userRole := "user-role"
269+
270+
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
271+
mockConfigStore.EXPECT().Read().Return(types.Config{Role: userRole}, nil).Times(1)
272+
273+
subject := configmanager.New(mockConfigStore).WithEnvironment()
274+
275+
Expect(subject.Config.Name).To(Equal(defaultName))
276+
Expect(subject.Config.Model).To(Equal(defaultModel))
277+
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
278+
Expect(subject.Config.URL).To(Equal(defaultURL))
279+
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
280+
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
281+
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
282+
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
283+
Expect(subject.Config.Thread).To(Equal(defaultThread))
284+
Expect(subject.Config.Role).To(Equal(userRole))
285+
})
255286
it("gives precedence to the OMIT_HISTORY environment variable", func() {
256287
var (
257288
environmentValue = true
@@ -272,6 +303,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
272303
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
273304
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
274305
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
306+
Expect(subject.Config.Role).To(Equal(defaultRole))
275307
Expect(subject.Config.Thread).To(Equal(defaultThread))
276308
Expect(subject.Config.OmitHistory).To(Equal(environmentValue))
277309
})
@@ -296,8 +328,33 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
296328
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
297329
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
298330
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
331+
Expect(subject.Config.Role).To(Equal(defaultRole))
299332
Expect(subject.Config.Thread).To(Equal(environmentValue))
300333
})
334+
it("gives precedence to the ROLE environment variable", func() {
335+
var (
336+
environmentValue = "env-role"
337+
configValue = "conf-role"
338+
)
339+
340+
Expect(os.Setenv(envPrefix+"ROLE", environmentValue)).To(Succeed())
341+
342+
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
343+
mockConfigStore.EXPECT().Read().Return(types.Config{Role: configValue}, nil).Times(1)
344+
345+
subject := configmanager.New(mockConfigStore).WithEnvironment()
346+
347+
Expect(subject.Config.Name).To(Equal(defaultName))
348+
Expect(subject.Config.Model).To(Equal(defaultModel))
349+
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
350+
Expect(subject.Config.URL).To(Equal(defaultURL))
351+
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
352+
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
353+
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
354+
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
355+
Expect(subject.Config.Thread).To(Equal(defaultThread))
356+
Expect(subject.Config.Role).To(Equal(environmentValue))
357+
})
301358
it("gives precedence to the API_KEY environment variable", func() {
302359
var (
303360
environmentKey = "environment-api-key"
@@ -319,6 +376,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
319376
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
320377
Expect(subject.Config.APIKey).To(Equal(environmentKey))
321378
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
379+
Expect(subject.Config.Role).To(Equal(defaultRole))
322380
Expect(subject.Config.Thread).To(Equal(defaultThread))
323381
})
324382
it("gives precedence to the MODEL environment variable", func() {
@@ -342,6 +400,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
342400
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
343401
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
344402
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
403+
Expect(subject.Config.Role).To(Equal(defaultRole))
345404
Expect(subject.Config.Thread).To(Equal(defaultThread))
346405
})
347406
it("gives precedence to the MAX_TOKENS environment variable", func() {
@@ -365,6 +424,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
365424
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
366425
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
367426
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
427+
Expect(subject.Config.Role).To(Equal(defaultRole))
368428
Expect(subject.Config.Thread).To(Equal(defaultThread))
369429
})
370430
it("gives precedence to the URL environment variable", func() {
@@ -388,6 +448,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
388448
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
389449
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
390450
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
451+
Expect(subject.Config.Role).To(Equal(defaultRole))
391452
Expect(subject.Config.Thread).To(Equal(defaultThread))
392453
})
393454
it("gives precedence to the COMPLETIONS_PATH environment variable", func() {
@@ -411,6 +472,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
411472
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
412473
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
413474
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
475+
Expect(subject.Config.Role).To(Equal(defaultRole))
414476
Expect(subject.Config.Thread).To(Equal(defaultThread))
415477
})
416478
it("gives precedence to the MODELS_PATH environment variable", func() {
@@ -434,6 +496,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
434496
Expect(subject.Config.ModelsPath).To(Equal(envModelsPath))
435497
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
436498
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
499+
Expect(subject.Config.Role).To(Equal(defaultRole))
437500
Expect(subject.Config.Thread).To(Equal(defaultThread))
438501
})
439502
})
@@ -455,6 +518,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
455518
Expect(result).To(ContainSubstring(defaultModelsPath))
456519
Expect(result).To(ContainSubstring(fmt.Sprintf("%d", defaultMaxTokens)))
457520
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
521+
Expect(subject.Config.Role).To(Equal(defaultRole))
458522
Expect(subject.Config.Thread).To(Equal(defaultThread))
459523
})
460524
})
@@ -499,4 +563,5 @@ func cleanEnv(envPrefix string) {
499563
Expect(os.Unsetenv(envPrefix + "MODELS_PATH")).To(Succeed())
500564
Expect(os.Unsetenv(envPrefix + "OMIT_HISTORY")).To(Succeed())
501565
Expect(os.Unsetenv(envPrefix + "THREAD")).To(Succeed())
566+
Expect(os.Unsetenv(envPrefix + "ROLE")).To(Succeed())
502567
}

integration/contract_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
4141
body := types.CompletionsRequest{
4242
Messages: []types.Message{{
4343
Role: client.SystemRole,
44-
Content: client.AssistantContent,
44+
Content: defaults.Role,
4545
}},
4646
Model: defaults.Model,
4747
Stream: false,

types/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ type Config struct {
55
APIKey string `yaml:"api_key"`
66
Model string `yaml:"model"`
77
MaxTokens int `yaml:"max_tokens"`
8+
Role string `yaml:"role"`
89
Thread string `yaml:"thread"`
910
OmitHistory bool `yaml:"omit_history"`
1011
URL string `yaml:"url"`

0 commit comments

Comments
 (0)