Skip to content

Commit 72efa42

Browse files
committed
Optimized schemaless more to use cache further
1 parent 131962f commit 72efa42

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

shuffle.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"net/url"
1414
"net/http"
1515
"io/ioutil"
16-
"math/rand"
16+
//"math/rand"
1717
"crypto/tls"
1818
"crypto/md5"
1919
"encoding/hex"
@@ -154,6 +154,19 @@ func AddShuffleFile(name, namespace string, data []byte, shuffleConfig ShuffleCo
154154
fileUrl += "&execution_id=" + shuffleConfig.ExecutionId
155155
}
156156

157+
// Check if the file has already been uploaded based on shuffleConfig.OrgId+namespace+data. No point in overwriting with the same data.
158+
hasher := md5.New()
159+
ctx := context.Background()
160+
hasher.Write([]byte(fmt.Sprintf("%s%s%s%s", shuffleConfig.OrgId, name, namespace, string(data))))
161+
cacheKey := hex.EncodeToString(hasher.Sum(nil))
162+
cache, err := GetCache(ctx, cacheKey)
163+
if err == nil {
164+
cacheData := []byte(cache.([]uint8))
165+
if len(cacheData) > 0 {
166+
return nil
167+
}
168+
}
169+
157170
fileDataJson, err := json.Marshal(fileData)
158171
if err != nil {
159172
return err
@@ -275,6 +288,12 @@ func AddShuffleFile(name, namespace string, data []byte, shuffleConfig ShuffleCo
275288
return err
276289
}
277290

291+
// Update with basically nothing, as the point isn't to get the file itself
292+
err = SetCache(ctx, cacheKey, []byte("1"), 10)
293+
if err != nil {
294+
log.Printf("[ERROR] Schemaless (8): Error setting cache for file %#v from Shuffle backend: %s", name, err)
295+
}
296+
278297
return nil
279298
}
280299

@@ -294,8 +313,6 @@ func GetShuffleFileById(id string, shuffleConfig ShuffleConfig) ([]byte, error)
294313
cacheKey := hex.EncodeToString(hasher.Sum(nil))
295314

296315
// The file will be grabbed a ton, hence the cache actually speeding things up and reducing requests
297-
sleepTime := time.Duration(25 + rand.Intn(1000-25)) * time.Millisecond
298-
time.Sleep(sleepTime)
299316

