Completed
Pull Request — master (#76)
by Olivier
02:32
created

asyncua.client.client.Client.set_values()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import asyncio
2
import logging
3
from typing import Union, Coroutine
4
from urllib.parse import urlparse
5
6
from asyncua import ua
7
from .ua_client import UaClient
8
from ..common.xmlimporter import XmlImporter
9
from ..common.xmlexporter import XmlExporter
10
from ..common.node import Node
11
from ..common.manage_nodes import delete_nodes
12
from ..common.subscription import Subscription
13
from ..common.shortcuts import Shortcuts
14
from ..common.structures import load_type_definitions, load_enums
15
from ..common.utils import create_nonce
16
from ..common.ua_utils import value_to_datavalue
17
from ..crypto import uacrypto, security_policies
18
19
_logger = logging.getLogger(__name__)
20
21
22
class Client:
23
    """
24
    High level client to connect to an OPC-UA server.
25
26
    This class makes it easy to connect and browse address space.
27
    It attempts to expose as much functionality as possible
28
    but if you want more flexibility it is possible and advised to
29
    use UaClient object, available as self.uaclient
30
    which offers the raw OPC-UA services interface.
31
    """
32
33
    def __init__(self, url: str, timeout: int = 4, loop=None):
34
        """
35
        :param url: url of the server.
36
            if you are unsure of url, write at least hostname
37
            and port and call get_endpoints
38
39
        :param timeout:
40
            Each request sent to the server expects an answer within this
41
            time. The timeout is specified in seconds.
42
        """
43
        self.logger = logging.getLogger(__name__)
44
        self.loop = loop or asyncio.get_event_loop()
45
        self.server_url = urlparse(url)
46
        # take initial username and password from the url
47
        self._username = self.server_url.username
48
        self._password = self.server_url.password
49
        self.name = "Pure Python Async. Client"
50
        self.description = self.name
51
        self.application_uri = "urn:freeopcua:client"
52
        self.product_uri = "urn:freeopcua.github.io:client"
53
        self.security_policy = ua.SecurityPolicy()
54
        self.secure_channel_id = None
55
        self.secure_channel_timeout = 3600000  # 1 hour
56
        self.session_timeout = 3600000  # 1 hour
57
        self._policy_ids = []
58
        self.uaclient: UaClient = UaClient(timeout, loop=self.loop)
59
        self.user_certificate = None
60
        self.user_private_key = None
61
        self._server_nonce = None
62
        self._session_counter = 1
63
        self.nodes = Shortcuts(self.uaclient)
64
        self.max_messagesize = 0  # No limits
65
        self.max_chunkcount = 0  # No limits
66
        self._renew_channel_task = None
67
68
    async def __aenter__(self):
69
        await self.connect()
70
        return self
71
72
    async def __aexit__(self, exc_type, exc_value, traceback):
73
        await self.disconnect()
74
75
    @staticmethod
76
    def find_endpoint(endpoints, security_mode, policy_uri):
77
        """
78
        Find endpoint with required security mode and policy URI
79
        """
80
        _logger.info("find_endpoint %r %r %r", endpoints, security_mode, policy_uri)
81
        for ep in endpoints:
82
            if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and
83
                    ep.SecurityMode == security_mode and
84
                    ep.SecurityPolicyUri == policy_uri):
85
                return ep
86
        raise ua.UaError("No matching endpoints: {0}, {1}".format(security_mode, policy_uri))
87
88
    def set_user(self, username: str):
89
        """
90
        Set user name for the connection.
91
        initial user from the URL will be overwritten
92
        """
93
        self._username = username
94
95
    def set_password(self, pwd: str):
96
        """
97
        Set user password for the connection.
98
        initial password from the URL will be overwritten
99
        """
100
        if not isinstance(pwd, str):
101
            raise TypeError(f"Password must be a string, got {pwd} of type {type(pwd)}")
102
        self._password = pwd
103
104
    async def set_security_string(self, string: str):
