JWTFilter.java

package de.dlr.shepard.filters;

import de.dlr.shepard.exceptions.ApiError;
import de.dlr.shepard.neo4Core.services.ApiKeyService;
import de.dlr.shepard.security.GracePeriodUtil;
import de.dlr.shepard.security.JWTPrincipal;
import de.dlr.shepard.security.JWTSecurityContext;
import de.dlr.shepard.security.JwtFilterGracePeriod;
import de.dlr.shepard.security.RolesList;
import de.dlr.shepard.util.Constants;
import de.dlr.shepard.util.PKIHelper;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.jackson.io.JacksonDeserializer;
import io.quarkus.logging.Log;
import jakarta.annotation.Priority;
import jakarta.enterprise.context.RequestScoped;
import jakarta.inject.Inject;
import jakarta.ws.rs.HttpMethod;
import jakarta.ws.rs.Priorities;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import jakarta.ws.rs.ext.Provider;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.eclipse.microprofile.config.inject.ConfigProperty;

@Provider
@Priority(Priorities.AUTHENTICATION)
@RequestScoped
public class JWTFilter implements ContainerRequestFilter {

  private PublicKey jwtPublicKey;

  private PublicKey oidcPublicKey;

  private String role;

  private GracePeriodUtil lastSeen;

  private ApiKeyService apiKeyService;

  JWTFilter() {}

