Passed
Pull Request — master (#422)
by
unknown
02:18
created

ospd_openvas.mqtt   A

Complexity

Total Complexity 18

Size/Duplication

Total Lines 134
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 91
dl 0
loc 134
rs 10
c 0
b 0
f 0
wmc 18

8 Methods

Rating   Name   Duplication   Size   Complexity  
A OpenvasMQTTHandler.report_results() 0 13 2
A MQTTHandler.__init__() 0 5 1
B OpenvasMQTTHandler.insert_result() 0 40 6
A MQTTHandler.on_message() 0 3 1
A OpenvasMQTTHandler.set_status() 0 4 1
A MQTTHandler.publish() 0 4 1
A OpenvasMQTTHandler.on_message() 0 14 3
A OpenvasMQTTHandler.__init__() 0 23 3
1
import json
2
import logging
3
4
from threading import Timer
5
from queue import SimpleQueue
6
from types import FunctionType
7
from typing import Optional
8
9
import paho.mqtt.client as mqtt
10
11
logger = logging.getLogger(__name__)
12
13
14
class MQTTHandler:
15
    """Simple Handler for MQTT traffic."""
16
17
    def __init__(self, client_id: str, host: str):
18
        self.client = mqtt.Client(client_id, userdata=self)
19
        self.client.connect(host)
20
        self.client.on_message = self.on_message
21
        self.client.loop_start()
22
23
    def publish(self, topic, msg):
24
        """Publish Messages via MQTT"""
25
26
        self.client.publish(topic, msg)
27
28
    @staticmethod
29
    def on_message(client, userdata, msg):
30
        return
31
32
33
class OpenvasMQTTHandler(MQTTHandler):
34
    """MQTT Handler for Openvas related messages."""
35
36
    def __init__(
37
        self,
38
        host: str,
39
        report_result_function: Optional[FunctionType] = None,
40
        report_stat_function: Optional[FunctionType] = None,
41
    ):
42
        super().__init__(client_id="ospd-openvas", host=host)
43
44
        # Set userdata to access handler
45
        self.client.user_data_set(self)
46
47
        # Enable result handling when function is given
48
        if report_result_function:
49
            self.res_fun = report_result_function
50
            self.result_timer_min = {}
51
            self.result_timer_max = {}
52
            self.client.subscribe("scanner/results")
53
            self.result_dict = {}
54
55
        # Enable status handling when function is given
56
        if report_stat_function:
57
            self.stat_fun = report_stat_function
58
            self.client.subscribe("scanner/status")
59
60
    def insert_result(self, result: dict) -> None:
61
        """Insert given results into a list corresponding to the scan_id"""
62
63
        scan_id = result.pop("scan_id")
64
65
        if scan_id in self.result_timer_min:
66
            self.result_timer_min[scan_id].cancel()
67
68
        if not scan_id in self.result_dict:
69
            self.result_dict[scan_id] = SimpleQueue()
70
71
        self.result_dict[scan_id].put(result)
72
73
        self.result_timer_min[scan_id] = Timer(
74
            0.5,
75
            self.report_results,
76
            [
77
                self.res_fun,
78
                self.result_dict[scan_id],
79
                scan_id,
80
                self.result_timer_max[scan_id],
81
            ],
82
        )
83
        if (
84
            not scan_id in self.result_timer_max
85
            or scan_id in self.result_timer_max
86
            and not self.result_timer_max[scan_id].is_alive()
87
        ):
88
            self.result_timer_max[scan_id] = Timer(
89
                10,
90
                self.report_results,
91
                [
92
                    self.res_fun,
93
                    self.result_dict[scan_id],
94
                    scan_id,
95
                    self.result_timer_min[scan_id],
96
                ],
97
            )
98
            self.result_timer_max[scan_id].start()
99
        self.result_timer_min[scan_id].start()
100
101
    @staticmethod
102
    def report_results(
103
        res_fun,
104
        result_queue: SimpleQueue,
105
        scan_id: str,
106
        timer_to_reset: Timer,
107
    ):
108
        timer_to_reset.cancel()
109
        results_list = []
110
        while not result_queue.empty():
111
            results_list.append(result_queue.get())
112
        res_fun(results_list, scan_id)
113
        timer_to_reset.join()
114
115
    def set_status(self, status: dict) -> None:
116
        # Get Scan ID
117
        scan_id = status.pop("scan_id")
118
        logger.debug("Got status update from: %s", scan_id)
119
120
    @staticmethod
121
    def on_message(client, userdata, msg):
122
        """Insert results"""
123
        logger.debug("Got MQTT message in topic %s", msg.topic)
124
        try:
125
            # Load msg as dictionary
126
            json_data = json.loads(msg.payload)
127
            print(msg.topic)
128
129
            # Test for different plugins
130
            if msg.topic == "scanner/results":
131
                userdata.insert_result(json_data)
132
        except json.JSONDecodeError:
133
            logger.error("Got MQTT message in non-json format.")
134