Passed
Push — master ( 00bea4...3b8832 )
by Tolga
01:30
created

utils.createFinalSnapshot   C

Complexity

Conditions 11

Size

Total Lines 49
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 23
nop 2
dl 0
loc 49
rs 5.4
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like utils.createFinalSnapshot often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"strconv"
11
	"strings"
12
	"time"
13
14
	"github.com/jackc/pgx/v5/pgxpool"
15
	"go.opentelemetry.io/otel/codes"
16
17
	"go.opentelemetry.io/otel/trace"
18
19
	"github.com/Masterminds/squirrel"
20
21
	base "github.com/Permify/permify/pkg/pb/base/v1"
22
)
23
24
const (
25
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id, snapshot`
26
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
27
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
28
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
29
30
	// ActiveRecordTxnID represents the maximum XID8 value used for active records
31
	// to avoid XID wraparound issues (instead of using 0)
32
	ActiveRecordTxnID = uint64(9223372036854775807)
33
	MaxXID8Value      = "'9223372036854775807'::xid8"
34
35
	// earliestPostgresVersion represents the earliest supported version of PostgreSQL is 13.8
36
	earliestPostgresVersion = 130008 // The earliest supported version of PostgreSQL is 13.8
37
)
38
39
// createFinalSnapshot creates a final snapshot string for proper transaction visibility.
40
// If xmax != xid, it adds xid to the xip_list to make the snapshot unique.
41
func createFinalSnapshot(snapshotValue string, xid uint64) string {
42
	// Parse snapshot: "xmin:xmax:xip_list"
43
	parts := strings.SplitN(strings.TrimSpace(snapshotValue), ":", 3)
44
	if len(parts) < 2 {
45
		// Invalid snapshot format, return original
46
		return snapshotValue
47
	}
48
49
	// Parse xmax as integer for proper comparison
50
	xmaxStr := parts[1]
51
	xmax, err := strconv.ParseUint(xmaxStr, 10, 64)
52
	if err != nil {
53
		// If parsing fails, use original snapshot
54
		return snapshotValue
55
	}
56
57
	// If xmax == xid, no need to modify snapshot
58
	if xmax == xid {
59
		return snapshotValue
60
	}
61
62
	// If xid > xmax, this is invalid - xid should be visible in snapshot
63
	// Return original snapshot to avoid creating invalid format
64
	if xid > xmax {
65
		return snapshotValue
66
	}
67
68
	// xmax > xid, need to consider adding xid to xip_list for uniqueness
69
	xmin := parts[0]
70
	// Do not add if xid <= xmin; it would wrongly mark a visible xid as in-progress.
71
	if xminVal, err := strconv.ParseUint(xmin, 10, 64); err == nil && xid <= xminVal {
72
		return snapshotValue
73
	}
74
75
	// Check if xip_list exists and is not empty
76
	if len(parts) == 3 && parts[2] != "" {
77
		// Check if xid is already in xip_list
78
		xipList := parts[2]
79
		for _, xip := range strings.Split(xipList, ",") {
80
			if strings.TrimSpace(xip) == fmt.Sprintf("%d", xid) {
81
				// xid already in xip_list, return original snapshot
82
				return snapshotValue
83
			}
84
		}
85
		// xid not in xip_list, append it
86
		return fmt.Sprintf("%s:%s:%s,%d", xmin, xmaxStr, parts[2], xid)
87
	} else {
88
		// xip_list is empty, add xid
89
		return fmt.Sprintf("%s:%s:%d", xmin, xmaxStr, xid)
90
	}
91
}
92
93
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
94
// Optimized version with parameterized queries for security.
95
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64, snapshotValue string) squirrel.SelectBuilder {
96
	// Backward compatibility: if snapshot is empty, use old method
97
	if snapshotValue == "" {
98
		// Create a subquery for the snapshot associated with the provided value.
99
		snapshotQuery := "(select snapshot from transactions where id = ?::xid8)"
100
101
		// Records that were created and are visible in the snapshot
102
		createdWhere := squirrel.Or{
103
			squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", squirrel.Expr(snapshotQuery, value)),
104
			squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
105
		}
106
107
		// Records that are still active (not expired) at snapshot time
108
		expiredWhere := squirrel.And{
109
			squirrel.Or{
110
				squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", squirrel.Expr(snapshotQuery, value)),
111
				squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
112
			},
113
			squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
114
		}
115
116
		// Add the created and expired conditions to the SELECT query.
117
		return sl.Where(createdWhere).Where(expiredWhere)
118
	}
119
120
	// Create final snapshot with proper visibility
121
	finalSnapshot := createFinalSnapshot(snapshotValue, value)
122
123
	// Records that were created and are visible in the snapshot
124
	createdWhere := squirrel.Or{
125
		squirrel.Expr("pg_visible_in_snapshot(created_tx_id, ?) = true", finalSnapshot),
126
		squirrel.Expr("created_tx_id = ?::xid8", value), // Include current transaction
127
	}
128
129
	// Records that are still active (not expired) at snapshot time
130
	expiredWhere := squirrel.And{
131
		squirrel.Or{
132
			squirrel.Expr("pg_visible_in_snapshot(expired_tx_id, ?) = false", finalSnapshot),
133
			squirrel.Expr("expired_tx_id = ?::xid8", ActiveRecordTxnID), // Never expired
134
		},
135
		squirrel.Expr("expired_tx_id <> ?::xid8", value), // Not expired by current transaction
136
	}
137
138
	// Add the created and expired conditions to the SELECT query.
139
	return sl.Where(createdWhere).Where(expiredWhere)
140
}
141
142
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
143
// It constructs a query to delete expired records from the specified table
144
// based on the provided value, which represents a transaction ID.
145
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
146
	// Create a Squirrel DELETE builder for the specified table.
147
	deleteBuilder := squirrel.Delete(table)
148
149
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
150
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
151
152
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
153
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
154
155
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
156
	return deleteBuilder.Where(expiredNotActiveExpr).Where(beforeExpr)
157
}
158
159
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
160
// It constructs a query to delete expired records from the specified table for a specific tenant
161
// based on the provided value, which represents a transaction ID.
162
func GenerateGCQueryForTenant(table, tenantID string, value uint64) squirrel.DeleteBuilder {
163
	// Create a Squirrel DELETE builder for the specified table.
164
	deleteBuilder := squirrel.Delete(table)
165
166
	// Create an expression to check if 'expired_tx_id' is not equal to ActiveRecordTxnID (expired records).
167
	expiredNotActiveExpr := squirrel.Expr("expired_tx_id <> ?::xid8", ActiveRecordTxnID)
168
169
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
170
	beforeExpr := squirrel.Expr("expired_tx_id < ?::xid8", value)
171
172
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
173
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredNotActiveExpr).Where(beforeExpr)
174
}
175
176
// HandleError records an error in the given span, logs the error, and returns a standardized error.
177
// This function is used for consistent error handling across different parts of the application.
178
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
179
	// Check if the error is context-related
180
	if IsContextRelatedError(ctx, err) {
181
		slog.DebugContext(ctx, "A context-related error occurred",
182
			slog.String("error", err.Error()))
183
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
184
	}
185
186
	// Check if the error is serialization-related
187
	if IsSerializationRelatedError(err) {
188
		slog.DebugContext(ctx, "A serialization-related error occurred",
189
			slog.String("error", err.Error()))
190
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
191
	}
192
193
	// For all other types of errors, log them at the error level and record them in the span
194
	slog.ErrorContext(ctx, "An operational error occurred",
195
		slog.Any("error", err))
196
	span.RecordError(err)
197
	span.SetStatus(codes.Error, err.Error())
198
199
	// Return a new error with the standard error code provided
200
	return errors.New(errorCode.String())
201
}
202
203
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
204
func IsContextRelatedError(ctx context.Context, err error) bool {
205
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
206
		return true
207
	}
208
	if errors.Is(err, context.Canceled) ||
209
		errors.Is(err, context.DeadlineExceeded) ||
210
		strings.Contains(err.Error(), "conn closed") {
211
		return true
212
	}
213
	return false
214
}
215
216
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
217
func IsSerializationRelatedError(err error) bool {
218
	if strings.Contains(err.Error(), "could not serialize") ||
219
		strings.Contains(err.Error(), "duplicate key value") {
220
		return true
221
	}
222
	return false
223
}
224
225
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
226
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
227
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
228
	// Calculate the base backoff with bit shifting for better performance
229
	baseBackoff := 20 * time.Millisecond
230
	if retries > 0 {
231
		// Use bit shifting instead of math.Pow for better performance
232
		shift := min(retries, 5) // Cap at 2^5 = 32, so max backoff is 640ms
233
		baseBackoff = baseBackoff << shift
234
	}
235
236
	// Cap at 1 second
237
	if baseBackoff > time.Second {
238
		baseBackoff = time.Second
239
	}
240
241
	// Generate jitter using crypto/rand
242
	jitter := time.Duration(secureRandomFloat64() * float64(baseBackoff) * 0.5)
243
	nextBackoff := baseBackoff + jitter
244
245
	// Log the retry wait
246
	slog.WarnContext(ctx, "waiting before retry",
247
		slog.String("tenant_id", tenantID),
248
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
249
250
	// Wait or exit on context cancellation
251
	select {
252
	case <-time.After(nextBackoff):
253
	case <-ctx.Done():
254
	}
255
}
256
257
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
258
// Optimized version with better error handling and performance.
259
func secureRandomFloat64() float64 {
260
	var b [8]byte
261
	if _, err := rand.Read(b[:]); err != nil {
262
		// Use a fallback random value instead of 0 to maintain jitter
263
		return 0.5 // Middle value for consistent jitter behavior
264
	}
265
	// Use bit shifting instead of division for better performance
266
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 63) / 2.0
267
}
268
269
// EnsureDBVersion checks the version of the given database connection
270
// and returns an error if the version is not supported.
271
func EnsureDBVersion(db *pgxpool.Pool) (version string, err error) {
272
	err = db.QueryRow(context.Background(), "SHOW server_version_num;").Scan(&version)
273
	if err != nil {
274
		return version, err
275
	}
276
	v, err := strconv.Atoi(version)
277
	if v < earliestPostgresVersion {
278
		err = fmt.Errorf("unsupported postgres version: %s, expected >= %d", version, earliestPostgresVersion)
279
	}
280
	return version, err
281
}
282