Skip to content

Commit 92426c1

Browse files
authored
Merge pull request #28 from fluxcd/code_cleanups
Restructure Options and Transport functionality to become generic
2 parents ee897f0 + ffaa4a4 commit 92426c1

File tree

10 files changed

+719
-231
lines changed

10 files changed

+719
-231
lines changed

github/auth.go

Lines changed: 145 additions & 212 deletions
Large diffs are not rendered by default.

github/auth_test.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
Copyright 2020 The Flux CD contributors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package github
18+
19+
import (
20+
"net/http"
21+
"reflect"
22+
"testing"
23+
24+
"github.com/fluxcd/go-git-providers/gitprovider"
25+
"github.com/fluxcd/go-git-providers/gitprovider/cache"
26+
"github.com/fluxcd/go-git-providers/validation"
27+
)
28+
29+
func dummyRoundTripper1(http.RoundTripper) http.RoundTripper { return nil }
30+
func dummyRoundTripper2(http.RoundTripper) http.RoundTripper { return nil }
31+
func dummyRoundTripper3(http.RoundTripper) http.RoundTripper { return nil }
32+
33+
func roundTrippersEqual(a, b gitprovider.ChainableRoundTripperFunc) bool {
34+
if a == nil && b == nil {
35+
return true
36+
} else if (a != nil && b == nil) || (a == nil && b != nil) {
37+
return false
38+
}
39+
// Note that this comparison relies on "undefined behavior" in the Go language spec, see:
40+
// https://stackoverflow.com/questions/9643205/how-do-i-compare-two-functions-for-pointer-equality-in-the-latest-go-weekly
41+
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
42+
}
43+
44+
func Test_clientOptions_getTransportChain(t *testing.T) {
45+
tests := []struct {
46+
name string
47+
preChain gitprovider.ChainableRoundTripperFunc
48+
postChain gitprovider.ChainableRoundTripperFunc
49+
auth gitprovider.ChainableRoundTripperFunc
50+
cache bool
51+
wantChain []gitprovider.ChainableRoundTripperFunc
52+
}{
53+
{
54+
name: "all roundtrippers",
55+
preChain: dummyRoundTripper1,
56+
postChain: dummyRoundTripper2,
57+
auth: dummyRoundTripper3,
58+
cache: true,
59+
// expect: "post chain" <-> "auth" <-> "cache" <-> "pre chain"
60+
wantChain: []gitprovider.ChainableRoundTripperFunc{
61+
dummyRoundTripper2,
62+
dummyRoundTripper3,
63+
cache.NewHTTPCacheTransport,
64+
dummyRoundTripper1,
65+
},
66+
},
67+
{
68+
name: "only pre + auth",
69+
preChain: dummyRoundTripper1,
70+
auth: dummyRoundTripper2,
71+
// expect: "auth" <-> "pre chain"
72+
wantChain: []gitprovider.ChainableRoundTripperFunc{
73+
dummyRoundTripper2,
74+
dummyRoundTripper1,
75+
},
76+
},
77+
{
78+
name: "only cache + auth",
79+
cache: true,
80+
auth: dummyRoundTripper1,
81+
// expect: "auth" <-> "cache"
82+
wantChain: []gitprovider.ChainableRoundTripperFunc{
83+
dummyRoundTripper1,
84+
cache.NewHTTPCacheTransport,
85+
},
86+
},
87+
}
88+
for _, tt := range tests {
89+
t.Run(tt.name, func(t *testing.T) {
90+
opts := &clientOptions{
91+
CommonClientOptions: gitprovider.CommonClientOptions{
92+
PreChainTransportHook: tt.preChain,
93+
PostChainTransportHook: tt.postChain,
94+
},
95+
AuthTransport: tt.auth,
96+
EnableConditionalRequests: &tt.cache,
97+
}
98+
gotChain := opts.getTransportChain()
99+
for i := range tt.wantChain {
100+
if !roundTrippersEqual(tt.wantChain[i], gotChain[i]) {
101+
t.Errorf("clientOptions.getTransportChain() = %v, want %v", gotChain, tt.wantChain)
102+
}
103+
break
104+
}
105+
})
106+
}
107+
}
108+
109+
func Test_makeOptions(t *testing.T) {
110+
tests := []struct {
111+
name string
112+
opts []ClientOption
113+
want *clientOptions
114+
expectedErrs []error
115+
}{
116+
{
117+
name: "no options",
118+
want: &clientOptions{},
119+
},
120+
{
121+
name: "WithDomain",
122+
opts: []ClientOption{WithDomain("foo")},
123+
want: buildCommonOption(gitprovider.CommonClientOptions{Domain: gitprovider.StringVar("foo")}),
124+
},
125+
{
126+
name: "WithDomain, empty",
127+
opts: []ClientOption{WithDomain("")},
128+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
129+
},
130+
{
131+
name: "WithDestructiveAPICalls",
132+
opts: []ClientOption{WithDestructiveAPICalls(true)},
133+
want: buildCommonOption(gitprovider.CommonClientOptions{EnableDestructiveAPICalls: gitprovider.BoolVar(true)}),
134+
},
135+
{
136+
name: "WithPreChainTransportHook",
137+
opts: []ClientOption{WithPreChainTransportHook(dummyRoundTripper1)},
138+
want: buildCommonOption(gitprovider.CommonClientOptions{PreChainTransportHook: dummyRoundTripper1}),
139+
},
140+
{
141+
name: "WithPreChainTransportHook, nil",
142+
opts: []ClientOption{WithPreChainTransportHook(nil)},
143+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
144+
},
145+
{
146+
name: "WithPostChainTransportHook",
147+
opts: []ClientOption{WithPostChainTransportHook(dummyRoundTripper2)},
148+
want: buildCommonOption(gitprovider.CommonClientOptions{PostChainTransportHook: dummyRoundTripper2}),
149+
},
150+
{
151+
name: "WithPostChainTransportHook, nil",
152+
opts: []ClientOption{WithPostChainTransportHook(nil)},
153+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
154+
},
155+
{
156+
name: "WithOAuth2Token",
157+
opts: []ClientOption{WithOAuth2Token("foo")},
158+
want: &clientOptions{AuthTransport: oauth2Transport("foo")},
159+
},
160+
{
161+
name: "WithOAuth2Token, empty",
162+
opts: []ClientOption{WithOAuth2Token("")},
163+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
164+
},
165+
{
166+
name: "WithPersonalAccessToken",
167+
opts: []ClientOption{WithPersonalAccessToken("foo")},
168+
want: &clientOptions{AuthTransport: patTransport("foo")},
169+
},
170+
{
171+
name: "WithPersonalAccessToken, empty",
172+
opts: []ClientOption{WithPersonalAccessToken("")},
173+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
174+
},
175+
{
176+
name: "WithPersonalAccessToken and WithOAuth2Token, exclusive",
177+
opts: []ClientOption{WithPersonalAccessToken("foo"), WithOAuth2Token("foo")},
178+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
179+
},
180+
{
181+
name: "WithConditionalRequests",
182+
opts: []ClientOption{WithConditionalRequests(true)},
183+
want: &clientOptions{EnableConditionalRequests: gitprovider.BoolVar(true)},
184+
},
185+
{
186+
name: "WithConditionalRequests, exclusive",
187+
opts: []ClientOption{WithConditionalRequests(true), WithConditionalRequests(false)},
188+
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
189+
},
190+
}
191+
for _, tt := range tests {
192+
t.Run(tt.name, func(t *testing.T) {
193+
got, err := makeOptions(tt.opts...)
194+
validation.TestExpectErrors(t, "makeOptions", err, tt.expectedErrs...)
195+
if tt.want == nil {
196+
return
197+
}
198+
if !roundTrippersEqual(got.AuthTransport, tt.want.AuthTransport) ||
199+
!roundTrippersEqual(got.PostChainTransportHook, tt.want.PostChainTransportHook) ||
200+
!roundTrippersEqual(got.PreChainTransportHook, tt.want.PreChainTransportHook) {
201+
t.Errorf("makeOptions() = %v, want %v", got, tt.want)
202+
}
203+
got.AuthTransport = nil
204+
got.PostChainTransportHook = nil
205+
got.PreChainTransportHook = nil
206+
tt.want.AuthTransport = nil
207+
tt.want.PostChainTransportHook = nil
208+
tt.want.PreChainTransportHook = nil
209+
if !reflect.DeepEqual(got, tt.want) {
210+
t.Errorf("makeOptions() = %v, want %v", got, tt.want)
211+
}
212+
})
213+
}
214+
}

