Passed
Push — master ( 4a3463...98fd52 )
by Olivier
02:51 queued 10s
created

SecureConnection.revolve_tokens()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
import hashlib
2
from datetime import datetime, timedelta
3
import logging
4
import copy
5
6
from ..ua.ua_binary import struct_from_binary, struct_to_binary, header_from_binary, header_to_binary
7
from asyncua import ua
8
9
10
logger = logging.getLogger('asyncua.uaprotocol')
11
12
13
class MessageChunk(ua.FrozenClass):
14
    """
15
    Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
16
    """
17
18
    def __init__(self, security_policy, body=b'', msg_type=ua.MessageType.SecureMessage, chunk_type=ua.ChunkType.Single):
19
        self.MessageHeader = ua.Header(msg_type, chunk_type)
20
        if msg_type in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
21
            self.SecurityHeader = ua.SymmetricAlgorithmHeader()
22
        elif msg_type == ua.MessageType.SecureOpen:
23
            self.SecurityHeader = ua.AsymmetricAlgorithmHeader()
24
        else:
25
            raise ua.UaError("Unsupported message type: {0}".format(msg_type))
26
        self.SequenceHeader = ua.SequenceHeader()
27
        self.Body = body
28
        self.security_policy = security_policy
29
30
    @staticmethod
31
    def from_binary(security_policy, data):
32
        h = header_from_binary(data)
33
        return MessageChunk.from_header_and_body(security_policy, h, data)
34
35
    @staticmethod
36
    def from_header_and_body(security_policy, header, buf):
37
        if not len(buf) >= header.body_size:
38
            raise ValueError('Full body expected here')
39
        data = buf.copy(header.body_size)
40
        buf.skip(header.body_size)
41
        if header.MessageType in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
42
            security_header = struct_from_binary(ua.SymmetricAlgorithmHeader, data)
43
            crypto = security_policy.symmetric_cryptography
44
        elif header.MessageType == ua.MessageType.SecureOpen:
45
            security_header = struct_from_binary(ua.AsymmetricAlgorithmHeader, data)
46
            crypto = security_policy.asymmetric_cryptography
47
        else:
48
            raise ua.UaError("Unsupported message type: {0}".format(header.MessageType))
49
        obj = MessageChunk(crypto)
50
        obj.MessageHeader = header
51
        obj.SecurityHeader = security_header
52
        decrypted = crypto.decrypt(data.read(len(data)))
53
        signature_size = crypto.vsignature_size()
54
        if signature_size > 0:
55
            signature = decrypted[-signature_size:]
56
            decrypted = decrypted[:-signature_size]
57
            crypto.verify(header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted, signature)
58
        data = ua.utils.Buffer(crypto.remove_padding(decrypted))
59
        obj.SequenceHeader = struct_from_binary(ua.SequenceHeader, data)
60
        obj.Body = data.read(len(data))
61
        return obj
62
63
    def encrypted_size(self, plain_size):
64
        size = plain_size + self.security_policy.signature_size()
65
        pbs = self.security_policy.plain_block_size()
66
        if size % pbs != 0:
67
            print("ENC", plain_size, size, pbs)
68
            raise ua.UaError("Encryption error")
69
        return size // pbs * self.security_policy.encrypted_block_size()
70
71
    def to_binary(self):
72
        security = struct_to_binary(self.SecurityHeader)
73
        encrypted_part = struct_to_binary(self.SequenceHeader) + self.Body
74
        encrypted_part += self.security_policy.padding(len(encrypted_part))
75
        self.MessageHeader.body_size = len(security) + self.encrypted_size(len(encrypted_part))
76
        header = header_to_binary(self.MessageHeader)
77
        encrypted_part += self.security_policy.signature(header + security + encrypted_part)
78
        return header + security + self.security_policy.encrypt(encrypted_part)
79
80
    @staticmethod
81
    def max_body_size(crypto, max_chunk_size):
82
        max_encrypted_size = max_chunk_size - ua.Header.max_size() - ua.SymmetricAlgorithmHeader.max_size()
