Passed
Pull Request — master (#422)
by
unknown
01:32
created

ospd_openvas.mqtt   A

Complexity

Total Complexity 18

Size/Duplication

Total Lines 136
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 87
dl 0
loc 136
rs 10
c 0
b 0
f 0
wmc 18

7 Methods

Rating   Name   Duplication   Size   Complexity  
A MQTTHandler.__init__() 0 5 1
A MQTTHandler.on_message() 0 3 1
A MQTTHandler.publish() 0 4 1
A OpenvasMQTTHandler.report_results() 0 16 4
B OpenvasMQTTHandler.insert_result() 0 50 6
A OpenvasMQTTHandler.on_message() 0 14 3
A OpenvasMQTTHandler.__init__() 0 17 2
1
from abc import abstractstaticmethod
2
import json
3
import logging
4
5
from threading import Timer
6
from queue import SimpleQueue
7
from types import FunctionType
8
9
import paho.mqtt.client as mqtt
10
11
logger = logging.getLogger(__name__)
12
13
14
class MQTTHandler:
15
    """Simple Handler for MQTT traffic."""
16
17
    def __init__(self, client_id: str, host: str):
18
        self.client = mqtt.Client(client_id, userdata=self)
19
        self.client.connect(host)
20
        self.client.on_message = self.on_message
21
        self.client.loop_start()
22
23
    def publish(self, topic, msg):
24
        """Publish Messages via MQTT"""
25
26
        self.client.publish(topic, msg)
27
28
    @abstractstaticmethod
29
    def on_message(client, userdata, msg):
30
        raise NotImplementedError()
31
32
33
class OpenvasMQTTHandler(MQTTHandler):
34
    """MQTT Handler for Openvas related messages."""
35
36
    def __init__(
37
        self,
38
        host: str,
39
        report_result_function: FunctionType,
40
    ):
41
        super().__init__(client_id="ospd-openvas", host=host)
42
43
        # Set userdata to access handler
44
        self.client.user_data_set(self)
45
46
        # Enable result handling when function is given
47
        if report_result_function:
48
            self.res_fun = report_result_function
49
            self.result_timer_min = {}
50
            self.result_timer_max = {}
51
            self.client.subscribe("scanner/results")
52
            self.result_dict = {}
53
54
    def insert_result(self, result: dict) -> None:
55
        """Insert given results into a list corresponding to the scan_id and
56
        reports them after 0.5 seconds without new incoming results or after
57
        a maximum of 10 seconds."""
58
59
        # Get scan ID
60
        scan_id = result.pop("scan_id")
61
62
        # Reset min timer
63
        if scan_id in self.result_timer_min:
64
            self.result_timer_min[scan_id].cancel()
65
        else:
66
            self.result_timer_min[scan_id] = None
67
68
        # Init result queue
69
        if not scan_id in self.result_dict:
70
            self.result_dict[scan_id] = SimpleQueue()
71
72
        self.result_dict[scan_id].put(result)
73
74
        # Start max timer if it is not running
75
        if (
76
            not scan_id in self.result_timer_max
77
            or scan_id in self.result_timer_max
78
            and not self.result_timer_max[scan_id].is_alive()
79
        ):
80
            self.result_timer_max[scan_id] = Timer(
81
                10,
82
                self.report_results,
83
                [
84
                    self.res_fun,
85
                    self.result_dict[scan_id],
86
                    scan_id,
87
                    self.result_timer_min[scan_id],
88
                ],
89
            )
90
            self.result_timer_max[scan_id].start()
91
92
        # Start min timer
93
        self.result_timer_min[scan_id] = Timer(
94
            0.5,
95
            self.report_results,
96
            [
97
                self.res_fun,
98
                self.result_dict[scan_id],
99
                scan_id,
100
                self.result_timer_max[scan_id],
101
            ],
102
        )
103
        self.result_timer_min[scan_id].start()
104
105
    def report_results(
106
        self,
107
        res_fun,
108
        result_queue: SimpleQueue,
109
        scan_id: str,
110
        timer_to_reset: Timer = None,
111
    ):
112
        """Report results with given res_fun."""
113
        if timer_to_reset:
114
            timer_to_reset.cancel()
115
        results_list = []
116
        while not result_queue.empty():
117
            results_list.append(result_queue.get())
118
        res_fun(results_list, scan_id)
119
        if timer_to_reset:
120
            timer_to_reset.join()
121
122
    @staticmethod
123
    def on_message(client, userdata, msg):
124
        """Insert results"""
125
        logger.debug("Got MQTT message in topic %s", msg.topic)
126
        try:
127
            # Load msg as dictionary
128
            json_data = json.loads(msg.payload)
129
130
            # Test for different plugins
131
            if msg.topic == "scanner/results":
132
                userdata.insert_result(json_data)
133
        except json.JSONDecodeError:
134
            logger.error("Got MQTT message in non-json format.")
135
            logger.debug("Got: %s", msg.payload)
136