105
        """
106
        Set SecureConnection mode. String format:
107
        Policy,Mode,certificate,private_key[,server_private_key]
108
        where Policy is Basic128Rsa15, Basic256 or Basic256Sha256,
109
            Mode is Sign or SignAndEncrypt
110
            certificate, private_key and server_private_key are
111
                paths to .pem or .der files
112
        Call this before connect()
113
        """
114
        if not string:
115
            return
116
        parts = string.split(",")
117
        if len(parts) < 4:
118
            raise ua.UaError("Wrong format: `{}`, expected at least 4 comma-separated values".format(string))
119
        policy_class = getattr(security_policies, "SecurityPolicy{}".format(parts[0]))
120
        mode = getattr(ua.MessageSecurityMode, parts[1])
121
        return await self.set_security(
122
            policy_class, parts[2], parts[3], parts[4] if len(parts) >= 5 else None, mode
123
        )
124
125
    async def set_security(self, policy, certificate_path: str, private_key_path: str,
126
                           server_certificate_path: str = None,
127
                           mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt):
128
        """
129
        Set SecureConnection mode.
130
        Call this before connect()
131
        """
132
        if server_certificate_path is None:
133
            # load certificate from server's list of endpoints
134
            endpoints = await self.connect_and_get_server_endpoints()
135
            endpoint = Client.find_endpoint(endpoints, mode, policy.URI)
136
            server_cert = uacrypto.x509_from_der(endpoint.ServerCertificate)
137
        else:
138
            server_cert = await uacrypto.load_certificate(server_certificate_path)
139
        cert = await uacrypto.load_certificate(certificate_path)
140
        pk = await uacrypto.load_private_key(private_key_path)
141
        self.security_policy = policy(server_cert, cert, pk, mode)
142
        self.uaclient.set_security(self.security_policy)
143
144
    async def load_client_certificate(self, path: str):
145
        """
146
        load our certificate from file, either pem or der
147
        """
148
        self.user_certificate = await uacrypto.load_certificate(path)
149
150
    async def load_private_key(self, path: str):
151
        """
152
        Load user private key. This is used for authenticating using certificate
153
        """
154
        self.user_private_key = await uacrypto.load_private_key(path)
155
156
    async def connect_and_get_server_endpoints(self):
157
        """
158
        Connect, ask server for endpoints, and disconnect
159
        """
160
        await self.connect_socket()
161
        try:
162
            await self.send_hello()
163
            await self.open_secure_channel()
164
            endpoints = await self.get_endpoints()
165
            await self.close_secure_channel()
166
        finally:
167
            self.disconnect_socket()
168
        return endpoints
169
170
    async def connect_and_find_servers(self):
171
        """
172
        Connect, ask server for a list of known servers, and disconnect
173
        """
174
        await self.connect_socket()
175
        try:
176
            await self.send_hello()
177
            await self.open_secure_channel()  # spec says it should not be necessary to open channel
178
            servers = await self.find_servers()
179
            await self.close_secure_channel()
180
        finally:
181
            self.disconnect_socket()
182
        return servers
183
184
    async def connect_and_find_servers_on_network(self):
185
        """
186
        Connect, ask server for a list of known servers on network, and disconnect
187
        """
188
        await self.connect_socket()
189
        try:
190
            await self.send_hello()
191
            await self.open_secure_channel()
192
            servers = await self.find_servers_on_network()
193
            await self.close_secure_channel()
194
        finally:
195
            self.disconnect_socket()
196
        return servers
197
198
    async def connect(self):
199
        """
200
        High level method
201
        Connect, create and activate session
202
        """
203
        _logger.info("connect")
204
        await self.connect_socket()
205
        try:
206
            await self.send_hello()
207
            await self.open_secure_channel()
208
            await self.create_session()
209
        except Exception:
210
            # clean up open socket
211
            self.disconnect_socket()
212
            raise
213
        await self.activate_session(username=self._username, password=self._password, certificate=self.user_certificate)
214
215
    async def disconnect(self):
216
        """
217
        High level method
218
        Close session, secure channel and socket
219
        """
