Passed
Pull Request — master (#534)
by
unknown
03:45
created

Client.create_session()   B

Complexity

Conditions 5

Size

Total Lines 44

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 34
CRAP Score 5.0005

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 5
c 1
b 0
f 0
dl 0
loc 44
ccs 34
cts 35
cp 0.9714
crap 5.0005
rs 8.0894
1 1
from __future__ import division  # support for python2
2 1
from threading import Thread, Condition
3 1
import logging
4 1
try:
5 1
    from urllib.parse import urlparse
6 1
except ImportError:  # support for python2
7 1
    from urlparse import urlparse
8
9 1
from opcua import ua
10 1
from opcua.client.ua_client import UaClient
11 1
from opcua.common.xmlimporter import XmlImporter
12 1
from opcua.common.xmlexporter import XmlExporter
13 1
from opcua.common.node import Node
14 1
from opcua.common.manage_nodes import delete_nodes
15 1
from opcua.common.subscription import Subscription
16 1
from opcua.common import utils
17 1
from opcua.crypto import security_policies
18 1
from opcua.common.shortcuts import Shortcuts
19 1
from opcua.common.structures import load_type_definitions
20 1
use_crypto = True
21 1
try:
22 1
    from opcua.crypto import uacrypto
23
except ImportError:
24
    logging.getLogger(__name__).warning("cryptography is not installed, use of crypto disabled")
25
    use_crypto = False
26
27
28 1
class KeepAlive(Thread):
29
30
    """
31
    Used by Client to keep the session open.
32
    OPCUA defines timeout both for sessions and secure channel
33
    """
34
35 1
    def __init__(self, client, timeout):
36
        """
37
        :param session_timeout: Timeout to re-new the session
38
            in milliseconds.
39
        """
40 1
        Thread.__init__(self)
41 1
        self.logger = logging.getLogger(__name__)
42
43 1
        self.client = client
44 1
        self._dostop = False
45 1
        self._cond = Condition()
46 1
        self.timeout = timeout
47
48
        # some server support no timeout, but we do not trust them
49 1
        if self.timeout == 0:
50
            self.timeout = 3600000 # 1 hour
51
52 1
    def run(self):
53 1
        self.logger.debug("starting keepalive thread with period of %s milliseconds", self.timeout)
54 1
        server_state = self.client.get_node(ua.FourByteNodeId(ua.ObjectIds.Server_ServerStatus_State))
55 1
        while not self._dostop:
56 1
            with self._cond:
57 1
                self._cond.wait(self.timeout / 1000)
58 1
            if self._dostop:
59 1
                break
60
            self.logger.debug("renewing channel")
61
            self.client.open_secure_channel(renew=True)
62
            val = server_state.get_value()
63
            self.logger.debug("server state is: %s ", val)
64 1
        self.logger.debug("keepalive thread has stopped")
65
66 1
    def stop(self):
67 1
        self.logger.debug("stoping keepalive thread")
68 1
        self._dostop = True
69 1
        with self._cond:
70 1
            self._cond.notify_all()
71
72
73 1
class Client(object):
74
75
    """
76
    High level client to connect to an OPC-UA server.
77
78
    This class makes it easy to connect and browse address space.
79
    It attemps to expose as much functionality as possible
80
    but if you want more flexibility it is possible and adviced to
81
    use UaClient object, available as self.uaclient
82
    which offers the raw OPC-UA services interface.
83
    """
84
85 1
    def __init__(self, url, timeout=4):
86
        """
87
88
        :param url: url of the server.
89
            if you are unsure of url, write at least hostname
90
            and port and call get_endpoints
91
92
        :param timeout:
93
            Each request sent to the server expects an answer within this
94
            time. The timeout is specified in seconds.
95
        """
96 1
        self.logger = logging.getLogger(__name__)
97 1
        self.server_url = urlparse(url)
98
        #take initial username and password from the url
99 1
        self._username = self.server_url.username
100 1
        self._password = self.server_url.password
101 1
        self.name = "Pure Python Client"
102 1
        self.description = self.name
103 1
        self.application_uri = "urn:freeopcua:client"
104 1
        self.product_uri = "urn:freeopcua.github.no:client"
105 1
        self.security_policy = ua.SecurityPolicy()
106 1
        self.secure_channel_id = None
107 1
        self.secure_channel_timeout = 3600000 # 1 hour
108 1
        self.session_timeout = 3600000 # 1 hour
109 1
        self._policy_ids = []
110 1
        self.uaclient = UaClient(timeout)
111 1
        self.user_certificate = None
112 1
        self.user_private_key = None
113 1
        self._server_nonce = None
114 1
        self._session_counter = 1
115 1
        self.keepalive = None
116 1
        self.nodes = Shortcuts(self.uaclient)
117 1
        self.max_messagesize = 0 # No limits
118 1
        self.max_chunkcount = 0 # No limits
119
120
121 1
    def __enter__(self):
122 1
        self.connect()
123 1
        return self
124
125 1
    def __exit__(self, exc_type, exc_value, traceback):
126 1
        self.disconnect()
127
128 1
    @staticmethod
129
    def find_endpoint(endpoints, security_mode, policy_uri):
130
        """
131
        Find endpoint with required security mode and policy URI
132
        """
133 1
        for ep in endpoints:
134 1
            if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and
135
                    ep.SecurityMode == security_mode and
136
                    ep.SecurityPolicyUri == policy_uri):