83
        max_plain_size = (max_encrypted_size // crypto.encrypted_block_size()) * crypto.plain_block_size()
84
        return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()
85
86
    @staticmethod
87
    def message_to_chunks(security_policy, body, max_chunk_size,
88
                          message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
89
        """
90
        Pack message body (as binary string) into one or more chunks.
91
        Size of each chunk will not exceed max_chunk_size.
92
        Returns a list of MessageChunks. SequenceNumber is not initialized here,
93
        it must be set by Secure Channel driver.
94
        """
95
        if message_type == ua.MessageType.SecureOpen:
96
            # SecureOpen message must be in a single chunk (specs, Part 6, 6.7.2)
97
            chunk = MessageChunk(security_policy.asymmetric_cryptography, body, message_type, ua.ChunkType.Single)
98
            chunk.SecurityHeader.SecurityPolicyURI = security_policy.URI
99
            if security_policy.client_certificate:
100
                chunk.SecurityHeader.SenderCertificate = security_policy.client_certificate
101
            if security_policy.server_certificate:
102
                chunk.SecurityHeader.ReceiverCertificateThumbPrint =\
103
                    hashlib.sha1(security_policy.server_certificate).digest()
104
            chunk.MessageHeader.ChannelId = channel_id
105
            chunk.SequenceHeader.RequestId = request_id
106
            return [chunk]
107
108
        crypto = security_policy.symmetric_cryptography
109
        max_size = MessageChunk.max_body_size(crypto, max_chunk_size)
110
111
        chunks = []
112
        for i in range(0, len(body), max_size):
113
            part = body[i:i + max_size]
114
            if i + max_size >= len(body):
115
                chunk_type = ua.ChunkType.Single
116
            else:
117
                chunk_type = ua.ChunkType.Intermediate
118
            chunk = MessageChunk(crypto, part, message_type, chunk_type)
119
            chunk.SecurityHeader.TokenId = token_id
120
            chunk.MessageHeader.ChannelId = channel_id
121
            chunk.SequenceHeader.RequestId = request_id
122
            chunks.append(chunk)
123
        return chunks
124
125
    def __str__(self):
126
        return "{0}({1}, {2}, {3}, {4} bytes)".format(self.__class__.__name__,
127
                                                      self.MessageHeader, self.SequenceHeader,
128
                                                      self.SecurityHeader, len(self.Body))
129
    __repr__ = __str__
130
131
132
class SecureConnection:
133
    """
134
    Common logic for client and server
135
    """
136
137
    def __init__(self, security_policy):
138
        self._sequence_number = 0
139
        self._peer_sequence_number = None
140
        self._incoming_parts = []
141
        self.security_policy = security_policy
142
        self._policies = []
143
        self._open = False
144
        self.security_token = ua.ChannelSecurityToken()
145
        self.next_security_token = ua.ChannelSecurityToken()
146
        self.prev_security_token = ua.ChannelSecurityToken()
147
        self.local_nonce = 0
148
        self.remote_nonce = 0
149
        self._allow_prev_token = False
150
        self._max_chunk_size = 65536
151
152
    def set_channel(self, params, requestType, clientNonce):
153
        """
154
        Called on client side when getting secure channel data from server.
155
        """
156
        if requestType == ua.SecurityTokenRequestType.Issue:
157
            self.security_token = params.SecurityToken
158
            self.local_nonce = clientNonce
159
            self.remote_nonce = params.ServerNonce
160
            self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
161
            self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
162
            self._open = True
163
        else:
164
            self.next_security_token = params.SecurityToken
165
            self.local_nonce = clientNonce
166
            self.remote_nonce = params.ServerNonce
167
168
        self._allow_prev_token = True
169
170
    def open(self, params, server):
171
        """
172
        Called on server side to open secure channel.
173
        """
174
175
        self.local_nonce = ua.utils.create_nonce(self.security_policy.symmetric_key_size)
176
        self.remote_nonce = params.ClientNonce
177
        response = ua.OpenSecureChannelResult()
178
        response.ServerNonce = self.local_nonce
179
180
        if not self._open or params.RequestType == ua.SecurityTokenRequestType.Issue:
181
            self._open = True
182
            self.security_token.TokenId = 13  # random value
183
            self.security_token.ChannelId = server.get_new_channel_id()
184
            self.security_token.RevisedLifetime = params.RequestedLifetime
185
            self.security_token.CreatedAt = datetime.utcnow()
186
187
            response.SecurityToken = self.security_token
188
189
            self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
190
            self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
191
        else:
192
            self.next_security_token = copy.deepcopy( self.security_token )
193
            self.next_security_token.TokenId += 1
194
            self.next_security_token.RevisedLifetime = params.RequestedLifetime
195
            self.next_security_token.CreatedAt = datetime.utcnow()
196
197
            response.SecurityToken = self.next_security_token
198
199
        return response
200
201
    def close(self):
202
        self._open = False
203
204
    def is_open(self):
205
        return self._open
206
207
    def set_policy_factories(self, policies):
208
        """
209
        Set a list of available security policies.
210
        Use this in servers with multiple endpoints with different security.
211
        """
212
        self._policies = policies
213
214
    @staticmethod
215
    def _policy_matches(policy, uri, mode=None):
216
        return policy.URI == uri and (mode is None or policy.Mode == mode)
217
218
    def select_policy(self, uri, peer_certificate, mode=None):
219
        for policy in self._policies:
220
            if policy.matches(uri, mode):
221
                self.security_policy = policy.create(peer_certificate)
222
                return
223
        if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode):
224
            raise ua.UaError("No matching policy: {0}, {1}".format(uri, mode))
225
226
    def revolve_tokens( self ):
227
        """
228
        Revolve security tokens of the security channel. Start using the
229
        next security token negotiated during the renewal of the channel and
230
        remember the previous token until the other communication party
231
        """
232
        self.prev_security_token = self.security_token
233
        self.security_token = self.next_security_token
234
        self.next_security_token = ua.ChannelSecurityToken()
235
        self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
236
237
    def message_to_binary(self, message, message_type=ua.MessageType.SecureMessage, request_id=0):
238
        """
239
        Convert OPC UA secure message to binary.
240
        The only supported types are SecureOpen, SecureMessage, SecureClose.
241
        If message_type is SecureMessage, the AlgorithmHeader should be passed as arg.
242
        """
243
        chunks = MessageChunk.message_to_chunks(
244
            self.security_policy, message, self._max_chunk_size,
245
            message_type=message_type,
246
            channel_id=self.security_token.ChannelId,
247
            request_id=request_id,
248
            token_id=self.security_token.TokenId)
249
        for chunk in chunks:
250
            self._sequence_number += 1
251
            if self._sequence_number >= (1 << 32):
252
                logger.debug("Wrapping sequence number: %d -> 1", self._sequence_number)
253
                self._sequence_number = 1
254
            chunk.SequenceHeader.SequenceNumber = self._sequence_number
255
        return b"".join([chunk.to_binary() for chunk in chunks])
256
257
    def _check_sym_header(self, securityHeader):
258
        """
259
        Validates the symmetric header of the message chunk and revolves the
260
        security token if needed.
261
        """
262
        assert isinstance(securityHeader, ua.SymmetricAlgorithmHeader), "Expected SymAlgHeader, got: {0}".format(securityHeader)
263
        if securityHeader.TokenId != self.security_token.TokenId:
264
            if securityHeader.TokenId != self.next_security_token.TokenId:
265
                if self._allow_prev_token and \
266
                   securityHeader.TokenId == self.prev_security_token.TokenId:
267
                    timeout = self.prev_security_token.CreatedAt + timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25 )
