Passed
Pull Request — master (#51)
by Olivier
02:27
created

asyncua.client.client.Client.__str__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import asyncio
2
import logging
3
from typing import Union, Coroutine
4
from urllib.parse import urlparse
5
6
from asyncua import ua
7
from .ua_client import UaClient
8
from ..common.xmlimporter import XmlImporter
9
from ..common.xmlexporter import XmlExporter
10
from ..common.node import Node
11
from ..common.manage_nodes import delete_nodes
12
from ..common.subscription import Subscription
13
from ..common.shortcuts import Shortcuts
14
from ..common.structures import load_type_definitions, load_enums
15
from ..common.utils import create_nonce
16
from ..crypto import uacrypto, security_policies
17
18
_logger = logging.getLogger(__name__)
19
20
21
class Client:
22
    """
23
    High level client to connect to an OPC-UA server.
24
25
    This class makes it easy to connect and browse address space.
26
    It attempts to expose as much functionality as possible
27
    but if you want more flexibility it is possible and advised to
28
    use UaClient object, available as self.uaclient
29
    which offers the raw OPC-UA services interface.
30
    """
31
32
    def __init__(self, url: str, timeout: int = 4, loop=None):
33
        """
34
        :param url: url of the server.
35
            if you are unsure of url, write at least hostname
36
            and port and call get_endpoints
37
38
        :param timeout:
39
            Each request sent to the server expects an answer within this
40
            time. The timeout is specified in seconds.
41
        """
42
        self.logger = logging.getLogger(__name__)
43
        self.loop = loop or asyncio.get_event_loop()
44
        self.server_url = urlparse(url)
45
        # take initial username and password from the url
46
        self._username = self.server_url.username
47
        self._password = self.server_url.password
48
        self.name = "Pure Python Async. Client"
49
        self.description = self.name
50
        self.application_uri = "urn:freeopcua:client"
51
        self.product_uri = "urn:freeopcua.github.io:client"
52
        self.security_policy = ua.SecurityPolicy()
53
        self.secure_channel_id = None
54
        self.secure_channel_timeout = 3600000  # 1 hour
55
        self.session_timeout = 3600000  # 1 hour
56
        self._policy_ids = []
57
        self.uaclient: UaClient = UaClient(timeout, loop=self.loop)
58
        self.user_certificate = None
59
        self.user_private_key = None
60
        self._server_nonce = None
61
        self._session_counter = 1
62
        self.nodes = Shortcuts(self.uaclient)
63
        self.max_messagesize = 0  # No limits
64
        self.max_chunkcount = 0  # No limits
65
        self._renew_channel_task = None
66
67
    async def __aenter__(self):
68
        await self.connect()
69
        return self
70
71
    async def __aexit__(self, exc_type, exc_value, traceback):
72
        await self.disconnect()
73
74
    def __str__(self):
75
        return f"Client({self.server_url.geturl()})"
76
    __repr__ = __str__
77
78
    @staticmethod
79
    def find_endpoint(endpoints, security_mode, policy_uri):
80
        """
81
        Find endpoint with required security mode and policy URI
82
        """
83
        _logger.info("find_endpoint %r %r %r", endpoints, security_mode, policy_uri)
84
        for ep in endpoints:
85
            if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and
86
                    ep.SecurityMode == security_mode and
87
                    ep.SecurityPolicyUri == policy_uri):
88
                return ep
89
        raise ua.UaError("No matching endpoints: {0}, {1}".format(security_mode, policy_uri))
90
91
    def set_user(self, username: str):
92
        """
93
        Set user name for the connection.
94
        initial user from the URL will be overwritten
95
        """
96
        self._username = username
97
98
    def set_password(self, pwd: str):
99
        """
100
        Set user password for the connection.
101
        initial password from the URL will be overwritten
102
        """
103
        if not isinstance(pwd, str):
104
            raise TypeError(f"Password must be a string, got {pwd} of type {type(pwd)}")
105
        self._password = pwd
