diff --git a/services/connect/bridge.go b/services/connect/bridge.go index 4f83295..40167a1 100644 --- a/services/connect/bridge.go +++ b/services/connect/bridge.go @@ -1,11 +1,13 @@ package connect import ( + "bufio" "context" "crypto/rand" "encoding/hex" "encoding/json" "fmt" + "io" "net/http" "strings" "time" @@ -86,33 +88,27 @@ func (b *Bridge) connect() { resp, err := http.Get(b.url + "/events?client_id=" + b.id) if err != nil || resp.StatusCode != 200 { b.connected = false - fmt.Println("bridge", b.name, "can't connect", err) + fmt.Printf("bridge %s can't connect: %v\n", b.name, err) time.Sleep(time.Second * 10) continue - } else { - b.connected = true } - for { - var event string - var data string - _, err := fmt.Fscanf(resp.Body, "%s %s", &event, &data) - if err != nil { - if err.Error() == "unexpected newline" { - continue - } - fmt.Println("bridge", b.name, err) - b.connected = false - b.reconnectCounter++ - break - } - if event == "data:" { + b.connected = true + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data:") { + data := strings.TrimPrefix(line, "data: ") select { case b.data <- data: default: - } } } + if err = scanner.Err(); err != nil && err != io.EOF { + fmt.Printf("bridge %s error: %v\n", b.name, err) + b.connected = false + b.reconnectCounter++ + } time.Sleep(time.Second * 10) } } diff --git a/services/connect/bridge_test.go b/services/connect/bridge_test.go new file mode 100644 index 0000000..d6576a1 --- /dev/null +++ b/services/connect/bridge_test.go @@ -0,0 +1,97 @@ +package connect + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// TestBridgeForceReconnect tests that the bridge will reconnect if the server is down +func TestBridgeForceReconnect(t *testing.T) { + var ( + mu sync.Mutex + messageCount int + reconnects int + serverDown bool + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + + if serverDown { + http.Error(w, "Server down", http.StatusServiceUnavailable) + return + } + + if r.Method == http.MethodPost && r.URL.Path == "/message" { + clientID := r.URL.Query().Get("client_id") + to := r.URL.Query().Get("to") + ttl := r.URL.Query().Get("ttl") + var payload string + if r.Body != nil { + b, err := io.ReadAll(r.Body) + if err == nil { + payload = string(b) + } + } + fmt.Printf("Received POST request to /message: client_id=%v, to=%v, ttl=%v, payload=%v\n", clientID, to, ttl, payload) + messageCount++ + w.WriteHeader(http.StatusOK) + + } else if r.URL.Path == "/events" { + b, _ := json.Marshal(map[string]string{"from": "test_client", "message": "test_message"}) + w.Write([]byte(fmt.Sprintf("data: %v\n", string(b)))) + + } else { + http.Error(w, "Not found", http.StatusNotFound) + } + + })) + defer server.Close() + + bridge := NewBridge("test", server.URL) + + // Simulate the server going down and coming back up + go func() { + for { + time.Sleep(10 * time.Second) // Simulate the server being up for 10 seconds + mu.Lock() + serverDown = true // Simulate the server going down + mu.Unlock() + time.Sleep(time.Second) // Give the bridge time to reconnect + mu.Lock() + serverDown = false // Simulate the server coming back up + bridge.connected = false // Force the bridge to reconnect + bridge.reconnectCounter++ // Increment the reconnect counter + reconnects++ + mu.Unlock() + } + }() + + // Run the test for 1 minute + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + metrics := bridge.GetMetrics(context.Background()) + if metrics.Reconnects != reconnects { + t.Errorf("Expected %d reconnects, but got %v", reconnects, metrics.Reconnects) + } + return + + case <-ticker.C: + bridge.GetMetrics(context.Background()) + } + } +}