Passed
Pull Request — master (#120)
by Olivier
02:31
created

asyncua.client.client.Client.write_values()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 3
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import asyncio
2
import logging
3
from typing import Union, Coroutine
4
from urllib.parse import urlparse
5
6
from asyncua import ua
7
from .ua_client import UaClient
8
from ..common.xmlimporter import XmlImporter
9
from ..common.xmlexporter import XmlExporter
10
from ..common.node import Node
11
from ..common.manage_nodes import delete_nodes
12
from ..common.subscription import Subscription
13
from ..common.shortcuts import Shortcuts
14
from ..common.structures import load_type_definitions, load_enums
15
from ..common.utils import create_nonce
16
from ..common.ua_utils import value_to_datavalue
17
from ..crypto import uacrypto, security_policies
18
19
_logger = logging.getLogger(__name__)
20
21
22
class Client:
23
    """
24
    High level client to connect to an OPC-UA server.
25
26
    This class makes it easy to connect and browse address space.
27
    It attempts to expose as much functionality as possible
28
    but if you want more flexibility it is possible and advised to
29
    use UaClient object, available as self.uaclient
30
    which offers the raw OPC-UA services interface.
31
    """
32
    def __init__(self, url: str, timeout: int = 4, loop=None):
33
        """
34
        :param url: url of the server.
35
            if you are unsure of url, write at least hostname
36
            and port and call get_endpoints
37
38
        :param timeout:
39
            Each request sent to the server expects an answer within this
40
            time. The timeout is specified in seconds.
41
        """
42
        _logger = logging.getLogger(__name__)
43
        self.loop = loop or asyncio.get_event_loop()
44
        self.server_url = urlparse(url)
45
        # take initial username and password from the url
46
        self._username = self.server_url.username
47
        self._password = self.server_url.password
48
        self.name = "Pure Python Async. Client"
49
        self.description = self.name
50
        self.application_uri = "urn:freeopcua:client"
51
        self.product_uri = "urn:freeopcua.github.io:client"
52
        self.security_policy = ua.SecurityPolicy()
53
        self.secure_channel_id = None
54
        self.secure_channel_timeout = 3600000  # 1 hour
55
        self.session_timeout = 3600000  # 1 hour
56
        self._policy_ids = []
57
        self.uaclient: UaClient = UaClient(timeout, loop=self.loop)
58
        self.user_certificate = None
59
        self.user_private_key = None
60
        self._server_nonce = None
61
        self._session_counter = 1
62
        self.nodes = Shortcuts(self.uaclient)
63
        self.max_messagesize = 0  # No limits
64
        self.max_chunkcount = 0  # No limits
65
        self._renew_channel_task = None
66
67
    async def __aenter__(self):
68
        await self.connect()
69
        return self
70
71
    async def __aexit__(self, exc_type, exc_value, traceback):
72
        await self.disconnect()
73
74
    def __str__(self):
75
        return f"Client({self.server_url.geturl()})"
76
    __repr__ = __str__
77
78
    @staticmethod
79
    def find_endpoint(endpoints, security_mode, policy_uri):
80
        """
81
        Find endpoint with required security mode and policy URI
82
        """
83
        _logger.info("find_endpoint %r %r %r", endpoints, security_mode, policy_uri)
84
        for ep in endpoints:
85
            if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == security_mode and ep.SecurityPolicyUri == policy_uri):
86
                return ep
87
        raise ua.UaError("No matching endpoints: {0}, {1}".format(security_mode, policy_uri))
88
89
    def set_user(self, username: str):
90
        """
91
        Set user name for the connection.
92
        initial user from the URL will be overwritten
93
        """
94
        self._username = username
95
96
    def set_password(self, pwd: str):
97
        """
98
        Set user password for the connection.
99
        initial password from the URL will be overwritten
100
        """
101
        if not isinstance(pwd, str):
102
            raise TypeError(f"Password must be a string, got {pwd} of type {type(pwd)}")
103
        self._password = pwd
104
105
    async def set_security_string(self, string: str):
