Passed
Push — master ( f45b5e...2f1c8e )
by Tolga
01:30 queued 36s
created

oidc.NewOidcAuthn   B

Complexity

Conditions 6

Size

Total Lines 58
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 35
nop 2
dl 0
loc 58
rs 8.1066
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
package oidc
2
3
import (
4
	"context"
5
	"encoding/json"
6
	"errors"
7
	"fmt"
8
	"io"
9
	"log/slog"
10
	"net/http"
11
	"strings"
12
	"sync"
13
	"time"
14
15
	"github.com/golang-jwt/jwt/v4"
16
	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
17
	"github.com/hashicorp/go-retryablehttp"
18
	"github.com/lestrrat-go/jwx/jwk"
19
20
	"github.com/Permify/permify/internal/config"
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
// Authenticator - Interface for oidc authenticator
25
type Authenticator interface {
26
	Authenticate(ctx context.Context) error
27
}
28
29
type Authn struct {
30
	// URL of the issuer. This is typically the base URL of the identity provider.
31
	IssuerURL string
32
	// Audience for which the token is intended. It must match the audience in the JWT.
33
	Audience string
34
	// URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures.
35
	JwksURI string
36
	// Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals.
37
	jwksSet *jwk.AutoRefresh
38
	// List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs.
39
	validMethods []string
40
	// Pointer to a JWT parser object. This is used to parse and validate the JWT tokens.
41
	jwtParser *jwt.Parser
42
	// Duration of the interval between retries for the backoff policy.
43
	backoffInterval time.Duration
44
	// Maximum number of retries for the backoff policy.
45
	backoffMaxRetries int
46
47
	backoffFrequency time.Duration
48
49
	// Global backoff state
50
	globalRetryCount  int
51
	globalFirstSeen   time.Time
52
	globalRetryKeyIds map[string]bool
53
	mu                sync.Mutex
54
}
55
56
// NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration.
57
// It takes in a context for managing cancellation and a configuration object. It returns a pointer to an Authn instance or an error.
58
func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
59
	// Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration.
60
	client := retryablehttp.NewClient()
61
	client.Logger = SlogAdapter{Logger: slog.Default()}
62
63
	// Fetch the OIDC configuration from the issuer's well-known configuration endpoint.
64
	oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration")
65
	if err != nil {
66
		// If there is an error fetching the OIDC configuration, return nil and the error.
67
		return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
68
	}
69
70
	// Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date.
71
	ar := jwk.NewAutoRefresh(ctx)                                                                                              // Create a new AutoRefresh instance for the JWKS.
72
	ar.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval)) // Configure the auto-refresh parameters.
73
74
	// Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency
75
	backoffInterval := conf.BackoffInterval
76
	if backoffInterval <= 0 {
77
		return nil, errors.New("invalid or missing backoffInterval")
78
	}
79
80
	backoffMaxRetries := conf.BackoffMaxRetries
81
	if backoffMaxRetries <= 0 {
82
		return nil, errors.New("invalid or missing backoffMaxRetries")
83
	}
84
85
	backoffFrequency := conf.BackoffFrequency
86
	if backoffFrequency <= 0 {
87
		return nil, errors.New("invalid or missing backoffFrequency")
88
	}
89
90
	// Initialize the Authn struct with the OIDC configuration details and other relevant settings.
91
	oidc := &Authn{
92
		IssuerURL:         conf.Issuer,                                            // URL of the token issuer.
93
		Audience:          conf.Audience,                                          // Intended audience of the token.
94
		JwksURI:           oidcConf.JWKsURI,                                       // URL of the JWKS endpoint.
95
		validMethods:      conf.ValidMethods,                                      // List of acceptable signing methods for the tokens.
96
		jwtParser:         jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)), // JWT parser configured with the valid signing methods.
97
		jwksSet:           ar,                                                     // Set the JWKS auto-refresh instance.
98
		backoffInterval:   backoffInterval,
99
		backoffMaxRetries: backoffMaxRetries,
100
		backoffFrequency:  backoffFrequency,
101
		globalRetryCount:  0,
102
		globalRetryKeyIds: make(map[string]bool),
103
		globalFirstSeen:   time.Time{},
104
		mu:                sync.Mutex{},
105
	}
