Passed
Pull Request — master (#1470)
by Tolga
02:39
created

context.*ContextualAttributes.QuerySingleAttribute   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
package context
2
3
import (
4
	"sort"
5
6
	"golang.org/x/exp/slices"
7
8
	"github.com/Permify/permify/internal/storage/context/utils"
9
	"github.com/Permify/permify/pkg/database"
10
	base "github.com/Permify/permify/pkg/pb/base/v1"
11
)
12
13
// ContextualAttributes - A collection of attributes with context.
14
type ContextualAttributes struct {
15
	Attributes []*base.Attribute
16
}
17
18
// NewContextualAttributes - Creates a new collection of attributes with context.
19
func NewContextualAttributes(attributes ...*base.Attribute) *ContextualAttributes {
20
	return &ContextualAttributes{
21
		Attributes: attributes,
22
	}
23
}
24
25
// QuerySingleAttribute filters the attributes based on the provided filter,
26
// and returns the first attribute from the filtered attributes, if any exist.
27
// If no attributes match the filter, it returns nil.
28
func (c *ContextualAttributes) QuerySingleAttribute(filter *base.AttributeFilter) (*base.Attribute, error) {
29
	filtered := c.filterAttributes(filter, "", "")
30
	if len(filtered) > 0 {
31
		return filtered[0], nil
32
	}
33
	return nil, nil
34
}
35
36
// QueryAttributes filters the attributes based on the provided filter,
37
// and returns an iterator to traverse through the filtered attributes.
38
func (c *ContextualAttributes) QueryAttributes(filter *base.AttributeFilter, pagination database.CursorPagination) (*database.AttributeIterator, error) {
39
	// Sort tuples based on the provided order field
40
	sort.Slice(c.Attributes, func(i, j int) bool {
41
		switch pagination.Sort() {
42
		case "entity_id":
43
			return c.Attributes[i].GetEntity().GetId() < c.Attributes[j].GetEntity().GetId()
44
		default:
45
			return false // If no valid order is provided, no sorting is applied
46
		}
47
	})
48
49
	cursor := ""
50
	if pagination.Cursor() != "" {
51
		t, err := utils.EncodedContinuousToken{Value: pagination.Cursor()}.Decode()
52
		if err != nil {
53
			return nil, err
54
		}
55
		cursor = t.(utils.ContinuousToken).Value
56
	}
57
58
	filtered := c.filterAttributes(filter, cursor, pagination.Sort())
59
	return database.NewAttributeIterator(filtered...), nil
60
}
61
62
// filterTuples applies the provided filter to c's Tuples and returns a slice of Tuples that match the filter.
63
func (c *ContextualAttributes) filterAttributes(filter *base.AttributeFilter, cursor, order string) []*base.Attribute {
64
	var filtered []*base.Attribute // Initialize a slice to hold the filtered tuples
65
66
	// Iterate over the tuples
67
	for _, attribute := range c.Attributes {
68
69
		// Skip tuples that come before the cursor based on the specified order field
70
		if cursor != "" && !isAttributeAfterCursor(attribute, cursor, order) {
71
			continue
72
		}
73
74
		// If a tuple matches the Entity, Relation, and Subject filters, add it to the filtered slice
75
		if matchesEntityFilterForAttributes(attribute, filter.GetEntity()) &&
76
			matchesAttributeFilter(attribute, filter.GetAttributes()) {
77
			filtered = append(filtered, attribute)
78
		}
79
	}
80
81
	return filtered // Return the filtered tuples
82
}
83
84
// isAfterCursor checks if the tuple's ID (based on the order field) comes after the cursor.
85
func isAttributeAfterCursor(attr *base.Attribute, cursor, order string) bool {
86
	switch order {
87
	case "entity_id":
88
		return attr.GetEntity().GetId() >= cursor
89
	default:
90
		// If the order field is not recognized, default to not skipping any tuples
91
		return true
92
	}
93
}
94
95
// matchesEntityFilterForAttributes checks if an attribute matches an entity filter.
96
func matchesEntityFilterForAttributes(attribute *base.Attribute, filter *base.EntityFilter) bool {
97
	typeMatches := filter.GetType() == "" || attribute.GetEntity().GetType() == filter.GetType()
98
	idMatches := len(filter.GetIds()) == 0 || slices.Contains(filter.GetIds(), attribute.GetEntity().GetId())
99
	return typeMatches && idMatches
100
}
101
102
// matchesAttributeFilter checks if an attribute matches the provided filter.
103
func matchesAttributeFilter(attribute *base.Attribute, filter []string) bool {
104
	return len(filter) == 0 || slices.Contains(filter, attribute.GetAttribute())
105
}
106