137 1
                return ep
138 1
        raise ua.UaError("No matching endpoints: {0}, {1}".format(security_mode, policy_uri))
139
140 1
    def set_user(self, username):
141
        """
142
        Set user name for the connection.
143
        initial user from the URL will be overwritten
144
        """
145
        self._username = username
146
147 1
    def set_password(self, pwd):
148
        """
149
        Set user password for the connection.
150
        initial password from the URL will be overwritten
151
        """
152
        self._password = pwd   
153
154 1
    def set_security_string(self, string):
155
        """
156
        Set SecureConnection mode. String format:
157
        Policy,Mode,certificate,private_key[,server_private_key]
158
        where Policy is Basic128Rsa15 or Basic256,
159
            Mode is Sign or SignAndEncrypt
160
            certificate, private_key and server_private_key are
161
                paths to .pem or .der files
162
        Call this before connect()
163
        """
164 1
        if not string:
165
            return
166 1
        parts = string.split(',')
167 1
        if len(parts) < 4:
168
            raise ua.UaError('Wrong format: `{0}`, expected at least 4 comma-separated values'.format(string))
169 1
        policy_class = getattr(security_policies, 'SecurityPolicy' + parts[0])
170 1
        mode = getattr(ua.MessageSecurityMode, parts[1])
171 1
        return self.set_security(policy_class, parts[2], parts[3],
172
                                 parts[4] if len(parts) >= 5 else None, mode)
173
174 1
    def set_security(self, policy, certificate_path, private_key_path,
175
                     server_certificate_path=None,
176
                     mode=ua.MessageSecurityMode.SignAndEncrypt):
177
        """
178
        Set SecureConnection mode.
179
        Call this before connect()
180
        """
181 1
        if server_certificate_path is None:
182
            # load certificate from server's list of endpoints
183 1
            endpoints = self.connect_and_get_server_endpoints()
184 1
            endpoint = Client.find_endpoint(endpoints, mode, policy.URI)
185 1
            server_cert = uacrypto.x509_from_der(endpoint.ServerCertificate)
186
        else:
187
            server_cert = uacrypto.load_certificate(server_certificate_path)
188 1
        cert = uacrypto.load_certificate(certificate_path)
189 1
        pk = uacrypto.load_private_key(private_key_path)
190 1
        self.security_policy = policy(server_cert, cert, pk, mode)
191 1
        self.uaclient.set_security(self.security_policy)
192
193 1
    def load_client_certificate(self, path):
