Passed
Push — master ( b9a458...e79a13 )
by Tolga
01:29 queued 13s
created

pkg/balancer/balancer.go   A

Size/Duplication

Total Lines 250
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
cc 27
eloc 149
dl 0
loc 250
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
C balancer.*Balancer.UpdateSubConnState 0 51 9
A balancer.*Balancer.ResolverError 0 15 3
F balancer.*Balancer.UpdateClientConnState 0 126 14
A balancer.*Balancer.Close 0 1 1
1
package balancer
2
3
import (
4
	"errors"
5
	"fmt"
6
	"log/slog"
7
8
	"google.golang.org/grpc/balancer"
9
	"google.golang.org/grpc/balancer/base"
10
	"google.golang.org/grpc/connectivity"
11
	"google.golang.org/grpc/resolver"
12
13
	"github.com/Permify/permify/pkg/consistent"
14
)
15
16
type Balancer struct {
17
	// Current overall connectivity state of the balancer.
18
	state connectivity.State
19
20
	// The ClientConn to communicate with the gRPC client.
21
	clientConn balancer.ClientConn
22
23
	// Current picker used to select SubConns for requests.
24
	picker balancer.Picker
25
26
	// Evaluates connectivity state transitions for SubConns.
27
	connectivityEvaluator *balancer.ConnectivityStateEvaluator
28
29
	// Map of resolver addresses to SubConns.
30
	addressSubConns *resolver.AddressMap
31
32
	// Tracks the connectivity state of each SubConn.
33
	subConnStates map[balancer.SubConn]connectivity.State
34
35
	// Configuration for consistent hashing and replication.
36
	config *Config
37
38
	// Consistent hashing mechanism to distribute requests.
39
	consistent *consistent.Consistent
40
41
	// Hasher used by the consistent hashing mechanism.
42
	hasher consistent.Hasher
43
44
	// Stores the last resolver error encountered.
45
	lastResolverError error
46
47
	// Stores the last connection error encountered.
48
	lastConnectionError error
49
}
50
51
func (b *Balancer) ResolverError(err error) {
52
	b.lastResolverError = err
53
	if b.addressSubConns.Len() == 0 {
54
		b.state = connectivity.TransientFailure
55
		b.picker = base.NewErrPicker(errors.Join(b.lastConnectionError, b.lastResolverError))
56
	}
57
58
	if b.state != connectivity.TransientFailure {
59
		return
60
	}
61
62
	// Update the balancer state and picker.
63
	b.clientConn.UpdateState(balancer.State{
64
		ConnectivityState: b.state,
65
		Picker:            b.picker,
66
	})
67
}
68
69
func (b *Balancer) UpdateClientConnState(s balancer.ClientConnState) error {
70
	// Log the new ClientConn state.
71
	slog.Info("Received new ClientConn state",
72
		slog.Any("state", s),
73
	)
74
75
	// Reset any existing resolver error.
76
	b.lastResolverError = nil
77
78
	// Handle changes to the balancer configuration.
79
	if s.BalancerConfig != nil {
80
		svcConfig := s.BalancerConfig.(*Config)
81
		if b.config == nil || svcConfig.ReplicationFactor != b.config.ReplicationFactor {
82
			slog.Info("Updating consistent hashing configuration",
83
				slog.Int("partition_count", svcConfig.PartitionCount),
84
				slog.Int("replication_factor", svcConfig.ReplicationFactor),
85
				slog.Float64("load", svcConfig.Load),
86
				slog.Int("picker_width", svcConfig.PickerWidth),
87
			)
88
			b.consistent = consistent.New(consistent.Config{
89
				PartitionCount:    svcConfig.PartitionCount,
90
				ReplicationFactor: svcConfig.ReplicationFactor,
91
				Load:              svcConfig.Load,
92
				PickerWidth:       svcConfig.PickerWidth,
93
				Hasher:            b.hasher,
94
			})
95
			b.config = svcConfig
96
		}
97
	}
98
99
	// Check if the consistent hashing configuration exists.
100
	if b.consistent == nil {
101
		slog.Error("No consistent hashing configuration found")
102
		b.picker = base.NewErrPicker(errors.Join(b.lastConnectionError, b.lastResolverError))
103
		b.clientConn.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker})
104
		return fmt.Errorf("no consistent hashing configuration found")
105
	}
106
107
	// Maintain a set of addresses provided by the resolver.
108
	addrsSet := resolver.NewAddressMap()
