Issues (58)

internal/authn/openid/authn.go (8 issues)

Severity
1
package openid
2
3
import (
4
	"context"
5
	"encoding/json"
6
	"errors"
7
	"fmt"
8
	"io"
9
	"log/slog"
10
	"net/http"
11
	"strings"
12
	"sync"
13
	"time"
14
15
	"github.com/golang-jwt/jwt/v4"
16
	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
17
	"github.com/hashicorp/go-retryablehttp"
18
	"github.com/lestrrat-go/jwx/jwk"
19
20
	"github.com/Permify/permify/internal/config"
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
type Authn struct {
25
	// URL of the issuer. This is typically the base URL of the identity provider.
26
	IssuerURL string
27
	// Audience for which the token is intended. It must match the audience in the JWT.
28
	Audience string
29
	// URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures.
30
	JwksURI string
31
	// Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals.
32
	jwksSet *jwk.AutoRefresh
33
	// List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs.
34
	validMethods []string
35
	// Pointer to a JWT parser object. This is used to parse and validate the JWT tokens.
36
	jwtParser *jwt.Parser
37
	// Duration of the interval between retries for the backoff policy.
38
	backoffInterval time.Duration
39
	// Maximum number of retries for the backoff policy.
40
	backoffMaxRetries int
41
42
	backoffFrequency time.Duration
43
44
	// Global backoff state for tracking retry attempts across concurrent requests
45
	globalRetryCount int
46
	globalFirstSeen  time.Time
47
	retriedKeys      map[string]bool
48
	mutex            sync.Mutex // protects concurrent access to retry state
49
}
50
51
// NewOidcAuthn creates a new OIDC authenticator.
52
func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
53
	// Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration.
54
	client := retryablehttp.NewClient()
55
	client.Logger = SlogAdapter{Logger: slog.Default()}
56
57
	// Fetch the OIDC configuration from the issuer's well-known configuration endpoint.
58
	oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration")
59
	if err != nil {
60
		return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
61
	}
62
63
	// Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date.
64
	autoRefresh := jwk.NewAutoRefresh(ctx)
65
	autoRefresh.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval))
66
67
	// Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency
68
	backoffInterval := conf.BackoffInterval
69
	if backoffInterval <= 0 {
70
		return nil, errors.New("invalid or missing backoffInterval")
71
	}
72
73
	backoffMaxRetries := conf.BackoffMaxRetries
74
	if backoffMaxRetries <= 0 {
75
		return nil, errors.New("invalid or missing backoffMaxRetries")
76
	}
77
78
	backoffFrequency := conf.BackoffFrequency
79
	if backoffFrequency <= 0 {
80
		return nil, errors.New("invalid or missing backoffFrequency")
81
	}
82
83
	// Initialize the Authn struct with the OIDC configuration details and other relevant settings.
84
	oidc := &Authn{
85
		IssuerURL:         conf.Issuer,
86
		Audience:          conf.Audience,
87
		JwksURI:           oidcConf.JWKsURI,
88
		validMethods:      conf.ValidMethods,
89
		jwtParser:         jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)),
90
		jwksSet:           autoRefresh,
91
		backoffInterval:   backoffInterval,
92
		backoffMaxRetries: backoffMaxRetries,
93
		backoffFrequency:  backoffFrequency,
94
		globalRetryCount:  0,
95
		retriedKeys:       make(map[string]bool),
96
		globalFirstSeen:   time.Time{},
97
		mutex:             sync.Mutex{},
98
	}
99
100
	// Attempt to fetch the JWKS immediately to ensure it's available and valid.
101
	if _, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI); err != nil {
102
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
103
	}
104
105
	return oidc, nil
106
}
107
108
// Authenticate validates the JWT token found in the authorization header of the incoming request.
109
func (oidc *Authn) Authenticate(ctx context.Context) error {
110
	// Extract the authorization header from the metadata of the incoming gRPC request.
111
	authHeader, err := grpcauth.AuthFromMD(ctx, "Bearer")
112
	if err != nil { // Check for authentication errors
113
		slog.Error("failed to extract authorization header from gRPC request", "error", err)
114
		return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String())
115
	}
