Passed
Pull Request — master (#422)
by
unknown
01:54
created

ospd_openvas.mqtt   A

Complexity

Total Complexity 14

Size/Duplication

Total Lines 101
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 63
dl 0
loc 101
rs 10
c 0
b 0
f 0
wmc 14

7 Methods

Rating   Name   Duplication   Size   Complexity  
A MQTTHandler.__init__() 0 5 1
A OpenvasMQTTHandler.publish_results() 0 6 2
A OpenvasMQTTHandler.insert_result() 0 23 3
A MQTTHandler.on_message() 0 3 1
A OpenvasMQTTHandler.set_status() 0 4 1
A OpenvasMQTTHandler.on_message() 0 13 3
A OpenvasMQTTHandler.__init__() 0 22 3
1
import json
2
import logging
3
4
from threading import Timer
5
from queue import SimpleQueue
6
7
import paho.mqtt.client as mqtt
8
9
logger = logging.getLogger(__name__)
10
11
12
class MQTTHandler:
13
    """Simple Handler for MQTT traffic."""
14
15
    def __init__(self, client_id: str, host: str):
16
        self.client = mqtt.Client(client_id, userdata=self)
17
        self.client.connect(host)
18
        self.client.on_message = self.on_message
19
        self.client.loop_start()
20
21
    @staticmethod
22
    def on_message(client, userdata, msg):
23
        return
24
25
26
class OpenvasMQTTHandler(MQTTHandler):
27
    """MQTT Handler for Openvas related messages."""
28
29
    def __init__(
30
        self,
31
        host: str,
32
        publish_result_function=None,
33
        publish_stat_function=None,
34
    ):
35
        super().__init__(client_id="ospd-openvas", host=host)
36
37
        # Set userdata to access handler
38
        self.client.user_data_set(self)
39
40
        # Enable result handling when function is given
41
        if publish_result_function:
42
            self.res_fun = publish_result_function
43
            self.result_timer = {}
44
            self.client.subscribe("scanner/results")
45
            self.result_dict = {}
46
47
        # Enable status handling when function is given
48
        if publish_stat_function:
49
            self.stat_fun = publish_stat_function
50
            self.client.subscribe("scanner/status")
51
52
    def insert_result(self, result: dict) -> None:
53
        """Insert given results into a list corresponding to the scan_id"""
54
        # Get Scan ID
55
        scan_id = result.pop("scan_id")
56
57
        # Reset Pub Timer for Scan ID
58
        if scan_id in self.result_timer:
59
            self.result_timer[scan_id].cancel()
60
61
        # Create List for new Scan ID
62
        if not scan_id in self.result_dict:
63
            self.result_dict[scan_id] = SimpleQueue()
64
65
        # Append Result for ID
66
        self.result_dict[scan_id].put(result)
67
68
        # Set Timer for publishing results
69
        self.result_timer[scan_id] = Timer(
70
            1,
71
            self.publish_results,
72
            [self.res_fun, self.result_dict[scan_id], scan_id],
73
        )
74
        self.result_timer[scan_id].start()
75
76
    @staticmethod
77
    def publish_results(res_fun, result_queue: SimpleQueue, scan_id: str):
78
        results_list = []
79
        while not result_queue.empty():
80
            results_list.append(result_queue.get())
81
        res_fun(results_list, scan_id)
82
83
    def set_status(self, status: dict) -> None:
84
        # Get Scan ID
85
        scan_id = status.pop("scan_id")
86
        logger.debug("Got status update from: %s", scan_id)
87
88
    @staticmethod
89
    def on_message(client, userdata, msg):
90
        logger.debug("Got MQTT message in topic %s", msg.topic)
91
        try:
92
            # Load msg as dictionary
93
            json_data = json.loads(msg.payload)
94
            print(msg.topic)
95
96
            # Test for different plugins
97
            if msg.topic == "scanner/results":
98
                userdata.insert_result(json_data)
99
        except json.JSONDecodeError:
100
            logger.error("Got MQTT message in non-json format.")
101