Skip to content

Add rate limit to Port client #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions internal/cli/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/go-resty/resty/v2"
"github.com/port-labs/terraform-provider-port-labs/v2/internal/ratelimit"
)

type Option func(*PortClient)
Expand All @@ -19,9 +20,14 @@ type PortClient struct {
featureFlags []string
JSONEscapeHTML bool
BlueprintPropertyTypeChangeProtection bool

// Rate limiting
rateLimitManager *ratelimit.Manager
}

func New(baseURL string, opts ...Option) (*PortClient, error) {
rateLimitManager := ratelimit.NewManager()

c := &PortClient{
Client: resty.New().
SetBaseURL(baseURL).
Expand All @@ -39,13 +45,34 @@ func New(baseURL string, opts ...Option) (*PortClient, error) {
err = json.Unmarshal(r.Body(), &b)
return err != nil || b["ok"] != true
}),
rateLimitManager: rateLimitManager,
}

c.Client.
OnBeforeRequest(rateLimitManager.RequestMiddleware).
OnAfterResponse(rateLimitManager.ResponseMiddleware)

for _, opt := range opts {
opt(c)
}
return c, nil
}

// GetRateLimitInfo returns the current rate limit information
func (c *PortClient) GetRateLimitInfo() *ratelimit.RateLimitInfo {
return c.rateLimitManager.GetInfo()
}

// SetRateLimitEnabled enables or disables rate limiting
func (c *PortClient) SetRateLimitEnabled(enabled bool) {
c.rateLimitManager.SetEnabled(enabled)
}

// SetRateLimitThreshold sets the threshold for when to start throttling
func (c *PortClient) SetRateLimitThreshold(threshold float64) {
c.rateLimitManager.SetThreshold(threshold)
}

// FeatureFlags Fetches the feature flags from the Organization API. It caches the feature flags locally to reduce call
// count.
func (c *PortClient) FeatureFlags(ctx context.Context) ([]string, error) {
Expand Down Expand Up @@ -110,3 +137,20 @@ func WithToken(token string) Option {
pc.Client.SetAuthToken(token)
}
}

// WithRateLimitDisabled disables rate limiting
func WithRateLimitDisabled() Option {
return func(pc *PortClient) {
pc.rateLimitManager.SetEnabled(false)
}
}

// WithRateLimitThreshold sets the threshold for when to start throttling
// threshold should be between 0.0 and 1.0 (e.g., 0.1 means start throttling when 10% of requests remain)
func WithRateLimitThreshold(threshold float64) Option {
return func(pc *PortClient) {
if threshold >= 0.0 && threshold <= 1.0 {
pc.rateLimitManager.SetThreshold(threshold)
}
}
}
Comment on lines +148 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too obscure. I think that you should return an error or panic if it is out of bounds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to clamp the threshold between 0 and 1.

156 changes: 156 additions & 0 deletions internal/cli/client_ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package cli

import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func skipIfDisabled(t *testing.T) {
if os.Getenv("PORT_RATE_LIMIT_DISABLED") != "" {
t.Skip("Skipping rate limit test because PORT_RATE_LIMIT_DISABLED is set")
}
}

func TestClientRateLimitIntegration(t *testing.T) {
skipIfDisabled(t)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-ratelimit-limit", "1000")
w.Header().Set("x-ratelimit-period", "300")
w.Header().Set("x-ratelimit-remaining", "50")
w.Header().Set("x-ratelimit-reset", "120")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok": true}`))
}))
defer server.Close()

client, err := New(server.URL)
require.NoError(t, err)

resp, err := client.Client.R().Get("/test")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())

rateLimitInfo := client.GetRateLimitInfo()
require.NotNil(t, rateLimitInfo)
assert.Equal(t, 1000, rateLimitInfo.Limit)
assert.Equal(t, 300, rateLimitInfo.Period)
assert.Equal(t, 50, rateLimitInfo.Remaining)
assert.Equal(t, 120, rateLimitInfo.Reset)
}

func TestClientRateLimitNoHeaders(t *testing.T) {
skipIfDisabled(t)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok": true}`))
}))
defer server.Close()

client, err := New(server.URL)
require.NoError(t, err)

resp, err := client.Client.R().Get("/test")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())

rateLimitInfo := client.GetRateLimitInfo()
assert.Nil(t, rateLimitInfo)
}

func TestClientRateLimitDisabled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-ratelimit-limit", "1000")
w.Header().Set("x-ratelimit-remaining", "1")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok": true}`))
}))
defer server.Close()

client, err := New(server.URL, WithRateLimitDisabled())
require.NoError(t, err)

resp, err := client.Client.R().Get("/test")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())

rateLimitInfo := client.GetRateLimitInfo()
assert.Nil(t, rateLimitInfo)
}

func TestClientRateLimitThrottling(t *testing.T) {
skipIfDisabled(t)

requestCount := 0

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.Header().Set("x-ratelimit-limit", "100")
w.Header().Set("x-ratelimit-remaining", "5")
w.Header().Set("x-ratelimit-reset", "2")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok": true}`))
}))
defer server.Close()

client, err := New(server.URL, WithRateLimitThreshold(0.1))
require.NoError(t, err)

start := time.Now()
resp, err := client.Client.R().Get("/test1")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())

resp, err = client.Client.R().Get("/test2")
elapsed := time.Since(start)

require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())
assert.Equal(t, 2, requestCount)

assert.Greater(t, elapsed, 10*time.Millisecond, "Request should have been throttled")
}

