Passed
Push — master ( c4386d...837d72 )
by Tolga
01:27 queued 15s
created

oidc.*Authn.Authenticate   C

Complexity

Conditions 9

Size

Total Lines 73
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 31
nop 1
dl 0
loc 73
rs 6.6666
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
package oidc
2
3
import (
4
	"context"
5
	"encoding/json"
6
	"errors"
7
	"fmt"
8
	"io"
9
	"log/slog"
10
	"net/http"
11
	"strings"
12
	"sync"
13
	"time"
14
15
	"github.com/golang-jwt/jwt/v4"
16
	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
17
	"github.com/hashicorp/go-retryablehttp"
18
	"github.com/lestrrat-go/jwx/jwk"
19
20
	"github.com/Permify/permify/internal/config"
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
// Authenticator - Interface for oidc authenticator
25
type Authenticator interface {
26
	Authenticate(ctx context.Context) error
27
}
28
29
type Authn struct {
30
	// URL of the issuer. This is typically the base URL of the identity provider.
31
	IssuerURL string
32
	// Audience for which the token is intended. It must match the audience in the JWT.
33
	Audience string
34
	// URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures.
35
	JwksURI string
36
	// Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals.
37
	jwksSet *jwk.AutoRefresh
38
	// List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs.
39
	validMethods []string
40
	// Pointer to a JWT parser object. This is used to parse and validate the JWT tokens.
41
	jwtParser *jwt.Parser
42
	// Duration of the interval between retries for the backoff policy.
43
	backoffInterval time.Duration
44
	// Maximum number of retries for the backoff policy.
45
	backoffMaxRetries int
46
47
	backoffFrequency time.Duration
48
49
	// Global backoff state
50
	globalRetryCount int
51
	globalFirstSeen  time.Time
52
	mu               sync.Mutex
53
}
54
55
// NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration.
56
// It takes in a context for managing cancellation and a configuration object. It returns a pointer to an Authn instance or an error.
57
func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
58
	// Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration.
59
	client := retryablehttp.NewClient()
60
	client.Logger = SlogAdapter{Logger: slog.Default()}
61
62
	// Fetch the OIDC configuration from the issuer's well-known configuration endpoint.
63
	oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration")
64
	if err != nil {
65
		// If there is an error fetching the OIDC configuration, return nil and the error.
66
		return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
67
	}
68
69
	// Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date.
70
	ar := jwk.NewAutoRefresh(ctx)                                                                                              // Create a new AutoRefresh instance for the JWKS.
71
	ar.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval)) // Configure the auto-refresh parameters.
72
73
	// Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency
74
	backoffInterval := conf.BackoffInterval
75
	if backoffInterval <= 0 {
76
		return nil, errors.New("invalid or missing backoffInterval")
77
	}
78
79
	backoffMaxRetries := conf.BackoffMaxRetries
80
	if backoffMaxRetries <= 0 {
81
		return nil, errors.New("invalid or missing backoffMaxRetries")
82
	}
83
84
	backoffFrequency := conf.BackoffFrequency
85
	if backoffFrequency <= 0 {
86
		return nil, errors.New("invalid or missing backoffFrequency")
87
	}
88
89
	// Initialize the Authn struct with the OIDC configuration details and other relevant settings.
90
	oidc := &Authn{
91
		IssuerURL:         conf.Issuer,                                            // URL of the token issuer.
92
		Audience:          conf.Audience,                                          // Intended audience of the token.
93
		JwksURI:           oidcConf.JWKsURI,                                       // URL of the JWKS endpoint.
94
		validMethods:      conf.ValidMethods,                                      // List of acceptable signing methods for the tokens.
95
		jwtParser:         jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)), // JWT parser configured with the valid signing methods.
96
		jwksSet:           ar,                                                     // Set the JWKS auto-refresh instance.
97
		backoffInterval:   backoffInterval,
98
		backoffMaxRetries: backoffMaxRetries,
