Passed
Pull Request — master (#2433)
by Tolga
03:35
created

utils.SnapshotQuery   A

Complexity

Conditions 1

Size

Total Lines 25
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 11
nop 2
dl 0
loc 25
rs 9.85
c 0
b 0
f 0
1
package utils
2
3
import (
4
	"context"
5
	"crypto/rand"
6
	"encoding/binary"
7
	"errors"
8
	"fmt"
9
	"log/slog"
10
	"math"
11
	"strings"
12
	"time"
13
14
	"go.opentelemetry.io/otel/codes"
15
16
	"go.opentelemetry.io/otel/trace"
17
18
	"github.com/Masterminds/squirrel"
19
20
	base "github.com/Permify/permify/pkg/pb/base/v1"
21
)
22
23
const (
24
	TransactionTemplate       = `INSERT INTO transactions (tenant_id) VALUES ($1) RETURNING id`
25
	InsertTenantTemplate      = `INSERT INTO tenants (id, name) VALUES ($1, $2) RETURNING created_at`
26
	DeleteTenantTemplate      = `DELETE FROM tenants WHERE id = $1 RETURNING name, created_at`
27
	DeleteAllByTenantTemplate = `DELETE FROM %s WHERE tenant_id = $1`
28
)
29
30
// SnapshotQuery adds conditions to a SELECT query for checking transaction visibility based on created and expired transaction IDs.
31
// The query checks if transactions are visible in a snapshot associated with the provided value.
32
func SnapshotQuery(sl squirrel.SelectBuilder, value uint64) squirrel.SelectBuilder {
33
	// Convert the value to a string once to reduce redundant calls to fmt.Sprintf.
34
	valStr := fmt.Sprintf("'%v'::xid8", value)
35
36
	// Create a subquery for the snapshot associated with the provided value.
37
	snapshotQuery := fmt.Sprintf("(select snapshot from transactions where id = %s)", valStr)
38
39
	// Create an expression to check if a transaction with a specific created_tx_id is visible in the snapshot.
40
	visibilityExpr := squirrel.Expr(fmt.Sprintf("pg_visible_in_snapshot(created_tx_id, %s) = true", snapshotQuery))
41
	// Create an expression to check if the created_tx_id is equal to the provided value.
42
	createdExpr := squirrel.Expr(fmt.Sprintf("created_tx_id = %s", valStr))
43
	// Use OR condition for the created expressions.
44
	createdWhere := squirrel.Or{visibilityExpr, createdExpr}
45
46
	// Create an expression to check if a transaction with a specific expired_tx_id is not visible in the snapshot.
47
	expiredVisibilityExpr := squirrel.Expr(fmt.Sprintf("pg_visible_in_snapshot(expired_tx_id, %s) = false", snapshotQuery))
48
	// Create an expression to check if the expired_tx_id is equal to zero.
49
	expiredZeroExpr := squirrel.Expr("expired_tx_id = '0'::xid8")
50
	// Create an expression to check if the expired_tx_id is not equal to the provided value.
51
	expiredNotExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id <> %s", valStr))
52
	// Use AND condition for the expired expressions, checking both visibility and non-equality with value.
53
	expiredWhere := squirrel.And{squirrel.Or{expiredVisibilityExpr, expiredZeroExpr}, expiredNotExpr}
54
55
	// Add the created and expired conditions to the SELECT query.
56
	return sl.Where(createdWhere).Where(expiredWhere)
57
}
58
59
// GenerateGCQuery generates a Squirrel DELETE query builder for garbage collection.
60
// It constructs a query to delete expired records from the specified table
61
// based on the provided value, which represents a transaction ID.
62
func GenerateGCQuery(table string, value uint64) squirrel.DeleteBuilder {
63
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
64
	valStr := fmt.Sprintf("'%v'::xid8", value)
65
66
	// Create a Squirrel DELETE builder for the specified table.
67
	deleteBuilder := squirrel.Delete(table)
68
69
	// Create an expression to check if 'expired_tx_id' is not equal to '0' (not expired).
70
	expiredZeroExpr := squirrel.Expr("expired_tx_id <> '0'::xid8")
71
72
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
73
	beforeExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id < %s", valStr))
74
75
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data.
76
	return deleteBuilder.Where(expiredZeroExpr).Where(beforeExpr)
77
}
78
79
// GenerateGCQueryForTenant generates a Squirrel DELETE query builder for tenant-aware garbage collection.
80
// It constructs a query to delete expired records from the specified table for a specific tenant
81
// based on the provided value, which represents a transaction ID.
82
func GenerateGCQueryForTenant(table, tenantID string, value uint64) squirrel.DeleteBuilder {
83
	// Convert the provided value into a string format suitable for our SQL query, formatted as a transaction ID.
84
	valStr := fmt.Sprintf("'%v'::xid8", value)
85
86
	// Create a Squirrel DELETE builder for the specified table.
87
	deleteBuilder := squirrel.Delete(table)
88
89
	// Create an expression to check if 'expired_tx_id' is not equal to '0' (not expired).
90
	expiredZeroExpr := squirrel.Expr("expired_tx_id <> '0'::xid8")
91
92
	// Create an expression to check if 'expired_tx_id' is less than the provided value (before the cutoff).
93
	beforeExpr := squirrel.Expr(fmt.Sprintf("expired_tx_id < %s", valStr))
94
95
	// Add the WHERE clauses to the DELETE query builder to filter and delete expired data for the specific tenant.
96
	return deleteBuilder.Where(squirrel.Eq{"tenant_id": tenantID}).Where(expiredZeroExpr).Where(beforeExpr)
97
}
98
99
// HandleError records an error in the given span, logs the error, and returns a standardized error.
100
// This function is used for consistent error handling across different parts of the application.
101
func HandleError(ctx context.Context, span trace.Span, err error, errorCode base.ErrorCode) error {
102
	// Check if the error is context-related
103
	if IsContextRelatedError(ctx, err) {
104
		slog.DebugContext(ctx, "A context-related error occurred",
105
			slog.String("error", err.Error()))
106
		return errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
107
	}
108
109
	// Check if the error is serialization-related
110
	if IsSerializationRelatedError(err) {
111
		slog.DebugContext(ctx, "A serialization-related error occurred",
112
			slog.String("error", err.Error()))
113
		return errors.New(base.ErrorCode_ERROR_CODE_SERIALIZATION.String())
114
	}
115
116
	// For all other types of errors, log them at the error level and record them in the span
117
	slog.ErrorContext(ctx, "An operational error occurred",
118
		slog.Any("error", err))
119
	span.RecordError(err)
120
	span.SetStatus(codes.Error, err.Error())
121
122
	// Return a new error with the standard error code provided
123
	return errors.New(errorCode.String())
124
}
125
126
// IsContextRelatedError checks if the error is due to context cancellation, deadline exceedance, or closed connection
127
func IsContextRelatedError(ctx context.Context, err error) bool {
128
	if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
129
		return true
130
	}
131
	if errors.Is(err, context.Canceled) ||
132
		errors.Is(err, context.DeadlineExceeded) ||
133
		strings.Contains(err.Error(), "conn closed") {
134
		return true
135
	}
136
	return false
137
}
138
139
// IsSerializationRelatedError checks if the error is a serialization failure, typically in database transactions.
140
func IsSerializationRelatedError(err error) bool {
141
	if strings.Contains(err.Error(), "could not serialize") ||
142
		strings.Contains(err.Error(), "duplicate key value") {
143
		return true
144
	}
145
	return false
146
}
147
148
// WaitWithBackoff implements an exponential backoff strategy with jitter for retries.
149
// It waits for a calculated duration or until the context is cancelled, whichever comes first.
150
func WaitWithBackoff(ctx context.Context, tenantID string, retries int) {
151
	// Calculate the base backoff
152
	backoff := time.Duration(math.Min(float64(20*time.Millisecond)*math.Pow(2, float64(retries)), float64(1*time.Second)))
153
154
	// Generate jitter using crypto/rand
155
	jitter := time.Duration(secureRandomFloat64() * float64(backoff) * 0.5)
156
	nextBackoff := backoff + jitter
157
158
	// Log the retry wait
159
	slog.WarnContext(ctx, "waiting before retry",
160
		slog.String("tenant_id", tenantID),
161
		slog.Int64("backoff_duration", nextBackoff.Milliseconds()))
162
163
	// Wait or exit on context cancellation
164
	select {
165
	case <-time.After(nextBackoff):
166
	case <-ctx.Done():
167
	}
168
}
169
170
// secureRandomFloat64 generates a float64 value in the range [0, 1) using crypto/rand.
171
func secureRandomFloat64() float64 {
172
	var b [8]byte
173
	if _, err := rand.Read(b[:]); err != nil {
174
		return 0 // Default to 0 jitter on error
175
	}
176
	return float64(binary.BigEndian.Uint64(b[:])) / (1 << 64)
177
}
178