194
        """
195
        load our certificate from file, either pem or der
196
        """
197
        self.user_certificate = uacrypto.load_certificate(path)
198
199 1
    def load_private_key(self, path):
200
        """
201
        Load user private key. This is used for authenticating using certificate
202
        """
203
        self.user_private_key = uacrypto.load_private_key(path)
204
205 1
    def connect_and_get_server_endpoints(self):
206
        """
207
        Connect, ask server for endpoints, and disconnect
208
        """
209 1
        self.connect_socket()
210 1
        self.send_hello()
211 1
        self.open_secure_channel()
212 1
        endpoints = self.get_endpoints()
213 1
        self.close_secure_channel()
214 1
        self.disconnect_socket()
215 1
        return endpoints
216
217 1
    def connect_and_find_servers(self):
218
        """
219
        Connect, ask server for a list of known servers, and disconnect
220
        """
221
        self.connect_socket()
222
        self.send_hello()
223
        self.open_secure_channel()  # spec says it should not be necessary to open channel
224
        servers = self.find_servers()
225
        self.close_secure_channel()
226
        self.disconnect_socket()
227
        return servers
228
229 1
    def connect_and_find_servers_on_network(self):
230
        """
231
        Connect, ask server for a list of known servers on network, and disconnect
232
        """
233
        self.connect_socket()
234
        self.send_hello()
235
        self.open_secure_channel()
236
        servers = self.find_servers_on_network()
237
        self.close_secure_channel()
238
        self.disconnect_socket()
239
        return servers
240
241 1
    def connect(self):
242
        """
243
        High level method
244
        Connect, create and activate session
245
        """
246 1
        self.connect_socket()
247 1
        self.send_hello()
248 1
        self.open_secure_channel()
249 1
        self.create_session()
250 1
        self.activate_session(username=self._username, password=self._password, certificate=self.user_certificate)
251
252 1
    def disconnect(self):
253
        """
254
        High level method
255
        Close session, secure channel and socket
256
        """
257 1
        try:
258 1
            self.close_session()
259 1
            self.close_secure_channel()
260
        finally:
261 1
            self.disconnect_socket()
262
263 1
    def connect_socket(self):
264
        """
265
        connect to socket defined in url
266
        """
267 1
        self.uaclient.connect_socket(self.server_url.hostname, self.server_url.port)
268
269 1
    def disconnect_socket(self):
270 1
        self.uaclient.disconnect_socket()
271
272 1
    def send_hello(self):
273
        """
274
        Send OPC-UA hello to server
275
        """
276 1
        ack = self.uaclient.send_hello(self.server_url.geturl(), self.max_messagesize, self.max_chunkcount)
277
        # FIXME check ack
278
279 1
    def open_secure_channel(self, renew=False):
280
        """
281
        Open secure channel, if renew is True, renew channel
282
        """
283 1
        params = ua.OpenSecureChannelParameters()
284 1
        params.ClientProtocolVersion = 0
285 1
        params.RequestType = ua.SecurityTokenRequestType.Issue
286 1
        if renew:
287
            params.RequestType = ua.SecurityTokenRequestType.Renew
288 1
        params.SecurityMode = self.security_policy.Mode
289 1
        params.RequestedLifetime = self.secure_channel_timeout
290 1
        nonce = utils.create_nonce(self.security_policy.symmetric_key_size)   # length should be equal to the length of key of symmetric encryption
291 1
        params.ClientNonce = nonce	# this nonce is used to create a symmetric key
292 1
        result = self.uaclient.open_secure_channel(params)
293 1
        self.security_policy.make_symmetric_key(nonce, result.ServerNonce)
294 1
        self.secure_channel_timeout = result.SecurityToken.RevisedLifetime
295
296 1
    def close_secure_channel(self):
297 1
        return self.uaclient.close_secure_channel()
298
299 1
    def get_endpoints(self):
300 1
        params = ua.GetEndpointsParameters()
