openid.*Authn.fetchKey   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 33
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 19
nop 2
dl 0
loc 33
rs 9.45
c 0
b 0
f 0
1
package openid // OpenID Connect authentication package
2
import (       // Package imports
3
	"context" // Context for request cancellation and deadlines
4
	"encoding/json"
5
	"errors"
6
	"fmt" // formatting package for error messages
7
	"io"
8
	"log/slog"
9
	"net/http" // HTTP client and server implementations
10
	"strings"
11
	"sync"
12
	"time"
13
14
	// JWT package
15
	"github.com/golang-jwt/jwt/v4"
16
	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
17
	"github.com/hashicorp/go-retryablehttp"
18
19
	// JWT key sets
20
	"github.com/lestrrat-go/jwx/jwk" // JWT key sets
21
	// Internal packages
22
	"github.com/Permify/permify/internal/config"     // internal configuration
23
	base "github.com/Permify/permify/pkg/pb/base/v1" // Base protobuf definitions
24
) // End of imports
25
26
type Authn struct {
27
	// URL of the issuer. This is typically the base URL of the identity provider.
28
	IssuerURL string
29
	// Audience for which the token is intended. It must match the audience in the JWT.
30
	Audience string
31
	// URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures.
32
	JwksURI string
33
	// Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals.
34
	jwksSet *jwk.AutoRefresh
35
	// List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs.
36
	validMethods []string
37
	// Pointer to a JWT parser object. This is used to parse and validate the JWT tokens.
38
	jwtParser *jwt.Parser
39
	// Duration of the interval between retries for the backoff policy.
40
	backoffInterval time.Duration
41
	// Maximum number of retries for the backoff policy.
42
	backoffMaxRetries int
43
44
	backoffFrequency time.Duration
45
46
	// Global backoff state for tracking retry attempts across concurrent requests
47
	globalRetryCount int
48
	globalFirstSeen  time.Time
49
	retriedKeys      map[string]bool
50
	mutex            sync.Mutex // protects concurrent access to retry state
51
} // End of Authn struct
52
// NewOidcAuthn creates a new OIDC authenticator
53
// NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration.
54
55
// It takes in a context for managing cancellation and a configuration object. It returns a pointer to an Authn instance or an error.
56
func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
57
	// Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration.
58
	client := retryablehttp.NewClient()
59
	client.Logger = SlogAdapter{Logger: slog.Default()}
60
61
	// Fetch the OIDC configuration from the issuer's well-known configuration endpoint.
62
	oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration")
63
	if err != nil {
64
		// If there is an error fetching the OIDC configuration, return nil and the error.
65
		return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
66
	}
67
68
	// Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date.
69
	autoRefresh := jwk.NewAutoRefresh(ctx)                                                                                              // Create a new AutoRefresh instance for the JWKS.
70
	autoRefresh.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval)) // Configure the auto-refresh parameters.
71
72
	// Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency
73
	backoffInterval := conf.BackoffInterval
74
	if backoffInterval <= 0 {
75
		return nil, errors.New("invalid or missing backoffInterval")
76
	}
77
78
	backoffMaxRetries := conf.BackoffMaxRetries
79
	if backoffMaxRetries <= 0 {
80
		return nil, errors.New("invalid or missing backoffMaxRetries")
81
	}
82
83
	backoffFrequency := conf.BackoffFrequency
84
	if backoffFrequency <= 0 {
85
		return nil, errors.New("invalid or missing backoffFrequency")
86
	}
87
88
	// Initialize the Authn struct with the OIDC configuration details and other relevant settings.
89
	oidc := &Authn{
90
		IssuerURL:         conf.Issuer,                                            // URL of the token issuer.
91
		Audience:          conf.Audience,                                          // Intended audience of the token.
92
		JwksURI:           oidcConf.JWKsURI,                                       // URL of the JWKS endpoint.
93
		validMethods:      conf.ValidMethods,                                      // List of acceptable signing methods for the tokens.
94
		jwtParser:         jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)), // JWT parser configured with the valid signing methods.
95
		jwksSet:           autoRefresh,                                            // Set the JWKS auto-refresh instance.
96
		backoffInterval:   backoffInterval,
97
		backoffMaxRetries: backoffMaxRetries,
98
		backoffFrequency:  backoffFrequency,
99
		globalRetryCount:  0,
100
		retriedKeys:       make(map[string]bool),
101
		globalFirstSeen:   time.Time{},
102
		mutex:             sync.Mutex{},
103
	}
104
105
	// Attempt to fetch the JWKS immediately to ensure it's available and valid.
106
	if _, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI); err != nil {
107
		// If there is an error fetching the JWKS, return nil and the error.
108
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
109
	} // End of JWKS fetch error handling