106
107
	// Attempt to fetch the JWKS immediately to ensure it's available and valid.
108
	_, err = oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
109
	if err != nil {
110
		// If there is an error fetching the JWKS, return nil and the error.
111
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
112
	}
113
114
	// Return the initialized OIDC authentication object and no error.
115
	return oidc, nil
116
}
117
118
// Authenticate validates the JWT token found in the authorization header of the incoming request.
119
// It uses the OIDC configuration to validate the token against the issuer's public keys.
120
func (oidc *Authn) Authenticate(requestContext context.Context) error {
121
	// Extract the authorization header from the metadata of the incoming gRPC request.
122
	authHeader, err := grpcauth.AuthFromMD(requestContext, "Bearer")
123
	if err != nil {
124
		// Log the error if the authorization header is missing or does not start with "Bearer"
125
		slog.Error("failed to extract authorization header from gRPC request", "error", err)
126
		// Return an error indicating the missing or incorrect bearer token
127
		return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String())
128
	}
129
130
	// Log the successful extraction of the authorization header for debugging purposes.
131
	// Remember, do not log the actual content of authHeader as it might contain sensitive information.
132
	slog.Debug("Successfully extracted authorization header from gRPC request")
133
134
	// Parse and validate the JWT token extracted from the authorization header.
135
	parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) {
136
		slog.Info("starting JWT parsing and validation.")
137
138
		// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
139
		if keyID, ok := token.Header["kid"].(string); ok {
140
			return oidc.getKeyWithRetry(keyID, requestContext)
141
		}
142
		slog.Error("jwt does not contain a key ID")
143
		// If the JWT does not contain a key ID, return an error.
144
		return nil, errors.New("kid must be specified in the token header")
145
	})
146
	if err != nil {
147
		// Log that the token parsing or validation failed
148
		slog.Error("token parsing or validation failed", "error", err)
149
		// If token parsing or validation fails, return an error indicating the token is invalid.
150
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
151
	}
152
153
	// Ensure the token is valid.
154
	if !parsedToken.Valid {
155
		// Log that the parsed token was not valid
156
		slog.Warn("parsed token is invalid")
157
		// Return an error indicating the invalid token
158
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
159
	}
160
161
	// Extract the claims from the token.
162
	claims, ok := parsedToken.Claims.(jwt.MapClaims)
163
	if !ok {
164
		// Log that the claims were in an incorrect format
165
		slog.Warn("token claims are in an incorrect format")
166
		// Return an error
167
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String())
168
	}
169
170
	slog.Debug("extracted token claims", "claims", claims)
171
172
	// Verify the issuer of the token matches the expected issuer.
173
	if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok {
174
		// Log that the issuer did not match the expected issuer
175
		slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"])
176
		// Return an error
177
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String())
178
	}
179
180
	// Verify the audience of the token matches the expected audience.
181
	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
182
		// Log that the audience did not match the expected audience
183
		slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"])
184
		// Return an error
185
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String())
186
	}
187
188
	// Log that the token's issuer and audience were successfully validated
189
	slog.Info("token validation succeeded")
190
191
	// If all validations pass, return nil indicating the token is valid.
192
	return nil
193
}
194
195
// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
196
func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface{}, error) {
197
	var rawKey interface{}
198
	var err error
199
200
	oidc.mu.Lock()
201
	now := time.Now()
202
203
	// Reset global state if the interval has passed
204
	if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval {
205
		slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
206
		oidc.globalFirstSeen = now
207
		oidc.globalRetryCount = 0
208
		oidc.globalRetryKeyIds = make(map[string]bool)
209
	} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
210
		// If max retries reached within the interval, unlock and check keyID once
211
		slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
212
		oidc.mu.Unlock()
213
214
		// Try to fetch the keyID once
215
		rawKey, err = oidc.fetchKey(keyID, ctx)
216
		if err == nil {
217
			oidc.mu.Lock()
218
			if _, ok := oidc.globalRetryKeyIds[keyID]; ok {
219
				// Reset global backoff state if a valid key is found and that key had been retried.
220
				// Use case would be someone trying to exploit with bad KeyIDs, and along comes a valid KeyID
221
				// The valid KeyID should not reset the counters for a bad key
222
				slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
223
				oidc.globalRetryCount = 0
224
				oidc.globalFirstSeen = time.Time{}
225
				oidc.globalRetryKeyIds = make(map[string]bool)
226
			}
227
			oidc.mu.Unlock()
228
			return rawKey, nil
229
		}
230
231
		// Log the failure and return an error if keyID is not found
232
		slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err)
