Passed
Push — main ( 2cb7aa...7c8444 )
by Acho
01:28
created

repositories.*gormUserRepository.LoadAuthUser   B

Complexity

Conditions 5

Size

Total Lines 32
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 21
nop 2
dl 0
loc 32
rs 8.9093
c 0
b 0
f 0
1
package repositories
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/base64"
7
	"errors"
8
	"fmt"
9
	"time"
10
11
	"gorm.io/gorm/clause"
12
13
	"github.com/cockroachdb/cockroach-go/v2/crdb/crdbgorm"
14
	"github.com/dgraph-io/ristretto"
15
16
	"github.com/NdoleStudio/httpsms/pkg/entities"
17
	"github.com/NdoleStudio/httpsms/pkg/telemetry"
18
	"github.com/palantir/stacktrace"
19
	"gorm.io/gorm"
20
)
21
22
// gormUserRepository is responsible for persisting entities.User
23
type gormUserRepository struct {
24
	logger telemetry.Logger
25
	tracer telemetry.Tracer
26
	cache  *ristretto.Cache[string, entities.AuthUser]
27
	db     *gorm.DB
28
}
29
30
// NewGormUserRepository creates the GORM version of the UserRepository
31
func NewGormUserRepository(
32
	logger telemetry.Logger,
33
	tracer telemetry.Tracer,
34
	cache *ristretto.Cache[string, entities.AuthUser],
35
	db *gorm.DB,
36
) UserRepository {
37
	return &gormUserRepository{
38
		logger: logger.WithService(fmt.Sprintf("%T", &gormUserRepository{})),
39
		tracer: tracer,
40
		cache:  cache,
41
		db:     db,
42
	}
43
}
44
45
func (repository *gormUserRepository) RotateAPIKey(ctx context.Context, userID entities.UserID) (*entities.User, error) {
46
	ctx, span := repository.tracer.Start(ctx)
47
	defer span.End()
48
49
	apiKey, err := repository.generateAPIKey(64)
50
	if err != nil {
51
		return nil, stacktrace.Propagate(err, fmt.Sprintf("cannot generate apiKey for user [%s]", userID))
52
	}
53
54
	user := new(entities.User)
55
	err = crdbgorm.ExecuteTx(ctx, repository.db, nil,
56
		func(tx *gorm.DB) error {
57
			return tx.WithContext(ctx).Model(user).
58
				Clauses(clause.Returning{}).
59
				Where("id = ?", userID).
60
				Update("api_key", apiKey).Error
61
		},
62
	)
63
	if errors.Is(err, gorm.ErrRecordNotFound) {
64
		msg := fmt.Sprintf("user with ID [%s] does not exist", userID)
65
		return nil, repository.tracer.WrapErrorSpan(span, stacktrace.PropagateWithCode(err, ErrCodeNotFound, msg))
66
	}
67
68
	return user, nil
69
}
70
71
func (repository *gormUserRepository) LoadBySubscriptionID(ctx context.Context, subscriptionID string) (*entities.User, error) {
72
	ctx, span := repository.tracer.Start(ctx)
73
	defer span.End()
74
75
	user := new(entities.User)
76
	err := repository.db.WithContext(ctx).
77
		Where("subscription_id = ?", subscriptionID).
78
		First(user).
79
		Error
80
	if errors.Is(err, gorm.ErrRecordNotFound) {
81
		msg := fmt.Sprintf("user with subscriptionID [%s] does not exist", subscriptionID)
82
		return nil, repository.tracer.WrapErrorSpan(span, stacktrace.PropagateWithCode(err, ErrCodeNotFound, msg))
83
	}
84
85
	if err != nil {
86
		msg := fmt.Sprintf("cannot load user with subscription ID [%s]", subscriptionID)
87
		return nil, repository.tracer.WrapErrorSpan(span, stacktrace.Propagate(err, msg))
88
	}
89
90
	return user, nil
91
}
92
93
func (repository *gormUserRepository) Store(ctx context.Context, user *entities.User) error {
94
	ctx, span := repository.tracer.Start(ctx)
95
	defer span.End()
96
97
	if err := repository.db.WithContext(ctx).Create(user).Error; err != nil {
98
		msg := fmt.Sprintf("cannot save user with ID [%s]", user.ID)
99
		return repository.tracer.WrapErrorSpan(span, stacktrace.Propagate(err, msg))
100
	}
101
102
	return nil
103
}
104
105
func (repository *gormUserRepository) Update(ctx context.Context, user *entities.User) error {
106
	ctx, span := repository.tracer.Start(ctx)
107
	defer span.End()
108
109
	if err := repository.db.WithContext(ctx).Save(user).Error; err != nil {
110
		msg := fmt.Sprintf("cannot update user with ID [%s]", user.ID)
111
		return repository.tracer.WrapErrorSpan(span, stacktrace.Propagate(err, msg))
112
	}
113
114
	return nil
115
}
116
117
func (repository *gormUserRepository) LoadAuthUser(ctx context.Context, apiKey string) (entities.AuthUser, error) {
118
	ctx, span, ctxLogger := repository.tracer.StartWithLogger(ctx, repository.logger)
119
	defer span.End()
120
121
	if authUser, found := repository.cache.Get(apiKey); found {
122
		ctxLogger.Info(fmt.Sprintf("cache hit for user with ID [%s]", authUser.ID))
123
		return authUser, nil
124
	}
125
126
	user := new(entities.User)
127
	err := repository.db.WithContext(ctx).Where("api_key = ?", apiKey).First(user).Error
128
	if errors.Is(err, gorm.ErrRecordNotFound) {
129
		msg := fmt.Sprintf("user with api key [%s] does not exist", apiKey)
130
		return entities.AuthUser{}, repository.tracer.WrapErrorSpan(span, stacktrace.PropagateWithCode(err, ErrCodeNotFound, msg))
131
	}
132
133
	if err != nil {
134
		msg := fmt.Sprintf("cannot load user with api key [%s]", apiKey)
135
		return entities.AuthUser{}, repository.tracer.WrapErrorSpan(span, stacktrace.Propagate(err, msg))
136
	}
137
138
	authUser := entities.AuthUser{
139
		ID:    user.ID,
140
		Email: user.Email,
141
	}
142
143
	if result := repository.cache.SetWithTTL(apiKey, authUser, 1, 2*time.Hour); !result {
144
		msg := fmt.Sprintf("cannot cache [%T] with ID [%s] and result [%t]", authUser, user.ID, result)
145
		ctxLogger.Error(repository.tracer.WrapErrorSpan(span, stacktrace.NewError(msg)))
146
	}
147
148
	return authUser, nil
149
}
150
151
func (repository *gormUserRepository) Load(ctx context.Context, userID entities.UserID) (*entities.User, error) {
152
	ctx, span := repository.tracer.Start(ctx)
153
	defer span.End()
154
155
	user := new(entities.User)
156
	err := repository.db.WithContext(ctx).First(user, userID).Error
157
	if errors.Is(err, gorm.ErrRecordNotFound) {
158
		msg := fmt.Sprintf("user with ID [%s] does not exist", user.ID)
159
		return nil, repository.tracer.WrapErrorSpan(span, stacktrace.PropagateWithCode(err, ErrCodeNotFound, msg))
160
	}
161
162
	if err != nil {
163
		msg := fmt.Sprintf("cannot load user with ID [%s]", userID)
164
		return nil, repository.tracer.WrapErrorSpan(span, stacktrace.Propagate(err, msg))
165
	}
166
167
	return user, nil
168
}
169
170
func (repository *gormUserRepository) LoadOrStore(ctx context.Context, authUser entities.AuthUser) (*entities.User, bool, error) {
171
	ctx, span := repository.tracer.Start(ctx)
172
	defer span.End()
173
174
	user, err := repository.Load(ctx, authUser.ID)
175
	if err == nil {
176
		return user, false, nil
177
	}
178
179
	apiKey, err := repository.generateAPIKey(64)
180
	if err != nil {
181
		return nil, false, stacktrace.Propagate(err, fmt.Sprintf("cannot generate apiKey for user [%s]", authUser.ID))
182
	}
183
184
	user = &entities.User{
185
		ID:               authUser.ID,
186
		Email:            authUser.Email,
187
		APIKey:           apiKey,
188
		SubscriptionName: entities.SubscriptionNameFree,
189
		CreatedAt:        time.Now().UTC(),
190
		UpdatedAt:        time.Now().UTC(),
191
	}
192
193
	isNew := false
194
	err = crdbgorm.ExecuteTx(ctx, repository.db, nil, func(tx *gorm.DB) error {
195
		result := tx.WithContext(ctx).Where(entities.User{ID: user.ID}).FirstOrCreate(user)
196
		if result.Error != nil {
197
			return result.Error
198
		}
199
		if result.RowsAffected > 0 {
200
			isNew = true
201
		}
202
		return result.Error
203
	})
204
	if err != nil {
205
		msg := fmt.Sprintf("cannot create user from auth user [%+#v]", authUser)
206
		return user, isNew, repository.tracer.WrapErrorSpan(span, stacktrace.Propagate(err, msg))
207
	}
208
209
	return user, isNew, nil
210
}
211
212
// generateRandomBytes returns securely generated random bytes.
213
// It will return an error if the system's secure random
214
// number generator fails to function correctly, in which
215
// case the caller should not continue.
216
func (repository *gormUserRepository) generateRandomBytes(n int) ([]byte, error) {
217
	b := make([]byte, n)
218
	// Note that err == nil only if we read len(b) bytes.
219
	if _, err := rand.Read(b); err != nil {
220
		return nil, stacktrace.Propagate(err, fmt.Sprintf("cannot generate [%d] random bytes", n))
221
	}
222
223
	return b, nil
224
}
225
226
// generateAPIKey returns a URL-safe, base64 encoded
227
// securely generated random string.
228
// It will return an error if the system's secure random
229
// number generator fails to function correctly, in which
230
// case the caller should not continue.
231
func (repository *gormUserRepository) generateAPIKey(n int) (string, error) {
232
	b, err := repository.generateRandomBytes(n)
233
	return base64.URLEncoding.EncodeToString(b)[0:n], stacktrace.Propagate(err, "cannot generate random bytes")
234
}
235