postgres.*Watch.Watch   F
last analyzed

Complexity

Conditions 14

Size

Total Lines 103
Code Lines 56

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 14
eloc 56
nop 3
dl 0
loc 103
rs 3.6
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like postgres.*Watch.Watch often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
package postgres
2
3
import (
4
	"context"
5
	"errors"
6
	"fmt"
7
	"log/slog"
8
	"time"
9
10
	"github.com/jackc/pgx/v5"
11
12
	"google.golang.org/protobuf/encoding/protojson"
13
14
	"google.golang.org/protobuf/types/known/anypb"
15
16
	"github.com/Masterminds/squirrel"
17
18
	"github.com/Permify/permify/internal/storage"
19
	"github.com/Permify/permify/internal/storage/postgres/snapshot"
20
	"github.com/Permify/permify/internal/storage/postgres/types"
21
	db "github.com/Permify/permify/pkg/database/postgres"
22
	base "github.com/Permify/permify/pkg/pb/base/v1"
23
)
24
25
// Watch is an implementation of the storage.Watch interface, which is used
26
type Watch struct {
27
	// database is a pointer to a Postgres database instance, which is used
28
	// to perform operations on the relationship data.
29
	database *db.Postgres
30
31
	// txOptions holds the configuration for database transactions, such as
32
	// isolation level and read-only mode, to be applied when performing
33
	// operations on the relationship data.
34
	txOptions pgx.TxOptions
35
}
36
37
// NewWatcher returns a new instance of the Watch.
38
func NewWatcher(database *db.Postgres) *Watch {
39
	return &Watch{
40
		database:  database,
41
		txOptions: pgx.TxOptions{IsoLevel: pgx.ReadCommitted, AccessMode: pgx.ReadOnly},
42
	}
43
}
44
45
// Watch returns a channel that emits a stream of changes to the relationship tuples in the database.
46
func (w *Watch) Watch(ctx context.Context, tenantID, snap string) (<-chan *base.DataChanges, <-chan error) {
47
	// Create channels for changes and errors.
48
	changes := make(chan *base.DataChanges, w.database.GetWatchBufferSize())
49
	errs := make(chan error, 1)
50
51
	var sleep *time.Timer
52
	const maxSleepDuration = 2 * time.Second
53
	const defaultSleepDuration = 100 * time.Millisecond
54
	sleepDuration := defaultSleepDuration
55
56
	slog.DebugContext(ctx, "watching for changes in the database", slog.Any("tenant_id", tenantID), slog.Any("snapshot", snap))
57
58
	// Decode the snapshot value.
59
	// The snapshot value represents a point in the history of the database.
60
	st, err := snapshot.EncodedToken{Value: snap}.Decode()
61
	if err != nil {
62
		// If there is an error in decoding the snapshot, send the error and return.
63
		errs <- err
64
65
		slog.Error("failed to decode snapshot", slog.Any("error", err))
66
67
		return changes, errs
68
	}
69
70
	// Start a goroutine to watch for changes in the database.
71
	go func() {
72
		// Ensure to close the channels when we're done.
73
		defer close(changes)
74
		defer close(errs)
75
76
		// Get the transaction ID from the snapshot.
77
		cr := st.(snapshot.Token).Value.Uint
78
79
		// Continuously watch for changes.
80
		for {
81
			// Get the list of recent transaction IDs.
82
			recentIDs, err := w.getRecentXIDs(ctx, cr, tenantID)
83
			if err != nil {
84
				// If there is an error in getting recent transaction IDs, send the error and return.
85
86
				slog.Error("error getting recent transaction", slog.Any("error", err))
87
88
				errs <- err
89
				return
90
			}
91
92
			// Process each recent transaction ID.
93
			for _, id := range recentIDs {
94
				// Get the changes in the database associated with the current transaction ID.
95
				updates, err := w.getChanges(ctx, id, tenantID)
96
				if err != nil {
97
					// If there is an error in getting the changes, send the error and return.
98
					slog.ErrorContext(ctx, "failed to get changes for transaction", slog.Any("id", id), slog.Any("error", err))
99
					errs <- err
100
					return
101
				}
102
103
				// Send the changes, but respect the context cancellation.
104
				select {
105
				case <-ctx.Done(): // If the context is done, send an error and return.
106
					slog.ErrorContext(ctx, "context canceled, stopping watch")
107
					errs <- errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
108
					return
109
				case changes <- updates: // Send updates to the changes channel.
110
					slog.DebugContext(ctx, "sent updates to the changes channel for transaction", slog.Any("id", id))
111
				}
112
113
				// Update the transaction ID for the next round.
114
				cr = id.Uint
115
				sleepDuration = defaultSleepDuration
116
			}
117
118
			if len(recentIDs) == 0 {
119
120
				if sleep == nil {
121
					sleep = time.NewTimer(sleepDuration)
122
				} else {
123
					sleep.Reset(sleepDuration)
124
				}
125
126
				// Increase the sleep duration exponentially, but cap it at maxSleepDuration.
127
				if sleepDuration < maxSleepDuration {
128
					sleepDuration *= 2
129
				} else {
130
					sleepDuration = maxSleepDuration
131
				}
132
133
				select {
134
				case <-ctx.Done(): // If the context is done, send an error and return.
135
					slog.ErrorContext(ctx, "context canceled, stopping watch")
136
					errs <- errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
137
					return
138
				case <-sleep.C: // If the timer is done, continue the loop.
139
					slog.DebugContext(ctx, "no recent transaction IDs, waiting for changes")
140
				}
141
			}
142
		}
143
	}()
144
145
	slog.DebugContext(ctx, "watch started successfully")
146
147
	// Return the channels that the caller will listen to for changes and errors.
148
	return changes, errs
149
}
150
151
// getRecentXIDs fetches a list of XID8 identifiers from the 'transactions' table
152
// for all transactions committed after a specified XID value.
153
//
154
// Parameters:
155
//   - ctx:       A context to control the execution lifetime.
156
//   - value:     The transaction XID after which we need the changes.
157
//   - tenantID:  The ID of the tenant to filter the transactions for.
158
//
159
// Returns:
160
//   - A slice of XID8 identifiers.
161
//   - An error if the query fails to execute, or other error occurs during its execution.
162
func (w *Watch) getRecentXIDs(ctx context.Context, value uint64, tenantID string) ([]types.XID8, error) {
163
	// Convert the value to a string formatted as a Postgresql XID8 type.
164
	valStr := fmt.Sprintf("'%v'::xid8", value)
165
166
	subquery := fmt.Sprintf("(select pg_xact_commit_timestamp(id::xid) from transactions where id = %s)", valStr)
167
168
	// Build the main query to get transactions committed after the one with a given XID,
169
	// still visible in the current snapshot, ordered by their commit timestamps.
170
	builder := w.database.Builder.Select("id").
171
		From(TransactionsTable).
172
		Where(fmt.Sprintf("pg_xact_commit_timestamp(id::xid) > (%s)", subquery)).
173
		Where("id < pg_snapshot_xmin(pg_current_snapshot())").
174
		Where(squirrel.Eq{"tenant_id": tenantID}).
175
		OrderBy("pg_xact_commit_timestamp(id::xid)")
176
177
	// Convert the builder to a SQL query and arguments.
178
	query, args, err := builder.ToSql()
179
	if err != nil {
180
181
		slog.ErrorContext(ctx, "error while building sql query", slog.Any("error", err))
182
183
		return nil, err
184
	}
185
186
	slog.DebugContext(ctx, "executing SQL query to get recent transaction", slog.Any("query", query), slog.Any("arguments", args))
187
188
	// Execute the SQL query.
189
	rows, err := w.database.ReadPool.Query(ctx, query, args...)
190
	if err != nil {
191
		slog.ErrorContext(ctx, "failed to execute SQL query", slog.Any("error", err))
192
		return nil, err
193
	}
194
	defer rows.Close()
195
196
	// Loop through the rows and append XID8 values to the results.
197
	var xids []types.XID8
198
	for rows.Next() {
199
		var xid types.XID8
200
		err := rows.Scan(&xid)
201
		if err != nil {
202
			slog.ErrorContext(ctx, "error while scanning row", slog.Any("error", err))
203
			return nil, err
204
		}
205
		xids = append(xids, xid)
206
	}
207
208
	// Check for errors that could have occurred during iteration.
209
	err = rows.Err()
210
	if err != nil {
211
212
		slog.ErrorContext(ctx, "failed to iterate over rows", slog.Any("error", err))
213
214
		return nil, err
215
	}
216
217
	slog.DebugContext(ctx, "successfully retrieved recent transaction", slog.Any("ids", xids))
218
	return xids, nil
219
}
220
221
// getChanges is a method that retrieves the changes that occurred in the relation tuples within a specified transaction.
222
//
223
// ctx: The context.Context instance for managing the life-cycle of this function.
224
// value: The ID of the transaction for which to retrieve the changes.
225
// tenantID: The ID of the tenant for which to retrieve the changes.
226
//
227
// This method returns a TupleChanges instance that encapsulates the changes in the relation tuples within the specified
228
// transaction, or an error if something went wrong during execution.
229
func (w *Watch) getChanges(ctx context.Context, value types.XID8, tenantID string) (*base.DataChanges, error) {
230
	// Initialize a new TupleChanges instance.
231
	changes := &base.DataChanges{}
232
233
	slog.DebugContext(ctx, "retrieving changes for transaction", slog.Any("id", value), slog.Any("tenant_id", tenantID))
234
235
	// Construct the SQL SELECT statement for retrieving the changes from the RelationTuplesTable.
236
	tbuilder := w.database.Builder.Select("entity_type, entity_id, relation, subject_type, subject_id, subject_relation, expired_tx_id").
237
		From(RelationTuplesTable).
238
		Where(squirrel.Eq{"tenant_id": tenantID}).Where(squirrel.Or{
239
		squirrel.Eq{"created_tx_id": value},
240
		squirrel.Eq{"expired_tx_id": value},
241
	})
242
243
	// Generate the SQL query and arguments.
244
	tquery, targs, err := tbuilder.ToSql()
245
	if err != nil {
246
		slog.ErrorContext(ctx, "error while building sql query for relation tuples", slog.Any("error", err))
247
		return nil, err
248
	}
249
250
	slog.DebugContext(ctx, "executing sql query for relation tuples", slog.Any("query", tquery), slog.Any("arguments", targs))
251
252
	// Execute the SQL query and retrieve the result rows.
253
	var trows pgx.Rows
254
	trows, err = w.database.ReadPool.Query(ctx, tquery, targs...)
255
	if err != nil {
256
		slog.ErrorContext(ctx, "failed to execute sql query for relation tuples", slog.Any("error", err))
257
		return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
258
	}
259
	// Ensure the rows are closed after processing.
260
	defer trows.Close()
261
262
	abuilder := w.database.Builder.Select("entity_type, entity_id, attribute, value, expired_tx_id").
263
		From(AttributesTable).
264
		Where(squirrel.Eq{"tenant_id": tenantID}).Where(squirrel.Or{
265
		squirrel.Eq{"created_tx_id": value},
266
		squirrel.Eq{"expired_tx_id": value},
267
	})
268
269
	aquery, aargs, err := abuilder.ToSql()
270
	if err != nil {
271
		slog.ErrorContext(ctx, "error while building SQL query for attributes", slog.Any("error", err))
272
		return nil, err
273
	}
274
275
	slog.DebugContext(ctx, "executing sql query for attributes", slog.Any("query", aquery), slog.Any("arguments", aargs))
276
277
	var arows pgx.Rows
278
	arows, err = w.database.ReadPool.Query(ctx, aquery, aargs...)
279
	if err != nil {
280
		slog.ErrorContext(ctx, "error while executing SQL query for attributes", slog.Any("error", err))
281
		return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
282
	}
283
	// Ensure the rows are closed after processing.
284
	defer arows.Close()
285
286
	// Set the snapshot token for the changes.
287
	changes.SnapToken = snapshot.Token{Value: value}.Encode().String()
288
289
	// Iterate through the result rows.
290
	for trows.Next() {
291
		var expiredXID types.XID8
292
293
		rt := storage.RelationTuple{}
294
		// Scan the result row into a RelationTuple instance.
295
		err = trows.Scan(&rt.EntityType, &rt.EntityID, &rt.Relation, &rt.SubjectType, &rt.SubjectID, &rt.SubjectRelation, &expiredXID)
296
		if err != nil {
297
			slog.ErrorContext(ctx, "error while scanning row for relation tuples", slog.Any("error", err))
298
			return nil, err
299
		}
300
301
		// Determine the operation type based on the expired transaction ID.
302
		op := base.DataChange_OPERATION_CREATE
303
		if expiredXID.Uint == value.Uint {
304
			op = base.DataChange_OPERATION_DELETE
305
		}
306
307
		// Append the change to the list of changes.
308
		changes.DataChanges = append(changes.DataChanges, &base.DataChange{
309
			Operation: op,
310
			Type: &base.DataChange_Tuple{
311
				Tuple: rt.ToTuple(),
312
			},
313
		})
314
	}
315
316
	// Iterate through the result rows.
317
	for arows.Next() {
318
		var expiredXID types.XID8
319
320
		rt := storage.Attribute{}
321
322
		var valueStr string
323
324
		// Scan the result row into a RelationTuple instance.
325
		err = arows.Scan(&rt.EntityType, &rt.EntityID, &rt.Attribute, &valueStr, &expiredXID)
326
		if err != nil {
327
			slog.ErrorContext(ctx, "error while scanning row for attributes", slog.Any("error", err))
328
			return nil, err
329
		}
330
331
		// Unmarshal the JSON data from `valueStr` into `rt.Value`.
332
		rt.Value = &anypb.Any{}
333
		err = protojson.Unmarshal([]byte(valueStr), rt.Value)
334
		if err != nil {
335
			slog.ErrorContext(ctx, "failed to unmarshal attribute value", slog.Any("error", err))
336
			return nil, err
337
		}
338
339
		// Determine the operation type based on the expired transaction ID.
340
		op := base.DataChange_OPERATION_CREATE
341
		if expiredXID.Uint == value.Uint {
342
			op = base.DataChange_OPERATION_DELETE
343
		}
344
345
		// Append the change to the list of changes.
346
		changes.DataChanges = append(changes.DataChanges, &base.DataChange{
347
			Operation: op,
348
			Type: &base.DataChange_Attribute{
349
				Attribute: rt.ToAttribute(),
350
			},
351
		})
352
	}
353
354
	slog.DebugContext(ctx, "successfully retrieved changes for transaction", slog.Any("id", value))
355
356
	// Return the changes and no error.
357
	return changes, nil
358
}
359