Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@
<artifactId>spring-security-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-jwt</artifactId>
<version>1.1.1.RELEASE</version>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
<version>10.3</version>
</dependency>
<dependency>
<groupId>com.auth0</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,7 @@
*/
package org.cbioportal.application.security.token.oauth2;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkException;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.stereotype.Component;

@Component
Expand All @@ -48,10 +40,7 @@ public class JwtTokenVerifierBuilder {
@Value("${dat.oauth2.jwkUrl:}")
private String jwkUrl;

public RsaVerifier build(final String kid) throws MalformedURLException, JwkException {
final JwkProvider provider = new UrlJwkProvider(new URL(jwkUrl));
final Jwk jwk = provider.get(kid);
final RSAPublicKey publicKey = (RSAPublicKey) jwk.getPublicKey();
return new RsaVerifier(publicKey, "SHA512withRSA");
}
// Functionality of this class will be integrated into OAuth2DataAccessTokenServiceImpl
// or this class will be re-purposed. For now, build() method is removed.
// The jwkUrl field might be accessed by other beans or injected directly where needed.
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,43 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import jakarta.annotation.PostConstruct;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.util.Date;
import java.util.List;
import org.cbioportal.legacy.model.DataAccessToken;
import org.cbioportal.legacy.service.DataAccessTokenService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

public class OAuth2DataAccessTokenServiceImpl implements DataAccessTokenService {

private static final Logger LOG = LoggerFactory.getLogger(OAuth2DataAccessTokenServiceImpl.class);

@Value("${dat.oauth2.issuer}")
private String issuer;

Expand All @@ -68,15 +86,31 @@ public class OAuth2DataAccessTokenServiceImpl implements DataAccessTokenService
@Value("${dat.oauth2.redirectUri}")
private String redirectUri;

private final RestTemplate template;
@Value("${dat.oauth2.jwkUrl:}")
private String jwkUrl;

private final JwtTokenVerifierBuilder jwtTokenVerifierBuilder;
private final RestTemplate template;
private DefaultJWTProcessor<SecurityContext> jwtProcessor;

@Autowired
public OAuth2DataAccessTokenServiceImpl(
RestTemplate template, JwtTokenVerifierBuilder jwtTokenVerifierBuilder) {
public OAuth2DataAccessTokenServiceImpl(RestTemplate template) {
this.template = template;
this.jwtTokenVerifierBuilder = jwtTokenVerifierBuilder;
}

@PostConstruct
public void init() {
try {
JWKSource<SecurityContext> keySource = new RemoteJWKSet<>(new URL(this.jwkUrl));
JWSKeySelector<SecurityContext> keySelector =
new JWSVerificationKeySelector<>(JWSAlgorithm.RS512, keySource);
jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(keySelector);
} catch (MalformedURLException e) {
LOG.error("Invalid JWK URL: {}", this.jwkUrl, e);
// Handle initialization failure, perhaps by preventing the application from starting
// or by setting jwtProcessor to null and checking it in methods.
throw new RuntimeException("Failed to initialize JWT processor due to invalid JWK URL", e);
}
}

@Override
Expand Down Expand Up @@ -143,56 +177,67 @@ public void revokeDataAccessToken(final String token) {

@Override
public Boolean isValid(final String token) {
final String kid = JwtHelper.headers(token).get("kid");
if (jwtProcessor == null) {
LOG.error("JWT Processor not initialized, cannot validate token.");
throw new BadCredentialsException(
"Token validation system not initialized properly.");
}
try {
SignedJWT signedJWT = SignedJWT.parse(token);
JWTClaimsSet claimsSet = jwtProcessor.process(signedJWT, null);

final Jwt tokenDecoded = JwtHelper.decodeAndVerify(token, jwtTokenVerifierBuilder.build(kid));
final String claims = tokenDecoded.getClaims();
final JsonNode claimsMap = new ObjectMapper().readTree(claims);

hasValidIssuer(claimsMap);
hasValidClientId(claimsMap);
hasValidIssuer(claimsSet);
hasValidClientId(claimsSet);

} catch (Exception e) {
throw new BadCredentialsException("Token is not valid (wrong key, issuer, or audience).");
} catch (ParseException | BadJOSEException | JOSEException e) {
LOG.warn("Token validation failed: {}", e.getMessage());
throw new BadCredentialsException(
"Token is not valid (parsing/signature/claims validation failed).", e);
}
return true;
}

@Override
public String getUsername(final String token) {

final Jwt tokenDecoded = JwtHelper.decode(token);

final String claims = tokenDecoded.getClaims();
JsonNode claimsMap;
try {
claimsMap = new ObjectMapper().readTree(claims);
} catch (IOException e) {
throw new BadCredentialsException("User name could not be found in offline token.");
}

if (!claimsMap.has("sub")) {
throw new BadCredentialsException("User name could not be found in offline token.");
SignedJWT signedJWT = SignedJWT.parse(token);
JWTClaimsSet claimsSet = signedJWT.getJWTClaimsSet(); // No validation here, just parsing

if (claimsSet.getSubject() == null) {
throw new BadCredentialsException("User name (sub claim) could not be found in token.");
}
return claimsSet.getSubject();
} catch (ParseException e) {
LOG.warn("Token parsing failed while trying to get username: {}", e.getMessage());
throw new BadCredentialsException("User name could not be found in token (parse error).", e);
}

return claimsMap.get("sub").asText();
}

@Override
public Date getExpiration(final String token) {
return null;
// Nimbus JWT library can parse expiration time if needed.
// Example:
// try {
// SignedJWT signedJWT = SignedJWT.parse(token);
// JWTClaimsSet claimsSet = signedJWT.getJWTClaimsSet();
// return claimsSet.getExpirationTime();
// } catch (ParseException e) {
// LOG.warn("Failed to parse token for expiration: {}", e.getMessage());
// return null;
// }
return null; // Current behavior is to return null
}

private void hasValidIssuer(final JsonNode claimsMap) throws BadCredentialsException {
if (!claimsMap.get("iss").asText().equals(issuer)) {
private void hasValidIssuer(final JWTClaimsSet claimsSet) throws BadCredentialsException {
if (claimsSet.getIssuer() == null || !claimsSet.getIssuer().equals(issuer)) {
throw new BadCredentialsException("Wrong Issuer found in token");
}
}

private void hasValidClientId(final JsonNode claimsMap) throws BadCredentialsException {
if (!claimsMap.get("aud").asText().equals(clientId)) {
throw new BadCredentialsException("Wrong clientId found in token");
private void hasValidClientId(final JWTClaimsSet claimsSet) throws BadCredentialsException {
List<String> audience = claimsSet.getAudience();
if (audience == null || !audience.contains(clientId)) {
throw new BadCredentialsException("Wrong clientId (audience) found in token");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,26 @@

package org.cbioportal.application.security.token.oauth2;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.text.ParseException;
import java.util.Collection;
import org.cbioportal.application.security.util.ClaimRoleExtractorUtil;
import org.cbioportal.application.security.util.GrantedAuthorityUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;

public class OAuth2TokenAuthenticationProvider implements AuthenticationProvider {

private static final Logger LOG =
LoggerFactory.getLogger(OAuth2TokenAuthenticationProvider.class);

@Value("${dat.oauth2.jwtRolesPath:resource_access::cbioportal::roles}")
private String jwtRolesPath;

Expand All @@ -75,42 +78,39 @@ public Authentication authenticate(Authentication authentication) throws Authent
// request an access token from the OAuth2 identity provider
final String accessToken = tokenRefreshRestTemplate.getAccessToken(offlineToken);

Collection<GrantedAuthority> authorities = extractAuthorities(accessToken);
String username = getUsername(accessToken);
try {
SignedJWT signedJWT = SignedJWT.parse(accessToken);
JWTClaimsSet claimsSet = signedJWT.getJWTClaimsSet();

String username = claimsSet.getSubject();
if (username == null) {
throw new BadCredentialsException("Username (sub claim) not found in access token.");
}

Collection<GrantedAuthority> authorities = extractAuthorities(claimsSet);
return new OAuth2BearerAuthenticationToken(username, authorities);

return new OAuth2BearerAuthenticationToken(username, authorities);
} catch (ParseException e) {
LOG.warn("Access token parsing failed: {}", e.getMessage());
throw new BadCredentialsException("Invalid access token: " + e.getMessage(), e);
}
}

// Read roles/authorities from JWT token.
private Collection<GrantedAuthority> extractAuthorities(final String token)
private Collection<GrantedAuthority> extractAuthorities(final JWTClaimsSet claimsSet)
throws BadCredentialsException {
try {
final Jwt tokenDecoded = JwtHelper.decode(token);
final String claims = tokenDecoded.getClaims();
// ClaimRoleExtractorUtil expects a JSON string representation of the claims
String claimsJson = claimsSet.toJSONObject().toJSONString();
return GrantedAuthorityUtil.generateGrantedAuthoritiesFromRoles(
ClaimRoleExtractorUtil.extractClientRoles(claims, jwtRolesPath));
ClaimRoleExtractorUtil.extractClientRoles(claimsJson, jwtRolesPath));

} catch (Exception e) {
throw new BadCredentialsException("Authorities could not be extracted from access token.");
}
}

private String getUsername(final String token) {

final Jwt tokenDecoded = JwtHelper.decode(token);

final String claims = tokenDecoded.getClaims();
JsonNode claimsMap;
try {
claimsMap = new ObjectMapper().readTree(claims);
} catch (IOException e) {
throw new BadCredentialsException("User name could not be found in access token.");
// Catching a broader exception here as ClaimRoleExtractorUtil might throw various things
// if the claims structure is unexpected.
LOG.warn("Authorities extraction failed: {}", e.getMessage());
throw new BadCredentialsException(
"Authorities could not be extracted from access token.", e);
}

if (!claimsMap.has("sub")) {
throw new BadCredentialsException("User name could not be found in access token.");
}

return claimsMap.get("sub").asText();
}
}