106
107
    async def set_security_string(self, string: str):
108
        """
109
        Set SecureConnection mode. String format:
110
        Policy,Mode,certificate,private_key[,server_private_key]
111
        where Policy is Basic128Rsa15, Basic256 or Basic256Sha256,
112
            Mode is Sign or SignAndEncrypt
113
            certificate, private_key and server_private_key are
114
                paths to .pem or .der files
115
        Call this before connect()
116
        """
117
        if not string:
118
            return
119
        parts = string.split(",")
120
        if len(parts) < 4:
121
            raise ua.UaError("Wrong format: `{}`, expected at least 4 comma-separated values".format(string))
122
        policy_class = getattr(security_policies, "SecurityPolicy{}".format(parts[0]))
123
        mode = getattr(ua.MessageSecurityMode, parts[1])
124
        return await self.set_security(
125
            policy_class, parts[2], parts[3], parts[4] if len(parts) >= 5 else None, mode
126
        )
127
128
    async def set_security(self, policy, certificate_path: str, private_key_path: str,
129
                           server_certificate_path: str = None,
130
                           mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt):
131
        """
132
        Set SecureConnection mode.
133
        Call this before connect()
134
        """
135
        if server_certificate_path is None:
136
            # load certificate from server's list of endpoints
137
            endpoints = await self.connect_and_get_server_endpoints()
138
            endpoint = Client.find_endpoint(endpoints, mode, policy.URI)
139
            server_cert = uacrypto.x509_from_der(endpoint.ServerCertificate)
140
        else:
141
            server_cert = await uacrypto.load_certificate(server_certificate_path)
142
        cert = await uacrypto.load_certificate(certificate_path)
143
        pk = await uacrypto.load_private_key(private_key_path)
144
        self.security_policy = policy(server_cert, cert, pk, mode)
145
        self.uaclient.set_security(self.security_policy)
146
147
    async def load_client_certificate(self, path: str):
148
        """
149
        load our certificate from file, either pem or der
150
        """
151
        self.user_certificate = await uacrypto.load_certificate(path)
152
153
    async def load_private_key(self, path: str):
154
        """
155
        Load user private key. This is used for authenticating using certificate
156
        """
157
        self.user_private_key = await uacrypto.load_private_key(path)
158
159
    async def connect_and_get_server_endpoints(self):
160
        """
161
        Connect, ask server for endpoints, and disconnect
162
        """
163
        await self.connect_socket()
164
        try:
165
            await self.send_hello()
166
            await self.open_secure_channel()
167
            endpoints = await self.get_endpoints()
168
            await self.close_secure_channel()
169
        finally:
170
            self.disconnect_socket()
171
        return endpoints
172
173
    async def connect_and_find_servers(self):
174
        """
175
        Connect, ask server for a list of known servers, and disconnect
176
        """
177
        await self.connect_socket()
178
        try:
179
            await self.send_hello()
180
            await self.open_secure_channel()  # spec says it should not be necessary to open channel
181
            servers = await self.find_servers()
182
            await self.close_secure_channel()
183
        finally:
184
            self.disconnect_socket()
185
        return servers
186
187
    async def connect_and_find_servers_on_network(self):
188
        """
189
        Connect, ask server for a list of known servers on network, and disconnect
190
        """
191
        await self.connect_socket()
192
        try:
193
            await self.send_hello()
194
            await self.open_secure_channel()
195
            servers = await self.find_servers_on_network()
196
            await self.close_secure_channel()
197
        finally:
198
            self.disconnect_socket()
199
        return servers
200
201
    async def connect(self):
202
        """
203
        High level method
204
        Connect, create and activate session
205
        """
206
        _logger.info("connect")
207
        await self.connect_socket()
208
        try:
209
            await self.send_hello()
210
            await self.open_secure_channel()
211
            await self.create_session()
212
        except Exception:
213
            # clean up open socket
214
            self.disconnect_socket()
215
            raise
