Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions github/apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"net/http"
"net/url"
"path"
"time"

"github.com/go-jose/go-jose/v3"
Expand All @@ -18,29 +17,22 @@ import (

// GenerateOAuthTokenFromApp generates a GitHub OAuth access token from a set of valid GitHub App credentials.
// The returned token can be used to interact with both GitHub's REST and GraphQL APIs.
func GenerateOAuthTokenFromApp(baseURL *url.URL, appID, appInstallationID, pemData string) (string, error) {
func GenerateOAuthTokenFromApp(apiURL *url.URL, appID, appInstallationID, pemData string) (string, error) {
appJWT, err := generateAppJWT(appID, time.Now(), []byte(pemData))
if err != nil {
return "", err
}

token, err := getInstallationAccessToken(baseURL, appJWT, appInstallationID)
token, err := getInstallationAccessToken(apiURL, appJWT, appInstallationID)
if err != nil {
return "", err
}

return token, nil
}

func getInstallationAccessToken(baseURL *url.URL, jwt, installationID string) (string, error) {
hostname := baseURL.Hostname()
if hostname != DotComHost && !GHECDataResidencyHostMatch.MatchString(hostname) {
baseURL.Path = path.Join(baseURL.Path, "api/v3/")
}

baseURL.Path = path.Join(baseURL.Path, "app/installations/", installationID, "access_tokens")

req, err := http.NewRequest(http.MethodPost, baseURL.String(), nil)
func getInstallationAccessToken(apiURL *url.URL, jwt, installationID string) (string, error) {
req, err := http.NewRequest(http.MethodPost, apiURL.JoinPath("app/installations", installationID, "access_tokens").String(), nil)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion github/apps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TestGetInstallationAccessToken(t *testing.T) {

ts := githubApiMock([]*mockResponse{
{
ExpectedUri: fmt.Sprintf("/api/v3/app/installations/%s/access_tokens", testGitHubAppInstallationID),
ExpectedUri: fmt.Sprintf("/app/installations/%s/access_tokens", testGitHubAppInstallationID),
ExpectedHeaders: map[string]string{
"Accept": "application/vnd.github.v3+json",
"Authorization": fmt.Sprintf("Bearer %s", fakeJWT),
Expand Down
101 changes: 70 additions & 31 deletions github/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"net/url"
"path"
"regexp"
"strings"
"time"
Expand All @@ -19,7 +18,8 @@ import (
type Config struct {
Token string
Owner string
BaseURL string
BaseURL *url.URL
IsGHES bool
Insecure bool
WriteDelay time.Duration
ReadDelay time.Duration
Expand All @@ -38,12 +38,23 @@ type Owner struct {
IsOrganization bool
}

// DotComHost is the hostname for GitHub.com API.
const DotComHost = "api.github.com"
const (
// DotComAPIURL is the base API URL for github.com.
DotComAPIURL = "https://api.github.com/"
// DotComHost is the hostname for github.com.
DotComHost = "github.com"
// DotComAPIHost is the API hostname for github.com.
DotComAPIHost = "api.github.com"
// GHESRESTAPISuffix is the rest api suffix for GitHub Enterprise Server.
GHESRESTAPIPath = "api/v3/"
)

// GHECDataResidencyHostMatch is a regex to match a GitHub Enterprise Cloud data residency host:
// https://[hostname].ghe.com/ instances expect paths that behave similar to GitHub.com, not GitHub Enterprise Server.
var GHECDataResidencyHostMatch = regexp.MustCompile(`^[a-zA-Z0-9.\-]+\.ghe\.com\/?$`)
var (
// GHECHostMatch is a regex to match GitHub Enterprise Cloud hosts.
GHECHostMatch = regexp.MustCompile(`\.ghe\.com$`)
// GHECAPIHostMatch is a regex to match GitHub Enterprise Cloud API hosts.
GHECAPIHostMatch = regexp.MustCompile(`^api\.[a-zA-Z0-9-]+\.ghe\.com$`)
)

func RateLimitedHTTPClient(client *http.Client, writeDelay, readDelay, retryDelay time.Duration, parallelRequests bool, retryableErrors map[int]bool, maxRetries int) *http.Client {
client.Transport = NewEtagTransport(client.Transport)
Expand Down Expand Up @@ -81,38 +92,24 @@ func (c *Config) AnonymousHTTPClient() *http.Client {
}

func (c *Config) NewGraphQLClient(client *http.Client) (*githubv4.Client, error) {
uv4, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
}

hostname := uv4.Hostname()
if hostname != DotComHost && !GHECDataResidencyHostMatch.MatchString(hostname) {
uv4.Path = path.Join(uv4.Path, "api/graphql/")
var path string
if c.IsGHES {
path = "api/graphql"
} else {
uv4.Path = path.Join(uv4.Path, "graphql")
path = "graphql"
}

return githubv4.NewEnterpriseClient(uv4.String(), client), nil
return githubv4.NewEnterpriseClient(c.BaseURL.JoinPath(path).String(), client), nil
}

func (c *Config) NewRESTClient(client *http.Client) (*github.Client, error) {
uv3, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
path := ""
if c.IsGHES {
path = GHESRESTAPIPath
}

hostname := uv3.Hostname()
if hostname != DotComHost && !GHECDataResidencyHostMatch.MatchString(hostname) {
uv3.Path = fmt.Sprintf("%s/", path.Join(uv3.Path, "api/v3"))
}

v3client, err := github.NewClient(client).WithEnterpriseURLs(uv3.String(), "")
if err != nil {
return nil, err
}

v3client.BaseURL = uv3
v3client := github.NewClient(client)
v3client.BaseURL = c.BaseURL.JoinPath(path)

return v3client, nil
}
Expand Down Expand Up @@ -199,3 +196,45 @@ func (injector *previewHeaderInjectorTransport) RoundTrip(req *http.Request) (*h
}
return injector.rt.RoundTrip(req)
}

// getBaseURL returns a correctly configured base URL and a bool as to if this is GitHub Enterprise Server.
func getBaseURL(s string) (*url.URL, bool, error) {
if len(s) == 0 {
s = DotComAPIURL
}

u, err := url.Parse(s)
if err != nil {
return nil, false, err
}

if !u.IsAbs() {
return nil, false, fmt.Errorf("base url must be absolute")
}

u = u.JoinPath("/")

switch {
case u.Host == DotComAPIHost:
case u.Host == DotComHost:
u.Host = DotComAPIHost
case GHECAPIHostMatch.MatchString(u.Host):
case GHECHostMatch.MatchString(u.Host):
u.Host = fmt.Sprintf("api.%s", u.Host)
default:
u.Path = strings.TrimSuffix(u.Path, GHESRESTAPIPath)
return u, true, nil
}

if u.Scheme != "https" {
return nil, false, fmt.Errorf("base url for github.com or ghe.com must use the https scheme")
}

if len(u.Path) > 1 {
return nil, false, fmt.Errorf("base url for github.com or ghe.com must not contain a path, got %s", u.Path)
}

u.Path = "/"

return u, false, nil
}
Loading