116
	slog.Debug("Successfully extracted authorization header from gRPC request")
117
118
	// Parse and validate the JWT token extracted from the authorization header.
119
	parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) {
120
		slog.Info("starting JWT parsing and validation.")
121
122
		// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
123
		keyID, ok := token.Header["kid"].(string)
124
		if ok { // Key ID found in token header
125
			return oidc.getKeyWithRetry(ctx, keyID)
126
		}
127
		slog.Error("jwt does not contain a key ID")
128
		return nil, errors.New("kid must be specified in the token header")
129
	})
130
	if err != nil {
131
		slog.Error("token parsing or validation failed", "error", err)
132
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
133
	}
134
135
	// Ensure the token is valid.
136
	if !parsedToken.Valid {
137
		slog.Warn("parsed token is invalid")
138
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
139
	}
140
141
	// Extract the claims from the token.
142
	claims, ok := parsedToken.Claims.(jwt.MapClaims)
143
	if !ok {
144
		slog.Warn("token claims are in an incorrect format")
145
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String())
146
	}
147
148
	slog.Debug("extracted token claims", "claims", claims)
149
150
	// Verify the issuer of the token matches the expected issuer.
151
	if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok {
152
		slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"])
153
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String())
154
	}
155
	// Verify the audience of the token matches the expected audience.
156
157
	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
158
		slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"])
159
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String())
160
	}
161
162
	slog.Info("token validation succeeded")
163
164
	// If all validations pass, return nil indicating the token is valid.
165
	return nil
166
}
167
168
// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
169
func (oidc *Authn) getKeyWithRetry(
170
	ctx context.Context,
171
	keyID string,
172
) (interface{}, error) {
173
	var raw interface{}
174
	var err error
175
176
	oidc.mutex.Lock()
177
	now := time.Now()
178
179
	// Reset global state if the interval has passed
180
	if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval {
181
		slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
182
		oidc.globalFirstSeen = now
183
		oidc.globalRetryCount = 0
184
		oidc.retriedKeys = make(map[string]bool)
185
	} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
186
		// If max retries reached within the interval, unlock and check keyID once
187
		slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
188
		oidc.mutex.Unlock()
189
190
		// Try to fetch the keyID once
191
		raw, err = oidc.fetchKey(ctx, keyID)
192
		if err == nil { // Successfully fetched the key
193
			oidc.mutex.Lock()
194
			if _, wasRetried := oidc.retriedKeys[keyID]; wasRetried {
195
				// Reset global backoff state if a valid key is found and that key had been previously retried
196
				// Use case: prevents malicious keyIDs from blocking valid keyIDs
197
				// The valid KeyID should not reset counters for invalid keys
198
				slog.Info("valid key found in backoff period, resetting global state", "keyID", keyID)
199
				oidc.globalRetryCount = 0                // Reset retry counter
200
				oidc.globalFirstSeen = time.Time{}       // Reset timestamp
201
				oidc.retriedKeys = make(map[string]bool) // Clear retried keys
202
			}
203
			oidc.mutex.Unlock() // Release the lock
204
			return raw, nil
205
		}
206
207
		// Log the failure and return an error if keyID is not found
208
		slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err)
209
		return nil, errors.New("too many attempts, backoff in effect")
210
	}
211
	oidc.mutex.Unlock()
212
213
	// Retry mechanism
214
	retries := 0