216
        await self.activate_session(username=self._username, password=self._password, certificate=self.user_certificate)
217
218
    async def disconnect(self):
219
        """
220
        High level method
221
        Close session, secure channel and socket
222
        """
223
        _logger.info("disconnect")
224
        try:
225
            await self.close_session()
226
            await self.close_secure_channel()
227
        finally:
228
            self.disconnect_socket()
229
230
    async def connect_socket(self):
231
        """
232
        connect to socket defined in url
233
        """
234
        await self.uaclient.connect_socket(self.server_url.hostname, self.server_url.port)
235
236
    def disconnect_socket(self):
237
        self.uaclient.disconnect_socket()
238
239
    async def send_hello(self):
240
        """
241
        Send OPC-UA hello to server
242
        """
243
        ack = await self.uaclient.send_hello(self.server_url.geturl(), self.max_messagesize, self.max_chunkcount)
244
        if isinstance(ack, ua.UaStatusCodeError):
245
            raise ack
246
247
    async def open_secure_channel(self, renew=False):
248
        """
249
        Open secure channel, if renew is True, renew channel
250
        """
251
        params = ua.OpenSecureChannelParameters()
252
        params.ClientProtocolVersion = 0
253
        params.RequestType = ua.SecurityTokenRequestType.Issue
254
        if renew:
255
            params.RequestType = ua.SecurityTokenRequestType.Renew
256
        params.SecurityMode = self.security_policy.Mode
257
        params.RequestedLifetime = self.secure_channel_timeout
258
        # length should be equal to the length of key of symmetric encryption
259
        nonce = create_nonce(self.security_policy.symmetric_key_size)
260
        params.ClientNonce = nonce  # this nonce is used to create a symmetric key
261
        result = await self.uaclient.open_secure_channel(params)
262
        self.security_policy.make_symmetric_key(nonce, result.ServerNonce)
263
        self.secure_channel_timeout = result.SecurityToken.RevisedLifetime
264
265
    async def close_secure_channel(self):
266
        return await self.uaclient.close_secure_channel()
267
268
    async def get_endpoints(self) -> list:
269
        params = ua.GetEndpointsParameters()
270
        params.EndpointUrl = self.server_url.geturl()
271
        return await self.uaclient.get_endpoints(params)
272
273
    async def register_server(self, server, discovery_configuration=None):
274
        """
275
        register a server to discovery server
276
        if discovery_configuration is provided, the newer register_server2 service call is used
277
        """
278
        serv = ua.RegisteredServer()
279
        serv.ServerUri = server.get_application_uri()
280
        serv.ProductUri = server.product_uri
281
        serv.DiscoveryUrls = [server.endpoint.geturl()]
282
        serv.ServerType = server.application_type
283
        serv.ServerNames = [ua.LocalizedText(server.name)]
284
        serv.IsOnline = True
285
        if discovery_configuration:
286
            params = ua.RegisterServer2Parameters()
287
            params.Server = serv
288
            params.DiscoveryConfiguration = discovery_configuration
289
            return await self.uaclient.register_server2(params)
290
        return await self.uaclient.register_server(serv)
291
292
    async def find_servers(self, uris=None):
293
        """
294
        send a FindServer request to the server. The answer should be a list of
295
        servers the server knows about
296
        A list of uris can be provided, only server having matching uris will be returned
297
        """
298
        if uris is None:
299
            uris = []
300
        params = ua.FindServersParameters()
301
        params.EndpointUrl = self.server_url.geturl()
302
        params.ServerUris = uris
303
        return await self.uaclient.find_servers(params)
304
305
    async def find_servers_on_network(self):
306
        params = ua.FindServersOnNetworkParameters()
307
        return await self.uaclient.find_servers_on_network(params)
308
309
    async def create_session(self):
310
        """
311
        send a CreateSessionRequest to server with reasonable parameters.
312
        If you want o modify settings look at code of this methods
313
        and make your own
314
        """
315
        desc = ua.ApplicationDescription()
