Passed
Pull Request — master (#114)
by Aldo
04:31
created

TraceManager._run_traces()   A

Complexity

Conditions 5

Size

Total Lines 21
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 9
CRAP Score 6.1384

Importance

Changes 0
Metric Value
cc 5
eloc 14
nop 1
dl 0
loc 21
ccs 9
cts 14
cp 0.6429
crap 6.1384
rs 9.2333
c 0
b 0
f 0
1
"""
2
    Trace Manager Class
3
"""
4
5
6 1
import dill
7 1
import time
8 1
from janus import Queue
9 1
from _thread import start_new_thread as new_thread
10 1
from collections import defaultdict
11 1
from typing import Optional
12
13 1
from kytos.core import log
14 1
from napps.amlight.sdntrace import settings
15 1
from napps.amlight.sdntrace.shared.switches import Switches
16 1
from napps.amlight.sdntrace.shared.colors import Colors
17 1
from napps.amlight.sdntrace.tracing.tracer import TracePath
18 1
from napps.amlight.sdntrace.tracing.trace_pkt import process_packet
19 1
from napps.amlight.sdntrace.tracing.trace_entries import TraceEntries
20 1
from napps.amlight.sdntrace.tracing.trace_msg import TraceMsg
21
22
23 1
class TraceManager(object):
24
    """
25
        The TraceManager class is the class responsible to
26
        manage all trace requests.
27
    """
28
29 1
    def __init__(self, controller):
30
        """Initialization of the TraceManager class
31
        Args:
32
             controller = Kytos.core.controller object
33
        """
34
        # Controller
35 1
        self.controller = controller
36
37
        # Trace ID used to distinguish each trace
38 1
        self._id = 30000
39
40
        # Trace queues
41 1
        self._request_dict = dict()
42 1
        self._request_queue = None
43 1
        self._results_queue = dict()
44 1
        self._running_traces:dict[int, TraceEntries] = dict()
45
46
        # Counters
47 1
        self._total_traces_requested = 0
48
49
        # PacketIn queue with Probes
50 1
        self._trace_pkt_in = defaultdict(Queue)
51
52 1
        self._is_tracing_running = False
53
54 1
        self._async_loop = None
55
        # To start traces
56 1
        self.run_traces()
57
58 1
    def stop_traces(self):
59 1
        if self._is_tracing_running:
60 1
            self._is_tracing_running = False
61 1
            self._request_queue.close()
62 1
        for trace_obj in self._running_traces.values():
63
            trace_obj.trace_ended = True
64
            if trace_obj.id in self._trace_pkt_in:
65
                self._trace_pkt_in[trace_obj.id].close()
66
67 1
    def is_tracing_running(self):
68
        return self._is_tracing_running
69
70 1
    def run_traces(self):
71
        """
72
        Create the task to search for traces _run_traces.
73
        """
74
        self._request_queue = Queue()
75
        self._is_tracing_running = True
76
        new_thread(self._run_traces, ())
77
78 1
    def _run_traces(self):
79
        """ Thread that will keep reading the self._request_dict
80
        queue looking for new trace requests to run.
81
        """
82 1
        try:
83 1
            while self.is_tracing_running():
84 1
                try:
85 1
                    if not self.limit_traces_reached():
86 1
                        request_id = self._request_queue.sync_q.get()
87 1
                        entries = self._request_dict[request_id]
88 1
                        new_thread(self._spawn_trace, (request_id, entries))
89
                        # After starting traces for new requests,
90
                        # remove them from self._request_dict
91 1
                        del self._request_dict[request_id]
92
                    else:
93
                        # Wait for traces to end
94
                        time.sleep(1)
95
                except Exception as error:  # pylint: disable=broad-except
96
                    log.error("Trace Error: %s" % error)
97
        except RuntimeError:
98
            log.warning("Ignored trace request while sdntrace was shutting down.")
99
100 1
    def _spawn_trace(self, trace_id, trace_entries):
101
        """ Once a request is found by the run_traces method,
102
        instantiate a TracePath class and run the tracepath
103
104
        Args:
105
            trace_id: trace request id
106
            trace_entries: TraceEntries class
107
        """
108
        
109 1
        log.info("Creating task to trace request id %s..." % trace_id)
110 1
        tracer = TracePath(self, trace_id, trace_entries)
111
112 1
        self._running_traces[trace_id] = tracer
113 1
        tracer.tracepath()
114
115 1
    def add_result(self, trace_id, result):
116
        """Used to save trace results to self._results_queue
117
118
        Args:
119
            trace_id: trace ID
120
            result: trace result generated using tracer
121
        """
122 1
        self._results_queue[trace_id] = result
123 1
        self._running_traces.pop(trace_id, None)
124
125 1
    def avoid_duplicated_request(self, entries):
126
        """Verify if any of the requested queries has the same entries.
127
        If so, ignore it
128
129
        Args:
130
            entries: entries provided by user via REST.
131
        Return:
132
            True: if exists a similar request
133
            False: otherwise
134
        """
135 1
        for request in self._request_dict.copy():
136 1
            if entries == self._request_dict[request]:
137 1
                return True
138 1
        return False
139
140 1
    @staticmethod
141 1
    async def is_entry_valid(entries):
142
        """ This method validates all params provided, including
143
        if the switch/dpid requested exists.
144
145
        Args:
146
            entries: dictionary with user request
147
        Returns:
148
            TraceEntries class
149
            Error msg
150
        """
151 1
        try:
152 1
            trace_entries = TraceEntries()
153 1
            trace_entries.load_entries(entries)
154 1
        except ValueError as msg:
155 1
            return str(msg)
156
157 1
        init_switch = Switches().get_switch(trace_entries.dpid)
158 1
        if isinstance(init_switch, bool):
159 1
            return "Unknown Switch"
160 1
        color = await Colors().aget_switch_color(init_switch.dpid)
161
162 1
        if len(color) == 0:
163 1
            return "Switch not Colored"
164
165
        # TODO: get Coloring API to confirm color_field
166
167 1
        return trace_entries
168
169 1
    def get_id(self):
170
        """ID generator for each trace. Useful in case
171
        of parallel requests
172
173
        Returns:
174
            integer to be the new request/trace id
175
        """
176 1
        self._id += 1
177 1
        return self._id
178
179 1
    def get_result(self, trace_id):
180
        """Used by external apps to get a trace result using the trace ID
181
182
        Returns:
183
            result from self._results_queue
184
            msg depending of the status (unknown, pending, or active)
185
        """
186 1
        trace_id = int(trace_id)
187 1
        try:
188 1
            return self._results_queue[trace_id]
189 1
        except (ValueError, KeyError):
190 1
            if trace_id in self._running_traces:
191 1
                return {'msg': 'trace in process'}
192 1
            elif trace_id in self._request_dict:
193 1
                return {'msg': 'trace pending'}
194 1
            return {'msg': 'unknown trace id'}
195
196 1
    def get_results(self):
197
        """Used by external apps to get all trace results. Useful
198
        to see all requests and results
199
200
        Returns:
201
            list of results
202
        """
203
        return self._results_queue
204
205 1
    def limit_traces_reached(self):
206
        """ Control the number of active traces running in parallel. Protects the
207
        switches and avoid DoS.
208
209
        Returns:
210
            True: if the number of traces running is equal/more
211
                than settings.PARALLEL_TRACES
212
            False: if it is not.
213
        """
214 1
        if len(self._running_traces) >= settings.PARALLEL_TRACES:
215 1
            return True
216 1
        return False
217
218 1
    async def new_trace(self, trace_entries):
219
        """Receives external requests for traces.
220
221
        Args:
222
            trace_entries: TraceEntries Class
223
        Returns:
224
            int with the request/trace id
225
        """
226
227 1
        trace_id = self.get_id()
228
229
        # Add to request_queue
230 1
        self._request_dict[trace_id] = trace_entries
231 1
        try:
232 1
            await self._request_queue.async_q.put(trace_id)
233
        except RuntimeError:
234
            pass
235
236
        # Statistics
237 1
        self._total_traces_requested += 1
238
239 1
        return trace_id
240
241 1
    def number_pending_requests(self):
242
        """Used to check if there are entries to be traced
243
244
        Returns:
245
            length of self._request_dict
246
        """
247 1
        return len(self._request_dict)
248
249 1
    def get_unpickled_packet_eth(self, ethernet) -> Optional[TraceMsg]:
250
        """Unpickle PACKET_IN ethernet or catch errors."""
251
        try:
252
            msg = dill.loads(process_packet(ethernet))
253
        except dill.UnpicklingError as err:
254
            log.error(f"Error getting msg from PacketIn: {err}")
255
            return None
256
        return msg
257
258 1
    async def queue_probe_packet(self, event, ethernet, in_port, switch):
259
        """Used by sdntrace.packet_in_handler. Only tracing probes
260
        get to this point. Adds the PacketIn msg received to the
261
        trace_pkt_in queue.
262
263
        Args:
264
            event: PacketIn msg
265
            ethernet: ethernet frame
266
            in_port: in_port
267
            switch: kytos.core.switch.Switch() class
268
        """
269
        msg = dill.loads(process_packet(ethernet))
270
        if msg is None:
271
            return
272
        pkt_in = dict()
273
        pkt_in["dpid"] = switch.dpid
274
        pkt_in["in_port"] = in_port
275
        pkt_in["msg"] = msg
276
        pkt_in["ethernet"] = ethernet
277
        pkt_in["event"] = event
278
        request_id = pkt_in['msg'].request_id
279
280
        if request_id not in self._results_queue:
281
            # This queue stores all PacketIn message received
282
            try:
283
                await self._trace_pkt_in[request_id].async_q.put(pkt_in)
284
            except RuntimeError:
285
                # If queue was close do nothing
286
                pass
287
288
    # REST calls
289
290 1
    async def rest_new_trace(self, entries: dict):
291
        """Used for the REST PUT call
292
293
        Args:
294
            entries: user provided parameters to trace
295
        Returns:
296
            Trace_ID in JSON format
297
            Error msg if entries has invalid data
298
        """
299 1
        result = dict()
300 1
        trace_entries = await self.is_entry_valid(entries)
301 1
        if not isinstance(trace_entries, TraceEntries):
302 1
            result['result'] = {'error': trace_entries}
303 1
            return result
304
305 1
        if self.avoid_duplicated_request(entries):
306 1
            result['result'] = {'error': "Duplicated Trace Request ignored"}
307 1
            return result
308
309 1
        trace_id = await self.new_trace(trace_entries)
310 1
        result['result'] = {'trace_id': trace_id}
311 1
        return result
312
313 1
    def rest_get_result(self, trace_id):
314
        """Used for the REST GET call
315
316
        Returns:
317
            get_result in JSON format
318
        """
319 1
        return self.get_result(trace_id)
320
321 1
    def rest_list_results(self):
322
        """Used for the REST GET call
323
324
        Returns:
325
            get_results in JSON format
326
        """
327 1
        return self.get_results()
328
329 1
    def rest_list_stats(self):
330
        """ Used to export some info about the TraceManager.
331
        Total number of requests, number of active traces, number of
332
        pending traces, list of traces pending
333
        Returns:
334
                Total number of requests
335
                number of active traces
336
                number of pending traces
337
                list of traces pending
338
        """
339 1
        stats = dict()
340 1
        stats['number_of_requests'] = self._total_traces_requested
341 1
        stats['number_of_running_traces'] = len(self._running_traces)
342 1
        stats['number_of_pending_traces'] = len(self._request_dict)
343 1
        stats['list_of_pending_traces'] = self._results_queue
344
345
        return stats
346