268
                    if timeout < datetime.utcnow():
269
                        raise ua.UaError("Security token id {} has timed out ({} < {})"
270
                                         .format(securityHeader.TokenId, timeout, datetime.utcnow()))
271
                    else:
272
                        return
273
                raise ua.UaError("Invalid security token id {}, expected {} or {}"
274
                                 .format(securityHeader.TokenId,
275
                                         self.security_token.TokenId,
276
                                         self.next_security_token.TokenId))
277
            else:
278
                self.revolve_tokens()
279
                self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
280
                self.prev_security_token = ua.ChannelSecurityToken()
281
        if self.prev_security_token.TokenId != 0:
282
            self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
283
            self.prev_security_token = ua.ChannelSecurityToken()
284
285
    def _check_incoming_chunk(self, chunk):
286
        if not isinstance(chunk, MessageChunk):
287
            raise ValueError('Expected chunk, got: %r', chunk)
288
        if chunk.MessageHeader.MessageType != ua.MessageType.SecureOpen:
289
            if chunk.MessageHeader.ChannelId != self.security_token.ChannelId:
290
                raise ua.UaError('Wrong channel id {0}, expected {1}'.format(
291
                    chunk.MessageHeader.ChannelId,
292
                    self.security_token.ChannelId))
293
        if self._incoming_parts:
