Passed
Push — master ( 8c3de8...fb82f9 )
by Tolga
01:29 queued 14s
created

utils.IsSerializationRelatedError   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"math"
11
	"strings"
12
	"time"
13
14
	"go.opentelemetry.io/otel/codes"
15
16
	"go.opentelemetry.io/otel/trace"
17
18
	"github.com/Masterminds/squirrel"
19
20
	base "github.com/Permify/permify/pkg/pb/base/v1"
21
)
22
23
const (
24
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id`
25
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
26
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
27
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
28
)
29
30
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
31
// The query checks if transactions are visible in a snapshot associated with the provided value.
32
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64) squirrel.SelectBuilder {
33
	// Convert the value to a string once to reduce redundant calls to fmt.Sprintf.
34
	valStr := fmt.Sprintf("'%v'::xid8", value)
35
36
	// Create a subquery for the snapshot associated with the provided value.
37
	snapshotQuery := fmt.Sprintf("(select snapshot from transactions where id = %s)", valStr)
38
39
	// Create an expression to check if a transaction with a specific created_tx_id is visible in the snapshot.
40
	visibilityExpr := squirrel.Expr(fmt.Sprintf("pg_visible_in_snapshot(created_tx_id, %s) = true", snapshotQuery))
41
	// Create an expression to check if the created_tx_id is equal to the provided value.
42
	createdExpr := squirrel.Expr(fmt.Sprintf("created_tx_id = %s", valStr))
43
	// Use OR condition for the created expressions.
44
	createdWhere := squirrel.Or{visibilityExpr, createdExpr}
45
46
	// Create an expression to check if a transaction with a specific expired_tx_id is not visible in the snapshot.
47
	expiredVisibilityExpr := squirrel.Expr(fmt.Sprintf("pg_visible_in_snapshot(expired_tx_id, %s) = false", snapshotQuery))
48
	// Create an expression to check if the expired_tx_id is equal to zero.
49
	expiredZeroExpr := squirrel.Expr("expired_tx_id = '0'::xid8")
50
	// Create an expression to check if the expired_tx_id is not equal to the provided value.
51
	expiredNotExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id <> %s", valStr))
52
	// Use AND condition for the expired expressions, checking both visibility and non-equality with value.
53
	expiredWhere := squirrel.And{squirrel.Or{expiredVisibilityExpr, expiredZeroExpr}, expiredNotExpr}
54
55
	// Add the created and expired conditions to the SELECT query.
56
	return sl.Where(createdWhere).Where(expiredWhere)
57
}
58
59
// snapshotQuery function generates two strings representing conditions to be applied in a SQL query to filter data based on visibility of transactions.
60
func snapshotQuery(value uint64) (string, string) {
61
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
62
	valStr := fmt.Sprintf("'%v'::xid8", value)
63
64
	// Create a subquery that fetches the snapshot associated with the transaction ID.
65
	snapshotQ := fmt.Sprintf("(SELECT snapshot FROM transactions WHERE id = %s)", valStr)
66
67
	// Create an expression that checks whether a transaction (represented by 'created_tx_id') is visible in the snapshot.
68
	visibilityExpr := fmt.Sprintf("pg_visible_in_snapshot(created_tx_id, %s) = true", snapshotQ)
69
	// Create an expression that checks if the 'created_tx_id' is the same as our transaction ID.
70
	createdExpr := fmt.Sprintf("created_tx_id = %s", valStr)
71
	// Combine these expressions to form a condition. A row will satisfy this condition if either of the expressions are true.
72
	createdWhere := fmt.Sprintf("(%s OR %s)", visibilityExpr, createdExpr)
73
74
	// Create an expression that checks whether a transaction (represented by 'expired_tx_id') is not visible in the snapshot.
75
	expiredVisibilityExpr := fmt.Sprintf("pg_visible_in_snapshot(expired_tx_id, %s) = false", snapshotQ)
76
	// Create an expression that checks if the 'expired_tx_id' is zero. This handles cases where the transaction hasn't expired.
77
	expiredZeroExpr := "expired_tx_id = '0'::xid8"
78
	// Create an expression that checks if the 'expired_tx_id' is not the same as our transaction ID.
79
	expiredNotExpr := fmt.Sprintf("expired_tx_id <> %s", valStr)
80
	// Combine these expressions to form a condition. A row will satisfy this condition if the first set of expressions are true (either the transaction hasn't expired, or if it has, it's not visible in the snapshot) and the second expression is also true (the 'expired_tx_id' is not the same as our transaction ID).
81
	expiredWhere := fmt.Sprintf("(%s AND %s)", fmt.Sprintf("(%s OR %s)", expiredVisibilityExpr, expiredZeroExpr), expiredNotExpr)
82
83
	// Return the conditions for both 'created' and 'expired' transactions. These can be used in a WHERE clause of a SQL query to filter results.
84
	return createdWhere, expiredWhere
85
}
86
87
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
88
// It constructs a query to delete expired records from the specified table
89
// based on the provided value, which represents a transaction ID.
90
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
91
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
92
	valStr := fmt.Sprintf("'%v'::xid8", value)
93
94
	// Create a Squirrel DELETE builder for the specified table.
95
	deleteBuilder := squirrel.Delete(table)
96
97
	// Create an expression to check if 'expired_tx_id' is not equal to '0' (not expired).
98
	expiredZeroExpr := squirrel.Expr("expired_tx_id <> '0'::xid8")
99
100
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
101
	beforeExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id < %s", valStr))
102
103
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
104
	return deleteBuilder.Where(expiredZeroExpr).Where(beforeExpr)
105
}
106
107
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
108
// It constructs a query to delete expired records from the specified table for a specific tenant
109
// based on the provided value, which represents a transaction ID.
110
func GenerateGCQueryForTenant(table string, tenantID string, value uint64) squirrel.DeleteBuilder {
111
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
112
	valStr := fmt.Sprintf("'%v'::xid8", value)
113
114
	// Create a Squirrel DELETE builder for the specified table.
115
	deleteBuilder := squirrel.Delete(table)
116
117
	// Create an expression to check if 'expired_tx_id' is not equal to '0' (not expired).
118
	expiredZeroExpr := squirrel.Expr("expired_tx_id <> '0'::xid8")
119
120
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
121
	beforeExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id < %s", valStr))
122
123
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
124
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredZeroExpr).Where(beforeExpr)
125
}
126
127
// HandleError records an error in the given span, logs the error, and returns a standardized error.
128
// This function is used for consistent error handling across different parts of the application.
129
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
130
	// Check if the error is context-related
131
	if IsContextRelatedError(ctx, err) {
132
		slog.DebugContext(ctx, "A context-related error occurred",
133
			slog.String("error", err.Error()))
134
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
135
	}
136
137
	// Check if the error is serialization-related
138
	if IsSerializationRelatedError(err) {
139
		slog.DebugContext(ctx, "A serialization-related error occurred",
140
			slog.String("error", err.Error()))
141
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
142
	}
143
144
	// For all other types of errors, log them at the error level and record them in the span
145
	slog.ErrorContext(ctx, "An operational error occurred",
146
		slog.Any("error", err))
147
	span.RecordError(err)
148
	span.SetStatus(codes.Error, err.Error())
149
150
	// Return a new error with the standard error code provided
151
	return errors.New(errorCode.String())
152
}
153
154
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
155
func IsContextRelatedError(ctx context.Context, err error) bool {
156
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
157
		return true
158
	}
159
	if errors.Is(err, context.Canceled) ||
160
		errors.Is(err, context.DeadlineExceeded) ||
161
		strings.Contains(err.Error(), "conn closed") {
162
		return true
163
	}
164
	return false
165
}
166
167
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
168
func IsSerializationRelatedError(err error) bool {
169
	if strings.Contains(err.Error(), "could not serialize") ||
170
		strings.Contains(err.Error(), "duplicate key value") {
171
		return true
172
	}
173
	return false
174
}
175
176
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
177
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
178
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
179
	// Calculate the base backoff
180
	backoff := time.Duration(math.Min(float64(20*time.Millisecond)*math.Pow(2, float64(retries)), float64(1*time.Second)))
181
182
	// Generate jitter using crypto/rand
183
	jitter := time.Duration(secureRandomFloat64() * float64(backoff) * 0.5)
184
	nextBackoff := backoff + jitter
185
186
	// Log the retry wait
187
	slog.WarnContext(ctx, "waiting before retry",
188
		slog.String("tenant_id", tenantID),
189
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
190
191
	// Wait or exit on context cancellation
192
	select {
193
	case <-time.After(nextBackoff):
194
	case <-ctx.Done():
195
	}
196
}
197
198
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
199
func secureRandomFloat64() float64 {
200
	var b [8]byte
201
	if _, err := rand.Read(b[:]); err != nil {
202
		return 0 // Default to 0 jitter on error
203
	}
204
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 64)
205
}
206