Passed
Pull Request — master (#28)
by Vinicius
13:22 queued 10:18
created

build.main.Main._map_endpoints_from_link_ids()   A

Complexity

Conditions 3

Size

Total Lines 12
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 11
CRAP Score 3

Importance

Changes 0
Metric Value
cc 3
eloc 11
nop 2
dl 0
loc 12
rs 9.85
c 0
b 0
f 0
ccs 11
cts 11
cp 1
crap 3
1
"""Main module of kytos/pathfinder Kytos Network Application."""
2
3 1
from threading import Lock
4
5 1
from flask import jsonify, request
6 1
from kytos.core import KytosNApp, log, rest
7 1
from kytos.core.helpers import listen_to
8 1
from napps.kytos.pathfinder.graph import KytosGraph
9
# pylint: disable=import-error
10 1
from werkzeug.exceptions import BadRequest
11
12
13 1
class Main(KytosNApp):
14
    """
15
    Main class of kytos/pathfinder NApp.
16
17
    This class is the entry point for this napp.
18
    """
19
20 1
    def setup(self):
21
        """Create a graph to handle the nodes and edges."""
22 1
        self.graph = KytosGraph()
23 1
        self._topology = None
24 1
        self._lock = Lock()
25
26 1
    def execute(self):
27
        """Do nothing."""
28
29 1
    def shutdown(self):
30
        """Shutdown the napp."""
31
32 1
    def _filter_paths_le_cost(self, paths, max_cost):
33
        """Filter by paths where the cost is le <= max_cost."""
34 1
        if not max_cost:
35 1
            return paths
36 1
        return [path for path in paths if path["cost"] <= max_cost]
37
38 1
    def _map_endpoints_from_link_ids(self, link_ids: list[str]) -> dict:
39
        """Map endpoints from link ids."""
40 1
        endpoints = {}
41 1
        for link_id in link_ids:
42 1
            try:
43 1
                link = self._topology.links[link_id]
44 1
                endpoint_a, endpoint_b = link.endpoint_a, link.endpoint_b
45 1
                endpoints[(endpoint_a.id, endpoint_b.id)] = link
46 1
                endpoints[(endpoint_b.id, endpoint_a.id)] = link
47 1
            except KeyError:
48 1
                pass
49 1
        return endpoints
50
51 1
    def _find_any_link_ids(
52
        self, paths: list[dict], link_ids: list[str]
53
    ) -> set[int]:
54
        """Find indexes of the paths that contain any of the link ids."""
55 1
        endpoints_links = self._map_endpoints_from_link_ids(link_ids)
56 1
        indexes: set[int] = set()
57 1
        for idx, path in enumerate(paths):
58 1
            head, tail = path["hops"][:-1], path["hops"][1:]
59 1
            if idx in indexes:
60
                continue
61 1
            for endpoints in zip(head, tail):
62 1
                if endpoints in endpoints_links:
63 1
                    indexes.add(idx)
64 1
                    break
65 1
        return indexes
66
67 1
    def _filter_paths_undesired_links(
68
        self, paths: list[dict], undesired: list[str]
69
    ) -> list[dict]:
70
        """Filter by undesired_links, it performs a logical OR."""
71 1
        if not undesired:
72 1
            return paths
73 1
        excluded_indexes = self._find_any_link_ids(paths, undesired)
74 1
        return [path for idx, path in enumerate(paths) if idx not in excluded_indexes]
75
76 1
    def _filter_paths_desired_links(
77
        self, paths: list[dict], desired: list[str]
78
    ) -> list[dict]:
79
        """Filter by desired_links, it performs a logical OR."""
80 1
        if not desired:
81 1
            return paths
82 1
        included_indexes = self._find_any_link_ids(paths, desired)
83 1
        return [path for idx, path in enumerate(paths) if idx in included_indexes]
84
85 1
    def _validate_payload(self, data):
86
        """Validate shortest_path v2/ POST endpoint."""
87 1
        if data.get("desired_links"):
88 1
            if not isinstance(data["desired_links"], list):
89
                raise BadRequest(
90
                    f"TypeError: desired_links is supposed to be a list."
91
                    f" type: {type(data['desired_links'])}"
92
                )
93
94 1
        if data.get("undesired_links"):
95
            if not isinstance(data["undesired_links"], list):
96
                raise BadRequest(
97
                    f"TypeError: undesired_links is supposed to be a list."
98
                    f" type: {type(data['undesired_links'])}"
99
                )
100
101 1
        parameter = data.get("parameter")
102 1
        spf_attr = data.get("spf_attribute")
103 1
        if not spf_attr:
104 1
            spf_attr = parameter or "hop"
105 1
        data["spf_attribute"] = spf_attr
106
107 1
        if spf_attr not in self.graph.spf_edge_data_cbs:
108
            raise BadRequest(
109
                "Invalid 'spf_attribute'. Valid values: "
110
                f"{', '.join(self.graph.spf_edge_data_cbs.keys())}"
111
            )
112
113 1
        try:
114 1
            data["spf_max_paths"] = max(int(data.get("spf_max_paths", 2)), 1)
115
        except (TypeError, ValueError):
116
            raise BadRequest(
117
                f"spf_max_paths {data.get('spf_max_pahts')} must be an int"
118
            )
119
120 1
        spf_max_path_cost = data.get("spf_max_path_cost")
121 1
        if spf_max_path_cost:
122
            try:
123
                spf_max_path_cost = max(int(spf_max_path_cost), 1)
124
                data["spf_max_path_cost"] = spf_max_path_cost
125
            except (TypeError, ValueError):
126
                raise BadRequest(
127
                    f"spf_max_path_cost {data.get('spf_max_path_cost')} must"
128
                    " be an int"
129
                )
130
131 1
        data["mandatory_metrics"] = data.get("mandatory_metrics", {})
132 1
        data["flexible_metrics"] = data.get("flexible_metrics", {})
133
134 1
        try:
135 1
            minimum_hits = data.get("minimum_flexible_hits")
136 1
            if minimum_hits:
137 1
                minimum_hits = min(
138
                    len(data["flexible_metrics"]), max(0, int(minimum_hits))
139
                )
140 1
            data["minimum_flexible_hits"] = minimum_hits
141
        except (TypeError, ValueError):
142
            raise BadRequest(
143
                f"minimum_hits {data.get('minimum_flexible_hits')} must be an int"
144
            )
145
146 1
        return data
147
148 1
    @rest("v2/", methods=["POST"])
149 1
    def shortest_path(self):
150
        """Calculate the best path between the source and destination."""
151 1
        data = request.get_json()
152 1
        data = self._validate_payload(data)
153
154 1
        desired = data.get("desired_links")
155 1
        undesired = data.get("undesired_links")
156
157 1
        spf_attr = data.get("spf_attribute")
158 1
        spf_max_paths = data.get("spf_max_paths")
159 1
        spf_max_path_cost = data.get("spf_max_path_cost")
160 1
        mandatory_metrics = data.get("mandatory_metrics")
161 1
        flexible_metrics = data.get("flexible_metrics")
162 1
        minimum_hits = data.get("minimum_flexible_hits")
163 1
        log.debug(f"POST v2/ payload data: {data}")
164
165 1
        try:
166 1
            with self._lock:
167 1
                if any([mandatory_metrics, flexible_metrics]):
168 1
                    paths = self.graph.constrained_k_shortest_paths(
169
                        data["source"],
170
                        data["destination"],
171
                        weight=self.graph.spf_edge_data_cbs[spf_attr],
172
                        k=spf_max_paths,
173
                        minimum_hits=minimum_hits,
174
                        mandatory_metrics=mandatory_metrics,
175
                        flexible_metrics=flexible_metrics,
176
                    )
177
                else:
178 1
                    paths = self.graph.k_shortest_paths(
179
                        data["source"],
180
                        data["destination"],
181
                        weight=self.graph.spf_edge_data_cbs[spf_attr],
182
                        k=spf_max_paths,
183
                    )
184
185 1
                paths = self.graph.path_cost_builder(
186
                    paths,
187
                    weight=spf_attr,
188
                )
189 1
            log.debug(f"Found paths: {paths}")
190 1
        except TypeError as err:
191 1
            raise BadRequest(str(err))
192
193 1
        paths = self._filter_paths_le_cost(paths, max_cost=spf_max_path_cost)
194 1
        paths = self._filter_paths_undesired_links(paths, undesired)
195 1
        paths = self._filter_paths_desired_links(paths, desired)
196 1
        log.debug(f"Filtered paths: {paths}")
197 1
        return jsonify({"paths": paths})
198
199 1
    @listen_to("kytos.topology.updated", "kytos/topology.topology_loaded")
200 1
    def on_topology_updated(self, event):
201
        """Update the graph when the network topology is updated."""
202
        self.update_topology(event)
203
204 1
    def update_topology(self, event):
205
        """Update the graph when the network topology is updated."""
206 1
        if "topology" not in event.content:
207 1
            return
208 1
        topology = event.content["topology"]
209 1
        with self._lock:
210 1
            self._topology = topology
211 1
            self.graph.update_topology(topology)
212 1
        log.debug("Topology graph updated.")
213
214 1
    @listen_to("kytos/topology.links.metadata.(added|removed)")
215 1
    def on_links_metadata_changed(self, event):
216
        """Update the graph when links' metadata are added or removed."""
217
        link = event.content["link"]
218
        with self._lock:
219
            self.graph.update_link_metadata(link)
220
        metadata = event.content["metadata"]
221
        log.debug(f"Topology graph updated link id: {link.id} metadata: {metadata}")
222