Passed
Push — master ( fded09...ba6f13 )
by Tolga
01:20 queued 14s
created

internal/storage/postgres/utils/common.go   A

Size/Duplication

Total Lines 184
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
cc 20
eloc 89
dl 0
loc 184
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A utils.secureRandomFloat64 0 6 2
B utils.IsContextRelatedError 0 10 6
A utils.GenerateGCQuery 0 15 1
A utils.IsSerializationRelatedError 0 6 3
A utils.snapshotQuery 0 25 1
A utils.WaitWithBackoff 0 17 3
A utils.SnapshotQuery 0 25 1
A utils.HandleError 0 23 3
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"math"
11
	"strings"
12
	"time"
13
14
	"go.opentelemetry.io/otel/codes"
15
16
	"go.opentelemetry.io/otel/trace"
17
18
	"github.com/Masterminds/squirrel"
19
20
	base "github.com/Permify/permify/pkg/pb/base/v1"
21
)
22
23
const (
24
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id`
25
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
26
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
27
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
28
)
29
30
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
31
// The query checks if transactions are visible in a snapshot associated with the provided value.
32
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64) squirrel.SelectBuilder {
33
	// Convert the value to a string once to reduce redundant calls to fmt.Sprintf.
34
	valStr := fmt.Sprintf("'%v'::xid8", value)
35
36
	// Create a subquery for the snapshot associated with the provided value.
37
	snapshotQuery := fmt.Sprintf("(select snapshot from transactions where id = %s)", valStr)
38
39
	// Create an expression to check if a transaction with a specific created_tx_id is visible in the snapshot.
40
	visibilityExpr := squirrel.Expr(fmt.Sprintf("pg_visible_in_snapshot(created_tx_id, %s) = true", snapshotQuery))
41
	// Create an expression to check if the created_tx_id is equal to the provided value.
42
	createdExpr := squirrel.Expr(fmt.Sprintf("created_tx_id = %s", valStr))
43
	// Use OR condition for the created expressions.
44
	createdWhere := squirrel.Or{visibilityExpr, createdExpr}
45
46
	// Create an expression to check if a transaction with a specific expired_tx_id is not visible in the snapshot.
47
	expiredVisibilityExpr := squirrel.Expr(fmt.Sprintf("pg_visible_in_snapshot(expired_tx_id, %s) = false", snapshotQuery))
48
	// Create an expression to check if the expired_tx_id is equal to zero.
49
	expiredZeroExpr := squirrel.Expr("expired_tx_id = '0'::xid8")
50
	// Create an expression to check if the expired_tx_id is not equal to the provided value.
51
	expiredNotExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id <> %s", valStr))
52
	// Use AND condition for the expired expressions, checking both visibility and non-equality with value.
53
	expiredWhere := squirrel.And{squirrel.Or{expiredVisibilityExpr, expiredZeroExpr}, expiredNotExpr}
54
55
	// Add the created and expired conditions to the SELECT query.
56
	return sl.Where(createdWhere).Where(expiredWhere)
57
}
58
59
// snapshotQuery function generates two strings representing conditions to be applied in a SQL query to filter data based on visibility of transactions.
60
func snapshotQuery(value uint64) (string, string) {
61
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
62
	valStr := fmt.Sprintf("'%v'::xid8", value)
63
64
	// Create a subquery that fetches the snapshot associated with the transaction ID.
65
	snapshotQ := fmt.Sprintf("(SELECT snapshot FROM transactions WHERE id = %s)", valStr)
66
67
	// Create an expression that checks whether a transaction (represented by 'created_tx_id') is visible in the snapshot.
68
	visibilityExpr := fmt.Sprintf("pg_visible_in_snapshot(created_tx_id, %s) = true", snapshotQ)
69
	// Create an expression that checks if the 'created_tx_id' is the same as our transaction ID.
70
	createdExpr := fmt.Sprintf("created_tx_id = %s", valStr)
71
	// Combine these expressions to form a condition. A row will satisfy this condition if either of the expressions are true.
72
	createdWhere := fmt.Sprintf("(%s OR %s)", visibilityExpr, createdExpr)
73
74
	// Create an expression that checks whether a transaction (represented by 'expired_tx_id') is not visible in the snapshot.
75
	expiredVisibilityExpr := fmt.Sprintf("pg_visible_in_snapshot(expired_tx_id, %s) = false", snapshotQ)
76
	// Create an expression that checks if the 'expired_tx_id' is zero. This handles cases where the transaction hasn't expired.
77
	expiredZeroExpr := "expired_tx_id = '0'::xid8"
78
	// Create an expression that checks if the 'expired_tx_id' is not the same as our transaction ID.
79
	expiredNotExpr := fmt.Sprintf("expired_tx_id <> %s", valStr)
80
	// Combine these expressions to form a condition. A row will satisfy this condition if the first set of expressions are true (either the transaction hasn't expired, or if it has, it's not visible in the snapshot) and the second expression is also true (the 'expired_tx_id' is not the same as our transaction ID).
81
	expiredWhere := fmt.Sprintf("(%s AND %s)", fmt.Sprintf("(%s OR %s)", expiredVisibilityExpr, expiredZeroExpr), expiredNotExpr)
82
83
	// Return the conditions for both 'created' and 'expired' transactions. These can be used in a WHERE clause of a SQL query to filter results.
84
	return createdWhere, expiredWhere
85
}
86
87
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
88
// It constructs a query to delete expired records from the specified table
89
// based on the provided value, which represents a transaction ID.
90
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
91
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
92
	valStr := fmt.Sprintf("'%v'::xid8", value)
93
94
	// Create a Squirrel DELETE builder for the specified table.
95
	deleteBuilder := squirrel.Delete(table)
96
97
	// Create an expression to check if 'expired_tx_id' is not equal to '0' (not expired).
98
	expiredZeroExpr := squirrel.Expr("expired_tx_id <> '0'::xid8")
99
100
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
101
	beforeExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id < %s", valStr))
102
103
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
104
	return deleteBuilder.Where(expiredZeroExpr).Where(beforeExpr)
105
}
106
107
// HandleError records an error in the given span, logs the error, and returns a standardized error.
108
// This function is used for consistent error handling across different parts of the application.
109
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
110
	// Check if the error is context-related
111
	if IsContextRelatedError(ctx, err) {
112
		slog.DebugContext(ctx, "A context-related error occurred",
113
			slog.String("error", err.Error()))
114
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
115
	}
116
117
	// Check if the error is serialization-related
118
	if IsSerializationRelatedError(err) {
119
		slog.DebugContext(ctx, "A serialization-related error occurred",
120
			slog.String("error", err.Error()))
121
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
122
	}
123
124
	// For all other types of errors, log them at the error level and record them in the span
125
	slog.ErrorContext(ctx, "An operational error occurred",
126
		slog.Any("error", err))
127
	span.RecordError(err)
128
	span.SetStatus(codes.Error, err.Error())
129
130
	// Return a new error with the standard error code provided
131
	return errors.New(errorCode.String())
132
}
133
134
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
135
func IsContextRelatedError(ctx context.Context, err error) bool {
136
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
137
		return true
138
	}
139
	if errors.Is(err, context.Canceled) ||
140
		errors.Is(err, context.DeadlineExceeded) ||
141
		strings.Contains(err.Error(), "conn closed") {
142
		return true
143
	}
144
	return false
145
}
146
147
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
148
func IsSerializationRelatedError(err error) bool {
149
	if strings.Contains(err.Error(), "could not serialize") ||
150
		strings.Contains(err.Error(), "duplicate key value") {
151
		return true
152
	}
153
	return false
154
}
155
156
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
157
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
158
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
159
	// Calculate the base backoff
160
	backoff := time.Duration(math.Min(float64(20*time.Millisecond)*math.Pow(2, float64(retries)), float64(1*time.Second)))
161
162
	// Generate jitter using crypto/rand
163
	jitter := time.Duration(secureRandomFloat64() * float64(backoff) * 0.5)
164
	nextBackoff := backoff + jitter
165
166
	// Log the retry wait
167
	slog.WarnContext(ctx, "waiting before retry",
168
		slog.String("tenant_id", tenantID),
169
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
170
171
	// Wait or exit on context cancellation
172
	select {
173
	case <-time.After(nextBackoff):
174
	case <-ctx.Done():
175
	}
176
}
177
178
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
179
func secureRandomFloat64() float64 {
180
	var b [8]byte
181
	if _, err := rand.Read(b[:]); err != nil {
182
		return 0 // Default to 0 jitter on error
183
	}
184
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 64)
185
}
186