106
        """
107
        Set SecureConnection mode. String format:
108
        Policy,Mode,certificate,private_key[,server_private_key]
109
        where Policy is Basic128Rsa15, Basic256 or Basic256Sha256,
110
            Mode is Sign or SignAndEncrypt
111
            certificate, private_key and server_private_key are
112
                paths to .pem or .der files
113
        Call this before connect()
114
        """
115
        if not string:
116
            return
117
        parts = string.split(",")
118
        if len(parts) < 4:
119
            raise ua.UaError("Wrong format: `{}`, expected at least 4 comma-separated values".format(string))
120
        policy_class = getattr(security_policies, "SecurityPolicy{}".format(parts[0]))
121
        mode = getattr(ua.MessageSecurityMode, parts[1])
122
        return await self.set_security(policy_class, parts[2], parts[3], parts[4] if len(parts) >= 5 else None, mode)
123
124
    async def set_security(self,
125
                           policy,
126
                           certificate_path: str,
127
                           private_key_path: str,
128
                           server_certificate_path: str = None,
129
                           mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt):
130
        """
131
        Set SecureConnection mode.
132
        Call this before connect()
133
        """
134
        if server_certificate_path is None:
135
            # load certificate from server's list of endpoints
136
            endpoints = await self.connect_and_get_server_endpoints()
137
            endpoint = Client.find_endpoint(endpoints, mode, policy.URI)
138
            server_cert = uacrypto.x509_from_der(endpoint.ServerCertificate)
139
        else:
140
            server_cert = await uacrypto.load_certificate(server_certificate_path)
141
        cert = await uacrypto.load_certificate(certificate_path)
142
        pk = await uacrypto.load_private_key(private_key_path)
143
        self.security_policy = policy(server_cert, cert, pk, mode)
144
        self.uaclient.set_security(self.security_policy)
145
146
    async def load_client_certificate(self, path: str):
147
        """
148
        load our certificate from file, either pem or der
149
        """
150
        self.user_certificate = await uacrypto.load_certificate(path)
151
152
    async def load_private_key(self, path: str):
153
        """
154
        Load user private key. This is used for authenticating using certificate
155
        """
156
        self.user_private_key = await uacrypto.load_private_key(path)
157
158
    async def connect_and_get_server_endpoints(self):
159
        """
160
        Connect, ask server for endpoints, and disconnect
161
        """
162
        await self.connect_socket()
163
        try:
164
            await self.send_hello()
165
            await self.open_secure_channel()
166
            endpoints = await self.get_endpoints()
167
            await self.close_secure_channel()
168
        finally:
169
            self.disconnect_socket()
170
        return endpoints
171
172
    async def connect_and_find_servers(self):
173
        """
174
        Connect, ask server for a list of known servers, and disconnect
175
        """
176
        await self.connect_socket()
177
        try:
178
            await self.send_hello()
179
            await self.open_secure_channel()  # spec says it should not be necessary to open channel
180
            servers = await self.find_servers()
181
            await self.close_secure_channel()
182
        finally:
183
            self.disconnect_socket()
184
        return servers
185
186
    async def connect_and_find_servers_on_network(self):
187
        """
188
        Connect, ask server for a list of known servers on network, and disconnect
189
        """
190
        await self.connect_socket()
191
        try:
192
            await self.send_hello()
193
            await self.open_secure_channel()
194
            servers = await self.find_servers_on_network()
195
            await self.close_secure_channel()
196
        finally:
197
            self.disconnect_socket()
198
        return servers
199
200
    async def connect(self):
201
        """
202
        High level method
203
        Connect, create and activate session
204
        """
205
        _logger.info("connect")
206
        await self.connect_socket()
207
        try:
208
            await self.send_hello()
209
            await self.open_secure_channel()
210
            await self.create_session()
211
        except Exception:
212
            # clean up open socket
213
            self.disconnect_socket()
214
            raise
215
        await self.activate_session(username=self._username, password=self._password, certificate=self.user_certificate)
216
217
    async def disconnect(self):
218
        """
219
        High level method
220
        Close session, secure channel and socket
221
        """
