Passed
Pull Request — master (#2520)
by Tolga
03:24
created

openid.fetchOIDCConfiguration   A

Complexity

Conditions 3

Size

Total Lines 18
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 18
rs 10
c 0
b 0
f 0
1
package openid
2
3
import (
4
	"context"
5
	"encoding/json"
6
	"errors"
7
	"fmt"
8
	"io"
9
	"log/slog"
10
	"net/http"
11
	"strings"
12
	"sync"
13
	"time"
14
15
	"github.com/golang-jwt/jwt/v4"
16
	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
17
	"github.com/hashicorp/go-retryablehttp"
18
	"github.com/lestrrat-go/jwx/jwk"
19
20
	"github.com/Permify/permify/internal/config"
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
type Authn struct {
25
	// URL of the issuer. This is typically the base URL of the identity provider.
26
	IssuerURL string
27
	// Audience for which the token is intended. It must match the audience in the JWT.
28
	Audience string
29
	// URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures.
30
	JwksURI string
31
	// Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals.
32
	jwksSet *jwk.AutoRefresh
33
	// List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs.
34
	validMethods []string
35
	// Pointer to a JWT parser object. This is used to parse and validate the JWT tokens.
36
	jwtParser *jwt.Parser
37
	// Duration of the interval between retries for the backoff policy.
38
	backoffInterval time.Duration
39
	// Maximum number of retries for the backoff policy.
40
	backoffMaxRetries int
41
42
	backoffFrequency time.Duration
43
44
	// Global backoff state for tracking retry attempts across concurrent requests
45
	globalRetryCount int
46
	globalFirstSeen  time.Time
47
	retriedKeys      map[string]bool
48
	mutex            sync.Mutex
49
}
50
51
// NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration.
52
// It takes in a context for managing cancellation and a configuration object. It returns a pointer to an Authn instance or an error.
53
func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
54
	// Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration.
55
	client := retryablehttp.NewClient()
56
	client.Logger = SlogAdapter{Logger: slog.Default()}
57
58
	// Fetch the OIDC configuration from the issuer's well-known configuration endpoint.
59
	oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration")
60
	if err != nil {
61
		// If there is an error fetching the OIDC configuration, return nil and the error.
62
		return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
63
	}
64
65
	// Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date.
66
	autoRefresh := jwk.NewAutoRefresh(ctx)                                                                                              // Create a new AutoRefresh instance for the JWKS.
67
	autoRefresh.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval)) // Configure the auto-refresh parameters.
68
69
	// Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency
70
	backoffInterval := conf.BackoffInterval
71
	if backoffInterval <= 0 {
72
		return nil, errors.New("invalid or missing backoffInterval")
73
	}
74
75
	backoffMaxRetries := conf.BackoffMaxRetries
76
	if backoffMaxRetries <= 0 {
77
		return nil, errors.New("invalid or missing backoffMaxRetries")
78
	}
79
80
	backoffFrequency := conf.BackoffFrequency
81
	if backoffFrequency <= 0 {
82
		return nil, errors.New("invalid or missing backoffFrequency")
83
	}
84
85
	// Initialize the Authn struct with the OIDC configuration details and other relevant settings.
86
	oidc := &Authn{
87
		IssuerURL:         conf.Issuer,                                            // URL of the token issuer.
88
		Audience:          conf.Audience,                                          // Intended audience of the token.
89
		JwksURI:           oidcConf.JWKsURI,                                       // URL of the JWKS endpoint.
90
		validMethods:      conf.ValidMethods,                                      // List of acceptable signing methods for the tokens.
91
		jwtParser:         jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)), // JWT parser configured with the valid signing methods.
92
		jwksSet:           autoRefresh,                                            // Set the JWKS auto-refresh instance.
93
		backoffInterval:   backoffInterval,
94
		backoffMaxRetries: backoffMaxRetries,
95
		backoffFrequency:  backoffFrequency,
96
		globalRetryCount:  0,
97
		retriedKeys:       make(map[string]bool),
98
		globalFirstSeen:   time.Time{},
99
		mutex:             sync.Mutex{},
100
	}
101
102
	// Attempt to fetch the JWKS immediately to ensure it's available and valid.
103
	_, err = oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
104
	if err != nil {
105
		// If there is an error fetching the JWKS, return nil and the error.
106
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
107
	}