220
        _logger.info("disconnect")
221
        try:
222
            await self.close_session()
223
            await self.close_secure_channel()
224
        finally:
225
            self.disconnect_socket()
226
227
    async def connect_socket(self):
228
        """
229
        connect to socket defined in url
230
        """
231
        await self.uaclient.connect_socket(self.server_url.hostname, self.server_url.port)
232
233
    def disconnect_socket(self):
234
        self.uaclient.disconnect_socket()
235
236
    async def send_hello(self):
237
        """
238
        Send OPC-UA hello to server
239
        """
240
        ack = await self.uaclient.send_hello(self.server_url.geturl(), self.max_messagesize, self.max_chunkcount)
241
        if isinstance(ack, ua.UaStatusCodeError):
242
            raise ack
243
244
    async def open_secure_channel(self, renew=False):
245
        """
246
        Open secure channel, if renew is True, renew channel
247
        """
248
        params = ua.OpenSecureChannelParameters()
249
        params.ClientProtocolVersion = 0
250
        params.RequestType = ua.SecurityTokenRequestType.Issue
251
        if renew:
252
            params.RequestType = ua.SecurityTokenRequestType.Renew
253
        params.SecurityMode = self.security_policy.Mode
254
        params.RequestedLifetime = self.secure_channel_timeout
255
        # length should be equal to the length of key of symmetric encryption
256
        nonce = create_nonce(self.security_policy.symmetric_key_size)
257
        params.ClientNonce = nonce  # this nonce is used to create a symmetric key
258
        result = await self.uaclient.open_secure_channel(params)
259
        self.security_policy.make_symmetric_key(nonce, result.ServerNonce)
260
        self.secure_channel_timeout = result.SecurityToken.RevisedLifetime
261
262
    async def close_secure_channel(self):
263
        return await self.uaclient.close_secure_channel()
264
265
    async def get_endpoints(self) -> list:
266
        params = ua.GetEndpointsParameters()
267
        params.EndpointUrl = self.server_url.geturl()
268
        return await self.uaclient.get_endpoints(params)
269
270
    async def register_server(self, server, discovery_configuration=None):
271
        """
272
        register a server to discovery server
273
        if discovery_configuration is provided, the newer register_server2 service call is used
274
        """
275
        serv = ua.RegisteredServer()
276
        serv.ServerUri = server.get_application_uri()
277
        serv.ProductUri = server.product_uri
278
        serv.DiscoveryUrls = [server.endpoint.geturl()]
279
        serv.ServerType = server.application_type
280
        serv.ServerNames = [ua.LocalizedText(server.name)]
281
        serv.IsOnline = True
282
        if discovery_configuration:
283
            params = ua.RegisterServer2Parameters()
284
            params.Server = serv
285
            params.DiscoveryConfiguration = discovery_configuration
286
            return await self.uaclient.register_server2(params)
287
        return await self.uaclient.register_server(serv)
288
289
    async def find_servers(self, uris=None):
290
        """
291
        send a FindServer request to the server. The answer should be a list of
292
        servers the server knows about
293
        A list of uris can be provided, only server having matching uris will be returned
294
        """
295
        if uris is None:
296
            uris = []
297
        params = ua.FindServersParameters()
298
        params.EndpointUrl = self.server_url.geturl()
299
        params.ServerUris = uris
300
        return await self.uaclient.find_servers(params)
301
302
    async def find_servers_on_network(self):
303
        params = ua.FindServersOnNetworkParameters()
304
        return await self.uaclient.find_servers_on_network(params)
305
306
    async def create_session(self):
307
        """
308
        send a CreateSessionRequest to server with reasonable parameters.
309
        If you want o modify settings look at code of this methods
310
        and make your own
311
        """
312
        desc = ua.ApplicationDescription()
313
        desc.ApplicationUri = self.application_uri
314
        desc.ProductUri = self.product_uri
315
        desc.ApplicationName = ua.LocalizedText(self.name)
316
        desc.ApplicationType = ua.ApplicationType.Client
