Skip to content

Commit 703ce62

Browse files
authored
Feat: add support for instanceFirewallUpdate (#788)
Signed-off-by: Tarun Chinmai Sekar <[email protected]>
1 parent 2f5a5d2 commit 703ce62

File tree

4 files changed

+123
-14
lines changed

4 files changed

+123
-14
lines changed

instance_firewalls.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,13 @@ import (
88
func (c *Client) ListInstanceFirewalls(ctx context.Context, linodeID int, opts *ListOptions) ([]Firewall, error) {
99
return getPaginatedResults[Firewall](ctx, c, formatAPIPath("linode/instances/%d/firewalls", linodeID), opts)
1010
}
11+
12+
type InstanceFirewallUpdateOptions struct {
13+
FirewallIDs []int `json:"firewall_ids"`
14+
}
15+
16+
// UpdateInstanceFirewalls updates the Cloud Firewalls for a Linode instance
17+
// Followup this call with `ListInstanceFirewalls` to verify the changes if necessary.
18+
func (c *Client) UpdateInstanceFirewalls(ctx context.Context, linodeID int, opts InstanceFirewallUpdateOptions) ([]Firewall, error) {
19+
return putPaginatedResults[Firewall, InstanceFirewallUpdateOptions](ctx, c, formatAPIPath("linode/instances/%d/firewalls", linodeID), nil, opts)
20+
}

request_helpers.go

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ type paginatedResponse[T any] struct {
1717
Data []T `json:"data"`
1818
}
1919

20-
// getPaginatedResults aggregates results from the given
21-
// paginated endpoint using the provided ListOptions.
20+
// handlePaginatedResults aggregates results from the given
21+
// paginated endpoint using the provided ListOptions and HTTP method.
2222
// nolint:funlen
23-
func getPaginatedResults[T any](
23+
func handlePaginatedResults[T any, O any](
2424
ctx context.Context,
2525
client *Client,
2626
endpoint string,
2727
opts *ListOptions,
28+
method string,
29+
options ...O,
2830
) ([]T, error) {
29-
var resultType paginatedResponse[T]
30-
3131
result := make([]T, 0)
3232

3333
if opts == nil {
@@ -38,38 +38,76 @@ func getPaginatedResults[T any](
3838
opts.PageOptions = &PageOptions{Page: 0}
3939
}
4040

41-
// Makes a request to a particular page and
42-
// appends the response to the result
41+
// Validate options
42+
numOpts := len(options)
43+
if numOpts > 1 {
44+
return nil, fmt.Errorf("invalid number of options: expected 0 or 1, got %d", numOpts)
45+
}
46+
47+
// Prepare request body if options are provided
48+
var reqBody string
49+
50+
if numOpts > 0 && !isNil(options[0]) {
51+
body, err := json.Marshal(options[0])
52+
if err != nil {
53+
return nil, fmt.Errorf("failed to marshal request body: %w", err)
54+
}
55+
56+
reqBody = string(body)
57+
}
58+
59+
// Makes a request to a particular page and appends the response to the result
4360
handlePage := func(page int) error {
61+
var resultType paginatedResponse[T]
62+
4463
// Override the page to be applied in applyListOptionsToRequest(...)
4564
opts.Page = page
4665

4766
// This request object cannot be reused for each page request
4867
// because it can lead to possible data corruption
49-
req := client.R(ctx).SetResult(resultType)
68+
req := client.R(ctx).SetResult(&resultType)
5069

5170
// Apply all user-provided list options to the request
5271
if err := applyListOptionsToRequest(opts, req); err != nil {
5372
return err
5473
}
5574

56-
res, err := coupleAPIErrors(req.Get(endpoint))
57-
if err != nil {
58-
return err
75+
// Set request body if provided
76+
if reqBody != "" {
77+
req.SetBody(reqBody)
5978
}
6079

61-
response := res.Result().(*paginatedResponse[T])
80+
var response *paginatedResponse[T]
81+
// Execute the appropriate HTTP method
82+
switch method {
83+
case "GET":
84+
res, err := coupleAPIErrors(req.Get(endpoint))
85+
if err != nil {
86+
return err
87+
}
88+
89+
response = res.Result().(*paginatedResponse[T])
90+
case "PUT":
91+
res, err := coupleAPIErrors(req.Put(endpoint))
92+
if err != nil {
93+
return err
94+
}
95+
96+
response = res.Result().(*paginatedResponse[T])
97+
default:
98+
return fmt.Errorf("unsupported HTTP method: %s", method)
99+
}
62100

101+
// Update pagination metadata
63102
opts.Page = page
64103
opts.Pages = response.Pages
65104
opts.Results = response.Results
66-
67105
result = append(result, response.Data...)
68106

69107
return nil
70108
}
71109

72-
// This helps simplify the logic below
110+
// Determine starting page
73111
startingPage := 1
74112
pageDefined := opts.Page > 0
75113

@@ -98,6 +136,29 @@ func getPaginatedResults[T any](
98136
return result, nil
99137
}
100138

139+
// getPaginatedResults aggregates results from the given
140+
// paginated endpoint using the provided ListOptions.
141+
func getPaginatedResults[T any](
142+
ctx context.Context,
143+
client *Client,
144+
endpoint string,
145+
opts *ListOptions,
146+
) ([]T, error) {
147+
return handlePaginatedResults[T, any](ctx, client, endpoint, opts, "GET")
148+
}
149+
150+
// putPaginatedResults sends a PUT request and aggregates the results from the given
151+
// paginated endpoint using the provided ListOptions.
152+
func putPaginatedResults[T, O any](
153+
ctx context.Context,
154+
client *Client,
155+
endpoint string,
156+
opts *ListOptions,
157+
options ...O,
158+
) ([]T, error) {
159+
return handlePaginatedResults[T, O](ctx, client, endpoint, opts, "PUT", options...)
160+
}
161+
101162
// doGETRequest runs a GET request using the given client and API endpoint,
102163
// and returns the result
103164
func doGETRequest[T any](
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"data": [
3+
{
4+
"id": 789,
5+
"label": "firewall-1",
6+
"status": "enabled",
7+
"rules": {
8+
"inbound": [],
9+
"outbound": []
10+
}
11+
}
12+
],
13+
"page": 1,
14+
"pages": 1,
15+
"results": 1
16+
}

test/unit/instance_firewall_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"testing"
66

7+
"github.com/linode/linodego"
78
"github.com/stretchr/testify/assert"
89
)
910

@@ -27,3 +28,24 @@ func TestInstanceFirewalls_List(t *testing.T) {
2728
assert.Equal(t, 789, firewalls[1].ID)
2829
assert.Equal(t, "firewall-2", firewalls[1].Label)
2930
}
31+
32+
func TestInstanceFirewalls_Update(t *testing.T) {
33+
fixtureData, err := fixtures.GetFixture("instance_firewall_update")
34+
assert.NoError(t, err)
35+
36+
var base ClientBaseCase
37+
base.SetUp(t)
38+
defer base.TearDown(t)
39+
40+
base.MockGet("linode/instances/123/firewalls", fixtureData)
41+
base.MockPut("linode/instances/123/firewalls", fixtureData)
42+
updateOpts := linodego.InstanceFirewallUpdateOptions{
43+
FirewallIDs: []int{789},
44+
}
45+
46+
firewalls, err := base.Client.UpdateInstanceFirewalls(context.Background(), 123, updateOpts)
47+
assert.NoError(t, err)
48+
assert.NotNil(t, firewalls)
49+
assert.Len(t, firewalls, 1)
50+
assert.Equal(t, 789, firewalls[0].ID)
51+
}

0 commit comments

Comments
 (0)