1 | package oidc |
||
2 | |||
3 | import ( |
||
4 | "context" |
||
5 | "encoding/json" |
||
6 | "errors" |
||
7 | "fmt" |
||
8 | "io" |
||
9 | "log/slog" |
||
10 | "net/http" |
||
11 | "strings" |
||
12 | "sync" |
||
13 | "time" |
||
14 | |||
15 | "github.com/golang-jwt/jwt/v4" |
||
16 | grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" |
||
17 | "github.com/hashicorp/go-retryablehttp" |
||
18 | "github.com/lestrrat-go/jwx/jwk" |
||
19 | |||
20 | "github.com/Permify/permify/internal/config" |
||
21 | base "github.com/Permify/permify/pkg/pb/base/v1" |
||
22 | ) |
||
23 | |||
24 | type Authn struct { |
||
25 | // URL of the issuer. This is typically the base URL of the identity provider. |
||
26 | IssuerURL string |
||
27 | // Audience for which the token is intended. It must match the audience in the JWT. |
||
28 | Audience string |
||
29 | // URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures. |
||
30 | JwksURI string |
||
31 | // Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals. |
||
32 | jwksSet *jwk.AutoRefresh |
||
33 | // List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs. |
||
34 | validMethods []string |
||
35 | // Pointer to a JWT parser object. This is used to parse and validate the JWT tokens. |
||
36 | jwtParser *jwt.Parser |
||
37 | // Duration of the interval between retries for the backoff policy. |
||
38 | backoffInterval time.Duration |
||
39 | // Maximum number of retries for the backoff policy. |
||
40 | backoffMaxRetries int |
||
41 | |||
42 | backoffFrequency time.Duration |
||
43 | |||
44 | // Global backoff state |
||
45 | globalRetryCount int |
||
46 | globalFirstSeen time.Time |
||
47 | globalRetryKeyIds map[string]bool |
||
48 | mu sync.Mutex |
||
49 | } |
||
50 | |||
51 | // NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration. |
||
52 | // It takes in a context for managing cancellation and a configuration object. It returns a pointer to an Authn instance or an error. |
||
53 | func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) { |
||
54 | // Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration. |
||
55 | client := retryablehttp.NewClient() |
||
56 | client.Logger = SlogAdapter{Logger: slog.Default()} |
||
57 | |||
58 | // Fetch the OIDC configuration from the issuer's well-known configuration endpoint. |
||
59 | oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration") |
||
60 | if err != nil { |
||
61 | // If there is an error fetching the OIDC configuration, return nil and the error. |
||
62 | return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err) |
||
0 ignored issues
–
show
introduced
by
![]() |
|||
63 | } |
||
64 | |||
65 | // Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date. |
||
66 | ar := jwk.NewAutoRefresh(ctx) // Create a new AutoRefresh instance for the JWKS. |
||
67 | ar.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval)) // Configure the auto-refresh parameters. |
||
68 | |||
69 | // Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency |
||
70 | backoffInterval := conf.BackoffInterval |
||
71 | if backoffInterval <= 0 { |
||
72 | return nil, errors.New("invalid or missing backoffInterval") |
||
73 | } |
||
74 | |||
75 | backoffMaxRetries := conf.BackoffMaxRetries |
||
76 | if backoffMaxRetries <= 0 { |
||
77 | return nil, errors.New("invalid or missing backoffMaxRetries") |
||
78 | } |
||
79 | |||
80 | backoffFrequency := conf.BackoffFrequency |
||
81 | if backoffFrequency <= 0 { |
||
82 | return nil, errors.New("invalid or missing backoffFrequency") |
||
83 | } |
||
84 | |||
85 | // Initialize the Authn struct with the OIDC configuration details and other relevant settings. |
||
86 | oidc := &Authn{ |
||
87 | IssuerURL: conf.Issuer, // URL of the token issuer. |
||
88 | Audience: conf.Audience, // Intended audience of the token. |
||
89 | JwksURI: oidcConf.JWKsURI, // URL of the JWKS endpoint. |
||
90 | validMethods: conf.ValidMethods, // List of acceptable signing methods for the tokens. |
||
91 | jwtParser: jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)), // JWT parser configured with the valid signing methods. |
||
92 | jwksSet: ar, // Set the JWKS auto-refresh instance. |
||
93 | backoffInterval: backoffInterval, |
||
94 | backoffMaxRetries: backoffMaxRetries, |
||
95 | backoffFrequency: backoffFrequency, |
||
96 | globalRetryCount: 0, |
||
97 | globalRetryKeyIds: make(map[string]bool), |
||
98 | globalFirstSeen: time.Time{}, |
||
99 | mu: sync.Mutex{}, |
||
100 | } |
||
101 | |||
102 | // Attempt to fetch the JWKS immediately to ensure it's available and valid. |
||
103 | _, err = oidc.jwksSet.Fetch(ctx, oidc.JwksURI) |
||
104 | if err != nil { |
||
105 | // If there is an error fetching the JWKS, return nil and the error. |
||
106 | return nil, fmt.Errorf("failed to fetch JWKS: %w", err) |
||
0 ignored issues
–
show
|
|||
107 | } |
||
108 | |||
109 | // Return the initialized OIDC authentication object and no error. |
||
110 | return oidc, nil |
||
111 | } |
||
112 | |||
113 | // Authenticate validates the JWT token found in the authorization header of the incoming request. |
||
114 | // It uses the OIDC configuration to validate the token against the issuer's public keys. |
||
115 | func (oidc *Authn) Authenticate(requestContext context.Context) error { |
||
116 | // Extract the authorization header from the metadata of the incoming gRPC request. |
||
117 | authHeader, err := grpcauth.AuthFromMD(requestContext, "Bearer") |
||
118 | if err != nil { |
||
119 | // Log the error if the authorization header is missing or does not start with "Bearer" |
||
120 | slog.Error("failed to extract authorization header from gRPC request", "error", err) |
||
121 | // Return an error indicating the missing or incorrect bearer token |
||
122 | return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String()) |
||
123 | } |
||
124 | |||
125 | // Log the successful extraction of the authorization header for debugging purposes. |
||
126 | // Remember, do not log the actual content of authHeader as it might contain sensitive information. |
||
127 | slog.Debug("Successfully extracted authorization header from gRPC request") |
||
128 | |||
129 | // Parse and validate the JWT token extracted from the authorization header. |
||
130 | parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) { |
||
131 | slog.Info("starting JWT parsing and validation.") |
||
132 | |||
133 | // Retrieve the key ID from the JWT header and find the corresponding key in the JWKS. |
||
134 | if keyID, ok := token.Header["kid"].(string); ok { |
||
135 | return oidc.getKeyWithRetry(requestContext, keyID) |
||
136 | } |
||
137 | slog.Error("jwt does not contain a key ID") |
||
138 | // If the JWT does not contain a key ID, return an error. |
||
139 | return nil, errors.New("kid must be specified in the token header") |
||
140 | }) |
||
141 | if err != nil { |
||
142 | // Log that the token parsing or validation failed |
||
143 | slog.Error("token parsing or validation failed", "error", err) |
||
144 | // If token parsing or validation fails, return an error indicating the token is invalid. |
||
145 | return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()) |
||
146 | } |
||
147 | |||
148 | // Ensure the token is valid. |
||
149 | if !parsedToken.Valid { |
||
150 | // Log that the parsed token was not valid |
||
151 | slog.Warn("parsed token is invalid") |
||
152 | // Return an error indicating the invalid token |
||
153 | return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()) |
||
154 | } |
||
155 | |||
156 | // Extract the claims from the token. |
||
157 | claims, ok := parsedToken.Claims.(jwt.MapClaims) |
||
158 | if !ok { |
||
159 | // Log that the claims were in an incorrect format |
||
160 | slog.Warn("token claims are in an incorrect format") |
||
161 | // Return an error |
||
162 | return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String()) |
||
163 | } |
||
164 | |||
165 | slog.Debug("extracted token claims", "claims", claims) |
||
166 | |||
167 | // Verify the issuer of the token matches the expected issuer. |
||
168 | if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok { |
||
169 | // Log that the issuer did not match the expected issuer |
||
170 | slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"]) |
||
171 | // Return an error |
||
172 | return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String()) |
||
173 | } |
||
174 | |||
175 | // Verify the audience of the token matches the expected audience. |
||
176 | if ok := claims.VerifyAudience(oidc.Audience, true); !ok { |
||
177 | // Log that the audience did not match the expected audience |
||
178 | slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"]) |
||
179 | // Return an error |
||
180 | return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String()) |
||
181 | } |
||
182 | |||
183 | // Log that the token's issuer and audience were successfully validated |
||
184 | slog.Info("token validation succeeded") |
||
185 | |||
186 | // If all validations pass, return nil indicating the token is valid. |
||
187 | return nil |
||
188 | } |
||
189 | |||
190 | // getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy. |
||
191 | func (oidc *Authn) getKeyWithRetry(ctx context.Context, keyID string) (interface{}, error) { |
||
192 | var rawKey interface{} |
||
193 | var err error |
||
194 | |||
195 | oidc.mu.Lock() |
||
196 | now := time.Now() |
||
197 | |||
198 | // Reset global state if the interval has passed |
||
199 | if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval { |
||
200 | slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID) |
||
201 | oidc.globalFirstSeen = now |
||
202 | oidc.globalRetryCount = 0 |
||
203 | oidc.globalRetryKeyIds = make(map[string]bool) |
||
204 | } else if oidc.globalRetryCount >= oidc.backoffMaxRetries { |
||
205 | // If max retries reached within the interval, unlock and check keyID once |
||
206 | slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID) |
||
207 | oidc.mu.Unlock() |
||
208 | |||
209 | // Try to fetch the keyID once |
||
210 | rawKey, err = oidc.fetchKey(ctx, keyID) |
||
211 | if err == nil { |
||
212 | oidc.mu.Lock() |
||
213 | if _, ok := oidc.globalRetryKeyIds[keyID]; ok { |
||
214 | // Reset global backoff state if a valid key is found and that key had been retried. |
||
215 | // Use case would be someone trying to exploit with bad KeyIDs, and along comes a valid KeyID |
||
216 | // The valid KeyID should not reset the counters for a bad key |
||
217 | slog.Info("valid key found during backoff period, resetting state", "keyID", keyID) |
||
218 | oidc.globalRetryCount = 0 |
||
219 | oidc.globalFirstSeen = time.Time{} |
||
220 | oidc.globalRetryKeyIds = make(map[string]bool) |
||
221 | } |
||
222 | oidc.mu.Unlock() |
||
223 | return rawKey, nil |
||
224 | } |
||
225 | |||
226 | // Log the failure and return an error if keyID is not found |
||
227 | slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err) |
||
228 | return nil, errors.New("too many attempts, backoff in effect") |
||
229 | } |
||
230 | oidc.mu.Unlock() |
||
231 | |||
232 | // Retry mechanism |
||
233 | retries := 0 |
||
234 | for retries <= oidc.backoffMaxRetries { |
||
235 | rawKey, err = oidc.fetchKey(ctx, keyID) |
||
236 | if err == nil { |
||
237 | if retries != 0 { |
||
238 | oidc.mu.Lock() |
||
239 | oidc.globalRetryCount = 0 |
||
240 | oidc.globalFirstSeen = time.Time{} |
||
241 | oidc.globalRetryKeyIds = make(map[string]bool) |
||
242 | oidc.mu.Unlock() |
||
243 | } |
||
244 | return rawKey, nil |
||
245 | } |
||
246 | oidc.mu.Lock() |
||
247 | initialGlobalRetryCount := oidc.globalRetryCount |
||
248 | oidc.globalRetryKeyIds[keyID] = true |
||
249 | if oidc.globalRetryCount > oidc.backoffMaxRetries { |
||
250 | slog.Error("key ID not found in JWKS due to global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount) |
||
251 | oidc.mu.Unlock() |
||
252 | return nil, errors.New("too many attempts, backoff in effect due to global retry count") |
||
253 | } |
||
254 | oidc.mu.Unlock() |
||
255 | if retries > 0 { |
||
256 | select { |
||
257 | case <-time.After(oidc.backoffFrequency): |
||
258 | // Log the wait before retrying |
||
259 | slog.Info("waiting before retrying", "keyID", keyID, "retries", retries) |
||
260 | case <-ctx.Done(): |
||
261 | slog.Error("context cancelled during retry", "keyID", keyID) |
||
262 | return nil, ctx.Err() |
||
263 | } |
||
264 | } |
||
265 | |||
266 | oidc.mu.Lock() |
||
267 | if oidc.globalRetryCount > initialGlobalRetryCount { |
||
268 | // Concurrent requests in retry loop at same time, another concurrent request already refreshed the JWKS |
||
269 | retries++ |
||
270 | slog.Warn("another concurrent request already refreshed the JWKS") |
||
271 | oidc.mu.Unlock() |
||
272 | continue |
||
273 | } |
||
274 | |||
275 | oidc.globalRetryCount++ |
||
276 | slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err) |
||
277 | retries++ |
||
278 | |||
279 | if _, refreshErr := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); refreshErr != nil { |
||
280 | oidc.mu.Unlock() |
||
281 | slog.Error("failed to refresh JWKS", "error", refreshErr) |
||
282 | return nil, refreshErr |
||
283 | } |
||
284 | // Unlock needs to follow Refresh to ensure that concurrent requests don't make duplicate calls to Refresh |
||
285 | oidc.mu.Unlock() |
||
286 | } |
||
287 | |||
288 | // Mark the global state to prevent further retries for the backoff interval |
||
289 | oidc.mu.Lock() |
||
290 | if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval { |
||
291 | slog.Warn("marking state to prevent further retries", "keyID", keyID) |
||
292 | oidc.globalRetryCount = oidc.backoffMaxRetries |
||
293 | } |
||
294 | oidc.mu.Unlock() |
||
295 | |||
296 | slog.Error("key ID not found in JWKS after retries", "keyID", keyID) |
||
297 | return nil, errors.New("key ID not found in JWKS after retries") |
||
298 | } |
||
299 | |||
300 | // fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID. |
||
301 | func (oidc *Authn) fetchKey(ctx context.Context, keyID string) (interface{}, error) { |
||
302 | // Log the attempt to find the key. |
||
303 | slog.DebugContext(ctx, "attempting to find key in JWKS", "kid", keyID) |
||
304 | |||
305 | // Fetch the JWKS from the configured URI. |
||
306 | jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI) |
||
307 | if err != nil { |
||
308 | // Log an error and return if fetching fails. |
||
309 | slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err) |
||
310 | return nil, fmt.Errorf("failed to fetch JWKS: %w", err) |
||
0 ignored issues
–
show
|
|||
311 | } |
||
312 | |||
313 | // Log a successful fetch of the JWKS. |
||
314 | slog.InfoContext(ctx, "successfully fetched JWKS") |
||
315 | |||
316 | // Attempt to find the key in the fetched JWKS using the key ID. |
||
317 | if key, found := jwks.LookupKeyID(keyID); found { |
||
318 | var k interface{} |
||
319 | // Convert the key to a usable format. |
||
320 | if err := key.Raw(&k); err != nil { |
||
321 | slog.ErrorContext(ctx, "failed to get raw public key", "kid", keyID, "error", err) |
||
322 | return nil, fmt.Errorf("failed to get raw public key: %w", err) |
||
0 ignored issues
–
show
|
|||
323 | } |
||
324 | // Log a successful retrieval of the raw public key. |
||
325 | slog.DebugContext(ctx, "successfully obtained raw public key", "key", k) |
||
326 | return k, nil // Return the public key for JWT signature verification. |
||
327 | } |
||
328 | // Log an error if the key ID is not found in the JWKS. |
||
329 | slog.ErrorContext(ctx, "key ID not found in JWKS", "kid", keyID) |
||
330 | return nil, fmt.Errorf("kid %s not found", keyID) |
||
331 | } |
||
332 | |||
333 | // Config holds OpenID Connect (OIDC) configuration details. |
||
334 | type Config struct { |
||
335 | // Issuer is the OIDC provider's unique identifier URL. |
||
336 | Issuer string `json:"issuer"` |
||
337 | // JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer. |
||
338 | JWKsURI string `json:"jwks_uri"` |
||
339 | } |
||
340 | |||
341 | // fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration. |
||
342 | // It requires an HTTP client and the URL from which to fetch the configuration. |
||
343 | func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) { |
||
344 | // Send an HTTP GET request to the provided URL to fetch the OIDC configuration. |
||
345 | // This typically points to the well-known configuration endpoint of the OIDC provider. |
||
346 | body, err := doHTTPRequest(client, url) |
||
347 | if err != nil { |
||
348 | // If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error. |
||
349 | return nil, err |
||
350 | } |
||
351 | |||
352 | // Parse the JSON response body into an OIDC Config struct. |
||
353 | // This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration. |
||
354 | oidcConfig, err := parseOIDCConfiguration(body) |
||
355 | if err != nil { |
||
356 | return nil, err |
||
357 | } |
||
358 | |||
359 | // Return the parsed OIDC configuration and nil as the error (indicating success). |
||
360 | return oidcConfig, nil |
||
361 | } |
||
362 | |||
363 | // doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body. |
||
364 | func doHTTPRequest(client *http.Client, url string) ([]byte, error) { |
||
365 | // Log the attempt to create a new HTTP GET request |
||
366 | slog.Debug("creating new HTTP GET request", "url", url) |
||
367 | |||
368 | // Create a new HTTP GET request. |
||
369 | req, err := http.NewRequest("GET", url, nil) |
||
370 | if err != nil { |
||
371 | slog.Error("failed to create HTTP request", "url", url, "error", err) |
||
372 | return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %s", err) |
||
373 | } |
||
374 | |||
375 | // Log the execution of the HTTP request |
||
376 | slog.Debug("executing HTTP request", "url", url) |
||
377 | |||
378 | // Send the request using the configured HTTP client. |
||
379 | res, err := client.Do(req) |
||
380 | if err != nil { |
||
381 | // Log the error if executing the HTTP request fails |
||
382 | slog.Error("failed to execute HTTP request", "url", url, "error", err) |
||
383 | return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %s", err) |
||
384 | } |
||
385 | |||
386 | // Log the HTTP status code of the response |
||
387 | slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url) |
||
388 | |||
389 | // Ensure the response body is closed after reading. |
||
390 | defer res.Body.Close() |
||
391 | |||
392 | // Check if the HTTP status code indicates success. |
||
393 | if res.StatusCode != http.StatusOK { |
||
394 | slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url) |
||
395 | return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode) |
||
396 | } |
||
397 | |||
398 | // Log the attempt to read the response body |
||
399 | slog.Debug("reading response body", "url", url) |
||
400 | |||
401 | // Read the response body. |
||
402 | body, err := io.ReadAll(res.Body) |
||
403 | if err != nil { |
||
404 | slog.Error("failed to read response body", "url", url, "error", err) |
||
405 | return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %s", err) |
||
406 | } |
||
407 | |||
408 | // Log the successful retrieval of the response body |
||
409 | slog.Debug("successfully read response body", "url", url, "response_length", len(body)) |
||
410 | |||
411 | // Return the response body. |
||
412 | return body, nil |
||
413 | } |
||
414 | |||
415 | // parseOIDCConfiguration decodes the OIDC configuration from the given JSON body. |
||
416 | func parseOIDCConfiguration(body []byte) (*Config, error) { |
||
417 | var oidcConfig Config |
||
418 | // Attempt to unmarshal the JSON body into the oidcConfig struct. |
||
419 | if err := json.Unmarshal(body, &oidcConfig); err != nil { |
||
420 | slog.Error("failed to unmarshal OIDC configuration", "error", err) |
||
421 | return nil, fmt.Errorf("failed to decode OIDC configuration: %s", err) |
||
422 | } |
||
423 | // Log the successful decoding of OIDC configuration |
||
424 | slog.Debug("successfully decoded OIDC configuration") |
||
425 | |||
426 | if oidcConfig.Issuer == "" { |
||
427 | slog.Warn("missing issuer value in OIDC configuration") |
||
428 | return nil, errors.New("issuer value is required but missing in OIDC configuration") |
||
429 | } |
||
430 | |||
431 | if oidcConfig.JWKsURI == "" { |
||
432 | slog.Warn("missing JWKsURI value in OIDC configuration") |
||
433 | return nil, errors.New("JWKsURI value is required but missing in OIDC configuration") |
||
434 | } |
||
435 | |||
436 | // Log the successful parsing of the OIDC configuration |
||
437 | slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI) |
||
438 | |||
439 | // Return the successfully parsed configuration. |
||
440 | return &oidcConfig, nil |
||
441 | } |
||
442 |