301 1
        params.EndpointUrl = self.server_url.geturl()
302 1
        return self.uaclient.get_endpoints(params)
303
304 1
    def register_server(self, server, discovery_configuration=None):
305
        """
306
        register a server to discovery server
307
        if discovery_configuration is provided, the newer register_server2 service call is used
308
        """
309 1
        serv = ua.RegisteredServer()
310 1
        serv.ServerUri = server.application_uri
311 1
        serv.ProductUri = server.product_uri
312 1
        serv.DiscoveryUrls = [server.endpoint.geturl()]
313 1
        serv.ServerType = server.application_type
314 1
        serv.ServerNames = [ua.LocalizedText(server.name)]
315 1
        serv.IsOnline = True
316 1
        if discovery_configuration:
317
            params = ua.RegisterServer2Parameters()
318
            params.Server = serv
319
            params.DiscoveryConfiguration = discovery_configuration
320
            return self.uaclient.register_server2(params)
321
        else:
322 1
            return self.uaclient.register_server(serv)
323
324 1
    def find_servers(self, uris=None):
325
        """
326
        send a FindServer request to the server. The answer should be a list of
327
        servers the server knows about
328
        A list of uris can be provided, only server having matching uris will be returned
329
        """
330 1
        if uris is None:
331 1
            uris = []
332 1
        params = ua.FindServersParameters()
333 1
        params.EndpointUrl = self.server_url.geturl()
334 1
        params.ServerUris = uris
335 1
        return self.uaclient.find_servers(params)
336
337 1
    def find_servers_on_network(self):
338
        params = ua.FindServersOnNetworkParameters()
339
        return self.uaclient.find_servers_on_network(params)
340
341 1
    def create_session(self):
342
        """
343
        send a CreateSessionRequest to server with reasonable parameters.
344
        If you want o modify settings look at code of this methods
345
        and make your own
346
        """
347 1
        desc = ua.ApplicationDescription()
348 1
        desc.ApplicationUri = self.application_uri
349 1
        desc.ProductUri = self.product_uri
350 1
        desc.ApplicationName = ua.LocalizedText(self.name)
351 1
        desc.ApplicationType = ua.ApplicationType.Client
352
353 1
        params = ua.CreateSessionParameters()
354 1
        nonce = utils.create_nonce(32)  # at least 32 random bytes for server to prove possession of private key (specs part 4, 5.6.2.2)
355 1
        params.ClientNonce = nonce
356 1
        params.ClientCertificate = self.security_policy.client_certificate
357 1
        params.ClientDescription = desc
358 1
        params.EndpointUrl = self.server_url.geturl()
359 1
        params.SessionName = self.description + " Session" + str(self._session_counter)
360 1
        params.RequestedSessionTimeout = 3600000
361 1
        params.MaxResponseMessageSize = 0  # means no max size
362 1
        response = self.uaclient.create_session(params)
363 1
        if self.security_policy.client_certificate is None:
364 1
            data = nonce
365
        else:
366 1
            data = self.security_policy.client_certificate + nonce
367 1
        self.security_policy.asymmetric_cryptography.verify(data, response.ServerSignature.Signature)
368 1
        self._server_nonce = response.ServerNonce
369 1
        if not self.security_policy.server_certificate:
370 1
            self.security_policy.server_certificate = response.ServerCertificate
371 1
        elif self.security_policy.server_certificate != response.ServerCertificate:
372
            raise ua.UaError("Server certificate mismatch")
373
        # remember PolicyId's: we will use them in activate_session()
374 1
        ep = Client.find_endpoint(response.ServerEndpoints, self.security_policy.Mode, self.security_policy.URI)
375 1
        self._policy_ids = ep.UserIdentityTokens
376 1
        self.session_timeout = response.RevisedSessionTimeout
377
        # In case there was an existing session, try to reap existing KeepAlive-Thread
378 1
        try:
379 1
            self.keepalive.stop()
380 1
        except AttributeError:
381 1
            pass
