Passed
Push — master ( 97820f...f3de37 )
by Olivier
02:29
created

BinaryServer._await_closing_tasks()   A

Complexity

Conditions 5

Size

Total Lines 13
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 12
nop 2
dl 0
loc 13
rs 9.3333
c 0
b 0
f 0
1
"""
2
Socket server forwarding request to internal server
3
"""
4
import logging
5
import asyncio
6
from typing import Optional
7
8
from ..ua.ua_binary import header_from_binary
9
from ..common.utils import Buffer, NotEnoughData
10
from .uaprocessor import UaProcessor
11
from .internal_server import InternalServer
12
13
logger = logging.getLogger(__name__)
14
15
16
class OPCUAProtocol(asyncio.Protocol):
17
    """
18
    Instantiated for every connection.
19
    """
20
21
    def __init__(self, iserver: InternalServer, policies, clients, closing_tasks):
22
        self.peer_name = None
23
        self.transport = None
24
        self.processor = None
25
        self._buffer = b''
26
        self.iserver: InternalServer = iserver
27
        self.policies = policies
28
        self.clients = clients
29
        self.closing_tasks = closing_tasks
30
        self.messages = asyncio.Queue()
31
        self._task = None
32
33
    def __str__(self):
34
        return f'OPCUAProtocol({self.peer_name}, {self.processor.session})'
35
36
    __repr__ = __str__
37
38
    def connection_made(self, transport):
39
        self.peer_name = transport.get_extra_info('peername')
40
        logger.info('New connection from %s', self.peer_name)
41
        self.transport = transport
42
        self.processor = UaProcessor(self.iserver, self.transport)
43
        self.processor.set_policies(self.policies)
44
        self.iserver.asyncio_transports.append(transport)
45
        self.clients.append(self)
46
        self._task = self.iserver.loop.create_task(self._process_received_message_loop())
47
48
    def connection_lost(self, ex):
49
        logger.info('Lost connection from %s, %s', self.peer_name, ex)
50
        self.transport.close()
51
        self.iserver.asyncio_transports.remove(self.transport)
52
        closing_task = self.iserver.loop.create_task(self.processor.close())
53
        self.closing_tasks.append(closing_task)
54
        if self in self.clients:
55
            self.clients.remove(self)
56
        self.messages.put_nowait((None, None))
57
        self._task.cancel()
58
59
    def data_received(self, data):
60
        self._buffer += data
61
        # try to parse the incoming data
62
        while self._buffer:
63
            try:
64
                buf = Buffer(self._buffer)
65
                try:
66
                    header = header_from_binary(buf)
67
                except NotEnoughData:
68
                    logger.debug('Not enough data while parsing header from client, waiting for more')
69
                    return
70
                if len(buf) < header.body_size:
71
                    logger.debug('We did not receive enough data from client. Need %s got %s', header.body_size,
72
                                 len(buf))
73
                    return
74
                # we have a complete message
75
                self.messages.put_nowait((header, buf))
76
                self._buffer = self._buffer[(header.header_size + header.body_size):]
77
            except Exception:
78
                logger.exception('Exception raised while parsing message from client')
79
                return
80
81
    async def _process_received_message_loop(self):
82
        """
83
        Take message from the queue and try to process it.
84
        """
85
        while True:
86
            header, buf = await self.messages.get()
87
            if header is None and buf is None:
88
                # Connection was closed, end task
89
                break
90
            try:
91
                await self._process_one_msg(header, buf)
92
            except Exception:
93
                logger.exception('Exception raised while processing message from client')
94
95
    async def _process_one_msg(self, header, buf):
96
        logger.debug('_process_received_message %s %s', header.body_size, len(buf))
97
        ret = await self.processor.process(header, buf)
98
        if not ret:
99
            logger.info('processor returned False, we close connection from %s', self.peer_name)
100
            self.transport.close()
101
            return
102
103
104
class BinaryServer:
105
    def __init__(self, internal_server: InternalServer, hostname, port):
106
        self.logger = logging.getLogger(__name__)
107
        self.hostname = hostname
108
        self.port = port
109
        self.iserver: InternalServer = internal_server
110
        self._server: Optional[asyncio.AbstractServer] = None
111
        self._policies = []
112
        self.clients = []
113
        self.closing_tasks = []
114
        self.cleanup_task = None
115
116
    def set_policies(self, policies):
117
        self._policies = policies
118
119
    def _make_protocol(self):
120
        """Protocol Factory"""
121
        return OPCUAProtocol(
122
            iserver=self.iserver,
123
            policies=self._policies,
124
            clients=self.clients,
125
            closing_tasks=self.closing_tasks,
126
        )
127
128
    async def start(self):
129
        self._server = await self.iserver.loop.create_server(self._make_protocol, self.hostname, self.port)
130
        # get the port and the hostname from the created server socket
131
        # only relevant for dynamic port asignment (when self.port == 0)
132
        if self.port == 0 and len(self._server.sockets) == 1:
133
            # will work for AF_INET and AF_INET6 socket names
134
            # these are to only families supported by the create_server call
135
            sockname = self._server.sockets[0].getsockname()
136
            self.hostname = sockname[0]
137
            self.port = sockname[1]
138
        self.logger.info('Listening on %s:%s', self.hostname, self.port)
139
        self.cleanup_task = self.iserver.loop.create_task(self._await_closing_tasks())
140
141
    async def stop(self):
142
        self.logger.info('Closing asyncio socket server')
143
        for transport in self.iserver.asyncio_transports:
144
            transport.close()
145
146
        # stop cleanup process and run it a last time
147
        self.cleanup_task.cancel()
148
        try:
149
            await self.cleanup_task
150
        except asyncio.CancelledError:
151
            pass
152
        await self._await_closing_tasks(recursive=False)
153
154
        if self._server:
155
            self.iserver.loop.call_soon(self._server.close)
156
            await self._server.wait_closed()
157
158
    async def _await_closing_tasks(self, recursive=True):
159
        while self.closing_tasks:
160
            task = self.closing_tasks.pop()
161
            try:
162
                await task
163
            except asyncio.CancelledError:
164
                # this means a stop request has been sent, it should not be catched
165
                raise
166
            except Exception:
167
                logger.exception("Unexpected crash in BinaryServer._await_closing_tasks")
168
        if recursive:
169
            await asyncio.sleep(10)
170
            await self._await_closing_tasks()
171