Completed
Push — master ( 618bcd...4a6efe )
by Bart
01:11
created

Stations   B

Complexity

Total Complexity 48

Size/Duplication

Total Lines 152
Duplicated Lines 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
wmc 48
c 4
b 0
f 0
dl 0
loc 152
rs 8.4864

12 Methods

Rating   Name   Duplication   Size   Complexity  
A get_station_code() 0 5 3
A find_station() 0 5 3
A get_stations_for_types() 0 7 4
A get_missing_destinations() 0 7 3
A create_traveltimes_data() 0 10 4
A update_station_data() 0 15 3
B __init__() 0 12 6
A __iter__() 0 2 1
A __len__() 0 2 1
C recreate_missing_destinations() 0 17 8
A travel_times_from_json() 0 8 4
D create_trip_data_from_station() 0 42 8

How to fix   Complexity   

Complex Class

Complex classes like Stations often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from datetime import datetime
2
from enum import Enum
3
import json
4
import os
5
import requests
6
7
import ns_api
8
9
from nsmaps.local_settings import USERNAME, APIKEY
10
from nsmaps.logger import logger
11
12
13
class StationType(Enum):
14
    stoptreinstation = 1
15
    megastation = 2
16
    knooppuntIntercitystation = 3
17
    sneltreinstation = 4
18
    intercitystation = 5
19
    knooppuntStoptreinstation = 6
20
    facultatiefStation = 7
21
    knooppuntSneltreinstation = 8
22
23
24
class Station(object):
25
    def __init__(self, nsstation, data_dir, travel_time_min=None):
26
        self.nsstation = nsstation
27
        self.data_dir = data_dir
28
        self.travel_time_min = travel_time_min
29
30
    def get_name(self):
31
        return self.nsstation.names['long']
32
33
    def get_code(self):
34
        return self.nsstation.code
35
36
    def get_country_code(self):
37
        return self.nsstation.country
38
39
    def get_lat(self):
40
        return float(self.nsstation.lat)
41
42
    def get_lon(self):
43
        return float(self.nsstation.lon)
44
45
    def get_travel_time_filepath(self):
46
        return os.path.join(self.data_dir, 'traveltimes/traveltimes_from_' + self.get_code() + '.json')
47
48
    def has_travel_time_data(self):
49
        return os.path.exists(self.get_travel_time_filepath())
50
51
    def __str__(self):
52
        return self.get_name() + ' (' +  self.get_code() + ')' + ', travel time: ' + str(self.travel_time_min)
53
54
55
class Stations(object):
56
    def __init__(self, data_dir, test=False):
57
        self.data_dir = data_dir
58
        self.stations = []
59
        nsapi = ns_api.NSAPI(USERNAME, APIKEY)
60
        nsapi_stations = nsapi.get_stations()
61
        for i, nsapi_station in enumerate(nsapi_stations):
62
            if test and i > 5 and nsapi_station.code != 'UT':
63
                continue
64
            if nsapi_station.country != 'NL':
65
                continue
66
            station = Station(nsapi_station, data_dir)
67
            self.stations.append(station)
68
69
    def __iter__(self):
70
        return self.stations.__iter__()
71
72
    def __len__(self):
73
        return self.stations.__len__()
74
75
    # def from_json(self, filename):
76
    #     stations_new = []
77
    #     with open(filename) as file:
78
    #         stations = json.load(file)['stations']
79
    #         for station in stations:
80
    #             self.find_station(self, station.name)
81
    #     return stations_new
82
83
    def find_station(self, name):
84
        for station in self.stations:
85
            if station.get_name() == name:
86
                return station
87
        return None
88
89
    def travel_times_from_json(self, filename):
90
        with open(filename) as file:
91
            travel_times = json.load(file)['stations']
92
            for travel_time in travel_times:
93
                station_name = travel_time['name']
94
                station = self.find_station(station_name)
95
                if station:
96
                    station.travel_time_min = int(travel_time['travel_time_min'])
97
98
    def update_station_data(self, filename_out):
99
        data = {'stations': []}
100
        for station in self.stations:
101
            # if station.country == "NL" and "Utrecht" in station.names['long']:
102
            travel_times_available = station.has_travel_time_data()
103
            contour_available = os.path.exists(os.path.join(self.data_dir, 'contours_' + station.get_code() + '.json'))
104
            data['stations'].append({'names': station.nsstation.names,
105
                                     'id': station.get_code(),
106
                                     'lon': station.get_lon(),
107
                                     'lat': station.get_lat(),
108
                                     'type': station.nsstation.stationtype,
109
                                     'travel_times_available': travel_times_available and contour_available})
110
        json_data = json.dumps(data, indent=4, sort_keys=True, ensure_ascii=False)