99
		backoffFrequency:  backoffFrequency,
100
		globalRetryCount:  0,
101
		globalFirstSeen:   time.Time{},
102
		mu:                sync.Mutex{},
103
	}
104
105
	// Attempt to fetch the JWKS immediately to ensure it's available and valid.
106
	_, err = oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
107
	if err != nil {
108
		// If there is an error fetching the JWKS, return nil and the error.
109
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
110
	}
111
112
	// Return the initialized OIDC authentication object and no error.
113
	return oidc, nil
114
}
115
116
// Authenticate validates the JWT token found in the authorization header of the incoming request.
117
// It uses the OIDC configuration to validate the token against the issuer's public keys.
118
func (oidc *Authn) Authenticate(requestContext context.Context) error {
119
	// Extract the authorization header from the metadata of the incoming gRPC request.
120
	authHeader, err := grpcauth.AuthFromMD(requestContext, "Bearer")
121
	if err != nil {
122
		// Log the error if the authorization header is missing or does not start with "Bearer"
123
		slog.Error("failed to extract authorization header from gRPC request", "error", err)
124
		// Return an error indicating the missing or incorrect bearer token
125
		return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String())
126
	}
127
128
	// Log the successful extraction of the authorization header for debugging purposes.
129
	// Remember, do not log the actual content of authHeader as it might contain sensitive information.
130
	slog.Debug("Successfully extracted authorization header from gRPC request")
131
132
	// Parse and validate the JWT token extracted from the authorization header.
133
	parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) {
134
		slog.Info("starting JWT parsing and validation.")
135
136
		// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
137
		if keyID, ok := token.Header["kid"].(string); ok {
138
			return oidc.getKeyWithRetry(keyID, requestContext)
139
		}
140
		slog.Error("jwt does not contain a key ID")
141
		// If the JWT does not contain a key ID, return an error.
142
		return nil, errors.New("kid must be specified in the token header")
143
	})
144
	if err != nil {
145
		// Log that the token parsing or validation failed
146
		slog.Error("token parsing or validation failed", "error", err)
147
		// If token parsing or validation fails, return an error indicating the token is invalid.
148
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
149
	}
150
151
	// Ensure the token is valid.
152
	if !parsedToken.Valid {
153
		// Log that the parsed token was not valid
154
		slog.Warn("parsed token is invalid")
155
		// Return an error indicating the invalid token
156
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
157
	}
158
159
	// Extract the claims from the token.
160
	claims, ok := parsedToken.Claims.(jwt.MapClaims)
161
	if !ok {
162
		// Log that the claims were in an incorrect format
163
		slog.Warn("token claims are in an incorrect format")
164
		// Return an error
165
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String())
166
	}
167
168
	slog.Debug("extracted token claims", "claims", claims)
169
170
	// Verify the issuer of the token matches the expected issuer.
171
	if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok {
172
		// Log that the issuer did not match the expected issuer
173
		slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"])
174
		// Return an error
175
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String())
176
	}
177
178
	// Verify the audience of the token matches the expected audience.
179
	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
180
		// Log that the audience did not match the expected audience
181
		slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"])
182
		// Return an error
183
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String())
184
	}
185
186
	// Log that the token's issuer and audience were successfully validated
187
	slog.Info("token validation succeeded")
188
189
	// If all validations pass, return nil indicating the token is valid.
190
	return nil
191
}
192
193
// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
194
func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface{}, error) {
195
	var rawKey interface{}
196
	var err error
197
198
	oidc.mu.Lock()
199
	now := time.Now()
200
201
	// Reset global state if the interval has passed
202
	if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval {
203
		slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
204
		oidc.globalFirstSeen = now
205
		oidc.globalRetryCount = 0
206
	} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
207
		// If max retries reached within the interval, unlock and check keyID once
208
		slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
209
		oidc.mu.Unlock()
210
211
		// Try to fetch the keyID once
212
		rawKey, err = oidc.fetchKey(keyID, ctx)
213
		if err == nil {
214
			// Reset global backoff state if a valid key is found
215
			slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
216
			oidc.mu.Lock()
217
			oidc.globalRetryCount = 0
218
			oidc.globalFirstSeen = time.Time{}
219
			oidc.mu.Unlock()
220
			return rawKey, nil
221
		}
222
223
		// Log the failure and return an error if keyID is not found
224
		slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err)
