Passed
Push — master ( 3b8832...d0e8fc )
by Tolga
01:28
created

utils.IsSerializationRelatedError   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"strconv"
11
	"strings"
12
	"time"
13
14
	"github.com/jackc/pgx/v5/pgxpool"
15
	"go.opentelemetry.io/otel/codes"
16
17
	"go.opentelemetry.io/otel/trace"
18
19
	"github.com/Masterminds/squirrel"
20
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
const (
25
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id, snapshot`
26
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
27
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
28
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
29
30
	// ActiveRecordTxnID represents the maximum XID8 value used for active records
31
	// to avoid XID wraparound issues (instead of using 0)
32
	ActiveRecordTxnID = uint64(9223372036854775807)
33
	MaxXID8Value      = "'9223372036854775807'::xid8"
34
35
	// earliestPostgresVersion represents the earliest supported version of PostgreSQL is 13.8
36
	earliestPostgresVersion = 130008 // The earliest supported version of PostgreSQL is 13.8
37
)
38
39
// createFinalSnapshot creates a final snapshot string for proper transaction visibility.
40
// If xmax != xid, it adds xid to the xip_list to make the snapshot unique.
41
func createFinalSnapshot(snapshotValue string, xid uint64) string {
42
	// Parse snapshot: "xmin:xmax:xip_list"
43
	parts := strings.SplitN(strings.TrimSpace(snapshotValue), ":", 3)
44
	if len(parts) < 2 {
45
		return snapshotValue
46
	}
47
48
	xminStr, xmaxStr := parts[0], parts[1]
49
50
	// Parse xmin and xmax for range validation
51
	xmin, err := strconv.ParseUint(xminStr, 10, 64)
52
	if err != nil {
53
		return snapshotValue
54
	}
55
	xmax, err := strconv.ParseUint(xmaxStr, 10, 64)
56
	if err != nil {
57
		return snapshotValue
58
	}
59
60
	// If xmax == xid, no need to modify snapshot
61
	if xmax == xid {
62
		return snapshotValue
63
	}
64
65
	// Validate xid is in valid range [xmin, xmax)
66
	if xid < xmin || xid >= xmax {
67
		return snapshotValue
68
	}
69
70
	// Parse existing xip_list
71
	var xips []uint64
72
	if len(parts) == 3 && parts[2] != "" {
73
		for _, xipStr := range strings.Split(parts[2], ",") {
74
			xipStr = strings.TrimSpace(xipStr)
75
			if xipStr == "" {
76
				continue
77
			}
78
			xip, err := strconv.ParseUint(xipStr, 10, 64)
79
			if err != nil {
80
				return snapshotValue
81
			}
82
			// Check if xid is already in xip_list
83
			if xip == xid {
84
				return snapshotValue
85
			}
86
			xips = append(xips, xip)
87
		}
88
	}
89
90
	// Add xid to the list and sort it
91
	xips = append(xips, xid)
92
	sortXips(xips)
93
94
	// Rebuild xip_list string
95
	var xipStrs []string
96
	for _, xip := range xips {
97
		xipStrs = append(xipStrs, fmt.Sprintf("%d", xip))
98
	}
99
	return fmt.Sprintf("%s:%s:%s", xminStr, xmaxStr, strings.Join(xipStrs, ","))
100
}
101
102
// sortXips sorts a slice of xip values in ascending order
103
func sortXips(xips []uint64) {
104
	for i := 0; i < len(xips)-1; i++ {
105
		for j := i + 1; j < len(xips); j++ {
106
			if xips[i] > xips[j] {
107
				xips[i], xips[j] = xips[j], xips[i]
108
			}
109
		}
110
	}
111
}
112
113
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
114
// Optimized version with parameterized queries for security.
115
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64, snapshotValue string) squirrel.SelectBuilder {
116
	// Backward compatibility: if snapshot is empty, use old method
117
	if snapshotValue == "" {
118
		// Create a subquery for the snapshot associated with the provided value.
119
		snapshotQuery := "(select snapshot from transactions where id = ?::xid8)"
120
121
		// Records that were created and are visible in the snapshot
122
		createdWhere := squirrel.Or{
123
			squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", squirrel.Expr(snapshotQuery, value)),
124
			squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
125
		}
126
127
		// Records that are still active (not expired) at snapshot time
128
		expiredWhere := squirrel.And{
129
			squirrel.Or{
130
				squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", squirrel.Expr(snapshotQuery, value)),
131
				squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
132
			},
133
			squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
134
		}
135
136
		// Add the created and expired conditions to the SELECT query.
137
		return sl.Where(createdWhere).Where(expiredWhere)
138
	}
139
140
	// Create final snapshot with proper visibility
141
	finalSnapshot := createFinalSnapshot(snapshotValue, value)
142
143
	// Records that were created and are visible in the snapshot
144
	createdWhere := squirrel.Or{
145
		squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", finalSnapshot),
146
		squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
147
	}
148
149
	// Records that are still active (not expired) at snapshot time
150
	expiredWhere := squirrel.And{
151
		squirrel.Or{
152
			squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", finalSnapshot),
153
			squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
154
		},
155
		squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
156
	}
157
158
	// Add the created and expired conditions to the SELECT query.
159
	return sl.Where(createdWhere).Where(expiredWhere)
160
}
161
162
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
163
// It constructs a query to delete expired records from the specified table
164
// based on the provided value, which represents a transaction ID.
165
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
166
	// Create a Squirrel DELETE builder for the specified table.
167
	deleteBuilder := squirrel.Delete(table)
168
169
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
170
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
171
172
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
173
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
174
175
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
176
	return deleteBuilder.Where(expiredNotActiveExpr).Where(beforeExpr)
177
}
178
179
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
180
// It constructs a query to delete expired records from the specified table for a specific tenant
181
// based on the provided value, which represents a transaction ID.
182
func GenerateGCQueryForTenant(table, tenantID string, value uint64) squirrel.DeleteBuilder {
183
	// Create a Squirrel DELETE builder for the specified table.
184
	deleteBuilder := squirrel.Delete(table)
185
186
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
187
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
188
189
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
190
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
191
192
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
193
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredNotActiveExpr).Where(beforeExpr)
194
}
195
196
// HandleError records an error in the given span, logs the error, and returns a standardized error.
197
// This function is used for consistent error handling across different parts of the application.
198
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
199
	// Check if the error is context-related
200
	if IsContextRelatedError(ctx, err) {
201
		slog.DebugContext(ctx, "A context-related error occurred",
202
			slog.String("error", err.Error()))
203
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
204
	}
205
206
	// Check if the error is serialization-related
207
	if IsSerializationRelatedError(err) {
208
		slog.DebugContext(ctx, "A serialization-related error occurred",
209
			slog.String("error", err.Error()))
210
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
211
	}
212
213
	// For all other types of errors, log them at the error level and record them in the span
214
	slog.ErrorContext(ctx, "An operational error occurred",
215
		slog.Any("error", err))
216
	span.RecordError(err)
217
	span.SetStatus(codes.Error, err.Error())
218
219
	// Return a new error with the standard error code provided
220
	return errors.New(errorCode.String())
221
}
222
223
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
224
func IsContextRelatedError(ctx context.Context, err error) bool {
225
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
226
		return true
227
	}
228
	if errors.Is(err, context.Canceled) ||
229
		errors.Is(err, context.DeadlineExceeded) ||
230
		strings.Contains(err.Error(), "conn closed") {
231
		return true
232
	}
233
	return false
234
}
235
236
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
237
func IsSerializationRelatedError(err error) bool {
238
	if strings.Contains(err.Error(), "could not serialize") ||
239
		strings.Contains(err.Error(), "duplicate key value") {
240
		return true
241
	}
242
	return false
243
}
244
245
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
246
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
247
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
248
	// Calculate the base backoff with bit shifting for better performance
249
	baseBackoff := 20 * time.Millisecond
250
	if retries > 0 {
251
		// Use bit shifting instead of math.Pow for better performance
252
		shift := min(retries, 5) // Cap at 2^5 = 32, so max backoff is 640ms
253
		baseBackoff = baseBackoff << shift
254
	}
255
256
	// Cap at 1 second
257
	if baseBackoff > time.Second {
258
		baseBackoff = time.Second
259
	}
260
261
	// Generate jitter using crypto/rand
262
	jitter := time.Duration(secureRandomFloat64() * float64(baseBackoff) * 0.5)
263
	nextBackoff := baseBackoff + jitter
264
265
	// Log the retry wait
266
	slog.WarnContext(ctx, "waiting before retry",
267
		slog.String("tenant_id", tenantID),
268
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
269
270
	// Wait or exit on context cancellation
271
	select {
272
	case <-time.After(nextBackoff):
273
	case <-ctx.Done():
274
	}
275
}
276
277
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
278
// Optimized version with better error handling and performance.
279
func secureRandomFloat64() float64 {
280
	var b [8]byte
281
	if _, err := rand.Read(b[:]); err != nil {
282
		// Use a fallback random value instead of 0 to maintain jitter
283
		return 0.5 // Middle value for consistent jitter behavior
284
	}
285
	// Use bit shifting instead of division for better performance
286
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 63) / 2.0
287
}
288
289
// EnsureDBVersion checks the version of the given database connection
290
// and returns an error if the version is not supported.
291
func EnsureDBVersion(db *pgxpool.Pool) (version string, err error) {
292
	err = db.QueryRow(context.Background(), "SHOW server_version_num;").Scan(&version)
293
	if err != nil {
294
		return version, err
295
	}
296
	v, err := strconv.Atoi(version)
297
	if v < earliestPostgresVersion {
298
		err = fmt.Errorf("unsupported postgres version: %s, expected >= %d", version, earliestPostgresVersion)
299
	}
300
	return version, err
301
}
302