233
		return nil, errors.New("too many attempts, backoff in effect")
234
	}
235
	oidc.mu.Unlock()
236
237
	// Retry mechanism
238
	retries := 0
239
	for retries <= oidc.backoffMaxRetries {
240
		rawKey, err = oidc.fetchKey(keyID, ctx)
241
		if err == nil {
242
			if retries != 0 {
243
				oidc.mu.Lock()
244
				oidc.globalRetryCount = 0
245
				oidc.globalFirstSeen = time.Time{}
246
				oidc.globalRetryKeyIds = make(map[string]bool)
247
				oidc.mu.Unlock()
248
			}
249
			return rawKey, nil
250
		}
251
		oidc.mu.Lock()
252
		initialGlobalRetryCount := oidc.globalRetryCount
253
		oidc.globalRetryKeyIds[keyID] = true
254
		if oidc.globalRetryCount > oidc.backoffMaxRetries {
255
			slog.Error("key ID not found in JWKS due to global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
256
			oidc.mu.Unlock()
257
			return nil, errors.New("too many attempts, backoff in effect due to global retry count")
258
		}
259
		oidc.mu.Unlock()
260
		if retries > 0 {
261
			select {
262
			case <-time.After(oidc.backoffFrequency):
263
				// Log the wait before retrying
264
				slog.Info("waiting before retrying", "keyID", keyID, "retries", retries)
265
			case <-ctx.Done():
266
				slog.Error("context cancelled during retry", "keyID", keyID)
267
				return nil, ctx.Err()
268
			}
269
		}
270
271
		oidc.mu.Lock()
272
		if oidc.globalRetryCount > initialGlobalRetryCount {
273
			// Concurrent requests in retry loop at same time, another concurrent request already refreshed the JWKS
274
			retries++
275
			slog.Warn("another concurrent request already refreshed the JWKS")
276
			oidc.mu.Unlock()
277
			continue
278
		}
279
280
		oidc.globalRetryCount++
281
		slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
282
		retries++
283
284
		if _, refreshErr := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); refreshErr != nil {
285
			oidc.mu.Unlock()
286
			slog.Error("failed to refresh JWKS", "error", refreshErr)
287
			return nil, refreshErr
288
		}
289
		// Unlock needs to follow Refresh to ensure that concurrent requests don't make duplicate calls to Refresh
290
		oidc.mu.Unlock()
291
	}
292
293
	// Mark the global state to prevent further retries for the backoff interval
294
	oidc.mu.Lock()
295
	if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval {
296
		slog.Warn("marking state to prevent further retries", "keyID", keyID)
297
		oidc.globalRetryCount = oidc.backoffMaxRetries
298
	}
299
	oidc.mu.Unlock()
300
301
	slog.Error("key ID not found in JWKS after retries", "keyID", keyID)
302
	return nil, errors.New("key ID not found in JWKS after retries")
303
}
304
305
// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
306
func (oidc *Authn) fetchKey(keyID string, ctx context.Context) (interface{}, error) {
307
	// Log the attempt to find the key.
308
	slog.Debug("attempting to find key in JWKS", "kid", keyID)
309
310
	// Fetch the JWKS from the configured URI.
311
	jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
312
	if err != nil {
313
		// Log an error and return if fetching fails.
314
		slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err)
315
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
316
	}
317
318
	// Log a successful fetch of the JWKS.
319
	slog.Info("successfully fetched JWKS")
320
321
	// Attempt to find the key in the fetched JWKS using the key ID.
