Passed
Pull Request — master (#2575)
by Tolga
03:57
created

utils.createFinalSnapshot   C

Complexity

Conditions 9

Size

Total Lines 45
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 21
nop 2
dl 0
loc 45
rs 6.6666
c 0
b 0
f 0
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"strconv"
11
	"strings"
12
	"time"
13
14
	"github.com/jackc/pgx/v5/pgxpool"
15
	"go.opentelemetry.io/otel/codes"
16
17
	"go.opentelemetry.io/otel/trace"
18
19
	"github.com/Masterminds/squirrel"
20
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
const (
25
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id, snapshot`
26
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
27
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
28
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
29
30
	// ActiveRecordTxnID represents the maximum XID8 value used for active records
31
	// to avoid XID wraparound issues (instead of using 0)
32
	ActiveRecordTxnID = uint64(9223372036854775807)
33
	MaxXID8Value      = "'9223372036854775807'::xid8"
34
35
	// earliestPostgresVersion represents the earliest supported version of PostgreSQL is 13.8
36
	earliestPostgresVersion = 130008 // The earliest supported version of PostgreSQL is 13.8
37
)
38
39
// createFinalSnapshot creates a final snapshot string for proper transaction visibility.
40
// If xmax != xid, it adds xid to the xip_list to make the snapshot unique.
41
func createFinalSnapshot(snapshotValue string, xid uint64) string {
42
	// Parse snapshot: "xmin:xmax:xip_list"
43
	parts := strings.Split(snapshotValue, ":")
44
	if len(parts) < 2 {
45
		// Invalid snapshot format, return original
46
		return snapshotValue
47
	}
48
49
	// Parse xmax as integer for proper comparison
50
	xmaxStr := parts[1]
51
	xmax, err := strconv.ParseUint(xmaxStr, 10, 64)
52
	if err != nil {
53
		// If parsing fails, use original snapshot
54
		return snapshotValue
55
	}
56
57
	// If xmax == xid, no need to modify snapshot
58
	if xmax == xid {
59
		return snapshotValue
60
	}
61
62
	// If xid > xmax, this is invalid - xid should be visible in snapshot
63
	// Return original snapshot to avoid creating invalid format
64
	if xid > xmax {
65
		return snapshotValue
66
	}
67
68
	// xmax > xid, need to add xid to xip_list for uniqueness
69
	xmin := parts[0]
70
71
	// Check if xip_list exists and is not empty
72
	if len(parts) == 3 && parts[2] != "" {
73
		// Check if xid is already in xip_list
74
		xipList := parts[2]
75
		for xip := range strings.SplitSeq(xipList, ",") {
76
			if strings.TrimSpace(xip) == fmt.Sprintf("%d", xid) {
77
				// xid already in xip_list, return original snapshot
78
				return snapshotValue
79
			}
80
		}
81
		// xid not in xip_list, append it
82
		return fmt.Sprintf("%s:%s:%s,%d", xmin, xmaxStr, parts[2], xid)
83
	} else {
84
		// xip_list is empty, add xid
85
		return fmt.Sprintf("%s:%s:%d", xmin, xmaxStr, xid)
86
	}
87
}
88
89
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
90
// Optimized version with parameterized queries for security.
91
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64, snapshotValue string) squirrel.SelectBuilder {
92
	// Backward compatibility: if snapshot is empty, use old method
93
	if snapshotValue == "" {
94
		// Create a subquery for the snapshot associated with the provided value.
95
		snapshotQuery := "(select snapshot from transactions where id = ?::xid8)"
96
97
		// Records that were created and are visible in the snapshot
98
		createdWhere := squirrel.Or{
99
			squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", squirrel.Expr(snapshotQuery, value)),
100
			squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
101
		}
102
103
		// Records that are still active (not expired) at snapshot time
104
		expiredWhere := squirrel.And{
105
			squirrel.Or{
106
				squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", squirrel.Expr(snapshotQuery, value)),
107
				squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
108
			},
109
			squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
110
		}
111
112
		// Add the created and expired conditions to the SELECT query.
113
		return sl.Where(createdWhere).Where(expiredWhere)
114
	}
115
116
	// Create final snapshot with proper visibility
117
	finalSnapshot := createFinalSnapshot(snapshotValue, value)
118
119
	// Records that were created and are visible in the snapshot
120
	createdWhere := squirrel.Or{
121
		squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", finalSnapshot),
122
		squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
123
	}
124
125
	// Records that are still active (not expired) at snapshot time
126
	expiredWhere := squirrel.And{
127
		squirrel.Or{
128
			squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", finalSnapshot),
129
			squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
130
		},
131
		squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
132
	}
133
134
	// Add the created and expired conditions to the SELECT query.
135
	return sl.Where(createdWhere).Where(expiredWhere)
136
}
137
138
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
139
// It constructs a query to delete expired records from the specified table
140
// based on the provided value, which represents a transaction ID.
141
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
142
	// Create a Squirrel DELETE builder for the specified table.
143
	deleteBuilder := squirrel.Delete(table)
144
145
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
146
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
147
148
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
149
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
150
151
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
152
	return deleteBuilder.Where(expiredNotActiveExpr).Where(beforeExpr)
153
}
154
155
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
156
// It constructs a query to delete expired records from the specified table for a specific tenant
157
// based on the provided value, which represents a transaction ID.
158
func GenerateGCQueryForTenant(table, tenantID string, value uint64) squirrel.DeleteBuilder {
159
	// Create a Squirrel DELETE builder for the specified table.
160
	deleteBuilder := squirrel.Delete(table)
161
162
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
163
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
164
165
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
166
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
167
168
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
169
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredNotActiveExpr).Where(beforeExpr)
170
}
171
172
// HandleError records an error in the given span, logs the error, and returns a standardized error.
173
// This function is used for consistent error handling across different parts of the application.
174
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
175
	// Check if the error is context-related
176
	if IsContextRelatedError(ctx, err) {
177
		slog.DebugContext(ctx, "A context-related error occurred",
178
			slog.String("error", err.Error()))
179
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
180
	}
181
182
	// Check if the error is serialization-related
183
	if IsSerializationRelatedError(err) {
184
		slog.DebugContext(ctx, "A serialization-related error occurred",
185
			slog.String("error", err.Error()))
186
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
187
	}
188
189
	// For all other types of errors, log them at the error level and record them in the span
190
	slog.ErrorContext(ctx, "An operational error occurred",
191
		slog.Any("error", err))
192
	span.RecordError(err)
193
	span.SetStatus(codes.Error, err.Error())
194
195
	// Return a new error with the standard error code provided
196
	return errors.New(errorCode.String())
197
}
198
199
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
200
func IsContextRelatedError(ctx context.Context, err error) bool {
201
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
202
		return true
203
	}
204
	if errors.Is(err, context.Canceled) ||
205
		errors.Is(err, context.DeadlineExceeded) ||
206
		strings.Contains(err.Error(), "conn closed") {
207
		return true
208
	}
209
	return false
210
}
211
212
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
213
func IsSerializationRelatedError(err error) bool {
214
	if strings.Contains(err.Error(), "could not serialize") ||
215
		strings.Contains(err.Error(), "duplicate key value") {
216
		return true
217
	}
218
	return false
219
}
220
221
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
222
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
223
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
224
	// Calculate the base backoff with bit shifting for better performance
225
	baseBackoff := 20 * time.Millisecond
226
	if retries > 0 {
227
		// Use bit shifting instead of math.Pow for better performance
228
		shift := min(retries, 5) // Cap at 2^5 = 32, so max backoff is 640ms
229
		baseBackoff = baseBackoff << shift
230
	}
231
232
	// Cap at 1 second
233
	if baseBackoff > time.Second {
234
		baseBackoff = time.Second
235
	}
236
237
	// Generate jitter using crypto/rand
238
	jitter := time.Duration(secureRandomFloat64() * float64(baseBackoff) * 0.5)
239
	nextBackoff := baseBackoff + jitter
240
241
	// Log the retry wait
242
	slog.WarnContext(ctx, "waiting before retry",
243
		slog.String("tenant_id", tenantID),
244
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
245
246
	// Wait or exit on context cancellation
247
	select {
248
	case <-time.After(nextBackoff):
249
	case <-ctx.Done():
250
	}
251
}
252
253
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
254
// Optimized version with better error handling and performance.
255
func secureRandomFloat64() float64 {
256
	var b [8]byte
257
	if _, err := rand.Read(b[:]); err != nil {
258
		// Use a fallback random value instead of 0 to maintain jitter
259
		return 0.5 // Middle value for consistent jitter behavior
260
	}
261
	// Use bit shifting instead of division for better performance
262
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 63) / 2.0
263
}
264
265
// EnsureDBVersion checks the version of the given database connection
266
// and returns an error if the version is not supported.
267
func EnsureDBVersion(db *pgxpool.Pool) (version string, err error) {
268
	err = db.QueryRow(context.Background(), "SHOW server_version_num;").Scan(&version)
269
	if err != nil {
270
		return version, err
271
	}
272
	v, err := strconv.Atoi(version)
273
	if v < earliestPostgresVersion {
274
		err = fmt.Errorf("unsupported postgres version: %s, expected >= %d", version, earliestPostgresVersion)
275
	}
276
	return version, err
277
}
278