294
            if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId:
295
                raise ua.UaError('Wrong request id {0}, expected {1}'.format(
296
                    chunk.SequenceHeader.RequestId,
297
                    self._incoming_parts[0].SequenceHeader.RequestId)
298
                )
299
        # The sequence number must monotonically increase (but it can wrap around)
300
        seq_num = chunk.SequenceHeader.SequenceNumber
301
        if self._peer_sequence_number is not None:
302
            if seq_num != self._peer_sequence_number + 1:
303
                wrap_limit = (1 << 32) - 1024
304
                if seq_num < 1024 and self._peer_sequence_number >= wrap_limit:
305
                    # The sequence number has wrapped around. See spec. part 6, 6.7.2
306
                    logger.debug('Sequence number wrapped: %d -> %d', self._peer_sequence_number, seq_num)
307
                else:
308
                    # Condition for monotonically increase is not met
309
                    raise ua.UaError(
310
                        'Wrong sequence {0} -> {1} (server bug or replay attack)'
311
                        .format(self._peer_sequence_number, seq_num)
312
                    )
313
        self._peer_sequence_number = seq_num
314
315
    def receive_from_header_and_body(self, header, body):
316
        """
317
        Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA
318
        specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message
319
        object, or None (if intermediate chunk is received)
320
        """
321
        if header.MessageType == ua.MessageType.SecureOpen:
322
            data = body.copy(header.body_size)
323
            security_header = struct_from_binary(ua.AsymmetricAlgorithmHeader, data)
324
            self.select_policy(security_header.SecurityPolicyURI, security_header.SenderCertificate)
325
        elif header.MessageType in (ua.MessageType.SecureMessage,
326
                                    ua.MessageType.SecureClose):
327
            data = body.copy(header.body_size)
328
            security_header = struct_from_binary(ua.SymmetricAlgorithmHeader, data)
329
            self._check_sym_header(security_header)
330
331
        if header.MessageType in (ua.MessageType.SecureMessage,
332
                                  ua.MessageType.SecureOpen,
333
                                  ua.MessageType.SecureClose):
334
            chunk = MessageChunk.from_header_and_body(self.security_policy,
335
                                                      header, body)
336
            return self._receive(chunk)
337
        elif header.MessageType == ua.MessageType.Hello:
338
            msg = struct_from_binary(ua.Hello, body)
339
            self._max_chunk_size = msg.ReceiveBufferSize
340
            return msg
341
        elif header.MessageType == ua.MessageType.Acknowledge:
342
            msg = struct_from_binary(ua.Acknowledge, body)
343
            self._max_chunk_size = msg.SendBufferSize
344
            return msg
345
        elif header.MessageType == ua.MessageType.Error:
346
            msg = struct_from_binary(ua.ErrorMessage, body)
347
            logger.warning("Received an error: %s", msg)
348
            return msg
349
        else:
350
            raise ua.UaError("Unsupported message type {0}".format(header.MessageType))
351
352
    def _receive(self, msg):
353
        self._check_incoming_chunk(msg)
354
        self._incoming_parts.append(msg)
355
        if msg.MessageHeader.ChunkType == ua.ChunkType.Intermediate:
356
            return None
357
        if msg.MessageHeader.ChunkType == ua.ChunkType.Abort:
358
            err = struct_from_binary(ua.ErrorMessage, ua.utils.Buffer(msg.Body))
359
            logger.warning("Message %s aborted: %s", msg, err)
360
            # specs Part 6, 6.7.3 say that aborted message shall be ignored
361
            # and SecureChannel should not be closed
362
            self._incoming_parts = []
363
            return None
364
        elif msg.MessageHeader.ChunkType == ua.ChunkType.Single:
365
            message = ua.Message(self._incoming_parts)
366
            self._incoming_parts = []
367
            return message
368
        else:
369
            raise ua.UaError("Unsupported chunk type: {0}".format(msg))
370