Passed
Pull Request — master (#317)
by Olivier
02:41
created

SecureConnection._check_sym_header()   C

Complexity

Conditions 10

Size

Total Lines 38
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 22
nop 2
dl 0
loc 38
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like asyncua.common.connection.SecureConnection._check_sym_header() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import hashlib
2
from datetime import datetime, timedelta
3
import logging
4
import copy
5
6
from asyncua import ua
7
from ..ua.ua_binary import struct_from_binary, struct_to_binary, header_from_binary, header_to_binary
8
9
logger = logging.getLogger('asyncua.uaprotocol')
10
11
12
class MessageChunk(ua.FrozenClass):
13
    """
14
    Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
15
    """
16
    def __init__(self, security_policy, body=b'', msg_type=ua.MessageType.SecureMessage,
17
                 chunk_type=ua.ChunkType.Single):
18
        self.MessageHeader = ua.Header(msg_type, chunk_type)
19
        if msg_type in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
20
            self.SecurityHeader = ua.SymmetricAlgorithmHeader()
21
        elif msg_type == ua.MessageType.SecureOpen:
22
            self.SecurityHeader = ua.AsymmetricAlgorithmHeader()
23
        else:
24
            raise ua.UaError(f"Unsupported message type: {msg_type}")
25
        self.SequenceHeader = ua.SequenceHeader()
26
        self.Body = body
27
        self.security_policy = security_policy
28
29
    @staticmethod
30
    def from_binary(security_policy, data):
31
        h = header_from_binary(data)
32
        return MessageChunk.from_header_and_body(security_policy, h, data)
33
34
    @staticmethod
35
    def from_header_and_body(security_policy, header, buf):
36
        if not len(buf) >= header.body_size:
37
            raise ValueError('Full body expected here')
38
        data = buf.copy(header.body_size)
39
        buf.skip(header.body_size)
40
        if header.MessageType in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
41
            security_header = struct_from_binary(ua.SymmetricAlgorithmHeader, data)
42
            crypto = security_policy.symmetric_cryptography
43
        elif header.MessageType == ua.MessageType.SecureOpen:
44
            security_header = struct_from_binary(ua.AsymmetricAlgorithmHeader, data)
45
            crypto = security_policy.asymmetric_cryptography
46
        else:
47
            raise ua.UaError(f"Unsupported message type: {header.MessageType}")
48
        obj = MessageChunk(crypto)
49
        obj.MessageHeader = header
50
        obj.SecurityHeader = security_header
51
        decrypted = crypto.decrypt(data.read(len(data)))
52
        signature_size = crypto.vsignature_size()
53
        if signature_size > 0:
54
            signature = decrypted[-signature_size:]
55
            decrypted = decrypted[:-signature_size]
56
            crypto.verify(header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted,
57
                          signature)
58
        data = ua.utils.Buffer(crypto.remove_padding(decrypted))
59
        obj.SequenceHeader = struct_from_binary(ua.SequenceHeader, data)
60
        obj.Body = data.read(len(data))
61
        return obj
62
63
    def encrypted_size(self, plain_size):
64
        size = plain_size + self.security_policy.signature_size()
65
        pbs = self.security_policy.plain_block_size()
66
        if size % pbs != 0:
67
            raise ua.UaError("Encryption error")
68
        return size // pbs * self.security_policy.encrypted_block_size()
69
70
    def to_binary(self):
71
        security = struct_to_binary(self.SecurityHeader)
72
        encrypted_part = struct_to_binary(self.SequenceHeader) + self.Body
73
        encrypted_part += self.security_policy.padding(len(encrypted_part))
74
        self.MessageHeader.body_size = len(security) + self.encrypted_size(len(encrypted_part))
75
        header = header_to_binary(self.MessageHeader)
76
        encrypted_part += self.security_policy.signature(header + security + encrypted_part)
77
        return header + security + self.security_policy.encrypt(encrypted_part)
78
79
    @staticmethod
80
    def max_body_size(crypto, max_chunk_size):
81
        max_encrypted_size = max_chunk_size - ua.Header.max_size() - ua.SymmetricAlgorithmHeader.max_size()
82
        max_plain_size = (max_encrypted_size // crypto.encrypted_block_size()) * crypto.plain_block_size()
83
        return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()
84
85
    @staticmethod
86
    def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage,
87
                          channel_id=1, request_id=1, token_id=1):
88
        """
89
        Pack message body (as binary string) into one or more chunks.
90
        Size of each chunk will not exceed max_chunk_size.
91
        Returns a list of MessageChunks. SequenceNumber is not initialized here,
92
        it must be set by Secure Channel driver.
93
        """
94
        if message_type == ua.MessageType.SecureOpen:
95
            # SecureOpen message must be in a single chunk (specs, Part 6, 6.7.2)
96
            chunk = MessageChunk(security_policy.asymmetric_cryptography, body, message_type, ua.ChunkType.Single)
97
            chunk.SecurityHeader.SecurityPolicyURI = security_policy.URI
98
            if security_policy.host_certificate:
99
                chunk.SecurityHeader.SenderCertificate = security_policy.host_certificate
100
            if security_policy.peer_certificate:
101
                chunk.SecurityHeader.ReceiverCertificateThumbPrint =\
102
                    hashlib.sha1(security_policy.peer_certificate).digest()
103
            chunk.MessageHeader.ChannelId = channel_id
104
            chunk.SequenceHeader.RequestId = request_id
105
            return [chunk]
106
107
        crypto = security_policy.symmetric_cryptography
108
        max_size = MessageChunk.max_body_size(crypto, max_chunk_size)
109
110
        chunks = []
111
        for i in range(0, len(body), max_size):
112
            part = body[i:i + max_size]
113
            if i + max_size >= len(body):
114
                chunk_type = ua.ChunkType.Single
115
            else:
116
                chunk_type = ua.ChunkType.Intermediate
117
            chunk = MessageChunk(crypto, part, message_type, chunk_type)
118
            chunk.SecurityHeader.TokenId = token_id
119
            chunk.MessageHeader.ChannelId = channel_id
120
            chunk.SequenceHeader.RequestId = request_id
121
            chunks.append(chunk)
122
        return chunks
123
124
    def __str__(self):
125
        return f"{self.__class__.__name__}({self.MessageHeader}, {self.SequenceHeader}," \
126
               f" {self.SecurityHeader}, {len(self.Body)} bytes)"
127
128
    __repr__ = __str__
129
130
131
class SecureConnection:
132
    """
133
    Common logic for client and server
134
    """
135
    def __init__(self, security_policy):
136
        self._sequence_number = 0
137
        self._peer_sequence_number = None
138
        self._incoming_parts = []
139
        self.security_policy = security_policy
140
        self._policies = []
141
        self._open = False
142
        self.security_token = ua.ChannelSecurityToken()
143
        self.next_security_token = ua.ChannelSecurityToken()
144
        self.prev_security_token = ua.ChannelSecurityToken()
145
        self.local_nonce = 0
146
        self.remote_nonce = 0
147
        self._allow_prev_token = False
148
        self._max_chunk_size = 65536
149
        self._renewal_started = False
150
151
    def start_renewal(self):
152
        self._renewal_started = True
153
154
    def set_channel(self, params, request_type, client_nonce):
155
        """
156
        Called on client side when getting secure channel data from server.
157
        """
158
        self._renewal_started = False
159
        if request_type == ua.SecurityTokenRequestType.Issue:
160
            self.security_token = params.SecurityToken
161
            self.local_nonce = client_nonce
162
            self.remote_nonce = params.ServerNonce
163
            self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
164
            self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
165
            self._open = True
166
        else:
167
            self.next_security_token = params.SecurityToken
168
            self.local_nonce = client_nonce
169
            self.remote_nonce = params.ServerNonce
170
171
        self._allow_prev_token = True
172
173
    def open(self, params, server):
174
        """
175
        Called on server side to open secure channel.
176
        """
177
178
        self.local_nonce = ua.utils.create_nonce(self.security_policy.symmetric_key_size)
179
        self.remote_nonce = params.ClientNonce
180
        response = ua.OpenSecureChannelResult()
181
        response.ServerNonce = self.local_nonce
182
183
        if not self._open or params.RequestType == ua.SecurityTokenRequestType.Issue:
184
            self._open = True
185
            self.security_token.TokenId = 13  # random value
186
            self.security_token.ChannelId = server.get_new_channel_id()
187
            self.security_token.RevisedLifetime = params.RequestedLifetime
188
            self.security_token.CreatedAt = datetime.utcnow()
189
190
            response.SecurityToken = self.security_token
191
192
            self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
193
            self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
194
        else:
195
            self.next_security_token = copy.deepcopy(self.security_token)
196
            self.next_security_token.TokenId += 1
197
            self.next_security_token.RevisedLifetime = params.RequestedLifetime
198
            self.next_security_token.CreatedAt = datetime.utcnow()
199
200
            response.SecurityToken = self.next_security_token
201
202
        return response
203
204
    def close(self):
205
        self._open = False
206
207
    def is_open(self):
208
        return self._open
209
210
    def set_policy_factories(self, policies):
211
        """
212
        Set a list of available security policies.
213
        Use this in servers with multiple endpoints with different security.
214
        """
215
        self._policies = policies
216
217
    @staticmethod
218
    def _policy_matches(policy, uri, mode=None):
219
        return policy.URI == uri and (mode is None or policy.Mode == mode)
220
221
    def select_policy(self, uri, peer_certificate, mode=None):
222
        for policy in self._policies:
223
            if policy.matches(uri, mode):
224
                self.security_policy = policy.create(peer_certificate)
225
                return
226
        if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode):