382 1
        self.keepalive = KeepAlive(self, min(self.session_timeout, self.secure_channel_timeout) * 0.7)  # 0.7 is from spec
383 1
        self.keepalive.start()
384 1
        return response
385
386 1
    def server_policy_id(self, token_type, default):
387
        """
388
        Find PolicyId of server's UserTokenPolicy by token_type.
389
        Return default if there's no matching UserTokenPolicy.
390
        """
391 1
        for policy in self._policy_ids:
392 1
            if policy.TokenType == token_type:
393 1
                return policy.PolicyId
394
        return default
395
396 1
    def server_policy_uri(self, token_type):
397
        """
398
        Find SecurityPolicyUri of server's UserTokenPolicy by token_type.
399
        If SecurityPolicyUri is empty, use default SecurityPolicyUri
400
        of the endpoint
401
        """
402 1
        for policy in self._policy_ids:
403 1
            if policy.TokenType == token_type:
404 1
                if policy.SecurityPolicyUri:
405
                    return policy.SecurityPolicyUri
406
                else:   # empty URI means "use this endpoint's policy URI"
407 1
                    return self.security_policy.URI
408
        return self.security_policy.URI
409
410 1
    def activate_session(self, username=None, password=None, certificate=None):
411
        """
412
        Activate session using either username and password or private_key
413
        """
414 1
        params = ua.ActivateSessionParameters()
415 1
        challenge = b""
416 1
        if self.security_policy.server_certificate is not None:
417 1
            challenge += self.security_policy.server_certificate
418 1
        if self._server_nonce is not None:
419 1
            challenge += self._server_nonce
420 1
        params.ClientSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
421 1
        params.ClientSignature.Signature = self.security_policy.asymmetric_cryptography.signature(challenge)
422 1
        params.LocaleIds.append("en")
423 1
        if not username and not certificate:
424 1
            self._add_anonymous_auth(params)
425 1
        elif certificate:
426
            self._add_certificate_auth(params, certificate, challenge)
427
        else:
428 1
            self._add_user_auth(params, username, password)
429 1
        return self.uaclient.activate_session(params)
430
431 1
    def _add_anonymous_auth(self, params):
432 1
        params.UserIdentityToken = ua.AnonymousIdentityToken()
433 1
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Anonymous, b"anonymous")
434
435 1
    def _add_certificate_auth(self, params, certificate, challenge):
436
        params.UserIdentityToken = ua.X509IdentityToken()
437
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, b"certificate_basic256")
438
        params.UserIdentityToken.CertificateData = uacrypto.der_from_x509(certificate)
439
        # specs part 4, 5.6.3.1: the data to sign is created by appending
440
        # the last serverNonce to the serverCertificate
441
        sig = uacrypto.sign_sha1(self.user_private_key, challenge)
442
        params.UserTokenSignature = ua.SignatureData()
443
        params.UserTokenSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
444
        params.UserTokenSignature.Signature = sig
445
446 1
    def _add_user_auth(self, params, username, password):
447 1
        params.UserIdentityToken = ua.UserNameIdentityToken()
448 1
        params.UserIdentityToken.UserName = username
449 1
        policy_uri = self.server_policy_uri(ua.UserTokenType.UserName)
450 1
        if not policy_uri or policy_uri == security_policies.POLICY_NONE_URI:
451
            # see specs part 4, 7.36.3: if the token is NOT encrypted,
452
            # then the password only contains UTF-8 encoded password
453
            # and EncryptionAlgorithm is null
454 1
            if self._password:
455
                self.logger.warning("Sending plain-text password")
456
                params.UserIdentityToken.Password = password
457 1
            params.UserIdentityToken.EncryptionAlgorithm = None
458
        elif self._password:
459
            data, uri = self._encrypt_password(password, policy_uri)
460
            params.UserIdentityToken.Password = data
461
            params.UserIdentityToken.EncryptionAlgorithm = uri
462 1
        params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, b"username_basic256")
