Passed
Pull Request — master (#2576)
by Tolga
03:47
created

utils.HandleError   A

Complexity

Conditions 3

Size

Total Lines 23
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 14
nop 4
dl 0
loc 23
rs 9.7
c 0
b 0
f 0
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"strconv"
11
	"strings"
12
	"time"
13
14
	"github.com/jackc/pgx/v5/pgxpool"
15
	"go.opentelemetry.io/otel/codes"
16
17
	"go.opentelemetry.io/otel/trace"
18
19
	"github.com/Masterminds/squirrel"
20
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
const (
25
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id, snapshot`
26
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
27
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
28
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
29
30
	// ActiveRecordTxnID represents the maximum XID8 value used for active records
31
	// to avoid XID wraparound issues (instead of using 0)
32
	ActiveRecordTxnID = uint64(9223372036854775807)
33
	MaxXID8Value      = "'9223372036854775807'::xid8"
34
35
	// earliestPostgresVersion represents the earliest supported version of PostgreSQL is 13.8
36
	earliestPostgresVersion = 130008 // The earliest supported version of PostgreSQL is 13.8
37
)
38
39
// createFinalSnapshot creates a final snapshot string for proper transaction visibility.
40
// If xmax != xid, it adds xid to the xip_list to make the snapshot unique.
41
func createFinalSnapshot(snapshotValue string, xid uint64) string {
42
	// Parse snapshot: "xmin:xmax:xip_list"
43
	parts := strings.SplitN(strings.TrimSpace(snapshotValue), ":", 3)
44
	if len(parts) < 2 {
45
		return snapshotValue
46
	}
47
48
	xminStr, xmaxStr := parts[0], parts[1]
49
50
	// If xmax == xid, no need to modify snapshot
51
	if xmaxStr == fmt.Sprintf("%d", xid) {
52
		return snapshotValue
53
	}
54
55
	// Parse existing xip_list
56
	var xips []uint64
57
	if len(parts) == 3 && parts[2] != "" {
58
		for _, xipStr := range strings.Split(parts[2], ",") {
59
			xipStr = strings.TrimSpace(xipStr)
60
			if xipStr == "" {
61
				continue
62
			}
63
			xip, err := strconv.ParseUint(xipStr, 10, 64)
64
			if err != nil {
65
				return snapshotValue
66
			}
67
			// Check if xid is already in xip_list
68
			if xip == xid {
69
				return snapshotValue
70
			}
71
			xips = append(xips, xip)
72
		}
73
	}
74
75
	// Add xid to the list and sort it
76
	xips = append(xips, xid)
77
	sortXips(xips)
78
79
	// Rebuild xip_list string
80
	var xipStrs []string
81
	for _, xip := range xips {
82
		xipStrs = append(xipStrs, fmt.Sprintf("%d", xip))
83
	}
84
	return fmt.Sprintf("%s:%s:%s", xminStr, xmaxStr, strings.Join(xipStrs, ","))
85
}
86
87
// sortXips sorts a slice of xip values in ascending order
88
func sortXips(xips []uint64) {
89
	for i := 0; i < len(xips)-1; i++ {
90
		for j := i + 1; j < len(xips); j++ {
91
			if xips[i] > xips[j] {
92
				xips[i], xips[j] = xips[j], xips[i]
93
			}
94
		}
95
	}
96
}
97
98
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
99
// Optimized version with parameterized queries for security.
100
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64, snapshotValue string) squirrel.SelectBuilder {
101
	// Backward compatibility: if snapshot is empty, use old method
102
	if snapshotValue == "" {
103
		// Create a subquery for the snapshot associated with the provided value.
104
		snapshotQuery := "(select snapshot from transactions where id = ?::xid8)"
105
106
		// Records that were created and are visible in the snapshot
107
		createdWhere := squirrel.Or{
108
			squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", squirrel.Expr(snapshotQuery, value)),
109
			squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
110
		}
111
112
		// Records that are still active (not expired) at snapshot time
113
		expiredWhere := squirrel.And{
114
			squirrel.Or{
115
				squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", squirrel.Expr(snapshotQuery, value)),
116
				squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
117
			},
118
			squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
119
		}
120
121
		// Add the created and expired conditions to the SELECT query.
122
		return sl.Where(createdWhere).Where(expiredWhere)
123
	}
124
125
	// Create final snapshot with proper visibility
126
	finalSnapshot := createFinalSnapshot(snapshotValue, value)
127
128
	// Records that were created and are visible in the snapshot
129
	createdWhere := squirrel.Or{
130
		squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", finalSnapshot),
131
		squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
132
	}
133
134
	// Records that are still active (not expired) at snapshot time
135
	expiredWhere := squirrel.And{
136
		squirrel.Or{
137
			squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", finalSnapshot),
138
			squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
139
		},
140
		squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
141
	}
142
143
	// Add the created and expired conditions to the SELECT query.
144
	return sl.Where(createdWhere).Where(expiredWhere)
145
}
146
147
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
148
// It constructs a query to delete expired records from the specified table
149
// based on the provided value, which represents a transaction ID.
150
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
151
	// Create a Squirrel DELETE builder for the specified table.
152
	deleteBuilder := squirrel.Delete(table)
153
154
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
155
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
156
157
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
158
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
159
160
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
161
	return deleteBuilder.Where(expiredNotActiveExpr).Where(beforeExpr)
162
}
163
164
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
165
// It constructs a query to delete expired records from the specified table for a specific tenant
166
// based on the provided value, which represents a transaction ID.
167
func GenerateGCQueryForTenant(table, tenantID string, value uint64) squirrel.DeleteBuilder {
168
	// Create a Squirrel DELETE builder for the specified table.
169
	deleteBuilder := squirrel.Delete(table)
170
171
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
172
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
173
174
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
175
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
176
177
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
178
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredNotActiveExpr).Where(beforeExpr)
179
}
180
181
// HandleError records an error in the given span, logs the error, and returns a standardized error.
182
// This function is used for consistent error handling across different parts of the application.
183
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
184
	// Check if the error is context-related
185
	if IsContextRelatedError(ctx, err) {
186
		slog.DebugContext(ctx, "A context-related error occurred",
187
			slog.String("error", err.Error()))
188
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
189
	}
190
191
	// Check if the error is serialization-related
192
	if IsSerializationRelatedError(err) {
193
		slog.DebugContext(ctx, "A serialization-related error occurred",
194
			slog.String("error", err.Error()))
195
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
196
	}
197
198
	// For all other types of errors, log them at the error level and record them in the span
199
	slog.ErrorContext(ctx, "An operational error occurred",
200
		slog.Any("error", err))
201
	span.RecordError(err)
202
	span.SetStatus(codes.Error, err.Error())
203
204
	// Return a new error with the standard error code provided
205
	return errors.New(errorCode.String())
206
}
207
208
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
209
func IsContextRelatedError(ctx context.Context, err error) bool {
210
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
211
		return true
212
	}
213
	if errors.Is(err, context.Canceled) ||
214
		errors.Is(err, context.DeadlineExceeded) ||
215
		strings.Contains(err.Error(), "conn closed") {
216
		return true
217
	}
218
	return false
219
}
220
221
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
222
func IsSerializationRelatedError(err error) bool {
223
	if strings.Contains(err.Error(), "could not serialize") ||
224
		strings.Contains(err.Error(), "duplicate key value") {
225
		return true
226
	}
227
	return false
228
}
229
230
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
231
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
232
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
233
	// Calculate the base backoff with bit shifting for better performance
234
	baseBackoff := 20 * time.Millisecond
235
	if retries > 0 {
236
		// Use bit shifting instead of math.Pow for better performance
237
		shift := min(retries, 5) // Cap at 2^5 = 32, so max backoff is 640ms
238
		baseBackoff = baseBackoff << shift
239
	}
240
241
	// Cap at 1 second
242
	if baseBackoff > time.Second {
243
		baseBackoff = time.Second
244
	}
245
246
	// Generate jitter using crypto/rand
247
	jitter := time.Duration(secureRandomFloat64() * float64(baseBackoff) * 0.5)
248
	nextBackoff := baseBackoff + jitter
249
250
	// Log the retry wait
251
	slog.WarnContext(ctx, "waiting before retry",
252
		slog.String("tenant_id", tenantID),
253
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
254
255
	// Wait or exit on context cancellation
256
	select {
257
	case <-time.After(nextBackoff):
258
	case <-ctx.Done():
259
	}
260
}
261
262
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
263
// Optimized version with better error handling and performance.
264
func secureRandomFloat64() float64 {
265
	var b [8]byte
266
	if _, err := rand.Read(b[:]); err != nil {
267
		// Use a fallback random value instead of 0 to maintain jitter
268
		return 0.5 // Middle value for consistent jitter behavior
269
	}
270
	// Use bit shifting instead of division for better performance
271
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 63) / 2.0
272
}
273
274
// EnsureDBVersion checks the version of the given database connection
275
// and returns an error if the version is not supported.
276
func EnsureDBVersion(db *pgxpool.Pool) (version string, err error) {
277
	err = db.QueryRow(context.Background(), "SHOW server_version_num;").Scan(&version)
278
	if err != nil {
279
		return version, err
280
	}
281
	v, err := strconv.Atoi(version)
282
	if v < earliestPostgresVersion {
283
		err = fmt.Errorf("unsupported postgres version: %s, expected >= %d", version, earliestPostgresVersion)
284
	}
285
	return version, err
286
}
287