Completed
Push — master ( 784599...e17a09 )
by
unknown
18s queued 12s
created

ospd_openvas.mqtt.MQTTHandler.publish()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
from abc import abstractstaticmethod
2
import json
3
import logging
4
5
from threading import Timer
6
from queue import SimpleQueue
7
from types import FunctionType
8
9
import paho.mqtt.client as mqtt
10
11
logger = logging.getLogger(__name__)
12
13
14
class MQTTHandler:
15
    """Simple Handler for MQTT traffic."""
16
17
    def __init__(self, client_id: str, host: str):
18
        self.client = mqtt.Client(
19
            client_id, userdata=self, protocol=mqtt.MQTTv5
20
        )
21
        self.client.connect(host)
22
        self.client.on_message = self.on_message
23
        self.client.loop_start()
24
25
    def publish(self, topic, msg):
26
        """Publish Messages via MQTT"""
27
        self.client.publish(topic, msg)
28
        logger.debug("Published message on topic %s.", topic)
29
30
    @abstractstaticmethod
31
    def on_message(client, userdata, msg):
32
        raise NotImplementedError()
33
34
35
class OpenvasMQTTHandler(MQTTHandler):
36
    """MQTT Handler for Openvas related messages."""
37
38
    def __init__(
39
        self,
40
        host: str,
41
        report_result_function: FunctionType,
42
    ):
43
        super().__init__(client_id="ospd-openvas", host=host)
44
45
        # Set userdata to access handler
46
        self.client.user_data_set(self)
47
48
        # Enable result handling when function is given
49
        if report_result_function:
50
            self.res_fun = report_result_function
51
            self.client.subscribe("scanner/results")
52
            self.result_dict = {}
53
54
    def insert_result(self, result: dict) -> None:
55
        """Inserts result into a queue. Queue gets emptied after 0.25 seconds
56
        after first result is inserted"""
57
58
        # Get scan ID
59
        scan_id = result.pop("scan_id")
60
61
        # Init result queue
62
        if not scan_id in self.result_dict:
63
            self.result_dict[scan_id] = SimpleQueue()
64
65
        timer = None
66
        # Setup Timer when result queue is empty
67
        if self.result_dict[scan_id].empty():
68
            timer = Timer(
69
                0.25,
70
                self.report_results,
71
                [self.res_fun, self.result_dict[scan_id], scan_id],
72
            )
73
74
        self.result_dict[scan_id].put(result)
75
76
        if timer:
77
            timer.start()
78
79
    @staticmethod
80
    def report_results(
81
        res_fun: FunctionType,
82
        result_queue: SimpleQueue,
83
        scan_id: str,
84
    ):
85
        """Report results with given res_fun."""
86
87
        # Create and fill result list
88
        results_list = []
89
        while not result_queue.empty():
90
            results_list.append(result_queue.get())
91
92
        # Insert results into scan table
93
        res_fun(results_list, scan_id)
94
95
    @staticmethod
96
    def on_message(client, userdata, msg):
97
        """Insert results"""
98
        try:
99
            # Load msg as dictionary
100
            json_data = json.loads(msg.payload)
101
102
            # Test for different plugins
103
            if msg.topic == "scanner/results":
104
                userdata.insert_result(json_data)
105
        except json.JSONDecodeError:
106
            logger.error("Got MQTT message in non-json format.")
107
            logger.debug("Got: %s", msg.payload)
108