110
	// Return the initialized OIDC authentication object and no error.
111
112
	return oidc, nil // Successfully initialized
113
} // End of NewOidcAuthn
114
// Authenticate validates the JWT token found in the authorization header of the incoming request.
115
116
// It uses the OIDC configuration to validate the token against the issuer's public keys.
117
func (oidc *Authn) Authenticate(ctx context.Context) error {
118
	// Extract the authorization header from the metadata of the incoming gRPC request.
119
	authHeader, err := grpcauth.AuthFromMD(ctx, "Bearer")
120
	if err != nil { // Check for authentication errors
121
		// Log the error if the authorization header is missing or does not start with "Bearer"
122
		slog.Error("failed to extract authorization header from gRPC request", "error", err)
123
		// Return an error indicating the missing or incorrect bearer token
124
		return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String())
125
	} // End of auth header error handling
126
	// Log the successful extraction of the authorization header for debugging purposes.
127
128
	// Remember, do not log the actual content of authHeader as it might contain sensitive information.
129
	slog.Debug("Successfully extracted authorization header from gRPC request")
130
131
	// Parse and validate the JWT token extracted from the authorization header.
132
	parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) {
133
		slog.Info("starting JWT parsing and validation.")
134
135
		// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
136
		keyID, ok := token.Header["kid"].(string)
137
		if ok { // Key ID found in token header
138
			return oidc.getKeyWithRetry(ctx, keyID)
139
		} // End of key ID check
140
		slog.Error("jwt does not contain a key ID")
141
		// If the JWT does not contain a key ID, return an error.
142
		return nil, errors.New("kid must be specified in the token header")
143
	})
144
	if err != nil { // Check for parsing errors
145
		// Log that the token parsing or validation failed
146
		slog.Error("token parsing or validation failed", "error", err)
147
		// If token parsing or validation fails, return an error indicating the token is invalid.
148
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
149
	}
150
151
	// Ensure the token is valid.
152
	if !parsedToken.Valid {
153
		// Log that the parsed token was not valid
154
		slog.Warn("parsed token is invalid")
155
		// Return an error indicating the invalid token
156
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
157
	}
158
159
	// Extract the claims from the token.
160
	claims, ok := parsedToken.Claims.(jwt.MapClaims)
161
	if !ok {
162
		// Log that the claims were in an incorrect format
163
		slog.Warn("token claims are in an incorrect format")
164
		// Return an error
165
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String())
166
	}
167
168
	slog.Debug("extracted token claims", "claims", claims)
169
170
	// Verify the issuer of the token matches the expected issuer.
171
	if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok {
172
		// Log that the issuer did not match the expected issuer
173
		slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"])
174
		// Return an error
175
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String())
176
	} // End of issuer verification
177
	// Verify the audience of the token matches the expected audience.
178
179
	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
180
		// Log that the audience did not match the expected audience
181
		slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"])
182
		// Return an error
183
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String())
184
	} // End of audience verification
185
186
	// Log that the token's issuer and audience were successfully validated
187
	slog.Info("token validation succeeded")
188
189
	// If all validations pass, return nil indicating the token is valid.
190
	return nil // Token is valid
191
} // End of Authenticate
192
// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
193
194
// This function uses a global backoff strategy to prevent excessive retries across concurrent requests.
195
func (oidc *Authn) getKeyWithRetry(
196
	ctx context.Context,
197
	keyID string,
198
) (interface{}, error) {
199
	var raw interface{}
200
	var err error
201
202
	oidc.mutex.Lock()
203
	now := time.Now()
204
205
	// Reset global state if the interval has passed
206
	if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval {
207
		slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
208
		oidc.globalFirstSeen = now
209
		oidc.globalRetryCount = 0
210
		oidc.retriedKeys = make(map[string]bool)
211
	} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
212
		// If max retries reached within the interval, unlock and check keyID once
213
		slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
214
		oidc.mutex.Unlock()
215
216
		// Try to fetch the keyID once
217
		raw, err = oidc.fetchKey(ctx, keyID)
218
		if err == nil { // Successfully fetched the key
219
			oidc.mutex.Lock()
220
			if _, wasRetried := oidc.retriedKeys[keyID]; wasRetried {
221
				// Reset global backoff state if a valid key is found and that key had been previously retried
222
				// Use case: prevents malicious keyIDs from blocking valid keyIDs
223
				// The valid KeyID should not reset counters for invalid keys
224
				slog.Info("valid key found in backoff period, resetting global state", "keyID", keyID)
225
				oidc.globalRetryCount = 0                // Reset retry counter
226
				oidc.globalFirstSeen = time.Time{}       // Reset timestamp
227
				oidc.retriedKeys = make(map[string]bool) // Clear retried keys
228
			} // End of retried key check
229
			oidc.mutex.Unlock() // Release the lock
230
			return raw, nil
231
		}
232
233
		// Log the failure and return an error if keyID is not found
234
		slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err)