316
        desc.ApplicationUri = self.application_uri
317
        desc.ProductUri = self.product_uri
318
        desc.ApplicationName = ua.LocalizedText(self.name)
319
        desc.ApplicationType = ua.ApplicationType.Client
320
        params = ua.CreateSessionParameters()
321
        # at least 32 random bytes for server to prove possession of private key (specs part 4, 5.6.2.2)
322
        nonce = create_nonce(32)
323
        params.ClientNonce = nonce
324
        params.ClientCertificate = self.security_policy.client_certificate
325
        params.ClientDescription = desc
326
        params.EndpointUrl = self.server_url.geturl()
327
        params.SessionName = f"{self.description} Session{self._session_counter}"
328
        # Requested maximum number of milliseconds that a Session should remain open without activity
329
        params.RequestedSessionTimeout = 60 * 60 * 1000
330
        params.MaxResponseMessageSize = 0  # means no max size
331
        response = await self.uaclient.create_session(params)
332
        if self.security_policy.client_certificate is None:
333
            data = nonce
334
        else:
335
            data = self.security_policy.client_certificate + nonce
336
        self.security_policy.asymmetric_cryptography.verify(data, response.ServerSignature.Signature)
337
        self._server_nonce = response.ServerNonce
338
        if not self.security_policy.server_certificate:
339
            self.security_policy.server_certificate = response.ServerCertificate
340
        elif self.security_policy.server_certificate != response.ServerCertificate:
341
            raise ua.UaError("Server certificate mismatch")
342
        # remember PolicyId's: we will use them in activate_session()
343
        ep = Client.find_endpoint(response.ServerEndpoints, self.security_policy.Mode, self.security_policy.URI)
344
        self._policy_ids = ep.UserIdentityTokens
345
        #  Actual maximum number of milliseconds that a Session shall remain open without activity
346
        self.session_timeout = response.RevisedSessionTimeout
347
        self._renew_channel_task = self.loop.create_task(self._renew_channel_loop())
348
        return response
349
350
    async def _renew_channel_loop(self):
351
        """
352
        Renew the SecureChannel before the SessionTimeout will happen.
353
        In theory we could do that only if no session activity
354
        but it does not cost much..
355
        """
356
        try:
357
            duration = min(self.session_timeout, self.secure_channel_timeout) * 0.7 * 0.001
358
            while True:
359
                # 0.7 is from spec. 0.001 is because asyncio.sleep expects time in seconds
360
                await asyncio.sleep(duration)
361
                self.logger.debug("renewing channel")
362
                await self.open_secure_channel(renew=True)
363
                val = await self.nodes.server_state.get_value()
364
                self.logger.debug("server state is: %s ", val)
365
        except asyncio.CancelledError:
366
            pass
367
368
    def server_policy_id(self, token_type, default):
369
        """
370
        Find PolicyId of server's UserTokenPolicy by token_type.
371
        Return default if there's no matching UserTokenPolicy.
372
        """
373
        for policy in self._policy_ids:
374
            if policy.TokenType == token_type:
375
                return policy.PolicyId
376
        return default
377
378
    def server_policy_uri(self, token_type):
379
        """
380
        Find SecurityPolicyUri of server's UserTokenPolicy by token_type.
381
        If SecurityPolicyUri is empty, use default SecurityPolicyUri
382
        of the endpoint
383
        """
384
        for policy in self._policy_ids:
385
            if policy.TokenType == token_type:
386
                if policy.SecurityPolicyUri:
387
                    return policy.SecurityPolicyUri
388
                # empty URI means "use this endpoint's policy URI"
389
                return self.security_policy.URI
390
        return self.security_policy.URI
391
392
    async def activate_session(self, username: str = None, password: str = None, certificate=None):
393
        """
394
        Activate session using either username and password or private_key
395
        """
396
        params = ua.ActivateSessionParameters()
397
        challenge = b""
398
        if self.security_policy.server_certificate is not None:
399
            challenge += self.security_policy.server_certificate