463
464 1
    def _encrypt_password(self, password, policy_uri):
465
        pubkey = uacrypto.x509_from_der(self.security_policy.server_certificate).public_key()
466
        # see specs part 4, 7.36.3: if the token is encrypted, password
467
        # shall be converted to UTF-8 and serialized with server nonce
468
        passwd = password.encode("utf8")
469
        if self._server_nonce is not None:
470
            passwd += self._server_nonce
471
        etoken = ua.ua_binary.Primitives.Bytes.pack(passwd)
472
        data, uri = security_policies.encrypt_asymmetric(pubkey, etoken, policy_uri)
473
        return data, uri
474
475 1
    def close_session(self):
476
        """
477
        Close session
478
        """
479 1
        if self.keepalive:
480 1
            self.keepalive.stop()
481 1
        return self.uaclient.close_session(True)
482
483 1
    def get_root_node(self):
484 1
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.RootFolder))
485
486 1
    def get_objects_node(self):
487 1
        return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
488
489 1
    def get_server_node(self):
490 1
        return self.get_node(ua.FourByteNodeId(ua.ObjectIds.Server))
491
492 1
    def get_node(self, nodeid):
493
        """
494
        Get node using NodeId object or a string representing a NodeId
495
        """
496 1
        return Node(self.uaclient, nodeid)
497
498 1
    def create_subscription(self, period, handler):
499
        """
500
        Create a subscription.
501
        returns a Subscription object which allow
502
        to subscribe to events or data on server
503
        handler argument is a class with data_change and/or event methods.
504
        period argument is either a publishing interval in milliseconds or a
505
        CreateSubscriptionParameters instance. The second option should be used,
506
        if the opcua-server has problems with the default options.
507
        These methods will be called when notfication from server are received.
508
        See example-client.py.
509
        Do not do expensive/slow or network operation from these methods
510
        since they are called directly from receiving thread. This is a design choice,
511
        start another thread if you need to do such a thing.
512
        """
513
514 1
        if isinstance(period, ua.CreateSubscriptionParameters):
515
            return Subscription(self.uaclient, period, handler)
516 1
        params = ua.CreateSubscriptionParameters()
517 1
        params.RequestedPublishingInterval = period
518 1
        params.RequestedLifetimeCount = 10000
519 1
        params.RequestedMaxKeepAliveCount = 3000
520 1
        params.MaxNotificationsPerPublish = 10000
521 1
        params.PublishingEnabled = True
522 1
        params.Priority = 0
523 1
        return Subscription(self.uaclient, params, handler)
524
525 1
    def get_namespace_array(self):
526 1
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
527 1
        return ns_node.get_value()
528
529 1
    def get_namespace_index(self, uri):
530 1
        uries = self.get_namespace_array()
531 1
        return uries.index(uri)
532
533 1
    def delete_nodes(self, nodes, recursive=False):
534 1
        return delete_nodes(self.uaclient, nodes, recursive)
535
536 1
    def import_xml(self, path):
537
        """
538
        Import nodes defined in xml
539
        """
540 1
        importer = XmlImporter(self)
541 1
        return importer.import_xml(path)
542
543 1
    def export_xml(self, nodes, path):
544
        """
545
        Export defined nodes to xml
546
        """
547 1
        exp = XmlExporter(self)
548 1
        exp.build_etree(nodes)
549 1
        return exp.write_xml(path)
550
551 1
    def register_namespace(self, uri):
552
        """
553
        Register a new namespace. Nodes should in custom namespace, not 0.
554
        This method is mainly implemented for symetry with server
555
        """
556 1
        ns_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_NamespaceArray))
557 1
        uries = ns_node.get_value()
558 1
        if uri in uries:
559 1
            return uries.index(uri)
560 1
        uries.append(uri)
561 1
        ns_node.set_value(uries)
562 1
        return len(uries) - 1
563
564 1
    def load_type_definitions(self, nodes=None):
565
        return load_type_definitions(self, nodes)
566
567
568