Passed
Pull Request — master (#422)
by
unknown
01:31
created

ospd_openvas.mqtt   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 105
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 62
dl 0
loc 105
rs 10
c 0
b 0
f 0
wmc 13

7 Methods

Rating   Name   Duplication   Size   Complexity  
A OpenvasMQTTHandler.report_results() 0 15 2
A MQTTHandler.__init__() 0 7 1
A OpenvasMQTTHandler.insert_result() 0 21 3
A MQTTHandler.on_message() 0 3 1
A MQTTHandler.publish() 0 4 1
A OpenvasMQTTHandler.on_message() 0 13 3
A OpenvasMQTTHandler.__init__() 0 15 2
1
from abc import abstractstaticmethod
2
import json
3
import logging
4
5
from threading import Timer
6
from queue import SimpleQueue
7
from types import FunctionType
8
9
import paho.mqtt.client as mqtt
10
11
logger = logging.getLogger(__name__)
12
13
14
class MQTTHandler:
15
    """Simple Handler for MQTT traffic."""
16
17
    def __init__(self, client_id: str, host: str):
18
        self.client = mqtt.Client(
19
            client_id, userdata=self, protocol=mqtt.MQTTv5
20
        )
21
        self.client.connect(host)
22
        self.client.on_message = self.on_message
23
        self.client.loop_start()
24
25
    def publish(self, topic, msg):
26
        """Publish Messages via MQTT"""
27
        self.client.publish(topic, msg)
28
        logger.debug("Published message on topic %s.", topic)
29
30
    @abstractstaticmethod
31
    def on_message(client, userdata, msg):
32
        raise NotImplementedError()
33
34
35
class OpenvasMQTTHandler(MQTTHandler):
36
    """MQTT Handler for Openvas related messages."""
37
38
    def __init__(
39
        self,
40
        host: str,
41
        report_result_function: FunctionType,
42
    ):
43
        super().__init__(client_id="ospd-openvas", host=host)
44
45
        # Set userdata to access handler
46
        self.client.user_data_set(self)
47
48
        # Enable result handling when function is given
49
        if report_result_function:
50
            self.res_fun = report_result_function
51
            self.client.subscribe("scanner/results")
52
            self.result_dict = {}
53
54
    def insert_result(self, result: dict) -> None:
55
        """Insert given results into a list corresponding to the scan_id and
56
        reports them after 0.5 seconds without new incoming results or after
57
        a maximum of 10 seconds."""
58
59
        # Get scan ID
60
        scan_id = result.pop("scan_id")
61
62
        # Init result queue
63
        if not scan_id in self.result_dict:
64
            self.result_dict[scan_id] = SimpleQueue()
65
66
        # Start Timer when result queue is empty
67
        if self.result_dict[scan_id].empty():
68
            Timer(
69
                0.25,
70
                self.report_results,
71
                [self.res_fun, self.result_dict[scan_id], scan_id],
72
            ).start()
73
74
        self.result_dict[scan_id].put(result)
75
76
    @staticmethod
77
    def report_results(
78
        res_fun: FunctionType,
79
        result_queue: SimpleQueue,
80
        scan_id: str,
81
    ):
82
        """Report results with given res_fun."""
83
84
        # Create and fill result list
85
        results_list = []
86
        while not result_queue.empty():
87
            results_list.append(result_queue.get())
88
89
        # Insert results into scan table
90
        res_fun(results_list, scan_id)
91
92
    @staticmethod
93
    def on_message(client, userdata, msg):
94
        """Insert results"""
95
        try:
96
            # Load msg as dictionary
97
            json_data = json.loads(msg.payload)
98
99
            # Test for different plugins
100
            if msg.topic == "scanner/results":
101
                userdata.insert_result(json_data)
102
        except json.JSONDecodeError:
103
            logger.error("Got MQTT message in non-json format.")
104
            logger.debug("Got: %s", msg.payload)
105