400
        if self._server_nonce is not None:
401
            challenge += self._server_nonce
402
        params.ClientSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
403
        params.ClientSignature.Signature = self.security_policy.asymmetric_cryptography.signature(challenge)
404
        params.LocaleIds.append("en")
405
        if not username and not certificate:
406
            self._add_anonymous_auth(params)
407
        elif certificate:
408
            self._add_certificate_auth(params, certificate, challenge)
409
        else:
410
            self._add_user_auth(params, username, password)
411
        return await self.uaclient.activate_session(params)
412
413
    def _add_anonymous_auth(self, params):
414
        params.UserIdentityToken = ua.AnonymousIdentityToken()
415
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Anonymous, "anonymous")
416
417
    def _add_certificate_auth(self, params, certificate, challenge):
418
        params.UserIdentityToken = ua.X509IdentityToken()
419
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, "certificate_basic256")
420
        params.UserIdentityToken.CertificateData = uacrypto.der_from_x509(certificate)
421
        # specs part 4, 5.6.3.1: the data to sign is created by appending
422
        # the last serverNonce to the serverCertificate
423
        sig = uacrypto.sign_sha1(self.user_private_key, challenge)
424
        params.UserTokenSignature = ua.SignatureData()
425
        params.UserTokenSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
426
        params.UserTokenSignature.Signature = sig
427
428
    def _add_user_auth(self, params, username: str, password: str):
429
        params.UserIdentityToken = ua.UserNameIdentityToken()
430
        params.UserIdentityToken.UserName = username
431
        policy_uri = self.server_policy_uri(ua.UserTokenType.UserName)
432
        if not policy_uri or policy_uri == security_policies.POLICY_NONE_URI:
433
            # see specs part 4, 7.36.3: if the token is NOT encrypted,
434
            # then the password only contains UTF-8 encoded password
435
            # and EncryptionAlgorithm is null
436
            if self._password:
437
                self.logger.warning("Sending plain-text password")
438
                params.UserIdentityToken.Password = password.encode("utf8")
439
            params.UserIdentityToken.EncryptionAlgorithm = None
440
        elif self._password:
441
            data, uri = self._encrypt_password(password, policy_uri)
442
            params.UserIdentityToken.Password = data
443
            params.UserIdentityToken.EncryptionAlgorithm = uri
444
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, "username_basic256")
445
446
    def _encrypt_password(self, password: str, policy_uri):
447
        pubkey = uacrypto.x509_from_der(self.security_policy.server_certificate).public_key()
448
        # see specs part 4, 7.36.3: if the token is encrypted, password
449
        # shall be converted to UTF-8 and serialized with server nonce
450
        passwd = password.encode("utf8")
451
        if self._server_nonce is not None:
452
            passwd += self._server_nonce
453
        etoken = ua.ua_binary.Primitives.Bytes.pack(passwd)
454
        data, uri = security_policies.encrypt_asymmetric(pubkey, etoken, policy_uri)
455
        return data, uri
456
457
    async def close_session(self) -> Coroutine:
458
        """
459
        Close session
460
        """
461
        self._renew_channel_task.cancel()
462
        await self._renew_channel_task
463
        return await self.uaclient.close_session(True)
464
465
    def get_root_node(self):
466
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.RootFolder))
467
468
    def get_objects_node(self):
469
        self.logger.info("get_objects_node")
470
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
471
472
    def get_server_node(self):
473
        return self.get_node(ua.FourByteNodeId(ua.ObjectIds.Server))
474
475
    def get_node(self, nodeid: Union[ua.NodeId, str]) -> Node:
476
        """
477
        Get node using NodeId object or a string representing a NodeId.
478
        """
479
        return Node(self.uaclient, nodeid)
480
481
    async def create_subscription(self, period, handler):