317
        params = ua.CreateSessionParameters()
318
        # at least 32 random bytes for server to prove possession of private key (specs part 4, 5.6.2.2)
319
        nonce = create_nonce(32)
320
        params.ClientNonce = nonce
321
        params.ClientCertificate = self.security_policy.client_certificate
322
        params.ClientDescription = desc
323
        params.EndpointUrl = self.server_url.geturl()
324
        params.SessionName = f"{self.description} Session{self._session_counter}"
325
        # Requested maximum number of milliseconds that a Session should remain open without activity
326
        params.RequestedSessionTimeout = 60 * 60 * 1000
327
        params.MaxResponseMessageSize = 0  # means no max size
328
        response = await self.uaclient.create_session(params)
329
        if self.security_policy.client_certificate is None:
330
            data = nonce
331
        else:
332
            data = self.security_policy.client_certificate + nonce
333
        self.security_policy.asymmetric_cryptography.verify(data, response.ServerSignature.Signature)
334
        self._server_nonce = response.ServerNonce
335
        if not self.security_policy.server_certificate:
336
            self.security_policy.server_certificate = response.ServerCertificate
337
        elif self.security_policy.server_certificate != response.ServerCertificate:
338
            raise ua.UaError("Server certificate mismatch")
339
        # remember PolicyId's: we will use them in activate_session()
340
        ep = Client.find_endpoint(response.ServerEndpoints, self.security_policy.Mode, self.security_policy.URI)
341
        self._policy_ids = ep.UserIdentityTokens
342
        #  Actual maximum number of milliseconds that a Session shall remain open without activity
343
        self.session_timeout = response.RevisedSessionTimeout
344
        self._renew_channel_task = self.loop.create_task(self._renew_channel_loop())
345
        return response
346
347
    async def _renew_channel_loop(self):
348
        """
349
        Renew the SecureChannel before the SessionTimeout will happen.
350
        In theory we could do that only if no session activity
351
        but it does not cost much..
352
        """
353
        try:
354
            duration = min(self.session_timeout, self.secure_channel_timeout) * 0.7 * 0.001
355
            while True:
356
                # 0.7 is from spec. 0.001 is because asyncio.sleep expects time in seconds
357
                await asyncio.sleep(duration)
358
                self.logger.debug("renewing channel")
359
                await self.open_secure_channel(renew=True)
360
                val = await self.nodes.server_state.get_value()
361
                self.logger.debug("server state is: %s ", val)
362
        except asyncio.CancelledError:
363
            pass
364
365
    def server_policy_id(self, token_type, default):
366
        """
367
        Find PolicyId of server's UserTokenPolicy by token_type.
368
        Return default if there's no matching UserTokenPolicy.
369
        """
370
        for policy in self._policy_ids:
371
            if policy.TokenType == token_type:
372
                return policy.PolicyId
373
        return default
374
375
    def server_policy_uri(self, token_type):
376
        """
377
        Find SecurityPolicyUri of server's UserTokenPolicy by token_type.
378
        If SecurityPolicyUri is empty, use default SecurityPolicyUri
379
        of the endpoint
380
        """
381
        for policy in self._policy_ids:
382
            if policy.TokenType == token_type:
383
                if policy.SecurityPolicyUri:
384
                    return policy.SecurityPolicyUri
385
                # empty URI means "use this endpoint's policy URI"
386
                return self.security_policy.URI
387
        return self.security_policy.URI
388
389
    async def activate_session(self, username: str = None, password: str = None, certificate=None):
390
        """
391
        Activate session using either username and password or private_key
392
        """
393
        params = ua.ActivateSessionParameters()
394
        challenge = b""
395
        if self.security_policy.server_certificate is not None:
396
            challenge += self.security_policy.server_certificate
397
        if self._server_nonce is not None:
398
            challenge += self._server_nonce
399
        params.ClientSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
400
        params.ClientSignature.Signature = self.security_policy.asymmetric_cryptography.signature(challenge)
401
        params.LocaleIds.append("en")
402
        if not username and not certificate:
