Passed
Pull Request — master (#2577)
by Tolga
03:20
created

utils.WaitWithBackoff   A

Complexity

Conditions 5

Size

Total Lines 27
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 15
nop 3
dl 0
loc 27
rs 9.1832
c 0
b 0
f 0
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"strconv"
11
	"strings"
12
	"time"
13
14
	"github.com/jackc/pgx/v5/pgxpool"
15
	"go.opentelemetry.io/otel/codes"
16
17
	"go.opentelemetry.io/otel/trace"
18
19
	"github.com/Masterminds/squirrel"
20
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
const (
25
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id, snapshot`
26
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
27
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
28
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
29
30
	// ActiveRecordTxnID represents the maximum XID8 value used for active records
31
	// to avoid XID wraparound issues (instead of using 0)
32
	ActiveRecordTxnID = uint64(9223372036854775807)
33
	MaxXID8Value      = "'9223372036854775807'::xid8"
34
35
	// earliestPostgresVersion represents the earliest supported version of PostgreSQL is 13.8
36
	earliestPostgresVersion = 130008 // The earliest supported version of PostgreSQL is 13.8
37
)
38
39
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
40
// Optimized version with parameterized queries for security.
41
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64, snapshotValue string) squirrel.SelectBuilder {
42
	slog.Info("SnapshotQuery called", slog.Uint64("xid", value), slog.String("snapshot", snapshotValue))
43
	// Backward compatibility: if snapshot is empty, use old method
44
	if snapshotValue == "" {
45
		// Create a subquery for the snapshot associated with the provided value.
46
		snapshotQuery := "(select snapshot from transactions where id = ?::xid8)"
47
48
		// Records that were created and are visible in the snapshot
49
		createdWhere := squirrel.Or{
50
			squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", squirrel.Expr(snapshotQuery, value)),
51
			squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
52
		}
53
54
		// Records that are still active (not expired) at snapshot time
55
		expiredWhere := squirrel.And{
56
			squirrel.Or{
57
				squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", squirrel.Expr(snapshotQuery, value)),
58
				squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
59
			},
60
			squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
61
		}
62
63
		// Add the created and expired conditions to the SELECT query.
64
		return sl.Where(createdWhere).Where(expiredWhere)
65
	}
66
67
	// Records that were created and are visible in the snapshot
68
	createdWhere := squirrel.Or{
69
		squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", snapshotValue),
70
		squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
71
	}
72
73
	// Records that are still active (not expired) at snapshot time
74
	expiredWhere := squirrel.And{
75
		squirrel.Or{
76
			squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", snapshotValue),
77
			squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
78
		},
79
		squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
80
	}
81
82
	// Add the created and expired conditions to the SELECT query.
83
	return sl.Where(createdWhere).Where(expiredWhere)
84
}
85
86
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
87
// It constructs a query to delete expired records from the specified table
88
// based on the provided value, which represents a transaction ID.
89
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
90
	// Create a Squirrel DELETE builder for the specified table.
91
	deleteBuilder := squirrel.Delete(table)
92
93
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
94
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
95
96
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
97
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
98
99
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
100
	return deleteBuilder.Where(expiredNotActiveExpr).Where(beforeExpr)
101
}
102
103
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
104
// It constructs a query to delete expired records from the specified table for a specific tenant
105
// based on the provided value, which represents a transaction ID.
106
func GenerateGCQueryForTenant(table, tenantID string, value uint64) squirrel.DeleteBuilder {
107
	// Create a Squirrel DELETE builder for the specified table.
108
	deleteBuilder := squirrel.Delete(table)
109
110
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
111
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
112
113
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
114
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
115
116
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
117
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredNotActiveExpr).Where(beforeExpr)
118
}
119
120
// HandleError records an error in the given span, logs the error, and returns a standardized error.
121
// This function is used for consistent error handling across different parts of the application.
122
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
123
	// Check if the error is context-related
124
	if IsContextRelatedError(ctx, err) {
125
		slog.DebugContext(ctx, "A context-related error occurred",
126
			slog.String("error", err.Error()))
127
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
128
	}
129
130
	// Check if the error is serialization-related
131
	if IsSerializationRelatedError(err) {
132
		slog.DebugContext(ctx, "A serialization-related error occurred",
133
			slog.String("error", err.Error()))
134
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
135
	}
136
137
	// For all other types of errors, log them at the error level and record them in the span
138
	slog.ErrorContext(ctx, "An operational error occurred",
139
		slog.Any("error", err))
140
	span.RecordError(err)
141
	span.SetStatus(codes.Error, err.Error())
142
143
	// Return a new error with the standard error code provided
144
	return errors.New(errorCode.String())
145
}
146
147
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
148
func IsContextRelatedError(ctx context.Context, err error) bool {
149
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
150
		return true
151
	}
152
	if errors.Is(err, context.Canceled) ||
153
		errors.Is(err, context.DeadlineExceeded) ||
154
		strings.Contains(err.Error(), "conn closed") {
155
		return true
156
	}
157
	return false
158
}
159
160
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
161
func IsSerializationRelatedError(err error) bool {
162
	if strings.Contains(err.Error(), "could not serialize") ||
163
		strings.Contains(err.Error(), "duplicate key value") {
164
		return true
165
	}
166
	return false
167
}
168
169
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
170
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
171
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
172
	// Calculate the base backoff with bit shifting for better performance
173
	baseBackoff := 20 * time.Millisecond
174
	if retries > 0 {
175
		// Use bit shifting instead of math.Pow for better performance
176
		shift := min(retries, 5) // Cap at 2^5 = 32, so max backoff is 640ms
177
		baseBackoff = baseBackoff << shift
178
	}
179
180
	// Cap at 1 second
181
	if baseBackoff > time.Second {
182
		baseBackoff = time.Second
183
	}
184
185
	// Generate jitter using crypto/rand
186
	jitter := time.Duration(secureRandomFloat64() * float64(baseBackoff) * 0.5)
187
	nextBackoff := baseBackoff + jitter
188
189
	// Log the retry wait
190
	slog.WarnContext(ctx, "waiting before retry",
191
		slog.String("tenant_id", tenantID),
192
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
193
194
	// Wait or exit on context cancellation
195
	select {
196
	case <-time.After(nextBackoff):
197
	case <-ctx.Done():
198
	}
199
}
200
201
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
202
// Optimized version with better error handling and performance.
203
func secureRandomFloat64() float64 {
204
	var b [8]byte
205
	if _, err := rand.Read(b[:]); err != nil {
206
		// Use a fallback random value instead of 0 to maintain jitter
207
		return 0.5 // Middle value for consistent jitter behavior
208
	}
209
	// Use bit shifting instead of division for better performance
210
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 63) / 2.0
211
}
212
213
// EnsureDBVersion checks the version of the given database connection
214
// and returns an error if the version is not supported.
215
func EnsureDBVersion(db *pgxpool.Pool) (version string, err error) {
216
	err = db.QueryRow(context.Background(), "SHOW server_version_num;").Scan(&version)
217
	if err != nil {
218
		return version, err
219
	}
220
	v, err := strconv.Atoi(version)
221
	if v < earliestPostgresVersion {
222
		err = fmt.Errorf("unsupported postgres version: %s, expected >= %d", version, earliestPostgresVersion)
223
	}
224
	return version, err
225
}
226