108
109
	// Return the initialized OIDC authentication object and no error.
110
	return oidc, nil
111
}
112
113
// Authenticate validates the JWT token found in the authorization header of the incoming request.
114
// It uses the OIDC configuration to validate the token against the issuer's public keys.
115
func (oidc *Authn) Authenticate(requestContext context.Context) error {
116
	// Extract the authorization header from the metadata of the incoming gRPC request.
117
	authHeader, err := grpcauth.AuthFromMD(requestContext, "Bearer")
118
	if err != nil {
119
		// Log the error if the authorization header is missing or does not start with "Bearer"
120
		slog.Error("failed to extract authorization header from gRPC request", "error", err)
121
		// Return an error indicating the missing or incorrect bearer token
122
		return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String())
123
	}
124
125
	// Log the successful extraction of the authorization header for debugging purposes.
126
	// Remember, do not log the actual content of authHeader as it might contain sensitive information.
127
	slog.Debug("Successfully extracted authorization header from gRPC request")
128
129
	// Parse and validate the JWT token extracted from the authorization header.
130
	parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) {
131
		slog.Info("starting JWT parsing and validation.")
132
133
		// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
134
		if keyID, ok := token.Header["kid"].(string); ok {
135
			return oidc.getKeyWithRetry(requestContext, keyID)
136
		}
137
		slog.Error("jwt does not contain a key ID")
138
		// If the JWT does not contain a key ID, return an error.
139
		return nil, errors.New("kid must be specified in the token header")
140
	})
141
	if err != nil {
142
		// Log that the token parsing or validation failed
143
		slog.Error("token parsing or validation failed", "error", err)
144
		// If token parsing or validation fails, return an error indicating the token is invalid.
145
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
146
	}
147
148
	// Ensure the token is valid.
149
	if !parsedToken.Valid {
150
		// Log that the parsed token was not valid
151
		slog.Warn("parsed token is invalid")
152
		// Return an error indicating the invalid token
153
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
154
	}
155
156
	// Extract the claims from the token.
157
	claims, ok := parsedToken.Claims.(jwt.MapClaims)
158
	if !ok {
159
		// Log that the claims were in an incorrect format
160
		slog.Warn("token claims are in an incorrect format")
161
		// Return an error
162
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String())
163
	}
164
165
	slog.Debug("extracted token claims", "claims", claims)
166
167
	// Verify the issuer of the token matches the expected issuer.
168
	if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok {
169
		// Log that the issuer did not match the expected issuer
170
		slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"])
171
		// Return an error
172
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String())
173
	}
174
175
	// Verify the audience of the token matches the expected audience.
176
	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
177
		// Log that the audience did not match the expected audience
178
		slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"])
179
		// Return an error
180
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String())
181
	}
182
183
	// Log that the token's issuer and audience were successfully validated
184
	slog.Info("token validation succeeded")
185
186
	// If all validations pass, return nil indicating the token is valid.
187
	return nil
188
}
189
190
// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
191
func (oidc *Authn) getKeyWithRetry(ctx context.Context, keyID string) (interface{}, error) {
192
	var rawKey interface{}
193
	var err error
194
195
	oidc.mutex.Lock()
196
	now := time.Now()
197
198
	// Reset global state if the interval has passed
199
	if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval {
200
		slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
201
		oidc.globalFirstSeen = now
202
		oidc.globalRetryCount = 0
203
		oidc.retriedKeys = make(map[string]bool)
204
	} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
205
		// If max retries reached within the interval, unlock and check keyID once
206
		slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
207
		oidc.mutex.Unlock()
208
209
		// Try to fetch the keyID once
210
		rawKey, err = oidc.fetchKey(ctx, keyID)
211
		if err == nil {
212
			oidc.mutex.Lock()
213
			if _, ok := oidc.retriedKeys[keyID]; ok {
214
				// Reset global backoff state if a valid key is found and that key had been retried.
215
				// Use case would be someone trying to exploit with bad KeyIDs, and along comes a valid KeyID
216
				// The valid KeyID should not reset the counters for a bad key
217
				slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
218
				oidc.globalRetryCount = 0
219
				oidc.globalFirstSeen = time.Time{}
220
				oidc.retriedKeys = make(map[string]bool)
221
			}
222
			oidc.mutex.Unlock()
223
			return rawKey, nil
224
		}
225
226
		// Log the failure and return an error if keyID is not found
227
		slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err)