403
            self._add_anonymous_auth(params)
404
        elif certificate:
405
            self._add_certificate_auth(params, certificate, challenge)
406
        else:
407
            self._add_user_auth(params, username, password)
408
        return await self.uaclient.activate_session(params)
409
410
    def _add_anonymous_auth(self, params):
411
        params.UserIdentityToken = ua.AnonymousIdentityToken()
412
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Anonymous, "anonymous")
413
414
    def _add_certificate_auth(self, params, certificate, challenge):
415
        params.UserIdentityToken = ua.X509IdentityToken()
416
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, "certificate_basic256")
417
        params.UserIdentityToken.CertificateData = uacrypto.der_from_x509(certificate)
418
        # specs part 4, 5.6.3.1: the data to sign is created by appending
419
        # the last serverNonce to the serverCertificate
420
        sig = uacrypto.sign_sha1(self.user_private_key, challenge)
421
        params.UserTokenSignature = ua.SignatureData()
422
        params.UserTokenSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
423
        params.UserTokenSignature.Signature = sig
424
425
    def _add_user_auth(self, params, username: str, password: str):
426
        params.UserIdentityToken = ua.UserNameIdentityToken()
427
        params.UserIdentityToken.UserName = username
428
        policy_uri = self.server_policy_uri(ua.UserTokenType.UserName)
429
        if not policy_uri or policy_uri == security_policies.POLICY_NONE_URI:
430
            # see specs part 4, 7.36.3: if the token is NOT encrypted,
431
            # then the password only contains UTF-8 encoded password
432
            # and EncryptionAlgorithm is null
433
            if self._password:
434
                self.logger.warning("Sending plain-text password")
435
                params.UserIdentityToken.Password = password.encode("utf8")
436
            params.UserIdentityToken.EncryptionAlgorithm = None
437
        elif self._password:
438
            data, uri = self._encrypt_password(password, policy_uri)
439
            params.UserIdentityToken.Password = data
440
            params.UserIdentityToken.EncryptionAlgorithm = uri
441
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, "username_basic256")
442
443
    def _encrypt_password(self, password: str, policy_uri):
444
        pubkey = uacrypto.x509_from_der(self.security_policy.server_certificate).public_key()
445
        # see specs part 4, 7.36.3: if the token is encrypted, password
446
        # shall be converted to UTF-8 and serialized with server nonce
447
        passwd = password.encode("utf8")
448
        if self._server_nonce is not None:
449
            passwd += self._server_nonce
450
        etoken = ua.ua_binary.Primitives.Bytes.pack(passwd)
451
        data, uri = security_policies.encrypt_asymmetric(pubkey, etoken, policy_uri)
452
        return data, uri
453
454
    async def close_session(self) -> Coroutine:
455
        """
456
        Close session
457
        """
458
        self._renew_channel_task.cancel()
459
        await self._renew_channel_task
460
        return await self.uaclient.close_session(True)
461
462
    def get_root_node(self):
463
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.RootFolder))
464
465
    def get_objects_node(self):
466
        self.logger.info("get_objects_node")
467
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
468
469
    def get_server_node(self):
470
        return self.get_node(ua.FourByteNodeId(ua.ObjectIds.Server))
471
472
    def get_node(self, nodeid: Union[ua.NodeId, str]) -> Node:
473
        """
474
        Get node using NodeId object or a string representing a NodeId.
475
        """
476
        return Node(self.uaclient, nodeid)
477
478
    async def create_subscription(self, period, handler):
479
        """
480
        Create a subscription.
481
        Returns a Subscription object which allows to subscribe to events or data changes on server.
482
483
        :param period: Either a publishing interval in milliseconds or a `CreateSubscriptionParameters` instance.
484
        The second option should be used, if the asyncua-server has problems with the default options.
485
        :param handler: Class instance with data_change and/or event methods (see `SubHandler`
486
        base class for details). Remember not to block the main event loop inside the handler methods.
487
        """
488
        if isinstance(period, ua.CreateSubscriptionParameters):