235
		return nil, errors.New("too many attempts, backoff in effect")
236
	}
237
	oidc.mutex.Unlock()
238
239
	// Retry mechanism
240
	retries := 0
241
	for retries <= oidc.backoffMaxRetries {
242
		raw, err = oidc.fetchKey(ctx, keyID)
243
		if err == nil { // Key successfully retrieved
244
			if retries != 0 { // Reset state if retry was successful
245
				oidc.mutex.Lock()
246
				oidc.globalRetryCount = 0                // Reset global counter
247
				oidc.globalFirstSeen = time.Time{}       // Reset first seen time
248
				oidc.retriedKeys = make(map[string]bool) // Clear retry tracking
249
				oidc.mutex.Unlock()
250
			} // End of retry state reset
251
			return raw, nil // Return the successfully fetched key
252
		} // End of successful key retrieval
253
		oidc.mutex.Lock()
254
		snapshotCount := oidc.globalRetryCount
255
		oidc.retriedKeys[keyID] = true
256
		if oidc.globalRetryCount > oidc.backoffMaxRetries { // Check if max retries exceeded
257
			slog.Error("key ID not found in JWKS due to exceeding global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
258
			oidc.mutex.Unlock() // Unlock before returning
259
			return nil, errors.New("too many retry attempts, backoff policy active due to global retry limit")
260
		} // End of max retries check
261
		oidc.mutex.Unlock() // Release mutex
262
		if retries > 0 {    // Wait before retry
263
			select {
264
			case <-time.After(oidc.backoffFrequency):
265
				// Log the wait before retrying
266
				slog.Info("waiting before retrying", "keyID", keyID, "retries", retries)
267
			case <-ctx.Done():
268
				slog.Error("context cancelled during retry", "keyID", keyID)
269
				return nil, ctx.Err()
270
			}
271
		}
272
273
		oidc.mutex.Lock()
274
		if oidc.globalRetryCount > snapshotCount { // Another goroutine already refreshed
275
			// Another concurrent request in retry loop has already refreshed the JWKS
276
			retries++ // Increment local attempt counter
277
			slog.Warn("concurrent request has already refreshed JWKS, skipping refresh")
278
			oidc.mutex.Unlock() // Unlock and continue
279
			continue            // Skip to next iteration
280
		}
281
282
		oidc.globalRetryCount++ // Increment the global retry counter
283
		slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
284
		retries++ // Increment retry counter
285
286
		if _, err := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); err != nil {
287
			oidc.mutex.Unlock()
288
			slog.Error("failed to refresh JWKS", "error", err)
289
			return nil, err
290
		}
291
		// Unlock after Refresh to prevent concurrent duplicate refresh calls
292
		oidc.mutex.Unlock() // Release lock after successful refresh
293
	}
294
295
	// Mark the global state to prevent further retries for the backoff interval
296
	oidc.mutex.Lock()
297
	if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval {
298
		slog.Warn("marking state to prevent further retries", "keyID", keyID)
299
		oidc.globalRetryCount = oidc.backoffMaxRetries
300
	}
301
	oidc.mutex.Unlock()
302
303
	slog.Error("key ID not found in JWKS after retries", "keyID", keyID)
304
	return nil, errors.New("key ID not found in JWKS after retries")
305
}
306
307
// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
308
// It fetches from the configured JWKS URI and looks up the key by its ID.
309
func (oidc *Authn) fetchKey(
310
	ctx context.Context,
311
	keyID string,
312
) (interface{}, error) {
313
	// Log the attempt to find the key in JWKS
314
	slog.DebugContext(ctx, "attempting to find key in JWKS", "kid", keyID)
315
316
	// Fetch the JWKS from the configured URI
317
	jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
318
	if err != nil { // Check for fetch errors
319
		// Log an error and return if fetching fails
320
		slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err)
321
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
322
	}
323
324
	// Log a successful fetch of the JWKS
325
	slog.InfoContext(ctx, "successfully fetched JWKS")
326
327
	// Attempt to find the key in the fetched JWKS using the key ID