109
	for _, addr := range s.ResolverState.Addresses {
110
		addrsSet.Set(addr, nil)
111
112
		// Add new SubConns for addresses that are not already tracked.
113
		if _, ok := b.addressSubConns.Get(addr); !ok {
114
			sc, err := b.clientConn.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{HealthCheckEnabled: false})
115
			if err != nil {
116
				slog.Warn("Failed to create new SubConn",
117
					slog.String("address", addr.Addr),
118
					slog.String("server_name", addr.ServerName),
119
					slog.String("error", err.Error()),
120
				)
121
				continue
122
			}
123
124
			b.addressSubConns.Set(addr, sc)
125
			b.subConnStates[sc] = connectivity.Idle
126
			b.connectivityEvaluator.RecordTransition(connectivity.Shutdown, connectivity.Idle)
127
			sc.Connect()
128
129
			b.consistent.Add(ConsistentMember{
130
				SubConn: sc,
131
				name:    fmt.Sprintf("%s|%s", addr.ServerName, addr.Addr),
132
			})
133
		}
134
	}
135
136
	// Remove SubConns that are no longer part of the resolved addresses.
137
	for _, addr := range b.addressSubConns.Keys() {
138
		sci, _ := b.addressSubConns.Get(addr)
139
		sc := sci.(balancer.SubConn)
140
		if _, ok := addrsSet.Get(addr); !ok {
141
			slog.Info("Removing SubConn",
142
				slog.String("address", addr.Addr),
143
				slog.String("server_name", addr.ServerName),
144
			)
145
			b.clientConn.RemoveSubConn(sc)
146
			b.addressSubConns.Delete(addr)
147
			b.consistent.Remove(ConsistentMember{
148
				SubConn: sc,
149
				name:    fmt.Sprintf("%s|%s", addr.ServerName, addr.Addr),
150
			}.String())
151
		}
152
	}
153
154
	// Log the current members in the consistent hashing ring.
155
	slog.Info("Current consistent members",
156
		slog.Int("member_count", len(b.consistent.Members())),
157
	)
158
	for _, m := range b.consistent.Members() {
159
		slog.Info("Consistent member", slog.String("member", m.String()))
160
	}
161
162
	// Handle the case where the resolver produces zero addresses.
163
	if len(s.ResolverState.Addresses) == 0 {
164
		err := errors.New("resolver produced zero addresses")
165
		b.ResolverError(err)
166
		slog.Error("Resolver produced zero addresses")
167
		return balancer.ErrBadResolverState
168
	}
169
170
	// Update the picker based on the current balancer state.
171
	if b.state == connectivity.TransientFailure {
172
		slog.Warn("Transient failure detected, using error picker")
173
		b.picker = base.NewErrPicker(errors.Join(b.lastConnectionError, b.lastResolverError))
174
	} else {
175
		width := b.config.PickerWidth
176
		if width < 1 {
177
			width = 1
178
		}
179
		slog.Info("Creating new picker",
180
			slog.Int("width", width),
181
		)
182
		b.picker = &picker{
183
			consistent: b.consistent,
184
			width:      width,
185
		}
186
	}
187
188
	// Update the ClientConn state with the new picker.
189
	slog.Info("Updating ClientConn state",
190
		slog.String("connectivity_state", b.state.String()),
191
	)
192
	b.clientConn.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker})
193
194
	return nil
195
}
196
197
func (b *Balancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
198
	s := state.ConnectivityState
199
	slog.Info("Received SubConn state change",
200
		slog.String("connectivity_state", s.String()),
201
		slog.String("sub_conn", fmt.Sprintf("%p", sc)),
202
	)
203
204
	oldS, ok := b.subConnStates[sc]
205
	if !ok {
206
		slog.Warn("State change for unknown SubConn",
207
			slog.String("connectivity_state", s.String()),
208
			slog.String("sub_conn", fmt.Sprintf("%p", sc)),
209
		)
210
		return
211
	}
212
213
	if oldS == connectivity.TransientFailure && (s == connectivity.Connecting || s == connectivity.Idle) {
214
		if s == connectivity.Idle {
215
			slog.Info("Transitioning SubConn to connecting state",
216
				slog.String("sub_conn", fmt.Sprintf("%p", sc)),
217
			)
218
			sc.Connect()
219
		}
220
		return
221
	}
222
223
	b.subConnStates[sc] = s
224
	switch s {
225
	case connectivity.Idle:
226
		slog.Info("SubConn is idle, initiating connection",
227
			slog.String("sub_conn", fmt.Sprintf("%p", sc)),
228
		)
229
		sc.Connect()
230
	case connectivity.Shutdown:
231
		slog.Info("Removing shutdown SubConn",
232
			slog.String("sub_conn", fmt.Sprintf("%p", sc)),
233
		)
234
		delete(b.subConnStates, sc)
235
	case connectivity.TransientFailure:
236
		slog.Warn("SubConn in transient failure",
237
			slog.String("sub_conn", fmt.Sprintf("%p", sc)),
238
			slog.String("error", state.ConnectionError.Error()),
239
		)
240
		b.lastConnectionError = state.ConnectionError
241
	}
242
243
	b.state = b.connectivityEvaluator.RecordTransition(oldS, s)
244
	slog.Info("Updating ClientConn state",
245
		slog.String("connectivity_state", b.state.String()),
246
	)
247
	b.clientConn.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker})
248
}
249
250
func (b *Balancer) Close() {}
251