Passed
Pull Request — master (#2556)
by Tolga
03:34
created

schema.*LinkedSchemaGraph.GetInverseRelation   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 9
nop 2
dl 0
loc 14
rs 9.95
c 0
b 0
f 0
1
package schema
2
3
import (
4
	"errors"
5
	"sync"
6
7
	"github.com/Permify/permify/pkg/dsl/utils"
8
	base "github.com/Permify/permify/pkg/pb/base/v1"
9
)
10
11
// RelationMap represents a cached mapping for O(1) relation lookups between entity types.
12
// Structure: sourceEntityType -> targetEntityType -> relationName
13
type RelationMap map[string]map[string]string
14
15
// LinkedSchemaGraph represents a graph of linked schema objects. The schema object contains definitions for entities,
16
// relationships, and permissions, and the graph is constructed by linking objects together based on their dependencies. The
17
// graph is used by the PermissionEngine to resolve permissions and expand user sets for a given request.
18
type LinkedSchemaGraph struct {
19
	schema      *base.SchemaDefinition
20
	relationMap RelationMap
21
	once        sync.Once
22
}
23
24
// NewLinkedGraph returns a new instance of LinkedSchemaGraph with the specified base.SchemaDefinition as its schema.
25
// The schema object contains definitions for entities, relationships, and permissions, and is used to construct a graph of
26
// linked schema objects. The graph is used by the PermissionEngine to resolve permissions and expand user sets for a
27
// given request.
28
//
29
// Parameters:
30
//   - schema: pointer to the base.SchemaDefinition that defines the schema objects in the graph
31
//
32
// Returns:
33
//   - pointer to a new instance of LinkedSchemaGraph with the specified schema object
34
func NewLinkedGraph(schema *base.SchemaDefinition) *LinkedSchemaGraph {
35
	return &LinkedSchemaGraph{
36
		schema:      schema,
37
		relationMap: make(RelationMap),
38
		once:        sync.Once{},
39
	}
40
}
41
42
// ensureRelationMapInitialized ensures the relation map is initialized exactly once
43
func (g *LinkedSchemaGraph) ensureRelationMapInitialized() {
44
	g.once.Do(func() {
45
		g.relationMap = make(RelationMap)
46
47
		for sourceType, entityDef := range g.schema.EntityDefinitions {
48
			g.relationMap[sourceType] = make(map[string]string)
49
			for relationName, relationDef := range entityDef.Relations {
50
				for _, relRef := range relationDef.RelationReferences {
51
					g.relationMap[sourceType][relRef.GetType()] = relationName
52
				}
53
			}
54
		}
55
	})
56
}
57
58
// LinkedEntranceKind is a string type that represents the kind of LinkedEntrance object. An LinkedEntrance object defines an entry point
59
// into the LinkedSchemaGraph, which is used to resolve permissions and expand user sets for a given request.
60
//
61
// Values:
62
//   - RelationLinkedEntrance: represents an entry point into a relationship object in the schema graph
63
//   - TupleToUserSetLinkedEntrance: represents an entry point into a tuple-to-user-set object in the schema graph
64
//   - ComputedUserSetLinkedEntrance: represents an entry point into a computed user set object in the schema graph
65
type LinkedEntranceKind string
66
67
const (
68
	RelationLinkedEntrance        LinkedEntranceKind = "relation"
69
	TupleToUserSetLinkedEntrance  LinkedEntranceKind = "tuple_to_user_set"
70
	ComputedUserSetLinkedEntrance LinkedEntranceKind = "computed_user_set"
71
	AttributeLinkedEntrance       LinkedEntranceKind = "attribute"
72
)
73
74
// LinkedEntrance represents an entry point into the LinkedSchemaGraph, which is used to resolve permissions and expand user
75
// sets for a given request. The object contains a kind that specifies the type of entry point (e.g. relation, tuple-to-user-set),
76
// an entry point reference that identifies the specific entry point in the graph, and a tuple set relation reference that
77
// specifies the relation to use when expanding user sets for the entry point.
78
//
79
// Fields:
80
//   - Kind: LinkedEntranceKind representing the type of entry point
81
//   - TargetEntrance: pointer to a base.Entrance that identifies the entry point in the schema graph
82
//   - TupleSetRelation: string that specifies the relation to use when expanding user sets for the entry point
83
//   - RelationPath: slice of base.RelationReference that represents the path for nested attributes
84
type LinkedEntrance struct {
85
	Kind             LinkedEntranceKind
86
	TargetEntrance   *base.Entrance
87
	TupleSetRelation string
88
	RelationPath     []*base.RelationReference // Path for nested attributes using protobuf RelationReference
89
}
90
91
// LinkedEntranceKind returns the kind of the LinkedEntrance object. The kind specifies the type of entry point (e.g. relation,
92
// tuple-to-user-set, computed user set).
93
//
94
// Returns:
95
//   - LinkedEntranceKind representing the type of entry point
96
func (re LinkedEntrance) LinkedEntranceKind() LinkedEntranceKind {
97
	return re.Kind
98
}
99
100
// RelationshipLinkedEntrances returns a slice of LinkedEntrance objects that represent entry points into the LinkedSchemaGraph
101
// for the specified target and source relations. The function recursively searches the graph for all entry points that can
102
// be reached from the target relation through the specified source relation. The resulting entry points contain a reference
103
// to the relation object in the schema graph and the relation used to expand user sets for the entry point. If the target or
104
// source relation does not exist in the schema graph, the function returns an error.
105
//
106
// Parameters:
107
//   - target: pointer to a base.RelationReference that identifies the target relation
108
//   - source: pointer to a base.RelationReference that identifies the source relation used to reach the target relation
109
//
110
// Returns:
111
//   - slice of LinkedEntrance objects that represent entry points into the LinkedSchemaGraph, or an error if the target or
112
//     source relation does not exist in the schema graph
113
func (g *LinkedSchemaGraph) LinkedEntrances(target, source *base.Entrance) ([]*LinkedEntrance, error) {
114
	entries, err := g.findEntrance(target, source, map[string]struct{}{})
115
	if err != nil {
116
		return nil, err
117
	}
118
119
	return entries, nil
120
}
121
122
// findEntrance is a recursive helper function that searches the LinkedSchemaGraph for all entry points that can be reached
123
// from the specified target relation through the specified source relation. The function uses a depth-first search to traverse
124
// the schema graph and identify entry points, marking visited nodes in a map to avoid infinite recursion. If the target or
125
// source relation does not exist in the schema graph, the function returns an error. If the source relation is an action
126
// reference, the function recursively searches the graph for entry points reachable from the action child. If the source
127
// relation is a regular relational reference, the function delegates to findRelationEntrance to search for entry points.
128
//
129
// Parameters:
130
//   - target: pointer to a base.RelationReference that identifies the target relation
131
//   - source: pointer to a base.RelationReference that identifies the source relation used to reach the target relation
132
//   - visited: map used to track visited nodes and avoid infinite recursion
133
//
134
// Returns:
135
//   - slice of LinkedEntrance objects that represent entry points into the LinkedSchemaGraph, or an error if the target or
136
//     source relation does not exist in the schema graph
137
func (g *LinkedSchemaGraph) findEntrance(target, source *base.Entrance, visited map[string]struct{}) ([]*LinkedEntrance, error) {
138
	key := utils.Key(target.GetType(), target.GetValue())
139
	if _, ok := visited[key]; ok {
140
		return nil, nil
141
	}
142
	visited[key] = struct{}{}
143
144
	def, ok := g.schema.EntityDefinitions[target.GetType()]
145
	if !ok {
146
		return nil, errors.New("entity definition not found")
147
	}
148
149
	switch def.References[target.GetValue()] {
150
	case base.EntityDefinition_REFERENCE_PERMISSION:
151
		permission, ok := def.Permissions[target.GetValue()]
152
		if !ok {
153
			return nil, errors.New("permission not found")
154
		}
155
		child := permission.GetChild()
156
		if child.GetRewrite() != nil {
157
			return g.findEntranceRewrite(target, source, child.GetRewrite(), visited)
158
		}
159
		return g.findEntranceLeaf(target, source, child.GetLeaf(), visited)
160
	case base.EntityDefinition_REFERENCE_ATTRIBUTE:
161
		attribute, ok := def.Attributes[target.GetValue()]
162
		if !ok {
163
			return nil, errors.New("attribute not found")
164
		}
165
		return []*LinkedEntrance{
166
			{
167
				Kind: AttributeLinkedEntrance,
168
				TargetEntrance: &base.Entrance{
169
					Type:  target.GetType(),
170
					Value: attribute.GetName(),
171
				},
172
			},
173
		}, nil
174
	case base.EntityDefinition_REFERENCE_RELATION:
175
		return g.findRelationEntrance(target, source, visited)
176
	default:
177
		return nil, ErrUnimplemented
178
	}
179
}
180
181
// findRelationEntrance is a helper function that searches the LinkedSchemaGraph for entry points that can be reached from
182
// the specified target relation through the specified source relation. The function only returns entry points that are directly
183
// related to the target relation (i.e. the relation specified by the source reference is one of the relation's immediate children).
184
// The function recursively searches the children of the target relation and returns all reachable entry points. If the target
185
// or source relation does not exist in the schema graph, the function returns an error.
186
//
187
// Parameters:
188
//   - target: pointer to a base.RelationReference that identifies the target relation
189
//   - source: pointer to a base.RelationReference that identifies the source relation used to reach the target relation
190
//   - visited: map used to track visited nodes and avoid infinite recursion
191
//
192
// Returns:
193
//   - slice of LinkedEntrance objects that represent entry points into the LinkedSchemaGraph, or an error if the target or
194
//     source relation does not exist in the schema graph
195
func (g *LinkedSchemaGraph) findRelationEntrance(target, source *base.Entrance, visited map[string]struct{}) ([]*LinkedEntrance, error) {
196
	var res []*LinkedEntrance
197
198
	entity, ok := g.schema.EntityDefinitions[target.GetType()]
199
	if !ok {
200
		return nil, errors.New("entity definition not found")
201
	}
202
203
	relation, ok := entity.Relations[target.GetValue()]
204
	if !ok {
205
		return nil, errors.New("relation definition not found")
206
	}
207
208
	if IsDirectlyRelated(relation, source) {
209
		res = append(res, &LinkedEntrance{
210
			Kind: RelationLinkedEntrance,
211
			TargetEntrance: &base.Entrance{
212
				Type:  target.GetType(),
213
				Value: target.GetValue(),
214
			},
215
		})
216
	}
217
218
	for _, rel := range relation.GetRelationReferences() {
219
		if rel.GetRelation() != "" {
220
			entrances, err := g.findEntrance(&base.Entrance{
221
				Type:  rel.GetType(),
222
				Value: rel.GetRelation(),
223
			}, source, visited)
224
			if err != nil {
225
				return nil, err
226
			}
227
			res = append(res, entrances...)
228
		}
229
	}
230
231
	return res, nil
232
}
233
234
// findEntranceWithLeaf is a helper function that searches the LinkedSchemaGraph for entry points that can be reached from
235
// the specified target relation through an action reference with a leaf child. The function searches for entry points that are
236
// reachable through a tuple-to-user-set or computed-user-set action. If the action child is a tuple-to-user-set action, the
237
// function recursively searches for entry points reachable through the child's tuple set relation and the child's computed user
238
// set relation. If the action child is a computed-user-set action, the function recursively searches for entry points reachable
239
// through the computed user set relation. The function only returns entry points that can be reached from the target relation
240
// using the specified source relation. If the target or source relation does not exist in the schema graph, the function returns
241
// an error.
242
//
243
// Parameters:
244
//   - target: pointer to a base.RelationReference that identifies the target relation
245
//   - source: pointer to a base.RelationReference that identifies the source relation used to reach the target relation
246
//   - leaf: pointer to a base.Leaf object that represents the child of an action reference
247
//   - visited: map used to track visited nodes and avoid infinite recursion
248
//
249
// Returns:
250
//   - slice of LinkedEntrance objects that represent entry points into the LinkedSchemaGraph, or an error if the target or
251
//     source relation does not exist in the schema graph
252
func (g *LinkedSchemaGraph) findEntranceLeaf(target, source *base.Entrance, leaf *base.Leaf, visited map[string]struct{}) ([]*LinkedEntrance, error) {
253
	switch t := leaf.GetType().(type) {
254
	case *base.Leaf_TupleToUserSet:
255
		tupleSet := t.TupleToUserSet.GetTupleSet().GetRelation()
256
		computedUserSet := t.TupleToUserSet.GetComputed().GetRelation()
257
258
		var res []*LinkedEntrance
259
		entityDefinitions, exists := g.schema.EntityDefinitions[target.GetType()]
260
		if !exists {
261
			return nil, errors.New("entity definition not found")
262
		}
263
264
		relations, exists := entityDefinitions.Relations[tupleSet]
265
		if !exists {
266
			return nil, errors.New("relation definition not found")
267
		}
268
269
		// Cache computed relation paths to avoid duplicate BuildRelationPath calls
270
		relationPathCache := make(map[string][]*base.RelationReference)
271
272
		for _, rel := range relations.GetRelationReferences() {
273
			if rel.GetType() == source.GetType() && source.GetValue() == computedUserSet {
274
				res = append(res, &LinkedEntrance{
275
					Kind:             TupleToUserSetLinkedEntrance,
276
					TargetEntrance:   target,
277
					TupleSetRelation: tupleSet,
278
				})
279
			}
280
281
			results, err := g.findEntrance(
282
				&base.Entrance{
283
					Type:  rel.GetType(),
284
					Value: computedUserSet,
285
				},
286
				source,
287
				visited,
288
			)
289
			if err != nil {
290
				return nil, err
291
			}
292
293
			// Populate relation path for nested attributes with caching
294
			for _, result := range results {
295
				if result.Kind == AttributeLinkedEntrance && target.GetType() != rel.GetType() {
296
					cacheKey := target.GetType() + "->" + rel.GetType()
297
					if relationPath, exists := relationPathCache[cacheKey]; exists {
298
						result.RelationPath = relationPath
299
					} else {
300
						relationPath, err := g.BuildRelationPath(target.GetType(), rel.GetType())
301
						if err == nil {
302
							relationPathCache[cacheKey] = relationPath
303
							result.RelationPath = relationPath
304
						}
305
					}
306
				}
307
			}
308
309
			res = append(res, results...)
310
		}
311
		return res, nil
312
	case *base.Leaf_ComputedUserSet:
313
		var entrances []*LinkedEntrance
314
315
		if target.GetType() == source.GetType() && t.ComputedUserSet.GetRelation() == source.GetValue() {
316
			entrances = append(entrances, &LinkedEntrance{
317
				Kind:           ComputedUserSetLinkedEntrance,
318
				TargetEntrance: target,
319
			})
320
		}
321
322
		results, err := g.findEntrance(
323
			&base.Entrance{
324
				Type:  target.GetType(),
325
				Value: t.ComputedUserSet.GetRelation(),
326
			},
327
			source,
328
			visited,
329
		)
330
		if err != nil {
331
			return nil, err
332
		}
333
334
		entrances = append(
335
			entrances,
336
			results...,
337
		)
338
		return entrances, nil
339
	case *base.Leaf_ComputedAttribute:
340
		var entrances []*LinkedEntrance
341
		entrances = append(entrances, &LinkedEntrance{
342
			Kind: AttributeLinkedEntrance,
343
			TargetEntrance: &base.Entrance{
344
				Type:  target.GetType(),
345
				Value: t.ComputedAttribute.GetName(),
346
			},
347
		})
348
		return entrances, nil
349
	case *base.Leaf_Call:
350
		var entrances []*LinkedEntrance
351
		for _, arg := range t.Call.GetArguments() {
352
			computedAttr := arg.GetComputedAttribute()
353
			if computedAttr != nil {
354
				entrances = append(entrances, &LinkedEntrance{
355
					Kind: AttributeLinkedEntrance,
356
					TargetEntrance: &base.Entrance{
357
						Type:  target.GetType(),
358
						Value: computedAttr.GetName(),
359
					},
360
				})
361
			}
362
		}
363
		return entrances, nil
364
	default:
365
		return nil, ErrUndefinedLeafType
366
	}
367
}
368
369
// findEntranceWithRewrite is a helper function that searches the LinkedSchemaGraph for entry points that can be reached from
370
// the specified target relation through an action reference with a rewrite child. The function recursively searches each child of
371
// the rewrite and calls either findEntranceWithRewrite or findEntranceWithLeaf, depending on the child's type. The function
372
// only returns entry points that can be reached from the target relation using the specified source relation. If the target or
373
// source relation does not exist in the schema graph, the function returns an error.
374
//
375
// Parameters:
376
//   - target: pointer to a base.RelationReference that identifies the target relation
377
//   - source: pointer to a base.RelationReference that identifies the source relation used to reach the target relation
378
//   - rewrite: pointer to a base.Rewrite object that represents the child of an action reference
379
//   - visited: map used to track visited nodes and avoid infinite recursion
380
//
381
// Returns:
382
//   - slice of LinkedEntrance objects that represent entry points into the LinkedSchemaGraph, or an error if the target or
383
//     source relation does not exist in the schema graph
384
func (g *LinkedSchemaGraph) findEntranceRewrite(target, source *base.Entrance, rewrite *base.Rewrite, visited map[string]struct{}) (results []*LinkedEntrance, err error) {
385
	var res []*LinkedEntrance
386
	for _, child := range rewrite.GetChildren() {
387
		switch child.GetType().(type) {
388
		case *base.Child_Rewrite:
389
			results, err = g.findEntranceRewrite(target, source, child.GetRewrite(), visited)
390
			if err != nil {
391
				return nil, err
392
			}
393
		case *base.Child_Leaf:
394
			results, err = g.findEntranceLeaf(target, source, child.GetLeaf(), visited)
395
			if err != nil {
396
				return nil, err
397
			}
398
		default:
399
			return nil, errors.New("undefined child type")
400
		}
401
		res = append(res, results...)
402
	}
403
	return res, nil
404
}
405
406
// GetInverseRelation finds which relation connects the given entity types.
407
// Returns the relation name that connects sourceEntityType to targetEntityType.
408
// Uses O(1) cached lookup for optimal performance.
409
func (g *LinkedSchemaGraph) GetInverseRelation(sourceEntityType, targetEntityType string) (string, error) {
410
	g.ensureRelationMapInitialized()
411
412
	sourceRelations, exists := g.relationMap[sourceEntityType]
413
	if !exists {
414
		return "", errors.New("source entity definition not found")
415
	}
416
417
	relationName, exists := sourceRelations[targetEntityType]
418
	if !exists {
419
		return "", errors.New("no relation found connecting source to target entity type")
420
	}
421
422
	return relationName, nil
423
}
424
425
// BuildRelationPath builds the relation path for nested attributes
426
func (g *LinkedSchemaGraph) BuildRelationPath(sourceEntityType, targetEntityType string) ([]*base.RelationReference, error) {
427
	relationName, err := g.GetInverseRelation(sourceEntityType, targetEntityType)
428
	if err != nil {
429
		return nil, err
430
	}
431
432
	return []*base.RelationReference{
433
		{
434
			Type:     sourceEntityType,
435
			Relation: relationName,
436
		},
437
	}, nil
438
}
439