JWTFilter.java

  1. package de.dlr.shepard.filters;

  2. import de.dlr.shepard.exceptions.ApiError;
  3. import de.dlr.shepard.neo4Core.services.ApiKeyService;
  4. import de.dlr.shepard.security.GracePeriodUtil;
  5. import de.dlr.shepard.security.JWTPrincipal;
  6. import de.dlr.shepard.security.JWTSecurityContext;
  7. import de.dlr.shepard.security.JwtFilterGracePeriod;
  8. import de.dlr.shepard.security.RolesList;
  9. import de.dlr.shepard.util.Constants;
  10. import de.dlr.shepard.util.PKIHelper;
  11. import io.jsonwebtoken.Claims;
  12. import io.jsonwebtoken.Jws;
  13. import io.jsonwebtoken.JwtException;
  14. import io.jsonwebtoken.Jwts;
  15. import io.jsonwebtoken.jackson.io.JacksonDeserializer;
  16. import io.quarkus.logging.Log;
  17. import jakarta.annotation.Priority;
  18. import jakarta.enterprise.context.RequestScoped;
  19. import jakarta.inject.Inject;
  20. import jakarta.ws.rs.HttpMethod;
  21. import jakarta.ws.rs.Priorities;
  22. import jakarta.ws.rs.container.ContainerRequestContext;
  23. import jakarta.ws.rs.container.ContainerRequestFilter;
  24. import jakarta.ws.rs.core.HttpHeaders;
  25. import jakarta.ws.rs.core.Response;
  26. import jakarta.ws.rs.core.Response.Status;
  27. import jakarta.ws.rs.ext.Provider;
  28. import java.security.KeyFactory;
  29. import java.security.NoSuchAlgorithmException;
  30. import java.security.PublicKey;
  31. import java.security.spec.InvalidKeySpecException;
  32. import java.security.spec.X509EncodedKeySpec;
  33. import java.util.Arrays;
  34. import java.util.Base64;
  35. import java.util.Map;
  36. import java.util.Optional;
  37. import java.util.UUID;
  38. import org.eclipse.microprofile.config.inject.ConfigProperty;

  39. @Provider
  40. @Priority(Priorities.AUTHENTICATION)
  41. @RequestScoped
  42. public class JWTFilter implements ContainerRequestFilter {

  43.   private PublicKey jwtPublicKey;

  44.   private PublicKey oidcPublicKey;

  45.   private String role;

  46.   private GracePeriodUtil lastSeen;

  47.   private ApiKeyService apiKeyService;

  48.   JWTFilter() {}

  49.   @Inject
  50.   public JWTFilter(
  51.     PKIHelper pkiHelper,
  52.     ApiKeyService apiKeyService,
  53.     JwtFilterGracePeriod jwtFilterGracePeriod,
  54.     @ConfigProperty(name = "oidc.public") String oidcPublic,
  55.     @ConfigProperty(name = "oidc.role") Optional<String> oidcRole
  56.   ) throws NoSuchAlgorithmException, InvalidKeySpecException, IllegalArgumentException {
  57.     try {
  58.       this.apiKeyService = apiKeyService;
  59.       this.lastSeen = jwtFilterGracePeriod;
  60.       this.role = oidcRole.orElse("");

  61.       var kFactory = KeyFactory.getInstance("RSA");
  62.       byte[] kcDecoded;
  63.       try {
  64.         kcDecoded = Base64.getDecoder().decode(oidcPublic);
  65.       } catch (IllegalArgumentException e) {
  66.         throw new IllegalArgumentException("The given oidc public key is invalid", e);
  67.       }
  68.       var kcSpec = new X509EncodedKeySpec(kcDecoded);
  69.       oidcPublicKey = kFactory.generatePublic(kcSpec);

  70.       pkiHelper.init();
  71.       jwtPublicKey = pkiHelper.getPublicKey();
  72.     } catch (Exception ex) {
  73.       Log.fatalf("Cannot create instance of JWTFilter: %s", ex.getMessage());
  74.       throw ex;
  75.     }
  76.   }

  77.   @Override
  78.   public void filter(ContainerRequestContext requestContext) {
  79.     if (PublicEndpointRegistry.isRequestPathPublic(requestContext)) return;

  80.     // Allow CORS preflight requests
  81.     if (HttpMethod.OPTIONS.equals(requestContext.getMethod())) {
  82.       // Allow all requests with request method OPTIONS
  83.       return;
  84.     }

  85.     JWTPrincipal principal;

  86.     // Get the HTTP Authorization header from the request
  87.     String authorizationHeader = requestContext.getHeaderString(HttpHeaders.AUTHORIZATION);
  88.     String apiKeyHeader = requestContext.getHeaderString(Constants.API_KEY_HEADER);
  89.     if (authorizationHeader != null && authorizationHeader.startsWith("Bearer ")) {
  90.       principal = parseAccessToken(authorizationHeader);
  91.     } else if (apiKeyHeader != null) {
  92.       principal = parseApiKey(apiKeyHeader);
  93.     } else {
  94.       Log.warnf(
  95.         "Invalid/missing authorization header (Authorization: %s, X-API-KEY: %s) on endpoint %s",
  96.         authorizationHeader,
  97.         apiKeyHeader,
  98.         requestContext.getUriInfo().getAbsolutePath()
  99.       );
  100.       requestContext.abortWith(
  101.         Response.status(Status.UNAUTHORIZED)
  102.           .entity(
  103.             new ApiError(
  104.               Status.UNAUTHORIZED.getStatusCode(),
  105.               "AuthenticationException",
  106.               "Invalid/missing authorization header"
  107.             )
  108.           )
  109.           .build()
  110.       );
  111.       return;
  112.     }
  113.     if (principal == null) {
  114.       requestContext.abortWith(
  115.         Response.status(Status.UNAUTHORIZED)
  116.           .entity(
  117.             new ApiError(Status.UNAUTHORIZED.getStatusCode(), "AuthenticationException", "Invalid Authentication")
  118.           )
  119.           .build()
  120.       );
  121.       return;
  122.     }

  123.     var securityContext = new JWTSecurityContext(requestContext.getSecurityContext(), principal);
  124.     requestContext.setSecurityContext(securityContext);
  125.   }

  126.   private JWTPrincipal parseAccessToken(String header) {
  127.     JWTPrincipal result = null;
  128.     Jws<Claims> jws = parseAccessTokenFromHeader(header);
  129.     if (jws != null) {
  130.       result = parsePrincipalFromAccessToken(jws);
  131.     }
  132.     return result;
  133.   }

  134.   private Jws<Claims> parseAccessTokenFromHeader(String header) {
  135.     // Extract the token from the HTTP Authorization header
  136.     String token = header.replace("Bearer ", "");
  137.     var parser = Jwts.parserBuilder()
  138.       .setSigningKey(oidcPublicKey)
  139.       .deserializeJsonWith(new JacksonDeserializer<>(Map.of("realm_access", RolesList.class)))
  140.       .build();

  141.     Jws<Claims> jws;

  142.     try {
  143.       jws = parser.parseClaimsJws(token);
  144.       Log.debugf("Valid token: %s", jws.getBody().getId());
  145.     } catch (JwtException ex) {
  146.       Log.warnf("Invalid token: %s", ex.getMessage());
  147.       return null;
  148.     }
  149.     return jws;
  150.   }

  151.   private JWTPrincipal parsePrincipalFromAccessToken(Jws<Claims> jws) {
  152.     var body = jws.getBody();
  153.     String keyId = body.getId();
  154.     String subject = body.getSubject();
  155.     String audience = body.getAudience();
  156.     String issuedFor = body.get("azp", String.class);
  157.     Optional<RolesList> realmAccess = Optional.ofNullable(body.get("realm_access", RolesList.class));

  158.     if (subject == null || subject.isEmpty()) {
  159.       Log.warn("Token is missing a subject");
  160.       return null;
  161.     }

  162.     // Read realm roles
  163.     if (!role.isBlank()) {
  164.       var realmRoles = realmAccess.map(RolesList::getRoles).orElse(new String[0]);
  165.       var hasRole = Arrays.stream(realmRoles).anyMatch(r -> r.equals(role));
  166.       if (!hasRole) {
  167.         Log.warnf("User is missing required role: %s", role);
  168.         return null;
  169.       }
  170.     }

  171.     // We only want the last part of the subject, since this is usually a human
  172.     // readable user name
  173.     var splitted = subject.split(":");
  174.     String username = splitted[splitted.length - 1];

  175.     var principal = new JWTPrincipal(audience, issuedFor, username, keyId, new String[0]);

  176.     return principal;
  177.   }

  178.   private Jws<Claims> parseApiKeyFromHeader(String token) {
  179.     // Extract the api key from the HTTP Authorization header
  180.     Jws<Claims> jws;

  181.     try {
  182.       jws = Jwts.parserBuilder().setSigningKey(jwtPublicKey).build().parseClaimsJws(token);
  183.       Log.debugf("Valid token: %s", jws.getBody().getId());
  184.     } catch (JwtException ex) {
  185.       Log.warnf("Invalid token: %s", ex.getMessage());
  186.       return null;
  187.     }
  188.     return jws;
  189.   }

  190.   private JWTPrincipal parseApiKey(String header) {
  191.     JWTPrincipal principal = null;
  192.     Jws<Claims> jws = parseApiKeyFromHeader(header);
  193.     if (jws != null) {
  194.       principal = parsePrincipalFromApiKey(jws);
  195.       if (principal == null) return null;
  196.       UUID tokenId = UUID.fromString(jws.getBody().getId());

  197.       if (lastSeen.elementIsKnown(tokenId.toString())) {
  198.         return principal;
  199.       }

  200.       var storedKey = apiKeyService.getApiKey(tokenId);
  201.       if (storedKey == null) {
  202.         Log.warn("Token was not found in database");
  203.         return null;
  204.       } else if (!storedKey.getJws().equals(header)) {
  205.         Log.warn("Token from header is not equal to the token from database");
  206.         return null;
  207.       }
  208.       lastSeen.elementSeen(tokenId.toString());
  209.     }
  210.     return principal;
  211.   }

  212.   private JWTPrincipal parsePrincipalFromApiKey(Jws<Claims> jws) {
  213.     var body = jws.getBody();
  214.     String subject = body.getSubject();
  215.     String keyId = body.getId();
  216.     if (subject == null || subject.isEmpty()) {
  217.       Log.warn("Token is missing a subject");
  218.       return null;
  219.     }

  220.     var principal = new JWTPrincipal(subject, keyId);
  221.     return principal;
  222.   }
  223. }