227
            raise ua.UaError(f"No matching policy: {uri}, {mode}")
228
229
    def revolve_tokens(self):
230
        """
231
        Revolve security tokens of the security channel. Start using the
232
        next security token negotiated during the renewal of the channel and
233
        remember the previous token until the other communication party
234
        """
235
        self.prev_security_token = self.security_token
236
        self.security_token = self.next_security_token
237
        self.next_security_token = ua.ChannelSecurityToken()
238
        self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
239
        self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
240
241
    def message_to_binary(self, message, message_type=ua.MessageType.SecureMessage, request_id=0):
242
        """
243
        Convert OPC UA secure message to binary.
244
        The only supported types are SecureOpen, SecureMessage, SecureClose.
245
        If message_type is SecureMessage, the AlgorithmHeader should be passed as arg.
246
        """
247
        chunks = MessageChunk.message_to_chunks(self.security_policy, message, self._max_chunk_size,
248
                                                message_type=message_type, channel_id=self.security_token.ChannelId,
249
                                                request_id=request_id, token_id=self.security_token.TokenId)
250
        for chunk in chunks:
251
            self._sequence_number += 1
252
            if self._sequence_number >= (1 << 32):
253
                logger.debug("Wrapping sequence number: %d -> 1", self._sequence_number)
254
                self._sequence_number = 1