  @Inject
  public JWTFilter(
    PKIHelper pkiHelper,
    ApiKeyService apiKeyService,
    JwtFilterGracePeriod jwtFilterGracePeriod,
    @ConfigProperty(name = "oidc.public") String oidcPublic,
    @ConfigProperty(name = "oidc.role") Optional<String> oidcRole
  ) throws NoSuchAlgorithmException, InvalidKeySpecException, IllegalArgumentException {
    try {
      this.apiKeyService = apiKeyService;
      this.lastSeen = jwtFilterGracePeriod;
      this.role = oidcRole.orElse("");

      var kFactory = KeyFactory.getInstance("RSA");
      byte[] kcDecoded;
      try {
        kcDecoded = Base64.getDecoder().decode(oidcPublic);
      } catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("The given oidc public key is invalid", e);
      }
      var kcSpec = new X509EncodedKeySpec(kcDecoded);
      oidcPublicKey = kFactory.generatePublic(kcSpec);

      pkiHelper.init();
      jwtPublicKey = pkiHelper.getPublicKey();
    } catch (Exception ex) {
      Log.fatalf("Cannot create instance of JWTFilter: %s", ex.getMessage());
      throw ex;
    }
  }

  @Override
  public void filter(ContainerRequestContext requestContext) {
    if (PublicEndpointRegistry.isRequestPathPublic(requestContext)) return;

    // Allow CORS preflight requests
    if (HttpMethod.OPTIONS.equals(requestContext.getMethod())) {
      // Allow all requests with request method OPTIONS
      return;
    }

    JWTPrincipal principal;

    // Get the HTTP Authorization header from the request
    String authorizationHeader = requestContext.getHeaderString(HttpHeaders.AUTHORIZATION);
    String apiKeyHeader = requestContext.getHeaderString(Constants.API_KEY_HEADER);
    if (authorizationHeader != null && authorizationHeader.startsWith("Bearer ")) {
      principal = parseAccessToken(authorizationHeader);
    } else if (apiKeyHeader != null) {
      principal = parseApiKey(apiKeyHeader);
    } else {
      Log.warnf(
        "Invalid/missing authorization header (Authorization: %s, X-API-KEY: %s) on endpoint %s",
        authorizationHeader,
        apiKeyHeader,
        requestContext.getUriInfo().getAbsolutePath()
      );
      requestContext.abortWith(
        Response.status(Status.UNAUTHORIZED)
          .entity(
            new ApiError(
              Status.UNAUTHORIZED.getStatusCode(),
              "AuthenticationException",
              "Invalid/missing authorization header"
            )
          )
          .build()
      );
      return;
    }
    if (principal == null) {
      requestContext.abortWith(
        Response.status(Status.UNAUTHORIZED)
          .entity(
            new ApiError(Status.UNAUTHORIZED.getStatusCode(), "AuthenticationException", "Invalid Authentication")
          )
          .build()
      );
      return;
    }

    var securityContext = new JWTSecurityContext(requestContext.getSecurityContext(), principal);
    requestContext.setSecurityContext(securityContext);
  }

  private JWTPrincipal parseAccessToken(String header) {
    JWTPrincipal result = null;
    Jws<Claims> jws = parseAccessTokenFromHeader(header);
    if (jws != null) {
      result = parsePrincipalFromAccessToken(jws);
    }
    return result;
  }

  private Jws<Claims> parseAccessTokenFromHeader(String header) {
    // Extract the token from the HTTP Authorization header
    String token = header.replace("Bearer ", "");
    var parser = Jwts.parserBuilder()
      .setSigningKey(oidcPublicKey)
      .deserializeJsonWith(new JacksonDeserializer<>(Map.of("realm_access", RolesList.class)))
      .build();

    Jws<Claims> jws;

    try {
      jws = parser.parseClaimsJws(token);
      Log.debugf("Valid token: %s", jws.getBody().getId());
    } catch (JwtException ex) {
      Log.warnf("Invalid token: %s", ex.getMessage());
      return null;
    }
    return jws;
  }

  private JWTPrincipal parsePrincipalFromAccessToken(Jws<Claims> jws) {
    var body = jws.getBody();
    String keyId = body.getId();
    String subject = body.getSubject();
    String audience = body.getAudience();
    String issuedFor = body.get("azp", String.class);
    Optional<RolesList> realmAccess = Optional.ofNullable(body.get("realm_access", RolesList.class));

    if (subject == null || subject.isEmpty()) {
      Log.warn("Token is missing a subject");
      return null;
    }

    // Read realm roles
    if (!role.isBlank()) {
      var realmRoles = realmAccess.map(RolesList::getRoles).orElse(new String[0]);
      var hasRole = Arrays.stream(realmRoles).anyMatch(r -> r.equals(role));
      if (!hasRole) {
        Log.warnf("User is missing required role: %s", role);
        return null;
      }
    }

    // We only want the last part of the subject, since this is usually a human
    // readable user name
    var splitted = subject.split(":");
    String username = splitted[splitted.length - 1];

    var principal = new JWTPrincipal(audience, issuedFor, username, keyId, new String[0]);

    return principal;
  }

  private Jws<Claims> parseApiKeyFromHeader(String token) {
    // Extract the api key from the HTTP Authorization header
    Jws<Claims> jws;

    try {
      jws = Jwts.parserBuilder().setSigningKey(jwtPublicKey).build().parseClaimsJws(token);
      Log.debugf("Valid token: %s", jws.getBody().getId());
    } catch (JwtException ex) {
      Log.warnf("Invalid token: %s", ex.getMessage());
      return null;
    }
    return jws;
  }

  private JWTPrincipal parseApiKey(String header) {
    JWTPrincipal principal = null;
    Jws<Claims> jws = parseApiKeyFromHeader(header);
    if (jws != null) {
      principal = parsePrincipalFromApiKey(jws);
      if (principal == null) return null;
      UUID tokenId = UUID.fromString(jws.getBody().getId());

      if (lastSeen.elementIsKnown(tokenId.toString())) {
        return principal;
      }

      var storedKey = apiKeyService.getApiKey(tokenId);
      if (storedKey == null) {
        Log.warn("Token was not found in database");
        return null;
      } else if (!storedKey.getJws().equals(header)) {
        Log.warn("Token from header is not equal to the token from database");
        return null;
      }
      lastSeen.elementSeen(tokenId.toString());
    }
    return principal;
  }

  private JWTPrincipal parsePrincipalFromApiKey(Jws<Claims> jws) {
    var body = jws.getBody();
    String subject = body.getSubject();
    String keyId = body.getId();
    if (subject == null || subject.isEmpty()) {
      Log.warn("Token is missing a subject");
      return null;
    }

    var principal = new JWTPrincipal(subject, keyId);
    return principal;
  }
}