215
	for retries <= oidc.backoffMaxRetries {
216
		raw, err = oidc.fetchKey(ctx, keyID)
217
		if err == nil { // Key successfully retrieved
218
			if retries != 0 { // Reset state if retry was successful
219
				oidc.mutex.Lock()
220
				oidc.globalRetryCount = 0
221
				oidc.globalFirstSeen = time.Time{}
222
				oidc.retriedKeys = make(map[string]bool)
223
				oidc.mutex.Unlock()
224
			}
225
			return raw, nil
226
		}
227
		oidc.mutex.Lock()
228
		snapshotCount := oidc.globalRetryCount
229
		oidc.retriedKeys[keyID] = true
230
		if oidc.globalRetryCount > oidc.backoffMaxRetries {
231
			slog.Error("key ID not found in JWKS due to exceeding global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
232
			oidc.mutex.Unlock() // Unlock before returning
233
			return nil, errors.New("too many retry attempts, backoff policy active due to global retry limit")
234
		}
235
		oidc.mutex.Unlock() // Release mutex
236
		if retries > 0 {
237
			select {
238
			case <-time.After(oidc.backoffFrequency):
239
				slog.Info("waiting before retrying", "keyID", keyID, "retries", retries)
240
			case <-ctx.Done():
241
				slog.Error("context cancelled during retry", "keyID", keyID)
242
				return nil, ctx.Err()
243
			}
244
		}
245
246
		oidc.mutex.Lock()
247
		if oidc.globalRetryCount > snapshotCount { // Another goroutine already refreshed
248
			// Another concurrent request in retry loop has already refreshed the JWKS
249
			retries++
250
			slog.Warn("concurrent request has already refreshed JWKS, skipping refresh")
251
			oidc.mutex.Unlock() // Unlock and continue
252
			continue            // Skip to next iteration
253
		}
254
255
		oidc.globalRetryCount++ // Increment the global retry counter
256
		slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
257
		retries++ // Increment retry counter
258
259
		if _, err := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); err != nil {
260
			oidc.mutex.Unlock()
261
			slog.Error("failed to refresh JWKS", "error", err)
262
			return nil, err
263
		}
264
		// Unlock after Refresh to prevent concurrent duplicate refresh calls
265
		oidc.mutex.Unlock() // Release lock after successful refresh
266
	}
267
268
	// Mark the global state to prevent further retries for the backoff interval
269
	oidc.mutex.Lock()
270
	if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval {
271
		slog.Warn("marking state to prevent further retries", "keyID", keyID)
272
		oidc.globalRetryCount = oidc.backoffMaxRetries
273
	}
274
	oidc.mutex.Unlock()
275
276
	slog.Error("key ID not found in JWKS after retries", "keyID", keyID)
277
	return nil, errors.New("key ID not found in JWKS after retries")
278
}
279
280
// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
281
// It fetches from the configured JWKS URI and looks up the key by its ID.
282
func (oidc *Authn) fetchKey(
283
	ctx context.Context,
284
	keyID string,
285
) (interface{}, error) {
286
	// Log the attempt to find the key in JWKS
287
	slog.DebugContext(ctx, "attempting to find key in JWKS", "kid", keyID)
288
289
	// Fetch the JWKS from the configured URI
290
	jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
291
	if err != nil { // Check for fetch errors
292
		// Log an error and return if fetching fails
293
		slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err)
294
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
295
	}
296
297
	// Log a successful fetch of the JWKS
298
	slog.InfoContext(ctx, "successfully fetched JWKS")
299
300
	// Attempt to find the key in the fetched JWKS using the key ID