300317
cache, err := GetCache(ctx, cacheKey)
301318
if err == nil {
@@ -352,7 +369,6 @@ func FindShuffleFile(name, category string, shuffleConfig ShuffleConfig) ([]byte
352369
}
353370

354371

355-
log.Printf("[INFO] Schemaless: Finding file %#v in category %#v from Shuffle backend", name, category)
356372

357373
// 1. Get the category
358374
// 2. Find the file in the category output
@@ -365,17 +381,16 @@ func FindShuffleFile(name, category string, shuffleConfig ShuffleConfig) ([]byte
365381
hasher.Write([]byte(categoryUrl+shuffleConfig.Authorization+shuffleConfig.OrgId+shuffleConfig.ExecutionId))
366382
cacheKey := hex.EncodeToString(hasher.Sum(nil))
367383

368-
sleepTime := time.Duration(25 + rand.Intn(1000-25)) * time.Millisecond
369-
time.Sleep(sleepTime)
370-
371384
// Get the cache
372385
ctx := context.Background()
373386
var body []byte
374387
cache, err := GetCache(ctx, cacheKey)
375388
if err == nil {
389+
//log.Printf("[INFO] Schemaless: FOUND file %#v in category %#v from cache", name, category)
376390
body = []byte(cache.([]uint8))
377391
//return cacheData, nil
378392
} else {
393+
log.Printf("[INFO] Schemaless: Finding file %#v in category %#v from Shuffle backend", name, category)
379394
if len(shuffleConfig.ExecutionId) > 0 {
380395
categoryUrl += "&execution_id=" + shuffleConfig.ExecutionId
381396
}

translate.go

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ var maxInputSize = 4000
2828

2929
func SaveQuery(inputStandard, gptTranslated string, shuffleConfig ShuffleConfig) error {
3030
if len(shuffleConfig.URL) > 0 {
31+
//return nil
3132
return AddShuffleFile(inputStandard, "translation_ai_queries", []byte(gptTranslated), shuffleConfig)
3233
}
3334

@@ -58,7 +59,6 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC
5859

5960
systemMessage := fmt.Sprintf("Ensure the output is valid JSON, and does NOT add more keys to the standard. Make sure each key in the standard has a value from the user input. If values are nested, ALWAYS add the nested value in jq format such as 'secret.version.value'. Example: If the standard is ```{\"id\": \"The id of the ticket\", \"title\": \"The ticket title\"}```, and the user input is ```{\"key\": \"12345\", \"fields:\": {\"summary\": \"The title of the ticket\"}}```, the output should be ```{\"id\": \"key\", \"title\": \"fields.summary\"}```")
6061

61-
log.Printf("[DEBUG] Schemaless: Running GPT with system message: %s", systemMessage)
6262

6363
userQuery := fmt.Sprintf("Translate the given user input JSON structure to a standard format. Use the values from the standard to guide you what to look for. The standard format should follow the pattern:\n\n```json\n%s\n```\n\nUser Input:\n```json\n%s\n```\n\nGenerate the standard output structure without providing the expected output.", standardFormat, inputDataFormat)
6464

@@ -70,6 +70,19 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC
7070
return standardFormat, errors.New(fmt.Sprintf("Input data too long. Max is %d. Current is %d", maxInputSize, len(inputDataFormat)))
7171
}
7272

73+
// Make md5 of the query, and put it in cache to check
74+
ctx := context.Background()
75+
md5Query := fmt.Sprintf("%x", md5.Sum([]byte(shuffleConfig.OrgId+systemMessage+userQuery)))
76+
77+
cacheKey := fmt.Sprintf("translationquery-%s", md5Query)
78+
cache, err := GetCache(ctx, cacheKey)
79+
if err == nil {
80+
contentOutput := string([]byte(cache.([]uint8)))
81+
return contentOutput, nil
82+
}
83+
84+
log.Printf("[DEBUG] Schemaless: Running GPT with system message: %s", systemMessage)
85+
7386
SaveQuery(keyTokenFile, userQuery, shuffleConfig)
7487

7588
openaiClient := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
@@ -111,6 +124,12 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC
111124
break
112125
}
113126

127+
err = SetCache(ctx, cacheKey, []byte(contentOutput), 30)
128+
if err != nil {
129+
log.Printf("[ERROR] Schemaless: Error setting cache for key %s: %v", cacheKey, err)
130+
return contentOutput, err
131+
}
132+
114133
return contentOutput, nil
115134
}
116135

@@ -361,7 +380,8 @@ func SaveTranslation(inputStandard, gptTranslated string, shuffleConfig ShuffleC
361380
gptTranslated = FixTranslationStructure(gptTranslated)
362381

363382
if len(shuffleConfig.URL) > 0 {
364-
go AddShuffleFile(inputStandard, "translation_output", []byte(gptTranslated), shuffleConfig)
383+
// Used to be a goroutine
384+
return AddShuffleFile(inputStandard, "translation_output", []byte(gptTranslated), shuffleConfig)
365385
return nil
366386
}
367387

@@ -387,6 +407,8 @@ func SaveTranslation(inputStandard, gptTranslated string, shuffleConfig ShuffleC
387407

388408
func SaveParsedInput(inputStandard string, gptTranslated []byte, shuffleConfig ShuffleConfig) error {
389409
if len(shuffleConfig.URL) > 0 {
410+
// FIXME: Should we upload everything? I think not
411+
return nil
390412
return AddShuffleFile(inputStandard, "translation_input", gptTranslated, shuffleConfig)
391413
}
392414

@@ -644,12 +666,33 @@ func handleSubStandard(ctx context.Context, subStandard string, returnJson strin
644666

645667
// For each item in the list, translate it to the substandard
646668
// Maybe do this with recursive Translate() calls?
669+
670+
skipAfterCount := 50
647671
var wg sync.WaitGroup
648672
var mu sync.Mutex // Mutex to safely access parsedOutput slice
673+
649674
parsedOutput := [][]byte{}
650675
for cnt, listItem := range listJson {
651-
wg.Add(1) // Increment the wait group counter for each goroutine
676+
// No goroutine on the first ones as we want to make sure caching is done properly
677+
if cnt == 0 {
678+
marshalledBody, err := json.Marshal(listItem)
679+
if err != nil {
680+
log.Printf("[ERROR] Schemaless: Error in marshalling of list item: %v", err)
681+
continue
682+
}
683+
684+
schemalessOutput, err := Translate(ctx, subStandard, marshalledBody, authConfig, "skip_substandard")
685+
if err != nil {
686+
log.Printf("[ERROR] Schemaless: Error in schemaless.Translate for sub list item: %v", err)
687+
continue
688+
}
689+
690+
parsedOutput = append(parsedOutput, schemalessOutput)
691+
continue
692+
}
693+
652694

695+
wg.Add(1) // Increment the wait group counter for each goroutine
653696
go func(cnt int, listItem interface{}) {
654697
defer wg.Done() // Decrement the wait group counter when the goroutine completes
655698

@@ -671,8 +714,8 @@ func handleSubStandard(ctx context.Context, subStandard string, returnJson strin
671714
parsedOutput = append(parsedOutput, schemalessOutput)
672715
}(cnt, listItem)
673716

674-
if cnt > 50 {
675-
log.Printf("[WARNING] Schemaless: Breaking after 10 items in the list")
717+
if cnt > skipAfterCount {
718+
log.Printf("[WARNING] Schemaless: Breaking after %d items in the list", skipAfterCount)
676719
break
677720
}
678721
}
@@ -830,9 +873,9 @@ func Translate(ctx context.Context, inputStandard string, inputValue []byte, inp
830873
if err != nil {
831874
log.Printf("[ERROR] Error in SaveTranslation (3): %v", err)
832875
return []byte{}, err
833-
}
876+
}
834877

835-
log.Printf("[DEBUG] Saved translation to file. Should now run OpenAI and set cache!")
878+
//log.Printf("[DEBUG] Saved GPT translation to file. Should now run OpenAI and set cache!")
836879
inputStructure = []byte(gptTranslated)
837880
}
838881

0 commit comments

Comments
 (0)