Passed
Pull Request — master (#367)
by
unknown
02:21
created

ha_client.HaManager.reconnect_warm()   B

Complexity

Conditions 8

Size

Total Lines 31
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 23
nop 1
dl 0
loc 31
rs 7.3333
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import asyncio
4
import inspect
5
import logging
6
7
from concurrent.futures import CancelledError, TimeoutError
8
from dataclasses import dataclass, field
9
from enum import IntEnum
10
from functools import partial
11
from itertools import chain
12
from sortedcontainers import SortedDict
13
from asyncua import Node, ua, Client
14
from asyncua.client.ua_client import UASocketProtocol
15
from asyncua.ua.uaerrors import BadSessionClosed, BadSessionNotActivated
16
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Type, Union
17
18
from .reconciliator import Reconciliator
19
from .utils import ClientNotFound, event_wait
20
from .virtual_subscription import TypeSubHandler, VirtualSubscription
21
from ...crypto.uacrypto import CertProperties
22
23
24
_logger = logging.getLogger(__name__)
25
26
27
class HaMode(IntEnum):
28
29
    # OPC UA Part 4 - 6.6.2.4.5.2 - Cold
30
    # Only connect to the active_client, failover is managed by
31
    # promoting another client of the pool to active_client
32
    COLD = 0
33
    # OPC UA Part 4 - 6.6.2.4.5.3 - Warm
34
    # Enable the active client similarly to the cold mode.
35
    # Secondary clients create the MonitoredItems,
36
    # but disable sampling and publishing.
37
    WARM = 1
38
    # OPC UA Part 4 - 6.6.2.4.5.4 - Hot A
39
    # Client connects to multiple Servers and establishes
40
    # subscription(s) in each where only one is Reporting;
41
    # the others are Sampling only.
42
    HOT_A = 2
43
    # OPC UA Part 4 - 6.6.2.4.5.4 - Hot B
44
    # Client connects to multiple Servers and establishes
45
    # subscription(s) in each where all subscriptions are Reporting.
46
    # The Client is responsible for handling/processing
47
    # multiple subscription streams concurrently.
48
    HOT_B = 3
49
50
51
class ConnectionStates(IntEnum):
52
    """
53
    OPC UA Part 4 - Services Release
54
    Section 6.6.2.4.2 ServiceLevel
55
    """
56
57
    IN_MAINTENANCE = 0
58
    NO_DATA = 1
59
    DEGRADED = 2
60
    HEALTHY = 200
61
62
63
@dataclass
64
class ServerInfo:
65
    url: str
66
    status: ConnectionStates = ConnectionStates(1)
67
68
69
@dataclass(frozen=True, eq=True)
70
class HaSecurityConfig:
71
    policy: Optional[Type[ua.SecurityPolicy]] = None
72
    certificate: Optional[CertProperties] = None
73
    private_key: Optional[CertProperties] = None
74
    server_certificate: Optional[CertProperties] = None
75
    mode: Optional[ua.MessageSecurityMode] = None
76
77
78
@dataclass(frozen=True, eq=True)
79
class HaConfig:
80
    """
81
    Parameters for the HaClient constructor.
82
    Timers and timeouts are all in seconds.
83
    """
84
85
    ha_mode: HaMode
86
    keepalive_timer: int = 15
87
    manager_timer: int = 15
88
    reconciliator_timer: int = 15
89
    session_timeout: int = 60
90
    request_timeout: int = 30
91
    session_name: str = "HaClient"
92
    urls: List[str] = field(default_factory=list)
93
94
95
class HaClient:
96
    """
97
    The HaClient is responsible for managing non-transparent server redundancy.
98
    The two servers must have:
99
        - Identical NodeIds
100
        - Identical browse path and AddressSpace structure
101
        - Identical Service Level logic
102
        - However nodes in the server local namespace can differ
103
        - Time synchronization (e.g NTP)
104
    It starts the OPC-UA clients and connect to the server that
105
    fits in the HaMode selected.
106
    """
107
108
    # Override this if your servers require custom ServiceLevels
109
    # i.e: You're using an OPC-UA proxy
110
    HEALTHY_STATE = ConnectionStates.HEALTHY
111
112
    def __init__(
113
        self, config: HaConfig, security: Optional[HaSecurityConfig] = None, loop=None
114
    ) -> None:
115
        self._config: HaConfig = config
116
        self._keepalive_task: Dict[KeepAlive, asyncio.Task] = {}
117
        self._manager_task: Dict[HaManager, asyncio.Task] = {}
118
        self._reconciliator_task: Dict[Reconciliator, asyncio.Task] = {}
119
        self._gen_sub: Generator[str, None, None] = self.generate_sub_name()
120
121
        self.loop: asyncio.unix_events._UnixSelectorEventLoop = (
122
            loop or asyncio.get_event_loop()
123
        )
124
        self._url_to_reset_lock = asyncio.Lock(loop=self.loop)
125
        self._ideal_map_lock: asyncio.Lock = asyncio.Lock(loop=self.loop)
126
        self._client_lock: asyncio.Lock = asyncio.Lock(loop=self.loop)
127
128
        self.clients: Dict[Client, ServerInfo] = {}
129
        self.active_client: Optional[Client] = None
130
        # full type: Dict[str, SortedDict[str, VirtualSubscription]]
131
        self.ideal_map: Dict[str, SortedDict] = {}
132
        self.sub_names: Set[str] = set()
133
        self.url_to_reset: List[str] = []
134
        self.is_running = False
135
136
        if config.ha_mode != HaMode.WARM:
137
            # TODO
138
            # Check if transparent redundancy support exist for the server (nodeid=2035)
139
            # and prevent using HaClient with such servers.
140
            raise NotImplementedError(
141
                f"{config.ha_mode} not currently supported by HaClient"
142
            )
143
144
        for url in self.urls:
145
            c = Client(url, timeout=self._config.request_timeout, loop=self.loop)
146
            # timeouts for the session and secure channel are in ms
147
            c.session_timeout = self._config.session_timeout * 1000
148
            c.secure_channel_timeout = self._config.request_timeout * 1000
149
            c.description = self._config.session_name
150
            server_info = ServerInfo(url)
151
            self.clients[c] = server_info
152
            self.ideal_map[url] = SortedDict()
153
154
        # security can also be set via the set_security method
155
        self.security_config: HaSecurityConfig = (
156
            security if security else HaSecurityConfig()
157
        )
158
        self.manager = HaManager(self, self._config.manager_timer)
159
        self.reconciliator = Reconciliator(self._config.reconciliator_timer, self)
160
161
    async def start(self) -> None:
162
        for client, server in self.clients.items():
163
            keepalive = KeepAlive(client, server, self._config.keepalive_timer)
164
            task = self.loop.create_task(keepalive.run())
165
            self._keepalive_task[keepalive] = task
166
167
        task = self.loop.create_task(self.manager.run())
168
        self._manager_task[self.manager] = task
169
170
        task = self.loop.create_task(self.reconciliator.run())
171
        self._reconciliator_task[self.reconciliator] = task
172
173
        self.is_running = True
174
175
    async def stop(self):
176
        to_stop = chain(
177
            self._keepalive_task, self._manager_task, self._reconciliator_task
178
        )
179
        stop = [p.stop() for p in to_stop]
180
181
        await asyncio.gather(*stop)
182
        disco = [c.disconnect() for c in self.clients]
183
        await asyncio.gather(*disco, return_exceptions=True)
184
185
        tasks = list(
186
            chain(
187
                self._keepalive_task.values(),
188
                self._manager_task.values(),
189
                self._reconciliator_task.values(),
190
            )
191
        )
192
193
        for task in tasks:
194
            task.cancel()
195
        for task in tasks:
196
            try:
197
                await task
198
            except CancelledError:
199
                pass
200
        self.is_running = False
201
202
    def set_security(
203
        self,
204
        policy: Type[ua.SecurityPolicy],
205
        certificate: CertProperties,
206
        private_key: CertProperties,
207
        server_certificate: Optional[CertProperties] = None,
208
        mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt,
209
    ) -> None:
210
211
        self.security_config = HaSecurityConfig(
212
            policy, certificate, private_key, server_certificate, mode
213
        )
214
215
    async def create_subscription(self, period: int, handler: TypeSubHandler) -> str:
216
        async with self._ideal_map_lock:
217
            sub_name = next(self._gen_sub)
218
            for client in self.clients:
219
                if client == self.active_client:
220
                    vs = VirtualSubscription(
221
                        period=period,
222
                        handler=handler,
223
                        publishing=True,
224
                        monitoring=ua.MonitoringMode.Reporting,
225
                    )
226
                else:
227
                    vs = VirtualSubscription(
228
                        period=period,
229
                        handler=handler,
230
                        publishing=False,
231
                        monitoring=ua.MonitoringMode.Disabled,
232
                    )
233
                url = client.server_url.geturl()
234
                self.ideal_map[url][sub_name] = vs
235
            return sub_name
236
237
    async def subscribe_data_change(
238
        self,
239
        sub_name: str,
240
        nodes: Union[Iterable[Node], Iterable[str]],
241
        attr=ua.AttributeIds.Value,
242
        queuesize=0,
243
    ) -> None:
244
245
        async with self._ideal_map_lock:
246
            nodes = [n.nodeid.to_string() if isinstance(n, Node) else n for n in nodes]
247
            for url in self.urls:
248
                vs = self.ideal_map[url].get(sub_name)
249
                if not vs:
250
                    _logger.warning(
251
                        f"The subscription specified for the data_change: {sub_name} doesn't exist in ideal_map"
252
                    )
253
                    return
254
                vs.subscribe_data_change(nodes, attr, queuesize)
255
                await self.hook_on_subscribe(
256
                    nodes=nodes, attr=attr, queuesize=queuesize, url=url
257
                )
258
259
    async def delete_subscriptions(self, sub_names: List[str]) -> None:
260
        async with self._ideal_map_lock:
261
            for sub_name in sub_names:
262
                for url in self.urls:
263
                    if self.ideal_map[url].get(sub_name):
264
                        self.ideal_map[url].pop(sub_name)
265
                    else:
266
                        _logger.warning(
267
                            f"No subscription named {sub_name} in ideal_map"
268
                        )
269
                self.sub_names.remove(sub_name)
270
271
    async def reconnect(self, client: Client) -> None:
272
        """
273
        Reconnect a client of the HA set and
274
        add its URL to the reset list.
275
        """
276
        async with self._url_to_reset_lock:
277
            url = client.server_url.geturl()
278
            self.url_to_reset.append(url)
279
        try:
280
            await client.disconnect()
281
        except Exception:
282
            pass
283
        await self.hook_on_reconnect(client=client)
284
        if self.security_config.policy:
285
            await client.set_security(**self.security_config.__dict__)
286
        await client.connect()
287
288
    async def unsubscribe(self, nodes: Union[Iterable[Node], Iterable[str]]) -> None:
289
        async with self._ideal_map_lock:
290
            sub_to_nodes = {}
291
            first_url = self.urls[0]
292
            for sub_name, vs in self.ideal_map[first_url].items():
293
                node_set = {
294
                    n.nodeid.to_string() if isinstance(n, Node) else n for n in nodes
295
                }
296
                to_del = node_set & vs.get_nodes()
297
                if to_del:
298
                    sub_to_nodes[sub_name] = to_del
299
            for url in self.urls:
300
                for sub_name, str_nodes in sub_to_nodes.items():
301
                    vs = self.ideal_map[url][sub_name]
302
                    vs.unsubscribe(str_nodes)
303
                    await self.hook_on_unsubscribe(url=url, nodes=str_nodes)
304
305
    async def failover_warm(
306
        self, primary: Optional[Client], secondaries: Iterable[Client]
307
    ) -> None:
308
        async with self._ideal_map_lock:
309
            if primary:
310
                self._set_monitoring_mode(
311
                    ua.MonitoringMode.Reporting, clients=[primary]
312
                )
313
                self._set_publishing_mode(True, clients=[primary])
314
            self.active_client = primary
315
            self._set_monitoring_mode(ua.MonitoringMode.Disabled, clients=secondaries)
316
            self._set_publishing_mode(False, clients=secondaries)
317
318
    def _set_monitoring_mode(
319
        self, monitoring: ua.MonitoringMode, clients: Iterable[Client]
320
    ) -> None:
321
        for client in clients:
322
            url = client.server_url.geturl()
323
            for sub in self.ideal_map[url]:
324
                vs = self.ideal_map[url][sub]
325
                vs.monitoring = monitoring
326
327
    def _set_publishing_mode(self, publishing: bool, clients: Iterable[Client]) -> None:
328
        for client in clients:
329
            url = client.server_url.geturl()
330
            for sub in self.ideal_map[url]:
331
                vs = self.ideal_map[url][sub]
332
                vs.publishing = publishing
333
334
    async def group_clients_by_health(self) -> Tuple[List[Client], List[Client]]:
335
        healthy = []
336
        unhealthy = []
337
        async with self._client_lock:
338
            for client, server in self.clients.items():
339
                if server.status >= self.HEALTHY_STATE:
340
                    healthy.append(client)
341
                else:
342
                    unhealthy.append(client)
343
            return healthy, unhealthy
344
345
    async def get_serving_client(
346
        self, clients: List[Client], serving_client: Optional[Client]
347
    ) -> Optional[Client]:
348
        """
349
        Returns the client with the higher service level.
350
351
        The service level reference is taken from the active_client,
352
        thus we prevent failing over when mutliple clients
353
        return the same number.
354
        """
355
        async with self._client_lock:
356
            if serving_client:
357
                max_slevel = self.clients[serving_client].status
358
            else:
359
                max_slevel = ConnectionStates.NO_DATA
360
361
            for c in clients:
362
                c_slevel = self.clients[c].status
363
                if c_slevel > max_slevel:
364
                    serving_client = c
365
                    max_slevel = c_slevel
366
            return serving_client if max_slevel >= self.HEALTHY_STATE else None
367
368
    async def debug_status(self):
369
        """
370
        Return the class attribute for troubleshooting purposes
371
        """
372
        for a in inspect.getmembers(self):
373
            if not a[0].startswith("__") and not inspect.ismethod(a[1]):
374
                _logger.debug(a)
375
376
    def get_client_warm_mode(self) -> List[Client]:
377
        return list(self.clients)
378
379
    def get_clients(self) -> List[Client]:
380
        ha_mode = self.ha_mode
381
        func = f"get_client_{ha_mode}_mode"
382
        get_clients = getattr(self, func)
383
        active_clients = get_clients()
384
        if not isinstance(active_clients, list):
385
            active_clients = [active_clients]
386
        return active_clients
387
388
    def get_client_by_url(self, url) -> Client:
389
        for client, srv_info in self.clients.items():
390
            if srv_info.url == url:
391
                return client
392
        raise ClientNotFound(f"{url} not managed by HaClient")
393
394
    @property
395
    def session_timeout(self) -> int:
396
        return self._config.session_timeout
397
398
    @property
399
    def ha_mode(self) -> str:
400
        return self._config.ha_mode.name.lower()
401
402
    @property
403
    def urls(self) -> List[str]:
404
        return self._config.urls
405
406
    def generate_sub_name(self) -> Generator[str, None, None]:
407
        """
408
        Asyncio unsafe - yield names for subscriptions.
409
        """
410
        while True:
411
            for i in range(9999):
412
                sub_name = f"sub_{i}"
413
                if sub_name in self.sub_names:
414
                    continue
415
                self.sub_names.add(sub_name)
416
                yield sub_name
417
418
    async def hook_on_reconnect(self, **kwargs):
419
        pass
420
421
    async def hook_on_subscribe(self, **kwargs):
422
        pass
423
424
    async def hook_on_unsubscribe(self, **kwargs):
425
        pass
426
427
428
class KeepAlive:
429
    """
430
    Ping the server status regularly to check its health
431
    """
432
433
    def __init__(self, client, server, timer) -> None:
434
        self.client: Client = client
435
        self.server: ServerInfo = server
436
        self.timer: int = timer
437
        self.stop_event: asyncio.locks.Event = asyncio.Event()
438
        self.is_running: bool = False
439
440
    async def stop(self) -> None:
441
        self.stop_event.set()
442
443
    async def run(self) -> None:
444
        status_node = self.client.nodes.server_state
445
        slevel_node = self.client.nodes.service_level
446
        server_info = self.server
447
        client = self.client
448
        # wait for HaManager to connect clients
449
        await asyncio.sleep(3)
450
        self.is_running = True
451
        _logger.info(
452
            f"Starting keepalive loop for {server_info.url}, checking every {self.timer}sec"
453
        )
454
        while self.is_running:
455
            try:
456
                status, slevel = await client.read_values([status_node, slevel_node])
457
                if status != ua.ServerState.Running:
458
                    _logger.info("ServerState is not running")
459
                    server_info.status = ConnectionStates.NO_DATA
460
                else:
461
                    server_info.status = slevel
462
            except BadSessionNotActivated:
463
                _logger.warning("Session is not yet activated.")
464
                server_info.status = ConnectionStates.NO_DATA
465
            except BadSessionClosed:
466
                _logger.warning("Session is closed.")
467
                server_info.status = ConnectionStates.NO_DATA
468
            except (TimeoutError, CancelledError):
469
                _logger.warning("Timeout when fetching state")
470
                server_info.status = ConnectionStates.NO_DATA
471
            except Exception:
472
                _logger.exception("Unknown exception during keepalive liveness check")
473
                server_info.status = ConnectionStates.NO_DATA
474
475
            _logger.info(f"ServiceLevel for {server_info.url}: {server_info.status}")
476
            if await event_wait(self.stop_event, self.timer):
477
                self.is_running = False
478
                break
479
480
481
class HaManager:
482
    """
483
    The manager handles individual client connections
484
    according to the selected HaMode
485
    """
486
487
    def __init__(self, ha_client: HaClient, timer: Optional[int] = None) -> None:
488
489
        self.ha_client = ha_client
490
        self.loop = ha_client.loop
491
        self.timer = self.set_loop_timer(timer)
492
        self.stop_event = asyncio.Event(loop=self.loop)
493
        self.is_running = False
494
495
    def set_loop_timer(self, timer: Optional[int]):
496
        return timer if timer else int(self.ha_client.session_timeout)
497
498
    async def run(self) -> None:
499
        ha_mode = self.ha_client.ha_mode
500
        update_func = f"update_state_{ha_mode}"
501
        update_state = getattr(self, update_func)
502
        reco_func = f"reconnect_{ha_mode}"
503
        reconnect = getattr(self, reco_func)
504
        self.is_running = True
505
506
        _logger.info(f"Starting HaManager loop, checking every {self.timer}sec")
507
        while self.is_running:
508
509
            # failover happens here
510
            await update_state()
511
            await reconnect()
512
            await self.ha_client.debug_status()
513
514
            if await event_wait(self.stop_event, self.timer):
515
                self.is_running = False
516
                break
517
518
    async def stop(self) -> None:
519
        self.stop_event.set()
520
521
    async def update_state_warm(self) -> None:
522
        active_client = self.ha_client.active_client
523
        clients = self.ha_client.get_clients()
524
        primary_client = await self.ha_client.get_serving_client(
525
            list(self.ha_client.clients), active_client
526
        )
527
        if primary_client != active_client:
528
            # disable monitoring and reporting when the service_level goes below 200
529
            _logger.info(
530
                f"Failing over active client from {active_client} to {primary_client}"
531
            )
532
            secondaries = (
533
                set(clients) - {primary_client} if primary_client else set(clients)
534
            )
535
            await self.ha_client.failover_warm(
536
                primary=primary_client, secondaries=secondaries
537
            )
538
539
    async def reconnect_warm(self) -> None:
540
        """
541
        Reconnect disconnected clients
542
        """
543
        healthy, unhealthy = await self.ha_client.group_clients_by_health()
544
545
        async def reco_resub(client: Client, force: bool):
546
            if (
547
                force
548
                or not client.uaclient.protocol
549
                or client.uaclient.protocol
550
                # pyre-fixme[16]: `Optional` has no attribute `state`.
551
                and client.uaclient.protocol.state == UASocketProtocol.CLOSED
552
            ):
553
                _logger.info(f"Virtually reconnecting and resubscribing {client}")
554
                await self.ha_client.reconnect(client=client)
555
556
        def log_exception(client: Client, fut: asyncio.Task):
557
            if fut.exception():
558
                _logger.warning(f"Error when reconnecting {client}: {fut.exception()}")
559
560
        tasks = []
561
        for client in healthy:
562
            task = self.loop.create_task(reco_resub(client, force=False))
563
            task.add_done_callback(partial(log_exception, client))
564
            tasks.append(task)
565
        for client in unhealthy:
566
            task = self.loop.create_task(reco_resub(client, force=True))
567
            task.add_done_callback(partial(log_exception, client))
568
            tasks.append(task)
569
        await asyncio.gather(*tasks, return_exceptions=True)
570
571