301
	if key, found := jwks.LookupKeyID(keyID); found {
302
		var k interface{} // Variable to hold the raw key
303
		// Convert the key to a usable format
304
		if err := key.Raw(&k); err != nil {
305
			slog.ErrorContext(ctx, "failed to get raw public key", "kid", keyID, "error", err)
306
			return nil, fmt.Errorf("failed to get raw public key: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
307
		}
308
		// Log a successful retrieval of the raw public key
309
		slog.DebugContext(ctx, "successfully obtained raw public key", "key", k)
310
		return k, nil // Return the public key for JWT signature verification
311
	}
312
	// Log an error if the key ID is not found in the JWKS
313
	slog.ErrorContext(ctx, "key ID not found in JWKS", "kid", keyID)
314
	return nil, fmt.Errorf("kid %s not found", keyID)
315
}
316
317
// Config holds OpenID Connect (OIDC) configuration details.
318
type Config struct {
319
	// Issuer is the OIDC provider's unique identifier URL.
320
	Issuer string `json:"issuer"`
321
	// JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer.
322
	JWKsURI string `json:"jwks_uri"`
323
}
324
325
// fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration.
326
// It requires an HTTP client and the URL from which to fetch the configuration.
327
func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) {
328
	// Send an HTTP GET request to the provided URL to fetch the OIDC configuration.
329
	// This typically points to the well-known configuration endpoint of the OIDC provider.
330
	body, err := doHTTPRequest(client, url)
331
	if err != nil {
332
		// If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error.
333
		return nil, err
334
	}
335
336
	// Parse the JSON response body into an OIDC Config struct.
337
	// This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration.
338
	oidcConfig, err := parseOIDCConfiguration(body)
339
	if err != nil {
340
		return nil, err
341
	}
342
343
	// Return the parsed OIDC configuration and nil as the error (indicating success).
344
	return oidcConfig, nil
345
}
346
347
// doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body.
348
// It handles HTTP errors and logs the request execution process.
349
func doHTTPRequest(client *http.Client, url string) ([]byte, error) {
350
	// Log the attempt to create a new HTTP GET request
351
	slog.Debug("creating new HTTP GET request", "url", url)
352
353
	// Create a new HTTP GET request.
354
	req, err := http.NewRequest("GET", url, nil)
355
	if err != nil {
356
		slog.Error("failed to create HTTP request", "url", url, "error", err)
357
		return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
358
	}
359
360
	// Log the execution of the HTTP request
361
	slog.Debug("executing HTTP request", "url", url)
362
363
	// Send the request using the configured HTTP client.
364
	res, err := client.Do(req)
365
	if err != nil {
366
		// Log the error if executing the HTTP request fails
367
		slog.Error("failed to execute HTTP request", "url", url, "error", err)
368
		return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
369
	}
370
371
	// Log the HTTP status code of the response
372
	slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url)
373
374
	// Ensure the response body is closed after reading.
375
	defer res.Body.Close()
376
377
	// Check if the HTTP status code indicates success.
378
	if res.StatusCode != http.StatusOK {
379
		slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url)
380
		return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode)
381
	}
382
383
	// Log the attempt to read the response body
384
	slog.Debug("reading response body", "url", url)
385
386
	// Read the response body.
387
	body, err := io.ReadAll(res.Body)
388
	if err != nil {
389
		slog.Error("failed to read response body", "url", url, "error", err)
390
		return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
391
	}
392
393
	// Log the successful retrieval of the response body
394
	slog.Debug("successfully read response body", "url", url, "response_length", len(body))
395
396
	// Return the response body.
397
	return body, nil
398
}
399
400
// parseOIDCConfiguration decodes the OIDC configuration from the given JSON body.
401
// It validates that required fields like Issuer and JWKsURI are present.
402
func parseOIDCConfiguration(body []byte) (*Config, error) {
403
	var oidcConfig Config
404
	// Attempt to unmarshal the JSON body into the oidcConfig struct.
405
	if err := json.Unmarshal(body, &oidcConfig); err != nil {
406
		slog.Error("failed to unmarshal OIDC configuration", "error", err)
407
		return nil, fmt.Errorf("failed to decode OIDC configuration: %w", err)
0 ignored issues
show
unrecognized printf verb 'w'
Loading history...
408
	}
409
	// Log the successful decoding of OIDC configuration
410
	slog.Debug("successfully decoded OIDC configuration")
411
412
	if oidcConfig.Issuer == "" {
413
		slog.Warn("missing issuer value in OIDC configuration")
414
		return nil, errors.New("issuer value is required but missing in OIDC configuration")
415
	}
416
417
	if oidcConfig.JWKsURI == "" {
418
		slog.Warn("missing JWKsURI value in OIDC configuration")
419
		return nil, errors.New("JWKsURI value is required but missing in OIDC configuration")
420
	}
421
422
	// Log the successful parsing of the OIDC configuration
423
	slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI)
424
425
	// Return the successfully parsed configuration.
426
	return &oidcConfig, nil
427
}
428