489
            params = period
490
        else:
491
            params = ua.CreateSubscriptionParameters()
492
            params.RequestedPublishingInterval = period
493
            params.RequestedLifetimeCount = 10000
494
            params.RequestedMaxKeepAliveCount = 3000
495
            params.MaxNotificationsPerPublish = 10000
496
            params.PublishingEnabled = True
497
            params.Priority = 0
498
        subscription = Subscription(self.uaclient, params, handler)
499
        await subscription.init()
500
        return subscription
501
502
    def get_namespace_array(self) -> Coroutine:
503
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
504
        return ns_node.get_value()
505
506
    async def get_namespace_index(self, uri):
507
        uries = await self.get_namespace_array()
508
        _logger.info("get_namespace_index %s %r", type(uries), uries)
509
        return uries.index(uri)
510
511
    def delete_nodes(self, nodes, recursive=False) -> Coroutine:
512
        return delete_nodes(self.uaclient, nodes, recursive)
513
514
    def import_xml(self, path=None, xmlstring=None) -> Coroutine:
515
        """
516
        Import nodes defined in xml
517
        """
518
        importer = XmlImporter(self)
519
        return importer.import_xml(path, xmlstring)
520
521
    async def export_xml(self, nodes, path):
522
        """
523
        Export defined nodes to xml
524
        """
525
        exp = XmlExporter(self)
526
        await exp.build_etree(nodes)
527
        await exp.write_xml(path)
528
529 View Code Duplication
    async def register_namespace(self, uri):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
530
        """
531
        Register a new namespace. Nodes should in custom namespace, not 0.
532
        This method is mainly implemented for symetry with server
533
        """
534
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
535
        uries = await ns_node.get_value()
536
        if uri in uries:
537
            return uries.index(uri)
538
        uries.append(uri)
539
        await ns_node.set_value(uries)
540
        return len(uries) - 1
541
542
    def load_type_definitions(self, nodes=None) -> Coroutine:
543
        """
544
        Load custom types (custom structures/extension objects) definition from server
545
        Generate Python classes for custom structures/extension objects defined in server
546
        These classes will available in ua module
547
        """
548
        return load_type_definitions(self, nodes)
549
550
    def load_enums(self) -> Coroutine:
551
        """
552
        generate Python enums for custom enums on server.
553
        This enums will be available in ua module
554
        """
555
        return load_enums(self)
556
557
    async def register_nodes(self, nodes):
558
        """
559
        Register nodes for faster read and write access (if supported by server)
560
        Rmw: This call modifies the nodeid of the nodes, the original nodeid is
561
        available as node.basenodeid
562
        """
563
        nodeids = [node.nodeid for node in nodes]
564
        nodeids = await self.uaclient.register_nodes(nodeids)
565
        for node, nodeid in zip(nodes, nodeids):
566
            node.basenodeid = node.nodeid
567
            node.nodeid = nodeid
568
        return nodes
569
570
    async def unregister_nodes(self, nodes):
571
        """
572
        Unregister nodes
573
        """
574
        nodeids = [node.nodeid for node in nodes]
575
        await self.uaclient.unregister_nodes(nodeids)
576
        for node in nodes:
577
            if not node.basenodeid:
578
                continue
579
            node.nodeid = node.basenodeid
580
            node.basenodeid = None
581
582
    async def get_values(self, nodes):
583
        """
584
        Read the value of multiple nodes in one ua call.
585
        """
586
        nodeids = [node.nodeid for node in nodes]
587
        results = await self.uaclient.get_attributes(nodeids, ua.AttributeIds.Value)
588
        return [result.Value.Value for result in results]
589
590
    async def set_values(self, nodes, values):
591
        """
592
        Write values to multiple nodes in one ua call
593
        """
594
        nodeids = [node.nodeid for node in nodes]
595
        dvs = [value_to_datavalue(val) for val in values]
596
        results = await self.uaclient.set_attributes(nodeids, dvs, ua.AttributeIds.Value)
597
        return [result.check() for result in results]
598