222
        _logger.info("disconnect")
223
        try:
224
            await self.close_session()
225
            await self.close_secure_channel()
226
        finally:
227
            self.disconnect_socket()
228
229
    async def connect_socket(self):
230
        """
231
        connect to socket defined in url
232
        """
233
        await self.uaclient.connect_socket(self.server_url.hostname, self.server_url.port)
234
235
    def disconnect_socket(self):
236
        self.uaclient.disconnect_socket()
237
238
    async def send_hello(self):
239
        """
240
        Send OPC-UA hello to server
241
        """
242
        ack = await self.uaclient.send_hello(self.server_url.geturl(), self.max_messagesize, self.max_chunkcount)
243
        if isinstance(ack, ua.UaStatusCodeError):
244
            raise ack
245
246
    async def open_secure_channel(self, renew=False):
247
        """
248
        Open secure channel, if renew is True, renew channel
249
        """
250
        params = ua.OpenSecureChannelParameters()
251
        params.ClientProtocolVersion = 0
252
        params.RequestType = ua.SecurityTokenRequestType.Issue
253
        if renew:
254
            params.RequestType = ua.SecurityTokenRequestType.Renew
255
        params.SecurityMode = self.security_policy.Mode
256
        params.RequestedLifetime = self.secure_channel_timeout
257
        # length should be equal to the length of key of symmetric encryption
258
        params.ClientNonce = create_nonce(self.security_policy.symmetric_key_size)
259
        result = await self.uaclient.open_secure_channel(params)
260
        self.secure_channel_timeout = result.SecurityToken.RevisedLifetime
261
262
    async def close_secure_channel(self):
263
        return await self.uaclient.close_secure_channel()
264
265
    async def get_endpoints(self) -> list:
266
        params = ua.GetEndpointsParameters()
267
        params.EndpointUrl = self.server_url.geturl()
268
        return await self.uaclient.get_endpoints(params)
269
270
    async def register_server(self, server, discovery_configuration=None):
271
        """
272
        register a server to discovery server
273
        if discovery_configuration is provided, the newer register_server2 service call is used
274
        """
275
        serv = ua.RegisteredServer()
276
        serv.ServerUri = server.get_application_uri()
277
        serv.ProductUri = server.product_uri
278
        serv.DiscoveryUrls = [server.endpoint.geturl()]
279
        serv.ServerType = server.application_type
280
        serv.ServerNames = [ua.LocalizedText(server.name)]
281
        serv.IsOnline = True
282
        if discovery_configuration:
283
            params = ua.RegisterServer2Parameters()
284
            params.Server = serv
285
            params.DiscoveryConfiguration = discovery_configuration
286
            return await self.uaclient.register_server2(params)
287
        return await self.uaclient.register_server(serv)
288
289
    async def find_servers(self, uris=None):
290
        """
291
        send a FindServer request to the server. The answer should be a list of
292
        servers the server knows about
293
        A list of uris can be provided, only server having matching uris will be returned
294
        """
295
        if uris is None:
296
            uris = []
297
        params = ua.FindServersParameters()
298
        params.EndpointUrl = self.server_url.geturl()
299
        params.ServerUris = uris
300
        return await self.uaclient.find_servers(params)
301
302
    async def find_servers_on_network(self):
303
        params = ua.FindServersOnNetworkParameters()
304
        return await self.uaclient.find_servers_on_network(params)
305
306
    async def create_session(self):
307
        """
308
        send a CreateSessionRequest to server with reasonable parameters.
309
        If you want o modify settings look at code of this methods
310
        and make your own
311
        """
312
        desc = ua.ApplicationDescription()
313
        desc.ApplicationUri = self.application_uri
314
        desc.ProductUri = self.product_uri
315
        desc.ApplicationName = ua.LocalizedText(self.name)
316
        desc.ApplicationType = ua.ApplicationType.Client
317
        params = ua.CreateSessionParameters()
318
        # at least 32 random bytes for server to prove possession of private key (specs part 4, 5.6.2.2)
319
        nonce = create_nonce(32)