322
	if key, found := jwks.LookupKeyID(keyID); found {
323
		var k interface{}
324
		// Convert the key to a usable format.
325
		if err := key.Raw(&k); err != nil {
326
			slog.Error("failed to get raw public key", "kid", keyID, "error", err)
327
			return nil, fmt.Errorf("failed to get raw public key: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
328
		}
329
		// Log a successful retrieval of the raw public key.
330
		slog.Debug("successfully obtained raw public key", "key", k)
331
		return k, nil // Return the public key for JWT signature verification.
332
	}
333
	// Log an error if the key ID is not found in the JWKS.
334
	slog.Error("key ID not found in JWKS", "kid", keyID)
335
	return nil, fmt.Errorf("kid %s not found", keyID)
336
}
337
338
// Config holds OpenID Connect (OIDC) configuration details.
339
type Config struct {
340
	// Issuer is the OIDC provider's unique identifier URL.
341
	Issuer string `json:"issuer"`
342
	// JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer.
343
	JWKsURI string `json:"jwks_uri"`
344
}
345
346
// fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration.
347
// It requires an HTTP client and the URL from which to fetch the configuration.
348
func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) {
349
	// Send an HTTP GET request to the provided URL to fetch the OIDC configuration.
350
	// This typically points to the well-known configuration endpoint of the OIDC provider.
351
	body, err := doHTTPRequest(client, url)
352
	if err != nil {
353
		// If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error.
354
		return nil, err
355
	}
356
357
	// Parse the JSON response body into an OIDC Config struct.
358
	// This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration.
359
	oidcConfig, err := parseOIDCConfiguration(body)
360
	if err != nil {
361
		return nil, err
362
	}
363
364
	// Return the parsed OIDC configuration and nil as the error (indicating success).
365
	return oidcConfig, nil
366
}
367
368
// doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body.
369
func doHTTPRequest(client *http.Client, url string) ([]byte, error) {
370
	// Log the attempt to create a new HTTP GET request
371
	slog.Debug("creating new HTTP GET request", "url", url)
372
373
	// Create a new HTTP GET request.
374
	req, err := http.NewRequest("GET", url, nil)
375
	if err != nil {
376
		slog.Error("failed to create HTTP request", "url", url, "error", err)
377
		return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %s", err)
378
	}
379
380
	// Log the execution of the HTTP request
381
	slog.Debug("executing HTTP request", "url", url)
382
383
	// Send the request using the configured HTTP client.
384
	res, err := client.Do(req)
385
	if err != nil {
386
		// Log the error if executing the HTTP request fails
387
		slog.Error("failed to execute HTTP request", "url", url, "error", err)
388
		return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %s", err)
389
	}
390
391
	// Log the HTTP status code of the response
392
	slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url)
393
394
	// Ensure the response body is closed after reading.
395
	defer res.Body.Close()
396
397
	// Check if the HTTP status code indicates success.
398
	if res.StatusCode != http.StatusOK {
399
		slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url)
400
		return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode)
401
	}
402
403
	// Log the attempt to read the response body
404
	slog.Debug("reading response body", "url", url)
405
406
	// Read the response body.
407
	body, err := io.ReadAll(res.Body)
408
	if err != nil {
409
		slog.Error("failed to read response body", "url", url, "error", err)
410
		return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %s", err)
411
	}
412
413
	// Log the successful retrieval of the response body
414
	slog.Debug("successfully read response body", "url", url, "response_length", len(body))
415
416
	// Return the response body.
417
	return body, nil
418
}
419
420
// parseOIDCConfiguration decodes the OIDC configuration from the given JSON body.
421
func parseOIDCConfiguration(body []byte) (*Config, error) {
422
	var oidcConfig Config
423
	// Attempt to unmarshal the JSON body into the oidcConfig struct.
424
	if err := json.Unmarshal(body, &oidcConfig); err != nil {
425
		slog.Error("failed to unmarshal OIDC configuration", "error", err)
426
		return nil, fmt.Errorf("failed to decode OIDC configuration: %s", err)
427
	}
428
	// Log the successful decoding of OIDC configuration
429
	slog.Debug("successfully decoded OIDC configuration")
430
431
	if oidcConfig.Issuer == "" {
432
		slog.Warn("missing issuer value in OIDC configuration")
433
		return nil, errors.New("issuer value is required but missing in OIDC configuration")
434
	}
435
436
	if oidcConfig.JWKsURI == "" {
437
		slog.Warn("missing JWKsURI value in OIDC configuration")
438
		return nil, errors.New("JWKsURI value is required but missing in OIDC configuration")
439
	}
440
441
	// Log the successful parsing of the OIDC configuration
442
	slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI)
443
444
	// Return the successfully parsed configuration.
445
	return &oidcConfig, nil
446
}
447