Passed
Push — master ( 92321d...fd0104 )
by Tolga
03:21 queued 14s
created

internal/storage/postgres/watch.go   A

Size/Duplication

Total Lines 359
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
cc 33
eloc 186
dl 0
loc 359
rs 9.76
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
A postgres.NewWatcher 0 4 1
F postgres.*Watch.Watch 0 103 14
D postgres.*Watch.getChanges 0 130 12
B postgres.*Watch.getRecentXIDs 0 57 6
1
package postgres
2
3
import (
4
	"context"
5
	"errors"
6
	"fmt"
7
	"log/slog"
8
	"strings"
9
	"time"
10
11
	"github.com/jackc/pgx/v5"
12
13
	"github.com/golang/protobuf/jsonpb"
14
15
	"google.golang.org/protobuf/types/known/anypb"
16
17
	"github.com/Masterminds/squirrel"
18
19
	"github.com/Permify/permify/internal/storage"
20
	"github.com/Permify/permify/internal/storage/postgres/snapshot"
21
	"github.com/Permify/permify/internal/storage/postgres/types"
22
	db "github.com/Permify/permify/pkg/database/postgres"
23
	base "github.com/Permify/permify/pkg/pb/base/v1"
24
)
25
26
// Watch is an implementation of the storage.Watch interface, which is used
27
type Watch struct {
28
	// database is a pointer to a Postgres database instance, which is used
29
	// to perform operations on the relationship data.
30
	database *db.Postgres
31
32
	// txOptions holds the configuration for database transactions, such as
33
	// isolation level and read-only mode, to be applied when performing
34
	// operations on the relationship data.
35
	txOptions pgx.TxOptions
36
}
37
38
// NewWatcher returns a new instance of the Watch.
39
func NewWatcher(database *db.Postgres) *Watch {
40
	return &Watch{
41
		database:  database,
42
		txOptions: pgx.TxOptions{IsoLevel: pgx.ReadCommitted, AccessMode: pgx.ReadOnly},
43
	}
44
}
45
46
// Watch returns a channel that emits a stream of changes to the relationship tuples in the database.
47
func (w *Watch) Watch(ctx context.Context, tenantID, snap string) (<-chan *base.DataChanges, <-chan error) {
48
	// Create channels for changes and errors.
49
	changes := make(chan *base.DataChanges, w.database.GetWatchBufferSize())
50
	errs := make(chan error, 1)
51
52
	var sleep *time.Timer
53
	const maxSleepDuration = 2 * time.Second
54
	const defaultSleepDuration = 100 * time.Millisecond
55
	sleepDuration := defaultSleepDuration
56
57
	slog.DebugContext(ctx, "watching for changes in the database", slog.Any("tenant_id", tenantID), slog.Any("snapshot", snap))
58
59
	// Decode the snapshot value.
60
	// The snapshot value represents a point in the history of the database.
61
	st, err := snapshot.EncodedToken{Value: snap}.Decode()
62
	if err != nil {
63
		// If there is an error in decoding the snapshot, send the error and return.
64
		errs <- err
65
66
		slog.Error("failed to decode snapshot", slog.Any("error", err))
67
68
		return changes, errs
69
	}
70
71
	// Start a goroutine to watch for changes in the database.
72
	go func() {
73
		// Ensure to close the channels when we're done.
74
		defer close(changes)
75
		defer close(errs)
76
77
		// Get the transaction ID from the snapshot.
78
		cr := st.(snapshot.Token).Value.Uint
79
80
		// Continuously watch for changes.
81
		for {
82
			// Get the list of recent transaction IDs.
83
			recentIDs, err := w.getRecentXIDs(ctx, cr, tenantID)
84
			if err != nil {
85
				// If there is an error in getting recent transaction IDs, send the error and return.
86
87
				slog.Error("error getting recent transaction", slog.Any("error", err))
88
89
				errs <- err
90
				return
91
			}
92
93
			// Process each recent transaction ID.
94
			for _, id := range recentIDs {
95
				// Get the changes in the database associated with the current transaction ID.
96
				updates, err := w.getChanges(ctx, id, tenantID)
97
				if err != nil {
98
					// If there is an error in getting the changes, send the error and return.
99
					slog.ErrorContext(ctx, "failed to get changes for transaction", slog.Any("id", id), slog.Any("error", err))
100
					errs <- err
101
					return
102
				}
103
104
				// Send the changes, but respect the context cancellation.
105
				select {
106
				case <-ctx.Done(): // If the context is done, send an error and return.
107
					slog.ErrorContext(ctx, "context canceled, stopping watch")
108
					errs <- errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
109
					return
110
				case changes <- updates: // Send updates to the changes channel.
111
					slog.DebugContext(ctx, "sent updates to the changes channel for transaction", slog.Any("id", id))
112
				}
113
114
				// Update the transaction ID for the next round.
115
				cr = id.Uint
116
				sleepDuration = defaultSleepDuration
117
			}
118
119
			if len(recentIDs) == 0 {
120
121
				if sleep == nil {
122
					sleep = time.NewTimer(sleepDuration)
123
				} else {
124
					sleep.Reset(sleepDuration)
125
				}
126
127
				// Increase the sleep duration exponentially, but cap it at maxSleepDuration.
128
				if sleepDuration < maxSleepDuration {
129
					sleepDuration *= 2
130
				} else {
131
					sleepDuration = maxSleepDuration
132
				}
133
134
				select {
135
				case <-ctx.Done(): // If the context is done, send an error and return.
136
					slog.ErrorContext(ctx, "context canceled, stopping watch")
137
					errs <- errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
138
					return
139
				case <-sleep.C: // If the timer is done, continue the loop.
140
					slog.DebugContext(ctx, "no recent transaction IDs, waiting for changes")
141
				}
142
			}
143
		}
144
	}()
145
146
	slog.DebugContext(ctx, "watch started successfully")
147
148
	// Return the channels that the caller will listen to for changes and errors.
149
	return changes, errs
150
}
151
152
// getRecentXIDs fetches a list of XID8 identifiers from the 'transactions' table
153
// for all transactions committed after a specified XID value.
154
//
155
// Parameters:
156
//   - ctx:       A context to control the execution lifetime.
157
//   - value:     The transaction XID after which we need the changes.
158
//   - tenantID:  The ID of the tenant to filter the transactions for.
159
//
160
// Returns:
161
//   - A slice of XID8 identifiers.
162
//   - An error if the query fails to execute, or other error occurs during its execution.
163
func (w *Watch) getRecentXIDs(ctx context.Context, value uint64, tenantID string) ([]types.XID8, error) {
164
	// Convert the value to a string formatted as a Postgresql XID8 type.
165
	valStr := fmt.Sprintf("'%v'::xid8", value)
166
167
	subquery := fmt.Sprintf("(select pg_xact_commit_timestamp(id::xid) from transactions where id = %s)", valStr)
168
169
	// Build the main query to get transactions committed after the one with a given XID,
170
	// still visible in the current snapshot, ordered by their commit timestamps.
171
	builder := w.database.Builder.Select("id").
172
		From(TransactionsTable).
173
		Where(fmt.Sprintf("pg_xact_commit_timestamp(id::xid) > (%s)", subquery)).
174
		Where("id < pg_snapshot_xmin(pg_current_snapshot())").
175
		Where(squirrel.Eq{"tenant_id": tenantID}).
176
		OrderBy("pg_xact_commit_timestamp(id::xid)")
177
178
	// Convert the builder to a SQL query and arguments.
179
	query, args, err := builder.ToSql()
180
	if err != nil {
181
182
		slog.ErrorContext(ctx, "error while building sql query", slog.Any("error", err))
183
184
		return nil, err
185
	}
186
187
	slog.DebugContext(ctx, "executing SQL query to get recent transaction", slog.Any("query", query), slog.Any("arguments", args))
188
189
	// Execute the SQL query.
190
	rows, err := w.database.ReadPool.Query(ctx, query, args...)
191
	if err != nil {
192
		slog.ErrorContext(ctx, "failed to execute SQL query", slog.Any("error", err))
193
		return nil, err
194
	}
195
	defer rows.Close()
196
197
	// Loop through the rows and append XID8 values to the results.
198
	var xids []types.XID8
199
	for rows.Next() {
200
		var xid types.XID8
201
		err := rows.Scan(&xid)
202
		if err != nil {
203
			slog.ErrorContext(ctx, "error while scanning row", slog.Any("error", err))
204
			return nil, err
205
		}
206
		xids = append(xids, xid)
207
	}
208
209
	// Check for errors that could have occurred during iteration.
210
	err = rows.Err()
211
	if err != nil {
212
213
		slog.ErrorContext(ctx, "failed to iterate over rows", slog.Any("error", err))
214
215
		return nil, err
216
	}
217
218
	slog.DebugContext(ctx, "successfully retrieved recent transaction", slog.Any("ids", xids))
219
	return xids, nil
220
}
221
222
// getChanges is a method that retrieves the changes that occurred in the relation tuples within a specified transaction.
223
//
224
// ctx: The context.Context instance for managing the life-cycle of this function.
225
// value: The ID of the transaction for which to retrieve the changes.
226
// tenantID: The ID of the tenant for which to retrieve the changes.
227
//
228
// This method returns a TupleChanges instance that encapsulates the changes in the relation tuples within the specified
229
// transaction, or an error if something went wrong during execution.
230
func (w *Watch) getChanges(ctx context.Context, value types.XID8, tenantID string) (*base.DataChanges, error) {
231
	// Initialize a new TupleChanges instance.
232
	changes := &base.DataChanges{}
233
234
	slog.DebugContext(ctx, "retrieving changes for transaction", slog.Any("id", value), slog.Any("tenant_id", tenantID))
235
236
	// Construct the SQL SELECT statement for retrieving the changes from the RelationTuplesTable.
237
	tbuilder := w.database.Builder.Select("entity_type, entity_id, relation, subject_type, subject_id, subject_relation, expired_tx_id").
238
		From(RelationTuplesTable).
239
		Where(squirrel.Eq{"tenant_id": tenantID}).Where(squirrel.Or{
240
		squirrel.Eq{"created_tx_id": value},
241
		squirrel.Eq{"expired_tx_id": value},
242
	})
243
244
	// Generate the SQL query and arguments.
245
	tquery, targs, err := tbuilder.ToSql()
246
	if err != nil {
247
		slog.ErrorContext(ctx, "error while building sql query for relation tuples", slog.Any("error", err))
248
		return nil, err
249
	}
250
251
	slog.DebugContext(ctx, "executing sql query for relation tuples", slog.Any("query", tquery), slog.Any("arguments", targs))
252
253
	// Execute the SQL query and retrieve the result rows.
254
	var trows pgx.Rows
255
	trows, err = w.database.ReadPool.Query(ctx, tquery, targs...)
256
	if err != nil {
257
		slog.ErrorContext(ctx, "failed to execute sql query for relation tuples", slog.Any("error", err))
258
		return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
259
	}
260
	// Ensure the rows are closed after processing.
261
	defer trows.Close()
262
263
	abuilder := w.database.Builder.Select("entity_type, entity_id, attribute, value, expired_tx_id").
264
		From(AttributesTable).
265
		Where(squirrel.Eq{"tenant_id": tenantID}).Where(squirrel.Or{
266
		squirrel.Eq{"created_tx_id": value},
267
		squirrel.Eq{"expired_tx_id": value},
268
	})
269
270
	aquery, aargs, err := abuilder.ToSql()
271
	if err != nil {
272
		slog.ErrorContext(ctx, "error while building SQL query for attributes", slog.Any("error", err))
273
		return nil, err
274
	}
275
276
	slog.DebugContext(ctx, "executing sql query for attributes", slog.Any("query", aquery), slog.Any("arguments", aargs))
277
278
	var arows pgx.Rows
279
	arows, err = w.database.ReadPool.Query(ctx, aquery, aargs...)
280
	if err != nil {
281
		slog.ErrorContext(ctx, "error while executing SQL query for attributes", slog.Any("error", err))
282
		return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
283
	}
284
	// Ensure the rows are closed after processing.
285
	defer arows.Close()
286
287
	// Set the snapshot token for the changes.
288
	changes.SnapToken = snapshot.Token{Value: value}.Encode().String()
289
290
	// Iterate through the result rows.
291
	for trows.Next() {
292
		var expiredXID types.XID8
293
294
		rt := storage.RelationTuple{}
295
		// Scan the result row into a RelationTuple instance.
296
		err = trows.Scan(&rt.EntityType, &rt.EntityID, &rt.Relation, &rt.SubjectType, &rt.SubjectID, &rt.SubjectRelation, &expiredXID)
297
		if err != nil {
298
			slog.ErrorContext(ctx, "error while scanning row for relation tuples", slog.Any("error", err))
299
			return nil, err
300
		}
301
302
		// Determine the operation type based on the expired transaction ID.
303
		op := base.DataChange_OPERATION_CREATE
304
		if expiredXID.Uint == value.Uint {
305
			op = base.DataChange_OPERATION_DELETE
306
		}
307
308
		// Append the change to the list of changes.
309
		changes.DataChanges = append(changes.DataChanges, &base.DataChange{
310
			Operation: op,
311
			Type: &base.DataChange_Tuple{
312
				Tuple: rt.ToTuple(),
313
			},
314
		})
315
	}
316
317
	// Iterate through the result rows.
318
	for arows.Next() {
319
		var expiredXID types.XID8
320
321
		rt := storage.Attribute{}
322
323
		var valueStr string
324
325
		// Scan the result row into a RelationTuple instance.
326
		err = arows.Scan(&rt.EntityType, &rt.EntityID, &rt.Attribute, &valueStr, &expiredXID)
327
		if err != nil {
328
			slog.ErrorContext(ctx, "error while scanning row for attributes", slog.Any("error", err))
329
			return nil, err
330
		}
331
332
		// Unmarshal the JSON data from `valueStr` into `rt.Value`.
333
		rt.Value = &anypb.Any{}
334
		unmarshaler := &jsonpb.Unmarshaler{}
335
		err = unmarshaler.Unmarshal(strings.NewReader(valueStr), rt.Value)
336
		if err != nil {
337
			slog.ErrorContext(ctx, "failed to unmarshal attribute value", slog.Any("error", err))
338
			return nil, err
339
		}
340
341
		// Determine the operation type based on the expired transaction ID.
342
		op := base.DataChange_OPERATION_CREATE
343
		if expiredXID.Uint == value.Uint {
344
			op = base.DataChange_OPERATION_DELETE
345
		}
346
347
		// Append the change to the list of changes.
348
		changes.DataChanges = append(changes.DataChanges, &base.DataChange{
349
			Operation: op,
350
			Type: &base.DataChange_Attribute{
351
				Attribute: rt.ToAttribute(),
352
			},
353
		})
354
	}
355
356
	slog.DebugContext(ctx, "successfully retrieved changes for transaction", slog.Any("id", value))
357
358
	// Return the changes and no error.
359
	return changes, nil
360
}
361