Passed
Pull Request — master (#114)
by Aldo
06:05
created

TraceManager._run_traces()   A

Complexity

Conditions 5

Size

Total Lines 21
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 9
CRAP Score 6.1384

Importance

Changes 0
Metric Value
cc 5
eloc 14
nop 1
dl 0
loc 21
ccs 9
cts 14
cp 0.6429
crap 6.1384
rs 9.2333
c 0
b 0
f 0
1
"""
2
    Trace Manager Class
3
"""
4
5
6 1
import dill
7 1
import time
8 1
from janus import Queue
9 1
from _thread import start_new_thread as new_thread
10 1
from collections import defaultdict
11 1
from typing import Optional
12
13 1
from kytos.core import log
14 1
from napps.amlight.sdntrace import settings
15 1
from napps.amlight.sdntrace.shared.switches import Switches
16 1
from napps.amlight.sdntrace.shared.colors import Colors
17 1
from napps.amlight.sdntrace.tracing.tracer import TracePath
18 1
from napps.amlight.sdntrace.tracing.trace_pkt import process_packet
19 1
from napps.amlight.sdntrace.tracing.trace_entries import TraceEntries
20 1
from napps.amlight.sdntrace.tracing.trace_msg import TraceMsg
21
22
23 1
class TraceManager(object):
24
    """
25
        The TraceManager class is the class responsible to
26
        manage all trace requests.
27
    """
28
29 1
    def __init__(self, controller):
30
        """Initialization of the TraceManager class
31
        Args:
32
             controller = Kytos.core.controller object
33
        """
34
        # Controller
35 1
        self.controller = controller
36
37
        # Trace ID used to distinguish each trace
38 1
        self._id = 30000
39
40
        # Trace queues
41 1
        self._request_dict = dict()
42 1
        self._request_queue = None
43 1
        self._results_queue = dict()
44 1
        self._running_traces:dict[int, TraceEntries] = dict()
45
46
        # Counters
47 1
        self._total_traces_requested = 0
48
49
        # PacketIn queue with Probes
50 1
        self._trace_pkt_in = defaultdict(Queue)
51
52 1
        self._is_tracing_running = False
53
54 1
        self._async_loop = None
55
        # To start traces
56 1
        self.run_traces()
57
58 1
    def stop_traces(self):
59 1
        if self._is_tracing_running:
60 1
            self._is_tracing_running = False
61 1
            self._request_queue.close()
62 1
        for trace_obj in self._running_traces.values():
63
            trace_obj.trace_ended = True
64
            if trace_obj.id in self._trace_pkt_in:
65
                self._trace_pkt_in[trace_obj.id].close()
66
67 1
    def is_tracing_running(self):
68
        return self._is_tracing_running
69
70 1
    def run_traces(self):
71
        """
72
        Create the task to search for traces _run_traces.
73
        """
74
        self._request_queue = Queue()
75
        self._is_tracing_running = True
76
        new_thread(self._run_traces, ())
77
78 1
    def _run_traces(self):
79
        """ Thread that will keep reading the self._request_dict
80
        queue looking for new trace requests to run.
81
        """
82 1
        try:
83 1
            while self.is_tracing_running():
84 1
                try:
85 1
                    if not self.limit_traces_reached():
86 1
                        request_id = self._request_queue.sync_q.get()
87 1
                        entries = self._request_dict[request_id]
88 1
                        new_thread(self._spawn_trace, (request_id, entries))
89
                        # After starting traces for new requests,
90
                        # remove them from self._request_dict
91 1
                        del self._request_dict[request_id]
92
                    else:
93
                        # Wait for traces to end
94
                        time.sleep(1)
95
                except Exception as error:  # pylint: disable=broad-except
96
                    log.error("Trace Error: %s" % error)
97
        except RuntimeError:
98
            log.warning("Ignored trace request while sdntrace was shutting down.")
99
100 1
    def _spawn_trace(self, trace_id, trace_entries):
101
        """ Once a request is found by the run_traces method,
102
        instantiate a TracePath class and run the tracepath
103
104
        Args:
105
            trace_id: trace request id
106
            trace_entries: TraceEntries class
107
        """
108
        
109 1
        log.info("Creating task to trace request id %s..." % trace_id)
110 1
        tracer = TracePath(self, trace_id, trace_entries)
111
112 1
        self._running_traces[trace_id] = tracer
113 1
        print(1)
114 1
        tracer.tracepath()
115 1
        print(2)
116
117 1
    def add_result(self, trace_id, result):
118
        """Used to save trace results to self._results_queue
119
120
        Args:
121
            trace_id: trace ID
122
            result: trace result generated using tracer
123
        """
124 1
        self._results_queue[trace_id] = result
125 1
        self._running_traces.pop(trace_id, None)
126
127 1
    def avoid_duplicated_request(self, entries):
128
        """Verify if any of the requested queries has the same entries.
129
        If so, ignore it
130
131
        Args:
132
            entries: entries provided by user via REST.
133
        Return:
134
            True: if exists a similar request
135
            False: otherwise
136
        """
137 1
        for request in self._request_dict.copy():
138 1
            if entries == self._request_dict[request]:
139 1
                return True
140 1
        return False
141
142 1
    @staticmethod
143 1
    async def is_entry_valid(entries):
144
        """ This method validates all params provided, including
145
        if the switch/dpid requested exists.
146
147
        Args:
148
            entries: dictionary with user request
149
        Returns:
150
            TraceEntries class
151
            Error msg
152
        """
153 1
        try:
154 1
            trace_entries = TraceEntries()
155 1
            trace_entries.load_entries(entries)
156 1
        except ValueError as msg:
157 1
            return str(msg)
158
159 1
        init_switch = Switches().get_switch(trace_entries.dpid)
160 1
        if isinstance(init_switch, bool):
161 1
            return "Unknown Switch"
162 1
        color = await Colors().aget_switch_color(init_switch.dpid)
163
164 1
        if len(color) == 0:
165 1
            return "Switch not Colored"
166
167
        # TODO: get Coloring API to confirm color_field
168
169 1
        return trace_entries
170
171 1
    def get_id(self):
172
        """ID generator for each trace. Useful in case
173
        of parallel requests
174
175
        Returns:
176
            integer to be the new request/trace id
177
        """
178 1
        self._id += 1
179 1
        return self._id
180
181 1
    def get_result(self, trace_id):
182
        """Used by external apps to get a trace result using the trace ID
183
184
        Returns:
185
            result from self._results_queue
186
            msg depending of the status (unknown, pending, or active)
187
        """
188 1
        print("Results queue -> ", self._results_queue)
189 1
        print("Running traces -> ", self._running_traces)
190 1
        print("Request dict -> ", self._request_dict)
191 1
        trace_id = int(trace_id)
192 1
        try:
193 1
            return self._results_queue[trace_id]
194 1
        except (ValueError, KeyError):
195 1
            if trace_id in self._running_traces:
196 1
                return {'msg': 'trace in process'}
197 1
            elif trace_id in self._request_dict:
198 1
                return {'msg': 'trace pending'}
199 1
            return {'msg': 'unknown trace id'}
200
201 1
    def get_results(self):
202
        """Used by external apps to get all trace results. Useful
203
        to see all requests and results
204
205
        Returns:
206
            list of results
207
        """
208
        return self._results_queue
209
210 1
    def limit_traces_reached(self):
211
        """ Control the number of active traces running in parallel. Protects the
212
        switches and avoid DoS.
213
214
        Returns:
215
            True: if the number of traces running is equal/more
216
                than settings.PARALLEL_TRACES
217
            False: if it is not.
218
        """
219 1
        if len(self._running_traces) >= settings.PARALLEL_TRACES:
220 1
            return True
221 1
        return False
222
223 1
    async def new_trace(self, trace_entries):
224
        """Receives external requests for traces.
225
226
        Args:
227
            trace_entries: TraceEntries Class
228
        Returns:
229
            int with the request/trace id
230
        """
231
232 1
        trace_id = self.get_id()
233
234
        # Add to request_queue
235 1
        self._request_dict[trace_id] = trace_entries
236 1
        try:
237 1
            await self._request_queue.async_q.put(trace_id)
238
        except RuntimeError:
239
            pass
240
241
        # Statistics
242 1
        self._total_traces_requested += 1
243
244 1
        return trace_id
245
246 1
    def number_pending_requests(self):
247
        """Used to check if there are entries to be traced
248
249
        Returns:
250
            length of self._request_dict
251
        """
252 1
        return len(self._request_dict)
253
254 1
    def get_unpickled_packet_eth(self, ethernet) -> Optional[TraceMsg]:
255
        """Unpickle PACKET_IN ethernet or catch errors."""
256
        try:
257
            msg = dill.loads(process_packet(ethernet))
258
        except dill.UnpicklingError as err:
259
            log.error(f"Error getting msg from PacketIn: {err}")
260
            return None
261
        return msg
262
263 1
    async def queue_probe_packet(self, event, ethernet, in_port, switch):
264
        """Used by sdntrace.packet_in_handler. Only tracing probes
265
        get to this point. Adds the PacketIn msg received to the
266
        trace_pkt_in queue.
267
268
        Args:
269
            event: PacketIn msg
270
            ethernet: ethernet frame
271
            in_port: in_port
272
            switch: kytos.core.switch.Switch() class
273
        """
274
        msg = dill.loads(process_packet(ethernet))
275
        if msg is None:
276
            return
277
        pkt_in = dict()
278
        pkt_in["dpid"] = switch.dpid
279
        pkt_in["in_port"] = in_port
280
        pkt_in["msg"] = msg
281
        pkt_in["ethernet"] = ethernet
282
        pkt_in["event"] = event
283
        request_id = pkt_in['msg'].request_id
284
285
        if request_id not in self._results_queue:
286
            # This queue stores all PacketIn message received
287
            try:
288
                await self._trace_pkt_in[request_id].async_q.put(pkt_in)
289
            except RuntimeError:
290
                # If queue was close do nothing
291
                pass
292
293
    # REST calls
294
295 1
    async def rest_new_trace(self, entries: dict):
296
        """Used for the REST PUT call
297
298
        Args:
299
            entries: user provided parameters to trace
300
        Returns:
301
            Trace_ID in JSON format
302
            Error msg if entries has invalid data
303
        """
304 1
        result = dict()
305 1
        trace_entries = await self.is_entry_valid(entries)
306 1
        if not isinstance(trace_entries, TraceEntries):
307 1
            result['result'] = {'error': trace_entries}
308 1
            return result
309
310 1
        if self.avoid_duplicated_request(entries):
311 1
            result['result'] = {'error': "Duplicated Trace Request ignored"}
312 1
            return result
313
314 1
        trace_id = await self.new_trace(trace_entries)
315 1
        result['result'] = {'trace_id': trace_id}
316 1
        return result
317
318 1
    def rest_get_result(self, trace_id):
319
        """Used for the REST GET call
320
321
        Returns:
322
            get_result in JSON format
323
        """
324 1
        return self.get_result(trace_id)
325
326 1
    def rest_list_results(self):
327
        """Used for the REST GET call
328
329
        Returns:
330
            get_results in JSON format
331
        """
332 1
        return self.get_results()
333
334 1
    def rest_list_stats(self):
335
        """ Used to export some info about the TraceManager.
336
        Total number of requests, number of active traces, number of
337
        pending traces, list of traces pending
338
        Returns:
339
                Total number of requests
340
                number of active traces
341
                number of pending traces
342
                list of traces pending
343
        """
344 1
        stats = dict()
345 1
        stats['number_of_requests'] = self._total_traces_requested
346 1
        stats['number_of_running_traces'] = len(self._running_traces)
347 1
        stats['number_of_pending_traces'] = len(self._request_dict)
348 1
        stats['list_of_pending_traces'] = self._results_queue
349
350
        return stats
351