Passed
Pull Request — master (#2520)
by Tolga
03:24
created

openid.newFakeOidcProvider   A

Complexity

Conditions 4

Size

Total Lines 42
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 34
nop 1
dl 0
loc 42
rs 9.064
c 0
b 0
f 0
1
package openid
2
3
import (
4
	"crypto/ecdsa"
5
	"crypto/elliptic"
6
	"crypto/rand"
7
	"crypto/rsa"
8
	"encoding/json"
9
	"fmt"
10
	"net"
11
	"net/http"
12
	"net/http/httptest"
13
	"sync"
14
15
	"github.com/go-jose/go-jose/v3"
16
	"github.com/golang-jwt/jwt/v4"
17
)
18
19
type fakeOidcProvider struct {
20
	issuerURL    string
21
	authPath     string
22
	tokenPath    string
23
	userInfoPath string
24
	JWKSPath     string
25
26
	algorithms         []string
27
	signingKeyMap      map[jwt.SigningMethod]string
28
	jwks               []jose.JSONWebKey
29
	rsaPrivateKey      *rsa.PrivateKey
30
	rsaPrivateKeyForPS *rsa.PrivateKey
31
	ecdsaPrivateKey    *ecdsa.PrivateKey
32
	hmacKey            []byte
33
34
	mu sync.RWMutex
35
}
36
37
type ProviderConfig struct {
38
	IssuerURL    string
39
	AuthPath     string
40
	TokenPath    string
41
	UserInfoPath string
42
	JWKSPath     string
43
	Algorithms   []string
44
}
45
46
func newFakeOidcProvider(config ProviderConfig) (*fakeOidcProvider, error) {
47
	rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
48
	if err != nil {
49
		return nil, fmt.Errorf("failed to generate RSA key: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
50
	}
51
	rsaPrivateKeyForPS, err := rsa.GenerateKey(rand.Reader, 2048)
52
	if err != nil {
53
		return nil, fmt.Errorf("failed to generate RSA key for PS: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
54
	}
55
	ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
56
	if err != nil {
57
		return nil, fmt.Errorf("failed to generate ECDSA key: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
58
	}
59
	hmacKey := []byte("hmackeysecret")
60
61
	signingKeyMap := map[jwt.SigningMethod]string{
62
		jwt.SigningMethodRS256: "rs256keyid",
63
	}
64
	jwks := []jose.JSONWebKey{
65
		{
66
			Key:       rsaPrivateKey.Public(),
67
			KeyID:     signingKeyMap[jwt.SigningMethodRS256],
68
			Algorithm: "RS256",
69
			Use:       "sig",
70
		},
71
	}
72
73
	return &fakeOidcProvider{
74
		issuerURL:          config.IssuerURL,
75
		authPath:           config.AuthPath,
76
		tokenPath:          config.TokenPath,
77
		userInfoPath:       config.UserInfoPath,
78
		JWKSPath:           config.JWKSPath,
79
		algorithms:         config.Algorithms,
80
		rsaPrivateKey:      rsaPrivateKey,
81
		rsaPrivateKeyForPS: rsaPrivateKeyForPS,
82
		hmacKey:            hmacKey,
83
		jwks:               jwks,
84
		ecdsaPrivateKey:    ecdsaPrivateKey,
85
		signingKeyMap:      signingKeyMap,
86
		mu:                 sync.RWMutex{},
87
	}, nil
88
}
89
90
func (s *fakeOidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
91
	s.mu.RLock()
92
	defer s.mu.RUnlock()
93
94
	switch r.URL.Path {
95
	case "/.well-known/openid-configuration":
96
		s.responseWellKnown(w)
97
	case s.JWKSPath:
98
		s.responseJWKS(w)
99
	case s.authPath, s.tokenPath, s.userInfoPath:
100
		httpError(w, http.StatusNotFound)
101
	default:
102
		httpError(w, http.StatusNotFound)
103
	}
104
}
105
106
type providerJSON struct {
107
	Issuer      string   `json:"issuer"`
108
	AuthURL     string   `json:"authorization_endpoint"`
109
	TokenURL    string   `json:"token_endpoint"`
110
	JWKSURL     string   `json:"jwks_uri"`
111
	UserInfoURL string   `json:"userinfo_endpoint"`
112
	Algorithms  []string `json:"id_token_signing_alg_values_supported"`
113
}
114
115
func (s *fakeOidcProvider) responseWellKnown(w http.ResponseWriter) {
116
	jso := providerJSON{
117
		Issuer:      s.issuerURL,
118
		AuthURL:     s.issuerURL + s.authPath,
119
		TokenURL:    s.issuerURL + s.tokenPath,
120
		JWKSURL:     s.issuerURL + s.JWKSPath,
121
		UserInfoURL: s.issuerURL + s.userInfoPath,
122
		Algorithms:  s.algorithms,
123
	}
124
	httpJSON(w, jso)
125
}
126
127
func (s *fakeOidcProvider) responseJWKS(w http.ResponseWriter) {
128
	jwks := &jose.JSONWebKeySet{
129
		Keys: s.jwks,
130
	}
131
	httpJSON(w, jwks)
132
}
133
134
func httpJSON(w http.ResponseWriter, v interface{}) {
135
	w.Header().Set("Content-Type", "application/json")
136
	encoder := json.NewEncoder(w)
137
	encoder.SetIndent("", "  ")
138
	if err := encoder.Encode(v); err != nil {
139
		httpError(w, http.StatusInternalServerError)
140
	}
141
}
142
143
func httpError(w http.ResponseWriter, code int) {
144
	http.Error(w, http.StatusText(code), code)
145
}
146
147
func (s *fakeOidcProvider) UpdateKeyID(method jwt.SigningMethod, newKeyID string) {
148
	s.mu.Lock()
149
	defer s.mu.Unlock()
150
151
	s.signingKeyMap[method] = newKeyID
152
	for i, key := range s.jwks {
153
		if key.Algorithm == method.Alg() {
154
			s.jwks[i].KeyID = newKeyID
155
		}
156
	}
157
}
158
159
func (s *fakeOidcProvider) SignIDToken(unsignedToken *jwt.Token) (string, error) {
160
	var signedToken string
161
	var err error
162
163
	switch unsignedToken.Method {
164
	case jwt.SigningMethodRS256:
165
		signedToken, err = unsignedToken.SignedString(s.rsaPrivateKey)
166
	default:
167
		return "", fmt.Errorf("incorrect signing method type, supported algorithms: HS256, RS256, ES256, PS256")
168
	}
169
170
	if err != nil {
171
		return "", err
172
	}
173
174
	return signedToken, nil
175
}
176
177
func createUnsignedToken(regClaims jwt.RegisteredClaims, method jwt.SigningMethod) *jwt.Token {
178
	claims := struct {
179
		jwt.RegisteredClaims
180
	}{
181
		RegisteredClaims: regClaims,
182
	}
183
	return jwt.NewWithClaims(method, claims)
184
}
185
186
func fakeHttpServer(url string, handler http.HandlerFunc) (*httptest.Server, error) {
187
	listener, err := net.Listen("tcp", url)
188
	if err != nil {
189
		return nil, fmt.Errorf("failed to start listener: %w", err)
0 ignored issues
show
introduced by
unrecognized printf verb 'w'
Loading history...
190
	}
191
	testServer := httptest.NewUnstartedServer(handler)
192
	_ = testServer.Listener.Close()
193
	testServer.Listener = listener
194
	testServer.Start()
195
	return testServer, nil
196
}
197