225
		return nil, errors.New("too many attempts, backoff in effect")
226
	}
227
	oidc.mu.Unlock()
228
229
	// Retry mechanism
230
	retries := 0
231
	for retries <= oidc.backoffMaxRetries {
232
		if retries > 0 {
233
			select {
234
			case <-time.After(oidc.backoffFrequency):
235
				// Log the wait before retrying
236
				slog.Info("waiting before retrying", "keyID", keyID, "retries", retries)
237
			case <-ctx.Done():
238
				slog.Error("context cancelled during retry", "keyID", keyID)
239
				return nil, ctx.Err()
240
			}
241
		}
242
243
		rawKey, err = oidc.fetchKey(keyID, ctx)
244
		if err == nil {
245
			// Log the successful key fetch and reset global state
246
			slog.Info("successfully fetched key", "keyID", keyID)
247
			oidc.mu.Lock()
248
			oidc.globalRetryCount = 0
249
			oidc.globalFirstSeen = time.Time{}
250
			oidc.mu.Unlock()
251
			return rawKey, nil
252
		}
253
254
		slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
255
		retries++
256
257
		oidc.mu.Lock()
258
		oidc.globalRetryCount++
259
		oidc.mu.Unlock()
260
261
		if _, refreshErr := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); refreshErr != nil {
262
			slog.Error("failed to refresh JWKS", "error", refreshErr)
263
			return nil, refreshErr
264
		}
265
	}
266
267
	// Mark the global state to prevent further retries for the backoff interval
268
	oidc.mu.Lock()
269
	if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval {
270
		slog.Warn("marking state to prevent further retries", "keyID", keyID)
271
		oidc.globalRetryCount = oidc.backoffMaxRetries
272
	}
273
	oidc.mu.Unlock()
274
275
	slog.Error("key ID not found in JWKS after retries", "keyID", keyID)
276
	return nil, errors.New("key ID not found in JWKS after retries")
277
}
278
279
// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
280
func (oidc *Authn) fetchKey(keyID string, ctx context.Context) (interface{}, error) {
281
	// Log the attempt to find the key.
282
	slog.Debug("attempting to find key in JWKS", "kid", keyID)
283
284
	// Fetch the JWKS from the configured URI.
285
	jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
286
	if err != nil {
287
		// Log an error and return if fetching fails.
288
		slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err)
289
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
290
	}
291
292
	// Log a successful fetch of the JWKS.
293
	slog.Info("successfully fetched JWKS")
294
295
	// Attempt to find the key in the fetched JWKS using the key ID.