255
            chunk.SequenceHeader.SequenceNumber = self._sequence_number
256
        return b"".join([chunk.to_binary() for chunk in chunks])
257
258
    def _check_sym_header(self, security_hdr):
259
        """
260
        Validates the symmetric header of the message chunk and revolves the
261
        security token if needed.
262
        """
263
        assert isinstance(security_hdr, ua.SymmetricAlgorithmHeader), f"Expected SymAlgHeader, got: {security_hdr}"
264
265
        if security_hdr.TokenId == self.security_token.TokenId:
266
            return
267
268
        if security_hdr.TokenId == self.next_security_token.TokenId:
269
            self.revolve_tokens()
270
            return
271
272
        if self._renewal_started and security_hdr.TokenId == self.security_token.TokenId + 1:
273
            return
274
275
        if self._allow_prev_token and security_hdr.TokenId == self.prev_security_token.TokenId:
276
            # From spec, part 4, section 5.5.2.1: Clients should accept Messages secured by an
277
            # expired SecurityToken for up to 25 % of the token lifetime. This should ensure that
278
            # Messages sent by the Server before the token expired are not rejected because of
279
            # network delays.
280
            timeout = self.prev_security_token.CreatedAt + \
281
                      timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25)
282
            if timeout < datetime.utcnow():
283
                raise ua.UaError(f"Security token id {security_hdr.TokenId} has timed out "
284
                                 f"({timeout} < {datetime.utcnow()})")
285
            return
286
287
        expected_tokens = [self.security_token.TokenId, self.next_security_token.TokenId]
288
289
        if self._renewal_started:
290
            extra_token = self.security_token.TokenId + 1
291
            expected_tokens.append(extra_token)
292
293
        if self._allow_prev_token:
294
            expected_tokens.insert(0, self.prev_security_token.TokenId)
295
        raise ua.UaError(f"Invalid security token id {security_hdr.TokenId}, expected one of: {expected_tokens}")