228
		return nil, errors.New("too many attempts, backoff in effect")
229
	}
230
	oidc.mutex.Unlock()
231
232
	// Retry mechanism
233
	retries := 0
234
	for retries <= oidc.backoffMaxRetries {
235
		rawKey, err = oidc.fetchKey(ctx, keyID)
236
		if err == nil {
237
			if retries != 0 {
238
				oidc.mutex.Lock()
239
				oidc.globalRetryCount = 0
240
				oidc.globalFirstSeen = time.Time{}
241
				oidc.retriedKeys = make(map[string]bool)
242
				oidc.mutex.Unlock()
243
			}
244
			return rawKey, nil
245
		}
246
		oidc.mutex.Lock()
247
		snapshotRetryCount := oidc.globalRetryCount
248
		oidc.retriedKeys[keyID] = true
249
		if oidc.globalRetryCount > oidc.backoffMaxRetries {
250
			slog.Error("key ID not found in JWKS due to global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
251
			oidc.mutex.Unlock()
252
			return nil, errors.New("too many attempts, backoff in effect due to global retry count")
253
		}
254
		oidc.mutex.Unlock()
255
		if retries > 0 {
256
			select {
257
			case <-time.After(oidc.backoffFrequency):
258
				// Log the wait before retrying
259
				slog.Info("waiting before retrying", "keyID", keyID, "retries", retries)
260
			case <-ctx.Done():
261
				slog.Error("context cancelled during retry", "keyID", keyID)
262
				return nil, ctx.Err()
263
			}
264
		}
265
266
		oidc.mutex.Lock()
267
		if oidc.globalRetryCount > snapshotRetryCount {
268
			// Concurrent requests in retry loop at same time, another concurrent request already refreshed the JWKS
269
			retries++
270
			slog.Warn("another concurrent request already refreshed the JWKS")
271
			oidc.mutex.Unlock()
272
			continue
273
		}
274
275
		oidc.globalRetryCount++
276
		slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
277
		retries++
278
279
		if _, refreshErr := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); refreshErr != nil {
280
			oidc.mutex.Unlock()
281
			slog.Error("failed to refresh JWKS", "error", refreshErr)
282
			return nil, refreshErr
283
		}
284
		// Unlock needs to follow Refresh to ensure that concurrent requests don't make duplicate calls to Refresh
285
		oidc.mutex.Unlock()
286
	}
287
288
	// Mark the global state to prevent further retries for the backoff interval
289
	oidc.mutex.Lock()
290
	if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval {
291
		slog.Warn("marking state to prevent further retries", "keyID", keyID)
292
		oidc.globalRetryCount = oidc.backoffMaxRetries
293
	}
294
	oidc.mutex.Unlock()
295
296
	slog.Error("key ID not found in JWKS after retries", "keyID", keyID)
297
	return nil, errors.New("key ID not found in JWKS after retries")
298
}
299
300
// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
301
func (oidc *Authn) fetchKey(ctx context.Context, keyID string) (interface{}, error) {
302
	// Log the attempt to find the key.
303
	slog.DebugContext(ctx, "attempting to find key in JWKS", "kid", keyID)
304
305
	// Fetch the JWKS from the configured URI.
306
	jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
307
	if err != nil {
308
		// Log an error and return if fetching fails.
309
		slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err)
310
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
311
	}
312
313
	// Log a successful fetch of the JWKS.
314
	slog.InfoContext(ctx, "successfully fetched JWKS")
315
316
	// Attempt to find the key in the fetched JWKS using the key ID.