320
        params.ClientNonce = nonce
321
        params.ClientCertificate = self.security_policy.client_certificate
322
        params.ClientDescription = desc
323
        params.EndpointUrl = self.server_url.geturl()
324
        params.SessionName = f"{self.description} Session{self._session_counter}"
325
        # Requested maximum number of milliseconds that a Session should remain open without activity
326
        params.RequestedSessionTimeout = 60 * 60 * 1000
327
        params.MaxResponseMessageSize = 0  # means no max size
328
        response = await self.uaclient.create_session(params)
329
        if self.security_policy.client_certificate is None:
330
            data = nonce
331
        else:
332
            data = self.security_policy.client_certificate + nonce
333
        self.security_policy.asymmetric_cryptography.verify(data, response.ServerSignature.Signature)
334
        self._server_nonce = response.ServerNonce
335
        if not self.security_policy.server_certificate:
336
            self.security_policy.server_certificate = response.ServerCertificate
337
        elif self.security_policy.server_certificate != response.ServerCertificate:
338
            raise ua.UaError("Server certificate mismatch")
339
        # remember PolicyId's: we will use them in activate_session()
340
        ep = Client.find_endpoint(response.ServerEndpoints, self.security_policy.Mode, self.security_policy.URI)
341
        self._policy_ids = ep.UserIdentityTokens
342
        #  Actual maximum number of milliseconds that a Session shall remain open without activity
343
        self.session_timeout = response.RevisedSessionTimeout
344
        self._renew_channel_task = self.loop.create_task(self._renew_channel_loop())
345
        return response
346
347
    async def _renew_channel_loop(self):
348
        """
349
        Renew the SecureChannel before the SessionTimeout will happen.
350
        In theory we could do that only if no session activity
351
        but it does not cost much..
352
        """
353
        try:
354
            duration = min(self.session_timeout, self.secure_channel_timeout) * 0.7 * 0.001
355
            while True:
356
                # 0.7 is from spec. 0.001 is because asyncio.sleep expects time in seconds
357
                await asyncio.sleep(duration)
358
                _logger.debug("renewing channel")
359
                await self.open_secure_channel(renew=True)
360
                val = await self.nodes.server_state.read_value()
361
                _logger.debug("server state is: %s ", val)
362
        except asyncio.CancelledError:
363
            pass
364
365
    def server_policy_id(self, token_type, default):
366
        """
367
        Find PolicyId of server's UserTokenPolicy by token_type.
368
        Return default if there's no matching UserTokenPolicy.
369
        """
370
        for policy in self._policy_ids:
371
            if policy.TokenType == token_type:
372
                return policy.PolicyId
373
        return default
374
375
    def server_policy_uri(self, token_type):
376
        """
377
        Find SecurityPolicyUri of server's UserTokenPolicy by token_type.
378
        If SecurityPolicyUri is empty, use default SecurityPolicyUri
379
        of the endpoint
380
        """
381
        for policy in self._policy_ids:
382
            if policy.TokenType == token_type:
383
                if policy.SecurityPolicyUri:
384
                    return policy.SecurityPolicyUri
385
                # empty URI means "use this endpoint's policy URI"
386
                return self.security_policy.URI
387
        return self.security_policy.URI
388
389
    async def activate_session(self, username: str = None, password: str = None, certificate=None):
390
        """
391
        Activate session using either username and password or private_key
392
        """
393
        params = ua.ActivateSessionParameters()
394
        challenge = b""
395
        if self.security_policy.server_certificate is not None:
396
            challenge += self.security_policy.server_certificate
397
        if self._server_nonce is not None:
398
            challenge += self._server_nonce
399
        if self.security_policy.AsymmetricSignatureURI:
400
            params.ClientSignature.Algorithm = self.security_policy.AsymmetricSignatureURI
401
        else:
402
            params.ClientSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
403
        params.ClientSignature.Signature = self.security_policy.asymmetric_cryptography.signature(challenge)
404
        params.LocaleIds.append("en")
405
        if not username and not certificate:
406
            self._add_anonymous_auth(params)
407
        elif certificate:
408
            self._add_certificate_auth(params, certificate, challenge)
409
        else:
410
            self._add_user_auth(params, username, password)
411
        return await self.uaclient.activate_session(params)
412
413
    def _add_anonymous_auth(self, params):
414
        params.UserIdentityToken = ua.AnonymousIdentityToken()
415
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Anonymous, "anonymous")
416
417
    def _add_certificate_auth(self, params, certificate, challenge):
418
        params.UserIdentityToken = ua.X509IdentityToken()
419
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, "certificate_basic256")
420
        params.UserIdentityToken.CertificateData = uacrypto.der_from_x509(certificate)
421
        # specs part 4, 5.6.3.1: the data to sign is created by appending
422
        # the last serverNonce to the serverCertificate
423
        sig = uacrypto.sign_sha1(self.user_private_key, challenge)
424
        params.UserTokenSignature = ua.SignatureData()
425
        params.UserTokenSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
426
        params.UserTokenSignature.Signature = sig
427
428
    def _add_user_auth(self, params, username: str, password: str):
429
        params.UserIdentityToken = ua.UserNameIdentityToken()
430
        params.UserIdentityToken.UserName = username
431
        policy_uri = self.server_policy_uri(ua.UserTokenType.UserName)
432
        if not policy_uri or policy_uri == security_policies.POLICY_NONE_URI:
433
            # see specs part 4, 7.36.3: if the token is NOT encrypted,
434
            # then the password only contains UTF-8 encoded password
435
            # and EncryptionAlgorithm is null
436
            if self._password:
437
                _logger.warning("Sending plain-text password")
438
                params.UserIdentityToken.Password = password.encode("utf8")
439
            params.UserIdentityToken.EncryptionAlgorithm = None
440
        elif self._password:
441
            data, uri = self._encrypt_password(password, policy_uri)
442
            params.UserIdentityToken.Password = data
443
            params.UserIdentityToken.EncryptionAlgorithm = uri
444
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, "username_basic256")
445
446
    def _encrypt_password(self, password: str, policy_uri):
447
        pubkey = uacrypto.x509_from_der(self.security_policy.server_certificate).public_key()
448
        # see specs part 4, 7.36.3: if the token is encrypted, password
449
        # shall be converted to UTF-8 and serialized with server nonce
450
        passwd = password.encode("utf8")
451
        if self._server_nonce is not None:
452
            passwd += self._server_nonce
453
        etoken = ua.ua_binary.Primitives.Bytes.pack(passwd)
454
        data, uri = security_policies.encrypt_asymmetric(pubkey, etoken, policy_uri)
455
        return data, uri
456
457
    async def close_session(self) -> Coroutine:
458
        """
459
        Close session
460
        """
461
        self._renew_channel_task.cancel()
462
        await self._renew_channel_task
463
        return await self.uaclient.close_session(True)
464
465
    def get_root_node(self):
466
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.RootFolder))
467
468
    def get_objects_node(self):
469
        _logger.info("get_objects_node")
470
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
471
472
    def get_server_node(self):
473
        return self.get_node(ua.FourByteNodeId(ua.ObjectIds.Server))
474
475
    def get_node(self, nodeid: Union[ua.NodeId, str]) -> Node:
476
        """
477
        Get node using NodeId object or a string representing a NodeId.
478
        """
479
        return Node(self.uaclient, nodeid)
480
481
    async def create_subscription(self, period, handler):
482
        """
483
        Create a subscription.
484
        Returns a Subscription object which allows to subscribe to events or data changes on server.
485
486
        :param period: Either a publishing interval in milliseconds or a `CreateSubscriptionParameters` instance.
487
        The second option should be used, if the asyncua-server has problems with the default options.
488
        :param handler: Class instance with data_change and/or event methods (see `SubHandler`
489
        base class for details). Remember not to block the main event loop inside the handler methods.
490
        """
491
        if isinstance(period, ua.CreateSubscriptionParameters):
492
            params = period
493
        else:
494
            params = ua.CreateSubscriptionParameters()
