Passed
Pull Request — master (#1706)
by
unknown
03:55
created

postgres.*Watch.getBatchChanges   F

Complexity

Conditions 13

Size

Total Lines 138
Code Lines 77

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 13
eloc 77
nop 3
dl 0
loc 138
rs 3.6927
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like postgres.*Watch.getBatchChanges often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
package postgres
2
3
import (
4
	"context"
5
	"errors"
6
	"fmt"
7
	"log/slog"
8
	"strings"
9
	"time"
10
11
	"github.com/jackc/pgx/v5"
12
13
	"github.com/golang/protobuf/jsonpb"
14
15
	"google.golang.org/protobuf/types/known/anypb"
16
17
	"github.com/Masterminds/squirrel"
18
19
	"github.com/Permify/permify/internal/storage"
20
	"github.com/Permify/permify/internal/storage/postgres/snapshot"
21
	"github.com/Permify/permify/internal/storage/postgres/types"
22
	db "github.com/Permify/permify/pkg/database/postgres"
23
	base "github.com/Permify/permify/pkg/pb/base/v1"
24
)
25
26
// Watch is an implementation of the storage.Watch interface, which is used
27
type Watch struct {
28
	// database is a pointer to a Postgres database instance, which is used
29
	// to perform operations on the relationship data.
30
	database *db.Postgres
31
32
	// txOptions holds the configuration for database transactions, such as
33
	// isolation level and read-only mode, to be applied when performing
34
	// operations on the relationship data.
35
	txOptions pgx.TxOptions
36
}
37
38
// NewWatcher returns a new instance of the Watch.
39
func NewWatcher(database *db.Postgres) *Watch {
40
	return &Watch{
41
		database:  database,
42
		txOptions: pgx.TxOptions{IsoLevel: pgx.ReadCommitted, AccessMode: pgx.ReadOnly},
43
	}
44
}
45
46
// Watch returns a channel that emits a stream of changes to the relationship tuples in the database.
47
func (w *Watch) Watch(ctx context.Context, tenantID, snap string) (<-chan *base.DataChanges, <-chan error) {
48
	// Create channels for changes and errors.
49
	changes := make(chan *base.DataChanges, w.database.GetWatchBufferSize())
50
	errs := make(chan error, 1)
51
	sleepDuration := 100 * time.Millisecond
52
	const maxSleepDuration = 2 * time.Second
53
	var sleep *time.Timer
54
55
	slog.DebugContext(ctx, "watching for changes in the database", slog.Any("tenant_id", tenantID), slog.Any("snapshot", snap))
56
57
	// Decode the snapshot value.
58
	// The snapshot value represents a point in the history of the database.
59
	st, err := snapshot.EncodedToken{Value: snap}.Decode()
60
	if err != nil {
61
		// If there is an error in decoding the snapshot, send the error and return.
62
		errs <- err
63
		slog.Error("failed to decode snapshot", slog.Any("error", err))
64
		return changes, errs
65
	}
66
67
	// Start a goroutine to watch for changes in the database.
68
	go func() {
69
		// Ensure to close the channels when we're done.
70
		defer close(changes)
71
		defer close(errs)
72
73
		// Get the transaction ID from the snapshot.
74
		cr := st.(snapshot.Token).Value.Uint
75
76
		// Continuously watch for changes.
77
		for {
78
			// Get the list of recent transaction IDs.
79
			recentIDs, err := w.getRecentXIDs(ctx, cr, tenantID, w.database.GetWatchBufferSize())
80
			if err != nil {
81
				// If there is an error in getting recent transaction IDs, send the error and return.
82
				slog.Error("error getting recent transaction", slog.Any("error", err))
83
				errs <- err
84
				return
85
			}
86
87
			// Process each recent transaction ID.
88
			if len(recentIDs) > 0 {
89
				updatesBatch, err := w.getBatchChanges(ctx, recentIDs, tenantID)
90
				if err != nil {
91
					slog.ErrorContext(ctx, "failed to get batch changes", slog.Any("error", err))
92
					errs <- err
93
					return
94
				}
95
96
				// Send the batch of updates.
97
				select {
98
				case changes <- updatesBatch:
99
					slog.DebugContext(ctx, "sent batch updates to the changes channel")
100
				case <-ctx.Done():
101
					slog.ErrorContext(ctx, "context canceled, stopping watch")
102
					errs <- errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
103
					return
104
				}
105
106
				// Update the last transaction ID processed.
107
				cr = recentIDs[len(recentIDs)-1].Uint
108
109
				// Reset sleep duration if changes were found.
110
				sleepDuration = 100 * time.Millisecond
111
112
			} else {
113
				// If no recent transaction IDs were found, use exponential backoff.
114
115
				// Initialize the timer if it's the first iteration, or reset it.
116
				if sleep == nil {
117
					sleep = time.NewTimer(sleepDuration)
118
				} else {
119
					if !sleep.Stop() {
120
						<-sleep.C // Drain the channel to avoid a deadlock.
121
					}
122
					sleep.Reset(sleepDuration)
123
				}
124
125
				// Increase the sleep duration exponentially, but cap it at maxSleepDuration.
126
				if sleepDuration < maxSleepDuration {
127
					sleepDuration *= 2
128
				} else {
129
					sleepDuration = maxSleepDuration
130
				}
131
132
				// Wait for the timer or context cancellation.
133
				select {
134
				case <-sleep.C:
135
					slog.DebugContext(ctx, "no recent transaction IDs, waiting for changes")
136
				case <-ctx.Done():
137
					slog.ErrorContext(ctx, "context canceled, stopping watch")
138
					errs <- errors.New(base.ErrorCode_ERROR_CODE_CANCELLED.String())
139
					return
140
				}
141
			}
142
		}
143
	}()
144
145
	slog.DebugContext(ctx, "watch started successfully")
146
147
	// Return the channels that the caller will listen to for changes and errors.
148
	return changes, errs
149
}
150
151
// getRecentXIDs fetches a list of XID8 identifiers from the 'transactions' table
152
// for all transactions committed after a specified XID value.
153
//
154
// Parameters:
155
//   - ctx:       A context to control the execution lifetime.
156
//   - value:     The transaction XID after which we need the changes.
157
//   - tenantID:  The ID of the tenant to filter the transactions for.
158
//
159
// Returns:
160
//   - A slice of XID8 identifiers.
161
//   - An error if the query fails to execute, or other error occurs during its execution.
162
func (w *Watch) getRecentXIDs(ctx context.Context, value uint64, tenantID string, limit int) ([]types.XID8, error) {
163
	// Convert the value to a string formatted as a Postgresql XID8 type.
164
	valStr := fmt.Sprintf("'%v'::xid8", value)
165
166
	subquery := fmt.Sprintf("(select pg_xact_commit_timestamp(id::xid) from transactions where id = %s)", valStr)
167
168
	// Build the main query to get transactions committed after the one with a given XID,
169
	// still visible in the current snapshot, ordered by their commit timestamps.
170
	builder := w.database.Builder.Select("id").
171
		From(TransactionsTable).
172
		Where(fmt.Sprintf("pg_xact_commit_timestamp(id::xid) > (%s)", subquery)).
173
		Where("id < pg_snapshot_xmin(pg_current_snapshot())").
174
		Where(squirrel.Eq{"tenant_id": tenantID}).
175
		OrderBy("pg_xact_commit_timestamp(id::xid)").
176
		Limit(uint64(limit))
177
178
	// Convert the builder to a SQL query and arguments.
179
	query, args, err := builder.ToSql()
180
	if err != nil {
181
182
		slog.ErrorContext(ctx, "error while building sql query", slog.Any("error", err))
183
184
		return nil, err
185
	}
186
187
	slog.DebugContext(ctx, "executing SQL query to get recent transaction", slog.Any("query", query), slog.Any("arguments", args))
188
189
	// Execute the SQL query.
190
	rows, err := w.database.ReadPool.Query(ctx, query, args...)
191
	if err != nil {
192
193
		slog.ErrorContext(ctx, "failed to execute sql query", slog.Any("error", err))
194
195
		return nil, err
196
	}
197
	defer rows.Close()
198
199
	// Loop through the rows and append XID8 values to the results.
200
	var xids []types.XID8
201
	for rows.Next() {
202
		var xid types.XID8
203
		err := rows.Scan(&xid)
204
		if err != nil {
205
206
			slog.ErrorContext(ctx, "error while scanning row", slog.Any("error", err))
207
208
			return nil, err
209
		}
210
		xids = append(xids, xid)
211
	}
212
213
	// Check for errors that could have occurred during iteration.
214
	err = rows.Err()
215
	if err != nil {
216
217
		slog.ErrorContext(ctx, "failed to iterate over rows", slog.Any("error", err))
218
219
		return nil, err
220
	}
221
222
	slog.DebugContext(ctx, "successfully retrieved recent transaction", slog.Any("ids", xids))
223
	return xids, nil
224
}
225
226
func (w *Watch) getBatchChanges(ctx context.Context, values []types.XID8, tenantID string) (*base.DataChanges, error) {
227
	// Initialize a new TupleChanges instance.
228
	changes := &base.DataChanges{}
229
230
	// Log the batch of XID8 values.
231
	slog.DebugContext(ctx, "retrieving changes for transactions", slog.Any("ids", values), slog.Any("tenant_id", tenantID))
232
233
	// Convert the XID8 values into a slice of interface{} for the query.
234
	xidInterfaces := make([]interface{}, len(values))
235
	for i, xid := range values {
236
		xidInterfaces[i] = xid
237
	}
238
239
	fmt.Println(xidInterfaces...)
240
241
	// Construct the SQL SELECT statement for retrieving the changes from the RelationTuplesTable.
242
	tbuilder := w.database.Builder.Select("entity_type, entity_id, relation, subject_type, subject_id, subject_relation, expired_tx_id").
243
		From(RelationTuplesTable).
244
		Where(squirrel.Eq{"tenant_id": tenantID}).
245
		Where(squirrel.Or{
246
			squirrel.Eq{"created_tx_id": xidInterfaces},
247
			squirrel.Eq{"expired_tx_id": xidInterfaces},
248
		})
249
250
	// Generate the SQL query and arguments.
251
	tquery, targs, err := tbuilder.ToSql()
252
	if err != nil {
253
		slog.ErrorContext(ctx, "error while building sql query for relation tuples", slog.Any("error", err))
254
		return nil, err
255
	}
256
257
	slog.DebugContext(ctx, "executing sql query for relation tuples", slog.Any("query", tquery), slog.Any("arguments", targs))
258
259
	// Execute the SQL query and retrieve the result rows.
260
	trows, err := w.database.ReadPool.Query(ctx, tquery, targs...)
261
	if err != nil {
262
		slog.ErrorContext(ctx, "failed to execute sql query for relation tuples", slog.Any("error", err))
263
		return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
264
	}
265
	defer trows.Close()
266
267
	// Construct the SQL SELECT statement for retrieving changes from the AttributesTable.
268
	abuilder := w.database.Builder.Select("entity_type, entity_id, attribute, value, expired_tx_id").
269
		From(AttributesTable).
270
		Where(squirrel.Eq{"tenant_id": tenantID}).
271
		Where(squirrel.Or{
272
			squirrel.Eq{"created_tx_id": xidInterfaces},
273
			squirrel.Eq{"expired_tx_id": xidInterfaces},
274
		})
275
276
	// Generate the SQL query and arguments for attributes.
277
	aquery, aargs, err := abuilder.ToSql()
278
	if err != nil {
279
		slog.ErrorContext(ctx, "error while building SQL query for attributes", slog.Any("error", err))
280
		return nil, err
281
	}
282
283
	slog.DebugContext(ctx, "executing sql query for attributes", slog.Any("query", aquery), slog.Any("arguments", aargs))
284
285
	// Execute the SQL query and retrieve the result rows.
286
	arows, err := w.database.ReadPool.Query(ctx, aquery, aargs...)
287
	if err != nil {
288
		slog.ErrorContext(ctx, "error while executing SQL query for attributes", slog.Any("error", err))
289
		return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
290
	}
291
	defer arows.Close()
292
293
	// Set the snapshot token for the changes (encode the first XID as a reference).
294
	changes.SnapToken = snapshot.Token{Value: values[0]}.Encode().String()
295
296
	// Iterate through the result rows for relation tuples.
297
	for trows.Next() {
298
		var expiredXID types.XID8
299
300
		rt := storage.RelationTuple{}
301
		// Scan the result row into a RelationTuple instance.
302
		err = trows.Scan(&rt.EntityType, &rt.EntityID, &rt.Relation, &rt.SubjectType, &rt.SubjectID, &rt.SubjectRelation, &expiredXID)
303
		if err != nil {
304
			slog.ErrorContext(ctx, "error while scanning row for relation tuples", slog.Any("error", err))
305
			return nil, err
306
		}
307
308
		// Determine the operation type based on the expired transaction ID.
309
		op := base.DataChange_OPERATION_CREATE
310
		if containsXID(expiredXID, values) {
311
			op = base.DataChange_OPERATION_DELETE
312
		}
313
314
		// Append the change to the list of changes.
315
		changes.DataChanges = append(changes.DataChanges, &base.DataChange{
316
			Operation: op,
317
			Type: &base.DataChange_Tuple{
318
				Tuple: rt.ToTuple(),
319
			},
320
		})
321
	}
322
323
	// Iterate through the result rows for attributes.
324
	for arows.Next() {
325
		var expiredXID types.XID8
326
		rt := storage.Attribute{}
327
		var valueStr string
328
329
		// Scan the result row into an Attribute instance.
330
		err = arows.Scan(&rt.EntityType, &rt.EntityID, &rt.Attribute, &valueStr, &expiredXID)
331
		if err != nil {
332
			slog.ErrorContext(ctx, "error while scanning row for attributes", slog.Any("error", err))
333
			return nil, err
334
		}
335
336
		// Unmarshal the JSON data from `valueStr` into `rt.Value`.
337
		rt.Value = &anypb.Any{}
338
		unmarshaler := &jsonpb.Unmarshaler{}
339
		err = unmarshaler.Unmarshal(strings.NewReader(valueStr), rt.Value)
340
		if err != nil {
341
			slog.ErrorContext(ctx, "failed to unmarshal attribute value", slog.Any("error", err))
342
			return nil, err
343
		}
344
345
		// Determine the operation type based on the expired transaction ID.
346
		op := base.DataChange_OPERATION_CREATE
347
		if containsXID(expiredXID, values) {
348
			op = base.DataChange_OPERATION_DELETE
349
		}
350
351
		// Append the change to the list of changes.
352
		changes.DataChanges = append(changes.DataChanges, &base.DataChange{
353
			Operation: op,
354
			Type: &base.DataChange_Attribute{
355
				Attribute: rt.ToAttribute(),
356
			},
357
		})
358
	}
359
360
	slog.DebugContext(ctx, "successfully retrieved changes for transactions", slog.Any("ids", values))
361
362
	// Return the changes and no error.
363
	return changes, nil
364
}
365
366
// Helper function to check if an expiredXID is in the list of XID8 values.
367
func containsXID(xid types.XID8, xidList []types.XID8) bool {
368
	for _, x := range xidList {
369
		if x.Uint == xid.Uint {
370
			return true
371
		}
372
	}
373
	return false
374
}
375