111
        with open(os.path.join(self.data_dir, filename_out), 'w') as fileout:
112
            fileout.write(json_data)
113
114
    def get_stations_for_types(self, station_types):
115
        selected_stations = []
116
        for station in self.stations:
117
            for station_type in station_types:
118
                if station.nsstation.stationtype == station_type.name:
119
                    selected_stations.append(station)
120
        return selected_stations
121
122
    def create_traveltimes_data(self, stations_from, timestamp):
123
        """ timestamp format: DD-MM-YYYY hh:mm """
124
        for station_from in stations_from:
125
            filename_out = station_from.get_travel_time_filepath()
126
            if os.path.exists(filename_out):
127
                logger.warning('File ' + filename_out + ' already exists. Will not overwrite. Return.')
128
                continue
129
            json_data = self.create_trip_data_from_station(station_from, timestamp)
130
            with open(filename_out, 'w') as fileout:
131
                fileout.write(json_data)
132
133
    def get_station_code(self, station_name):
134
        for station in self.stations:
135
            if station.get_name() == station_name:
136
                return station.get_code()
137
        return None
138
139
    def create_trip_data_from_station(self, station_from, timestamp):
140
        """ timestamp format: DD-MM-YYYY hh:mm """
141
        via = ""
142
        data = {'stations': []}
143
        data['stations'].append({'name': station_from.get_name(),
144
                                 'id': station_from.get_code(),
145
                                 'travel_time_min': 0,
146
                                 'travel_time_planned': "0:00"})
147
        nsapi = ns_api.NSAPI(USERNAME, APIKEY)
148
        for station in self.stations:
149
            if station.get_code() == station_from.get_code():
150
                continue
151
            trips = []
152
            try:
153
                trips = nsapi.get_trips(timestamp, station_from.get_code(), via, station.get_code())
154
            except TypeError as error:
155
                # this is a bug in ns-api, should return empty trips in case there are no results
156
                logger.error('Error while trying to get trips for destination: ' + station.get_name() + ', from: ' + station_from.get_name())
157
                continue
158
            except requests.exceptions.HTTPError as error:
159
                # 500: Internal Server Error does always happen for some stations (example are Eijs-Wittem and Kerkrade-West)
160
                logger.error('HTTP Error while trying to get trips for destination: ' + station.get_name() + ', from: ' + station_from.get_name())
161
                continue
162
163
            if not trips:
164
                continue
165
166
            shortest_trip = trips[0]
167
            for trip in trips:
168
                travel_time = datetime.strptime(trip.travel_time_planned, "%H:%M").time()
169
                trip.travel_time_min = travel_time.hour * 60 + travel_time.minute
170
                if trip.travel_time_min < shortest_trip.travel_time_min:
171
                    shortest_trip = trip
172
173
            logger.info(shortest_trip.departure + ' - ' + shortest_trip.destination)
174
            data['stations'].append({'name': shortest_trip.destination,
175
                                     'id': self.get_station_code(shortest_trip.destination),
176
                                     'travel_time_min': shortest_trip.travel_time_min,
177
                                     'travel_time_planned': shortest_trip.travel_time_planned})
178
            # time.sleep(0.3)  # balance load on the NS server
179
        json_data = json.dumps(data, indent=4, sort_keys=True, ensure_ascii=False)
180
        return json_data
181
182
    def get_missing_destinations(self, filename_json):
183
        self.travel_times_from_json(filename_json)
184
        missing_stations = []
185
        for station in self.stations:
186
            if station.travel_time_min is None:
187
                missing_stations.append(station)
188
        return missing_stations
189
190
    def recreate_missing_destinations(self, departure_timestamp, dry_run=False):
191
        ignore_station_ids = ['HRY', 'WTM', 'KRW', 'VMW', 'RTST', 'WIJ', 'SPV', 'SPH']
192
        for station in self.stations:
193
            if not station.has_travel_time_data():
194
                continue
195
            stations_missing = self.get_missing_destinations(station.get_travel_time_filepath())
196
            stations_missing_filtered = []
197
            for station_missing in stations_missing:
198
                if station_missing.get_code() not in ignore_station_ids:
199
                    stations_missing_filtered.append(stations_missing)
200
                    logger.info(station.get_name() + ' has missing station: ' + station_missing.get_name())
201
            if stations_missing_filtered and not dry_run:
202
                json_data = self.create_trip_data_from_station(station, departure_timestamp)
203
                with open(station.get_travel_time_filepath(), 'w') as fileout:
204
                    fileout.write(json_data)
205
            else:
206
                logger.info('No missing destinations for ' + station.get_name() + ' with ' + str(len(ignore_station_ids)) + ' ignored.')
207
208