482
        """
483
        Create a subscription.
484
        Returns a Subscription object which allows to subscribe to events or data changes on server.
485
486
        :param period: Either a publishing interval in milliseconds or a `CreateSubscriptionParameters` instance.
487
        The second option should be used, if the asyncua-server has problems with the default options.
488
        :param handler: Class instance with data_change and/or event methods (see `SubHandler`
489
        base class for details). Remember not to block the main event loop inside the handler methods.
490
        """
491
        if isinstance(period, ua.CreateSubscriptionParameters):
492
            params = period
493
        else:
494
            params = ua.CreateSubscriptionParameters()
495
            params.RequestedPublishingInterval = period
496
            params.RequestedLifetimeCount = 10000
497
            params.RequestedMaxKeepAliveCount = 3000
498
            params.MaxNotificationsPerPublish = 10000
499
            params.PublishingEnabled = True
500
            params.Priority = 0
501
        subscription = Subscription(self.uaclient, params, handler)
502
        await subscription.init()
503
        return subscription
504
505
    def get_namespace_array(self) -> Coroutine:
506
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
507
        return ns_node.get_value()
508
509
    async def get_namespace_index(self, uri):
510
        uries = await self.get_namespace_array()
511
        _logger.info("get_namespace_index %s %r", type(uries), uries)
512
        return uries.index(uri)
513
514
    def delete_nodes(self, nodes, recursive=False) -> Coroutine:
515
        return delete_nodes(self.uaclient, nodes, recursive)
516
517
    def import_xml(self, path=None, xmlstring=None) -> Coroutine:
518
        """
519
        Import nodes defined in xml
520
        """
521
        importer = XmlImporter(self)
522
        return importer.import_xml(path, xmlstring)
523
524
    async def export_xml(self, nodes, path):
525
        """
526
        Export defined nodes to xml
527
        """
528
        exp = XmlExporter(self)
529
        await exp.build_etree(nodes)
530
        await exp.write_xml(path)
531
532 View Code Duplication
    async def register_namespace(self, uri):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
533
        """
534
        Register a new namespace. Nodes should in custom namespace, not 0.
535
        This method is mainly implemented for symetry with server
536
        """
537
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
538
        uries = await ns_node.get_value()
539
        if uri in uries:
540
            return uries.index(uri)
541
        uries.append(uri)
542
        await ns_node.set_value(uries)
543
        return len(uries) - 1
544
545
    def load_type_definitions(self, nodes=None) -> Coroutine:
546
        """
547
        Load custom types (custom structures/extension objects) definition from server
548
        Generate Python classes for custom structures/extension objects defined in server
549
        These classes will available in ua module
550
        """
551
        return load_type_definitions(self, nodes)
552
553
    def load_enums(self) -> Coroutine:
554
        """
555
        generate Python enums for custom enums on server.
556
        This enums will be available in ua module
557
        """
558
        return load_enums(self)
559
560
    async def register_nodes(self, nodes):
561
        """
562
        Register nodes for faster read and write access (if supported by server)
563
        Rmw: This call modifies the nodeid of the nodes, the original nodeid is
564
        available as node.basenodeid
565
        """
566
        nodeids = [node.nodeid for node in nodes]
567
        nodeids = await self.uaclient.register_nodes(nodeids)
568
        for node, nodeid in zip(nodes, nodeids):
569
            node.basenodeid = node.nodeid
570
            node.nodeid = nodeid
571
        return nodes
572
573
    async def unregister_nodes(self, nodes):
574
        """
575
        Unregister nodes
576
        """
577
        nodeids = [node.nodeid for node in nodes]
578
        await self.uaclient.unregister_nodes(nodeids)
579
        for node in nodes:
580
            if not node.basenodeid:
581
                continue
582
            node.nodeid = node.basenodeid
583
            node.basenodeid = None
584
585
    async def get_values(self, nodes):
586
        """
587
        Read the value of multiple nodes in one roundtrip.
588
        """
589
        nodes = [node.nodeid for node in nodes]
590
        results = await self.uaclient.get_attribute(nodes, ua.AttributeIds.Value)
591
        return [result.Value.Value for result in results]
592