317
	if key, found := jwks.LookupKeyID(keyID); found {
318
		var k interface{}
319
		// Convert the key to a usable format.
320
		if err := key.Raw(&k); err != nil {
321
			slog.ErrorContext(ctx, "failed to get raw public key", "kid", keyID, "error", err)
322
			return nil, fmt.Errorf("failed to get raw public key: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
323
		}
324
		// Log a successful retrieval of the raw public key.
325
		slog.DebugContext(ctx, "successfully obtained raw public key", "key", k)
326
		return k, nil // Return the public key for JWT signature verification.
327
	}
328
	// Log an error if the key ID is not found in the JWKS.
329
	slog.ErrorContext(ctx, "key ID not found in JWKS", "kid", keyID)
330
	return nil, fmt.Errorf("kid %s not found", keyID)
331
}
332
333
// Config holds OpenID Connect (OIDC) configuration details.
334
type Config struct {
335
	// Issuer is the OIDC provider's unique identifier URL.
336
	Issuer string `json:"issuer"`
337
	// JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer.
338
	JWKsURI string `json:"jwks_uri"`
339
}
340
341
// fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration.
342
// It requires an HTTP client and the URL from which to fetch the configuration.
343
func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) {
344
	// Send an HTTP GET request to the provided URL to fetch the OIDC configuration.
345
	// This typically points to the well-known configuration endpoint of the OIDC provider.
346
	body, err := doHTTPRequest(client, url)
347
	if err != nil {
348
		// If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error.
349
		return nil, err
350
	}
351
352
	// Parse the JSON response body into an OIDC Config struct.
353
	// This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration.
354
	oidcConfig, err := parseOIDCConfiguration(body)
355
	if err != nil {
356
		return nil, err
357
	}
358
359
	// Return the parsed OIDC configuration and nil as the error (indicating success).
360
	return oidcConfig, nil
361
}
362
363
// doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body.
364
func doHTTPRequest(client *http.Client, url string) ([]byte, error) {
365
	// Log the attempt to create a new HTTP GET request
366
	slog.Debug("creating new HTTP GET request", "url", url)
367
368
	// Create a new HTTP GET request.
369
	req, err := http.NewRequest("GET", url, nil)
370
	if err != nil {
371
		slog.Error("failed to create HTTP request", "url", url, "error", err)
372
		return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %s", err)
373
	}
374
375
	// Log the execution of the HTTP request
376
	slog.Debug("executing HTTP request", "url", url)
377
378
	// Send the request using the configured HTTP client.
379
	res, err := client.Do(req)
380
	if err != nil {
381
		// Log the error if executing the HTTP request fails
382
		slog.Error("failed to execute HTTP request", "url", url, "error", err)
383
		return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %s", err)
384
	}
385
386
	// Log the HTTP status code of the response
387
	slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url)
388
389
	// Ensure the response body is closed after reading.
390
	defer res.Body.Close()
391
392
	// Check if the HTTP status code indicates success.
393
	if res.StatusCode != http.StatusOK {
394
		slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url)
395
		return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode)
396
	}
397
398
	// Log the attempt to read the response body
399
	slog.Debug("reading response body", "url", url)
400
401
	// Read the response body.
402
	body, err := io.ReadAll(res.Body)
403
	if err != nil {
404
		slog.Error("failed to read response body", "url", url, "error", err)
405
		return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %s", err)
406
	}
407
408
	// Log the successful retrieval of the response body
409
	slog.Debug("successfully read response body", "url", url, "response_length", len(body))
410
411
	// Return the response body.
412
	return body, nil
413
}
414
415
// parseOIDCConfiguration decodes the OIDC configuration from the given JSON body.
416
func parseOIDCConfiguration(body []byte) (*Config, error) {
417
	var oidcConfig Config
418
	// Attempt to unmarshal the JSON body into the oidcConfig struct.
419
	if err := json.Unmarshal(body, &oidcConfig); err != nil {
420
		slog.Error("failed to unmarshal OIDC configuration", "error", err)
421
		return nil, fmt.Errorf("failed to decode OIDC configuration: %s", err)
422
	}
423
	// Log the successful decoding of OIDC configuration
424
	slog.Debug("successfully decoded OIDC configuration")
425
426
	if oidcConfig.Issuer == "" {
427
		slog.Warn("missing issuer value in OIDC configuration")
428
		return nil, errors.New("issuer value is required but missing in OIDC configuration")
429
	}
430
431
	if oidcConfig.JWKsURI == "" {
432
		slog.Warn("missing JWKsURI value in OIDC configuration")
433
		return nil, errors.New("JWKsURI value is required but missing in OIDC configuration")
434
	}
435
436
	// Log the successful parsing of the OIDC configuration
437
	slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI)
438
439
	// Return the successfully parsed configuration.
440
	return &oidcConfig, nil
441
}
442