328
	if key, found := jwks.LookupKeyID(keyID); found {
329
		var k interface{} // Variable to hold the raw key
330
		// Convert the key to a usable format
331
		if err := key.Raw(&k); err != nil {
332
			slog.ErrorContext(ctx, "failed to get raw public key", "kid", keyID, "error", err)
333
			return nil, fmt.Errorf("failed to get raw public key: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
334
		}
335
		// Log a successful retrieval of the raw public key
336
		slog.DebugContext(ctx, "successfully obtained raw public key", "key", k)
337
		return k, nil // Return the public key for JWT signature verification
338
	}
339
	// Log an error if the key ID is not found in the JWKS
340
	slog.ErrorContext(ctx, "key ID not found in JWKS", "kid", keyID)
341
	return nil, fmt.Errorf("kid %s not found", keyID)
342
}
343
344
// Config holds OpenID Connect (OIDC) configuration details.
345
type Config struct {
346
	// Issuer is the OIDC provider's unique identifier URL.
347
	Issuer string `json:"issuer"`
348
	// JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer.
349
	JWKsURI string `json:"jwks_uri"`
350
}
351
352
// fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration.
353
// It requires an HTTP client and the URL from which to fetch the configuration.
354
func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) {
355
	// Send an HTTP GET request to the provided URL to fetch the OIDC configuration.
356
	// This typically points to the well-known configuration endpoint of the OIDC provider.
357
	body, err := doHTTPRequest(client, url)
358
	if err != nil {
359
		// If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error.
360
		return nil, err
361
	} // End of HTTP request error handling
362
363
	// Parse the JSON response body into an OIDC Config struct.
364
	// This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration.
365
	oidcConfig, err := parseOIDCConfiguration(body)
366
	if err != nil {
367
		return nil, err
368
	}
369
370
	// Return the parsed OIDC configuration and nil as the error (indicating success).
371
	return oidcConfig, nil
372
}
373
374
// doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body.
375
// It handles HTTP errors and logs the request execution process.
376
func doHTTPRequest(client *http.Client, url string) ([]byte, error) {
377
	// Log the attempt to create a new HTTP GET request
378
	slog.Debug("creating new HTTP GET request", "url", url)
379
380
	// Create a new HTTP GET request.
381
	req, err := http.NewRequest("GET", url, nil)
382
	if err != nil {
383
		slog.Error("failed to create HTTP request", "url", url, "error", err)
384
		return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
385
	}
386
387
	// Log the execution of the HTTP request
388
	slog.Debug("executing HTTP request", "url", url)
389
390
	// Send the request using the configured HTTP client.
391
	res, err := client.Do(req)
392
	if err != nil {
393
		// Log the error if executing the HTTP request fails
394
		slog.Error("failed to execute HTTP request", "url", url, "error", err)
395
		return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
396
	}
397
398
	// Log the HTTP status code of the response
399
	slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url)
400
401
	// Ensure the response body is closed after reading.
402
	defer res.Body.Close()
403
404
	// Check if the HTTP status code indicates success.
405
	if res.StatusCode != http.StatusOK {
406
		slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url)
407
		return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode)
408
	}
409
410
	// Log the attempt to read the response body
411
	slog.Debug("reading response body", "url", url)
412
413
	// Read the response body.
414
	body, err := io.ReadAll(res.Body)
415
	if err != nil {
416
		slog.Error("failed to read response body", "url", url, "error", err)
417
		return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
418
	}
419
420
	// Log the successful retrieval of the response body
421
	slog.Debug("successfully read response body", "url", url, "response_length", len(body))
422
423
	// Return the response body.
424
	return body, nil
425
}
426
427
// parseOIDCConfiguration decodes the OIDC configuration from the given JSON body.
428
// It validates that required fields like Issuer and JWKsURI are present.
429
func parseOIDCConfiguration(body []byte) (*Config, error) {
430
	var oidcConfig Config
431
	// Attempt to unmarshal the JSON body into the oidcConfig struct.
432
	if err := json.Unmarshal(body, &oidcConfig); err != nil {
433
		slog.Error("failed to unmarshal OIDC configuration", "error", err)
434
		return nil, fmt.Errorf("failed to decode OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
435
	}
436
	// Log the successful decoding of OIDC configuration
437
	slog.Debug("successfully decoded OIDC configuration")
438
439
	if oidcConfig.Issuer == "" {
440
		slog.Warn("missing issuer value in OIDC configuration")
441
		return nil, errors.New("issuer value is required but missing in OIDC configuration")
442
	}
443
444
	if oidcConfig.JWKsURI == "" {
445
		slog.Warn("missing JWKsURI value in OIDC configuration")
446
		return nil, errors.New("JWKsURI value is required but missing in OIDC configuration")
447
	}
448
449
	// Log the successful parsing of the OIDC configuration
450
	slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI)
451
452
	// Return the successfully parsed configuration.
453
	return &oidcConfig, nil
454
}
455