func TestClientRateLimitSettings(t *testing.T) {
client, err := New("http://example.com")
require.NoError(t, err)

// somewhat dummy tests since we really can't test, so we check that they don't panic
client.SetRateLimitEnabled(false)
client.SetRateLimitEnabled(true)
client.SetRateLimitThreshold(0.25)
}

func TestClientRateLimitDisabledViaEnv(t *testing.T) {
t.Setenv("PORT_RATE_LIMIT_DISABLED", "1")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-ratelimit-limit", "1000")
w.Header().Set("x-ratelimit-remaining", "1")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok": true}`))
}))
defer server.Close()

client, err := New(server.URL)
require.NoError(t, err)

start := time.Now()
resp, err := client.Client.R().Get("/test")
elapsed := time.Since(start)

require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())

assert.Less(t, elapsed, 100*time.Millisecond, "Request should not be throttled when rate limiting is disabled")

rateLimitInfo := client.GetRateLimitInfo()
assert.Nil(t, rateLimitInfo)
}
74 changes: 37 additions & 37 deletions internal/cli/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,28 @@ type (
}

BlueprintProperty struct {
Type string `json:"type,omitempty"`
Title *string `json:"title,omitempty"`
Identifier string `json:"identifier,omitempty"`
Items map[string]any `json:"items,omitempty"`
Default any `json:"default,omitempty"`
Icon *string `json:"icon,omitempty"`
Format *string `json:"format,omitempty"`
MaxLength *int `json:"maxLength,omitempty"`
MinLength *int `json:"minLength,omitempty"`
MaxItems *int `json:"maxItems,omitempty"`
MinItems *int `json:"minItems,omitempty"`
Maximum *float64 `json:"maximum,omitempty"`
Minimum *float64 `json:"minimum,omitempty"`
Description *string `json:"description,omitempty"`
Blueprint *string `json:"blueprint,omitempty"`
Pattern *string `json:"pattern,omitempty"`
Enum []any `json:"enum,omitempty"`
Spec *string `json:"spec,omitempty"`
SpecAuthentication *SpecAuthentication `json:"specAuthentication,omitempty"`
EnumColors map[string]string `json:"enumColors,omitempty"`
Type string `json:"type,omitempty"`
Title *string `json:"title,omitempty"`
Identifier string `json:"identifier,omitempty"`
Items map[string]any `json:"items,omitempty"`
Default any `json:"default,omitempty"`
Icon *string `json:"icon,omitempty"`
Format *string `json:"format,omitempty"`
MaxLength *int `json:"maxLength,omitempty"`
MinLength *int `json:"minLength,omitempty"`
MaxItems *int `json:"maxItems,omitempty"`
MinItems *int `json:"minItems,omitempty"`
Maximum *float64 `json:"maximum,omitempty"`
Minimum *float64 `json:"minimum,omitempty"`
Description *string `json:"description,omitempty"`
Blueprint *string `json:"blueprint,omitempty"`
Pattern *string `json:"pattern,omitempty"`
Enum []any `json:"enum,omitempty"`
Spec *string `json:"spec,omitempty"`
SpecAuthentication *SpecAuthentication `json:"specAuthentication,omitempty"`
EnumColors map[string]string `json:"enumColors,omitempty"`
// UnknownFields captures any dynamic fields not explicitly defined above
UnknownFields map[string]any `json:"-"`
UnknownFields map[string]any `json:"-"`
}

EntitiesSortModel struct {
Expand Down Expand Up @@ -502,87 +502,87 @@ type (
func getKnownFields(bp *BlueprintProperty) map[string]bool {
knownFields := make(map[string]bool)
t := reflect.TypeOf(*bp)

for i := 0; i < t.NumField(); i++ {
field := t.Field(i)

// Get the JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue // Skip fields without JSON tags or with "-"
}

// Handle "fieldname,omitempty" format
fieldName, _, _ := strings.Cut(jsonTag, ",")
if fieldName != "" {
knownFields[fieldName] = true
}
}

return knownFields
}

// Custom UnmarshalJSON for BlueprintProperty to capture dynamic fields
func (bp *BlueprintProperty) UnmarshalJSON(data []byte) error {
// Define an alias to avoid infinite recursion
type Alias BlueprintProperty

// First, unmarshal into the alias to populate known fields
aux := &struct {
*Alias
}{
Alias: (*Alias)(bp),
}

if err := json.Unmarshal(data, aux); err != nil {
return err
}

// Now unmarshal into a map to capture all fields
var all map[string]any
if err := json.Unmarshal(data, &all); err != nil {
return err
}

// Initialize UnknownFields map
bp.UnknownFields = make(map[string]any)

// Use reflection to get known fields instead of hardcoding
knownFields := getKnownFields(bp)

// Add any unknown fields to UnknownFields
for key, value := range all {
if !knownFields[key] {
bp.UnknownFields[key] = value
}
}

return nil
}

// Custom MarshalJSON for BlueprintProperty to include dynamic fields
func (bp BlueprintProperty) MarshalJSON() ([]byte, error) {
// Define an alias to avoid infinite recursion
type Alias BlueprintProperty

// Marshal the known fields first
aux := Alias(bp)
aux.UnknownFields = nil // Don't marshal this field directly

data, err := json.Marshal(aux)
if err != nil {
return nil, err
}

var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}

for key, value := range bp.UnknownFields {
result[key] = value
}

return json.Marshal(result)
}

Expand Down
Loading
Loading