495
            params.RequestedPublishingInterval = period
496
            params.RequestedLifetimeCount = 10000
497
            params.RequestedMaxKeepAliveCount = 3000
498
            params.MaxNotificationsPerPublish = 10000
499
            params.PublishingEnabled = True
500
            params.Priority = 0
501
        subscription = Subscription(self.uaclient, params, handler)
502
        await subscription.init()
503
        return subscription
504
505
    def get_namespace_array(self) -> Coroutine:
506
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
507
        return ns_node.read_value()
508
509
    async def get_namespace_index(self, uri):
510
        uries = await self.get_namespace_array()
511
        _logger.info("get_namespace_index %s %r", type(uries), uries)
512
        return uries.index(uri)
513
514
    def delete_nodes(self, nodes, recursive=False) -> Coroutine:
515
        return delete_nodes(self.uaclient, nodes, recursive)
516
517
    def import_xml(self, path=None, xmlstring=None) -> Coroutine:
518
        """
519
        Import nodes defined in xml
520
        """
521
        importer = XmlImporter(self)
522
        return importer.import_xml(path, xmlstring)
523
524
    async def export_xml(self, nodes, path):
525
        """
526
        Export defined nodes to xml
527
        """
528
        exp = XmlExporter(self)
529
        await exp.build_etree(nodes)
530
        await exp.write_xml(path)
531
532 View Code Duplication
    async def register_namespace(self, uri):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
533
        """
534
        Register a new namespace. Nodes should in custom namespace, not 0.
535
        This method is mainly implemented for symetry with server
536
        """
537
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
538
        uries = await ns_node.read_value()
539
        if uri in uries:
540
            return uries.index(uri)
541
        uries.append(uri)
542
        await ns_node.write_value(uries)
543
        return len(uries) - 1
544
545
    def load_type_definitions(self, nodes=None) -> Coroutine:
546
        """
547
        Load custom types (custom structures/extension objects) definition from server
548
        Generate Python classes for custom structures/extension objects defined in server
549
        These classes will available in ua module
550
        """
551
        return load_type_definitions(self, nodes)
552
553
    def load_enums(self) -> Coroutine:
554
        """
555
        generate Python enums for custom enums on server.
556
        This enums will be available in ua module
557
        """
558
        return load_enums(self)
559
560
    async def register_nodes(self, nodes):
561
        """
562
        Register nodes for faster read and write access (if supported by server)
563
        Rmw: This call modifies the nodeid of the nodes, the original nodeid is
564
        available as node.basenodeid
565
        """
566
        nodeids = [node.nodeid for node in nodes]
567
        nodeids = await self.uaclient.register_nodes(nodeids)
568
        for node, nodeid in zip(nodes, nodeids):
569
            node.basenodeid = node.nodeid
570
            node.nodeid = nodeid
571
        return nodes
572
573
    async def unregister_nodes(self, nodes):
574
        """
575
        Unregister nodes
576
        """
577
        nodeids = [node.nodeid for node in nodes]
578
        await self.uaclient.unregister_nodes(nodeids)
579
        for node in nodes:
580
            if not node.basenodeid:
581
                continue
582
            node.nodeid = node.basenodeid
583
            node.basenodeid = None
584
585
    async def read_values(self, nodes):
586
        """
587
        Read the value of multiple nodes in one ua call.
588
        """
589
        nodeids = [node.nodeid for node in nodes]
590
        results = await self.uaclient.get_attributes(nodeids, ua.AttributeIds.Value)
591
        return [result.Value.Value for result in results]
592
593
    async def write_values(self, nodes, values):
594
        """
595
        Write values to multiple nodes in one ua call
596
        """
597
        nodeids = [node.nodeid for node in nodes]
598
        dvs = [value_to_datavalue(val) for val in values]
599
        results = await self.uaclient.set_attributes(nodeids, dvs, ua.AttributeIds.Value)
600
        for result in results:
601
            result.check()
602
603
    get_values = read_values  # legacy compatibility
604
    set_values = write_values  # legacy compatibility
605