github/example_organization_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func checkErr(err error) {
2020
func ExampleOrganizationsClient_Get() {
2121
// Create a new client
2222
ctx := context.Background()
23-
c, err := github.NewClient(ctx)
23+
c, err := github.NewClient()
2424
checkErr(err)
2525

2626
// Get public information about the fluxcd organization

github/example_repository_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
func ExampleOrgRepositoriesClient_Get() {
1313
// Create a new client
1414
ctx := context.Background()
15-
c, err := github.NewClient(ctx)
15+
c, err := github.NewClient()
1616
checkErr(err)
1717

1818
// Parse the URL into an OrgRepositoryRef

github/githubclient.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ func (c *githubClientImpl) UpdateRepo(ctx context.Context, owner, repo string, r
262262
func (c *githubClientImpl) DeleteRepo(ctx context.Context, owner, repo string) error {
263263
// Don't allow deleting repositories if the user didn't explicitly allow dangerous API calls.
264264
if !c.destructiveActions {
265-
return fmt.Errorf("cannot delete repository: %w", ErrDestructiveCallDisallowed)
265+
return fmt.Errorf("cannot delete repository: %w", gitprovider.ErrDestructiveCallDisallowed)
266266
}
267267
// DELETE /repos/{owner}/{repo}
268268
_, err := c.c.Repositories.Delete(ctx, owner, repo)

github/integration_test.go

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ const (
4545
defaultBranch = "master"
4646
)
4747

48+
var (
49+
// customTransportImpl is a shared instance of a customTransport, allowing counting of cache hits.
50+
customTransportImpl *customTransport
51+
)
52+
4853
func init() {
4954
// Call testing.Init() prior to tests.NewParams(), as otherwise -test.* will not be recognised. See also: https://golang.org/doc/go1.13#testing
5055
testing.Init()
@@ -56,21 +61,17 @@ func TestProvider(t *testing.T) {
5661
RunSpecs(t, "GitHub Provider Suite")
5762
}
5863

59-
type customTransportFactory struct {
60-
customTransport *customTransport
61-
}
62-
63-
func (f *customTransportFactory) Transport(transport http.RoundTripper) http.RoundTripper {
64-
if f.customTransport != nil {
64+
func customTransportFactory(transport http.RoundTripper) http.RoundTripper {
65+
if customTransportImpl != nil {
6566
panic("didn't expect this function to be called twice")
6667
}
67-
f.customTransport = &customTransport{
68+
customTransportImpl = &customTransport{
6869
transport: transport,
6970
countCacheHits: false,
7071
cacheHits: 0,
7172
mux: &sync.Mutex{},
7273
}
73-
return f.customTransport
74+
return customTransportImpl
7475
}
7576

7677
type customTransport struct {
@@ -125,9 +126,8 @@ func (t *customTransport) countCacheHitsForFunc(fn func()) int {
125126

126127
var _ = Describe("GitHub Provider", func() {
127128
var (
128-
ctx context.Context
129-
c gitprovider.Client
130-
transportFactory = &customTransportFactory{}
129+
ctx context.Context = context.Background()
130+
c gitprovider.Client
131131

132132
testRepoName string
133133
testOrgName string = "fluxcd-testing"
@@ -148,13 +148,12 @@ var _ = Describe("GitHub Provider", func() {
148148
testOrgName = orgName
149149
}
150150

151-
ctx = context.Background()
152151
var err error
153-
c, err = NewClient(ctx,
152+
c, err = NewClient(
154153
WithPersonalAccessToken(githubToken),
155154
WithDestructiveAPICalls(true),
156155
WithConditionalRequests(true),
157-
WithRoundTripper(transportFactory),
156+
WithPreChainTransportHook(customTransportFactory),
158157
)
159158
Expect(err).ToNot(HaveOccurred())
160159
})
@@ -174,7 +173,7 @@ var _ = Describe("GitHub Provider", func() {
174173
}
175174
Expect(listedOrg).ToNot(BeNil())
176175

177-
hits := transportFactory.customTransport.countCacheHitsForFunc(func() {
176+
hits := customTransportImpl.countCacheHitsForFunc(func() {
178177
// Do a GET call for that organization
179178
getOrg, err = c.Organizations().Get(ctx, listedOrg.Organization())
180179
Expect(err).ToNot(HaveOccurred())
@@ -200,7 +199,7 @@ var _ = Describe("GitHub Provider", func() {
200199
Expect(getOrg.Get().Description).To(Equal(internal.Description))
201200

202201
// Expect that when we do the same request a second time, it will hit the cache
203-
hits = transportFactory.customTransport.countCacheHitsForFunc(func() {
202+
hits = customTransportImpl.countCacheHitsForFunc(func() {
204203
getOrg2, err := c.Organizations().Get(ctx, listedOrg.Organization())
205204
Expect(err).ToNot(HaveOccurred())
206205
Expect(getOrg2).ToNot(BeNil())

0 commit comments

Comments
 (0)