296
	if key, found := jwks.LookupKeyID(keyID); found {
297
		var k interface{}
298
		// Convert the key to a usable format.
299
		if err := key.Raw(&k); err != nil {
300
			slog.Error("failed to get raw public key", "kid", keyID, "error", err)
301
			return nil, fmt.Errorf("failed to get raw public key: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
302
		}
303
		// Log a successful retrieval of the raw public key.
304
		slog.Debug("successfully obtained raw public key", "key", k)
305
		return k, nil // Return the public key for JWT signature verification.
306
	}
307
	// Log an error if the key ID is not found in the JWKS.
308
	slog.Error("key ID not found in JWKS", "kid", keyID)
309
	return nil, fmt.Errorf("kid %s not found", keyID)
310
}
311
312
// Config holds OpenID Connect (OIDC) configuration details.
313
type Config struct {
314
	// Issuer is the OIDC provider's unique identifier URL.
315
	Issuer string `json:"issuer"`
316
	// JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer.
317
	JWKsURI string `json:"jwks_uri"`
318
}
319
320
// fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration.
321
// It requires an HTTP client and the URL from which to fetch the configuration.
322
func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) {
323
	// Send an HTTP GET request to the provided URL to fetch the OIDC configuration.
324
	// This typically points to the well-known configuration endpoint of the OIDC provider.
325
	body, err := doHTTPRequest(client, url)
326
	if err != nil {
327
		// If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error.
328
		return nil, err
329
	}
330
331
	// Parse the JSON response body into an OIDC Config struct.
332
	// This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration.
333
	oidcConfig, err := parseOIDCConfiguration(body)
334
	if err != nil {
335
		return nil, err
336
	}
337
338
	// Return the parsed OIDC configuration and nil as the error (indicating success).
339
	return oidcConfig, nil
340
}
341
342
// doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body.
343
func doHTTPRequest(client *http.Client, url string) ([]byte, error) {
344
	// Log the attempt to create a new HTTP GET request
345
	slog.Debug("creating new HTTP GET request", "url", url)
346
347
	// Create a new HTTP GET request.
348
	req, err := http.NewRequest("GET", url, nil)
349
	if err != nil {
350
		slog.Error("failed to create HTTP request", "url", url, "error", err)
351
		return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %s", err)
352
	}
353
354
	// Log the execution of the HTTP request
355
	slog.Debug("executing HTTP request", "url", url)
356
357
	// Send the request using the configured HTTP client.
358
	res, err := client.Do(req)
359
	if err != nil {
360
		// Log the error if executing the HTTP request fails
361
		slog.Error("failed to execute HTTP request", "url", url, "error", err)
362
		return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %s", err)
363
	}
364
365
	// Log the HTTP status code of the response
366
	slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url)
367
368
	// Ensure the response body is closed after reading.
369
	defer res.Body.Close()
370
371
	// Check if the HTTP status code indicates success.
372
	if res.StatusCode != http.StatusOK {
373
		slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url)
374
		return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode)
375
	}
376
377
	// Log the attempt to read the response body
378
	slog.Debug("reading response body", "url", url)
379
380
	// Read the response body.
381
	body, err := io.ReadAll(res.Body)
382
	if err != nil {
383
		slog.Error("failed to read response body", "url", url, "error", err)
384
		return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %s", err)
385
	}
386
387
	// Log the successful retrieval of the response body
388
	slog.Debug("successfully read response body", "url", url, "response_length", len(body))
389
390
	// Return the response body.
391
	return body, nil
392
}
393
394
// parseOIDCConfiguration decodes the OIDC configuration from the given JSON body.
395
func parseOIDCConfiguration(body []byte) (*Config, error) {
396
	var oidcConfig Config
397
	// Attempt to unmarshal the JSON body into the oidcConfig struct.
398
	if err := json.Unmarshal(body, &oidcConfig); err != nil {
399
		slog.Error("failed to unmarshal OIDC configuration", "error", err)
400
		return nil, fmt.Errorf("failed to decode OIDC configuration: %s", err)
401
	}
402
	// Log the successful decoding of OIDC configuration
403
	slog.Debug("successfully decoded OIDC configuration")
404
405
	if oidcConfig.Issuer == "" {
406
		slog.Warn("missing issuer value in OIDC configuration")
407
		return nil, errors.New("issuer value is required but missing in OIDC configuration")
408
	}
409
410
	if oidcConfig.JWKsURI == "" {
411
		slog.Warn("missing JWKsURI value in OIDC configuration")
412
		return nil, errors.New("JWKsURI value is required but missing in OIDC configuration")
413
	}
414
415
	// Log the successful parsing of the OIDC configuration
416
	slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI)
417
418
	// Return the successfully parsed configuration.
419
	return &oidcConfig, nil
420
}
421