296
297
    def _check_incoming_chunk(self, chunk):
298
        if not isinstance(chunk, MessageChunk):
299
            raise ValueError(f'Expected chunk, got: {chunk}')
300
        if chunk.MessageHeader.MessageType != ua.MessageType.SecureOpen:
301
            if chunk.MessageHeader.ChannelId != self.security_token.ChannelId:
302
                raise ua.UaError(f'Wrong channel id {chunk.MessageHeader.ChannelId},'
303
                                 f' expected {self.security_token.ChannelId}')
304
        if self._incoming_parts:
305
            if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId:
306
                raise ua.UaError(f'Wrong request id {chunk.SequenceHeader.RequestId},'
307
                                 f' expected {self._incoming_parts[0].SequenceHeader.RequestId}')
308
        # The sequence number must monotonically increase (but it can wrap around)
309
        seq_num = chunk.SequenceHeader.SequenceNumber
310
        if self._peer_sequence_number is not None:
311
            if seq_num != self._peer_sequence_number + 1:
312
                wrap_limit = (1 << 32) - 1024
313
                if seq_num < 1024 and self._peer_sequence_number >= wrap_limit:
314
                    # The sequence number has wrapped around. See spec. part 6, 6.7.2
315
                    logger.debug('Sequence number wrapped: %d -> %d', self._peer_sequence_number, seq_num)
316
                else:
317
                    # Condition for monotonically increase is not met
318
                    raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:"
319
                                     f" {self._peer_sequence_number}, received: {seq_num},"
320
                                     f" spec says to close connection")
321
        self._peer_sequence_number = seq_num
322
323
    def receive_from_header_and_body(self, header, body):
324
        """
325
        Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA
326
        specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message
327
        object, or None (if intermediate chunk is received)
328
        """
329
        if header.MessageType == ua.MessageType.SecureOpen:
330
            data = body.copy(header.body_size)
331
            security_header = struct_from_binary(ua.AsymmetricAlgorithmHeader, data)
332
333
            if not self.is_open():
334
                # Only call select_policy if the channel isn't open. Otherwise
335
                # it will break the Secure channel renewal.
336
                self.select_policy(security_header.SecurityPolicyURI, security_header.SenderCertificate)
337
338
        elif header.MessageType in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
339
            data = body.copy(header.body_size)
340
            security_header = struct_from_binary(ua.SymmetricAlgorithmHeader, data)
341
            self._check_sym_header(security_header)
342
343
        if header.MessageType in (ua.MessageType.SecureMessage, ua.MessageType.SecureOpen, ua.MessageType.SecureClose):
344
            chunk = MessageChunk.from_header_and_body(self.security_policy, header, body)
345
            return self._receive(chunk)
346
        if header.MessageType == ua.MessageType.Hello:
347
            msg = struct_from_binary(ua.Hello, body)
348
            self._max_chunk_size = msg.ReceiveBufferSize
349
            return msg
350
        if header.MessageType == ua.MessageType.Acknowledge:
351
            msg = struct_from_binary(ua.Acknowledge, body)
352
            self._max_chunk_size = msg.SendBufferSize
353
            return msg
354
        if header.MessageType == ua.MessageType.Error:
355
            msg = struct_from_binary(ua.ErrorMessage, body)
356
            logger.warning(f"Received an error: {msg}")
357
            return msg
358
        raise ua.UaError(f"Unsupported message type {header.MessageType}")
359
360
    def _receive(self, msg):
361
        self._check_incoming_chunk(msg)
362
        self._incoming_parts.append(msg)
363
        if msg.MessageHeader.ChunkType == ua.ChunkType.Intermediate:
364
            return None
365
        if msg.MessageHeader.ChunkType == ua.ChunkType.Abort:
366
            err = struct_from_binary(ua.ErrorMessage, ua.utils.Buffer(msg.Body))
367
            logger.warning(f"Message {msg} aborted: {err}")
368
            # specs Part 6, 6.7.3 say that aborted message shall be ignored
369
            # and SecureChannel should not be closed
370
            self._incoming_parts = []
371
            return None
372
        if msg.MessageHeader.ChunkType == ua.ChunkType.Single:
373
            message = ua.Message(self._incoming_parts)
374
            self._incoming_parts = []
375
            return message
376
        raise ua.UaError(f"Unsupported chunk type: {msg}")
377