View Javadoc
1   package de.dlr.shepard.common.filters;
2   
3   import de.dlr.shepard.auth.apikey.services.ApiKeyService;
4   import de.dlr.shepard.auth.security.ApiKeyLastSeenCache;
5   import de.dlr.shepard.auth.security.AuthenticationContext;
6   import de.dlr.shepard.auth.security.JWTPrincipal;
7   import de.dlr.shepard.auth.security.JWTSecurityContext;
8   import de.dlr.shepard.auth.security.RolesList;
9   import de.dlr.shepard.common.exceptions.ApiError;
10  import de.dlr.shepard.common.util.Constants;
11  import de.dlr.shepard.common.util.PKIHelper;
12  import io.jsonwebtoken.Claims;
13  import io.jsonwebtoken.Jws;
14  import io.jsonwebtoken.JwtException;
15  import io.jsonwebtoken.Jwts;
16  import io.jsonwebtoken.jackson.io.JacksonDeserializer;
17  import io.quarkus.logging.Log;
18  import jakarta.annotation.Priority;
19  import jakarta.enterprise.context.RequestScoped;
20  import jakarta.inject.Inject;
21  import jakarta.ws.rs.HttpMethod;
22  import jakarta.ws.rs.Priorities;
23  import jakarta.ws.rs.container.ContainerRequestContext;
24  import jakarta.ws.rs.container.ContainerRequestFilter;
25  import jakarta.ws.rs.core.HttpHeaders;
26  import jakarta.ws.rs.core.Response;
27  import jakarta.ws.rs.core.Response.Status;
28  import jakarta.ws.rs.ext.Provider;
29  import java.security.KeyFactory;
30  import java.security.NoSuchAlgorithmException;
31  import java.security.PublicKey;
32  import java.security.spec.InvalidKeySpecException;
33  import java.security.spec.X509EncodedKeySpec;
34  import java.util.Arrays;
35  import java.util.Base64;
36  import java.util.Map;
37  import java.util.Optional;
38  import java.util.UUID;
39  import org.eclipse.microprofile.config.inject.ConfigProperty;
40  
41  @Provider
42  @Priority(Priorities.AUTHENTICATION)
43  @RequestScoped
44  public class JWTFilter implements ContainerRequestFilter {
45  
46    private PublicKey jwtPublicKey;
47  
48    private PublicKey oidcPublicKey;
49  
50    private String role;
51  
52    private ApiKeyLastSeenCache apiKeyLastSeenCache;
53  
54    private ApiKeyService apiKeyService;
55  
56    private AuthenticationContext authenticationContext;
57  
58    JWTFilter() {}
59  
60    @Inject
61    public JWTFilter(
62      PKIHelper pkiHelper,
63      ApiKeyService apiKeyService,
64      ApiKeyLastSeenCache apiKeyLastSeenCache,
65      AuthenticationContext authenticationContext,
66      @ConfigProperty(name = "oidc.public") String oidcPublic,
67      @ConfigProperty(name = "oidc.role") Optional<String> oidcRole
68    ) throws NoSuchAlgorithmException, InvalidKeySpecException, IllegalArgumentException {
69      try {
70        this.apiKeyService = apiKeyService;
71        this.apiKeyLastSeenCache = apiKeyLastSeenCache;
72        this.authenticationContext = authenticationContext;
73        this.role = oidcRole.orElse("");
74  
75        var kFactory = KeyFactory.getInstance("RSA");
76        byte[] kcDecoded;
77        try {
78          kcDecoded = Base64.getDecoder().decode(oidcPublic);
79        } catch (IllegalArgumentException e) {
80          throw new IllegalArgumentException("The given oidc public key is invalid", e);
81        }
82        var kcSpec = new X509EncodedKeySpec(kcDecoded);
83        oidcPublicKey = kFactory.generatePublic(kcSpec);
84  
85        pkiHelper.init();
86        jwtPublicKey = pkiHelper.getPublicKey();
87      } catch (Exception ex) {
88        Log.fatalf("Cannot create instance of JWTFilter: %s", ex.getMessage());
89        throw ex;
90      }
91    }
92  
93    @Override
94    public void filter(ContainerRequestContext requestContext) {
95      if (PublicEndpointRegistry.isRequestPathPublic(requestContext)) return;
96  
97      // Allow CORS preflight requests
98      if (HttpMethod.OPTIONS.equals(requestContext.getMethod())) {
99        // Allow all requests with request method OPTIONS
100       return;
101     }
102 
103     JWTPrincipal principal;
104 
105     // Get the HTTP Authorization header from the request
106     String authorizationHeader = requestContext.getHeaderString(HttpHeaders.AUTHORIZATION);
107     String apiKeyHeader = requestContext.getHeaderString(Constants.API_KEY_HEADER);
108     if (authorizationHeader != null && authorizationHeader.startsWith("Bearer ")) {
109       principal = parseAccessToken(authorizationHeader);
110     } else if (apiKeyHeader != null) {
111       principal = parseApiKey(apiKeyHeader);
112     } else {
113       Log.warnf(
114         "Invalid/missing authorization header (Authorization: %s, X-API-KEY: %s) on endpoint %s",
115         authorizationHeader,
116         apiKeyHeader,
117         requestContext.getUriInfo().getAbsolutePath()
118       );
119       requestContext.abortWith(
120         Response.status(Status.UNAUTHORIZED)
121           .entity(
122             new ApiError(
123               Status.UNAUTHORIZED.getStatusCode(),
124               "AuthenticationException",
125               "Invalid/missing authorization header"
126             )
127           )
128           .build()
129       );
130       return;
131     }
132     if (principal == null) {
133       requestContext.abortWith(
134         Response.status(Status.UNAUTHORIZED)
135           .entity(
136             new ApiError(Status.UNAUTHORIZED.getStatusCode(), "AuthenticationException", "Invalid Authentication")
137           )
138           .build()
139       );
140       return;
141     }
142 
143     var securityContext = new JWTSecurityContext(requestContext.getSecurityContext(), principal);
144     requestContext.setSecurityContext(securityContext);
145     authenticationContext.setPrincipal(principal);
146   }
147 
148   private JWTPrincipal parseAccessToken(String header) {
149     JWTPrincipal result = null;
150     Jws<Claims> jws = parseAccessTokenFromHeader(header);
151     if (jws != null) {
152       result = parsePrincipalFromAccessToken(jws);
153     }
154     return result;
155   }
156 
157   private Jws<Claims> parseAccessTokenFromHeader(String header) {
158     // Extract the token from the HTTP Authorization header
159     String token = header.replace("Bearer ", "");
160     var parser = Jwts.parserBuilder()
161       .setSigningKey(oidcPublicKey)
162       .deserializeJsonWith(new JacksonDeserializer<>(Map.of("realm_access", RolesList.class)))
163       .build();
164 
165     Jws<Claims> jws;
166 
167     try {
168       jws = parser.parseClaimsJws(token);
169       Log.debugf("Valid token: %s", jws.getBody().getId());
170     } catch (JwtException ex) {
171       Log.warnf("Invalid token: %s", ex.getMessage());
172       return null;
173     }
174     return jws;
175   }
176 
177   private JWTPrincipal parsePrincipalFromAccessToken(Jws<Claims> jws) {
178     var body = jws.getBody();
179     String keyId = body.getId();
180     String subject = body.getSubject();
181     String audience = body.getAudience();
182     String issuedFor = body.get("azp", String.class);
183     Optional<RolesList> realmAccess = Optional.ofNullable(body.get("realm_access", RolesList.class));
184 
185     if (subject == null || subject.isEmpty()) {
186       Log.warn("Token is missing a subject");
187       return null;
188     }
189 
190     // Read realm roles
191     if (!role.isBlank()) {
192       var realmRoles = realmAccess.map(RolesList::getRoles).orElse(new String[0]);
193       var hasRole = Arrays.stream(realmRoles).anyMatch(r -> r.equals(role));
194       if (!hasRole) {
195         Log.warnf("User is missing required role: %s", role);
196         return null;
197       }
198     }
199 
200     // We only want the last part of the subject, since this is usually a human
201     // readable user name
202     var splitted = subject.split(":");
203     String username = splitted[splitted.length - 1];
204 
205     var principal = new JWTPrincipal(audience, issuedFor, username, keyId, new String[0]);
206 
207     return principal;
208   }
209 
210   private Jws<Claims> parseApiKeyFromHeader(String token) {
211     // Extract the api key from the HTTP Authorization header
212     Jws<Claims> jws;
213 
214     try {
215       jws = Jwts.parserBuilder().setSigningKey(jwtPublicKey).build().parseClaimsJws(token);
216       Log.debugf("Valid token: %s", jws.getBody().getId());
217     } catch (JwtException ex) {
218       Log.warnf("Invalid token: %s", ex.getMessage());
219       return null;
220     }
221     return jws;
222   }
223 
224   private JWTPrincipal parseApiKey(String header) {
225     JWTPrincipal principal = null;
226     Jws<Claims> jws = parseApiKeyFromHeader(header);
227     if (jws != null) {
228       principal = parsePrincipalFromApiKey(jws);
229       if (principal == null) return null;
230       UUID tokenId = UUID.fromString(jws.getBody().getId());
231 
232       if (apiKeyLastSeenCache.isKeyCached(tokenId.toString())) {
233         return principal;
234       }
235 
236       var storedKey = apiKeyService.getApiKey(tokenId);
237       if (storedKey == null) {
238         Log.warn("Token was not found in database");
239         return null;
240       } else if (!storedKey.getJws().equals(header)) {
241         Log.warn("Token from header is not equal to the token from database");
242         return null;
243       }
244       apiKeyLastSeenCache.cacheKey(tokenId.toString());
245     }
246     return principal;
247   }
248 
249   private JWTPrincipal parsePrincipalFromApiKey(Jws<Claims> jws) {
250     var body = jws.getBody();
251     String subject = body.getSubject();
252     String keyId = body.getId();
253     if (subject == null || subject.isEmpty()) {
254       Log.warn("Token is missing a subject");
255       return null;
256     }
257 
258     var principal = new JWTPrincipal(subject, keyId);
259     return principal;
260   }
261 }