Completed
Pull Request — master (#59)
by Gonzalo
57s
created

GitHubRepo.labels()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 4
rs 10
1
# -*- coding: utf-8 -*-
2
# -----------------------------------------------------------------------------
3
# Copyright (c) The Spyder Development Team
4
#
5
# Licensed under the terms of the MIT License
6
# (See LICENSE.txt for details)
7
# -----------------------------------------------------------------------------
8
"""Github repo wrapper."""
9
10
from __future__ import print_function
11
12
# Standard library imports
13
import datetime
14
import sys
15
import time
16
17
# Local imports
18
from loghub.external.github import ApiError, ApiNotFoundError, GitHub
19
20
21
class GitHubRepo(object):
22
    """Github repository wrapper."""
23
24
    def __init__(self, username=None, password=None, token=None, repo=None):
25
        """Github repository wrapper."""
26
        self._username = username
27
        self._password = password
28
        self._token = token
29
30
        self.gh = GitHub(
31
            username=username,
32
            password=password,
33
            access_token=token, )
34
        repo_organization, repo_name = repo.split('/')
35
        self._repo_organization = repo_organization
36
        self._repo_name = repo_name
37
        self.repo = self.gh.repos(repo_organization)(repo_name)
38
39
        # Check username and repo name
40
        self._check_user()
41
        self._check_repo_name()
42
43
    def _check_user(self):
44
        """Check if the supplied username is valid."""
45
        try:
46
            self.gh.users(self._repo_organization).get()
47
        except ApiNotFoundError:
48
            print('LOGHUB: Organization/user `{}` seems to be '
49
                  'invalid.\n'.format(self._repo_organization))
50
            sys.exit(1)
51
        except ApiError:
52
            self._check_rate()
53
            print('LOGHUB: The credentials seems to be invalid!\n')
54
            sys.exit(1)
55
56
    def _check_repo_name(self):
57
        """Check if the supplied repository exists."""
58
        try:
59
            self.repo.get()
60
        except ApiNotFoundError:
61
            print('LOGHUB: Repository `{0}` for organization/username `{1}` '
62
                  'seems to be invalid.\n'.format(self._repo_name,
63
                                                  self._repo_organization))
64
            sys.exit(1)
65
        except ApiError:
66
            self._check_rate()
67
68
    def _check_rate(self):
69
        """Check and handle if api rate limit has been exceeded."""
70
        if self.gh.x_ratelimit_remaining == 0:
71
            reset_struct = time.gmtime(self.gh.x_ratelimit_reset)
72
            reset_format = time.strftime('%Y/%m/%d %H:%M', reset_struct)
73
            print('LOGHUB: GitHub API rate limit exceeded!')
74
            print('LOGHUB: GitHub API rate limit resets on '
75
                  '{}'.format(reset_format))
76
            if not self._username and not self._password or not self._token:
77
                print('LOGHUB: Try running loghub with user/password or '
78
                      'a valid token.\n')
79
            sys.exit(1)
80
81
    def _filter_since(self, issues, since):
82
        """Filter out all issues before `since` date."""
83
        if since:
84
            since_date = self.str_to_date(since)
85
            for issue in issues[:]:
86
                close_date = self.str_to_date(issue['closed_at'])
87
                if close_date < since_date and issue in issues:
88
                    issues.remove(issue)
89
        return issues
90
91
    def _filter_until(self, issues, until):
92
        """Filter out all issues after `until` date."""
93
        if until:
94
            until_date = self.str_to_date(until)
95
            for issue in issues[:]:
96
                close_date = self.str_to_date(issue['closed_at'])
97
                if close_date > until_date and issue in issues:
98
                    issues.remove(issue)
99
        return issues
100
101
    def _filter_by_branch(self, issues, issue, branch):
102
        """Filter prs by the branch they were merged into."""
103
        number = issue['number']
104
105
        if not self.is_merged(number) and issue in issues:
106
            issues.remove(issue)
107
108
        if branch:
109
            # Get PR info and get base branch
110
            pr_data = self.pr(number)
111
            base_ref = pr_data['base']['ref']
112
113
            if base_ref != branch and issue in issues:
114
                issues.remove(issue)
115
116
        return issues
117
118
    def _filer_closed_prs(self, issues, branch):
119
        """Filter out closed PRs."""
120
        for issue in issues[:]:
121
            pr = issue.get('pull_request', '')
122
123
            # Add label names inside additional key
124
            issue['loghub_label_names'] = [
125
                l['name'] for l in issue.get('labels')
126
            ]
127
128
            if pr:
129
                issues = self._filter_by_branch(issues, issue, branch)
130
131
        return issues
132
133
    def tags(self):
134
        """Return all tags."""
135
        self._check_rate()
136
        return self.repo('git')('refs')('tags').get()
137
138
    def tag(self, tag_name):
139
        """Get tag information."""
140
        self._check_rate()
141
        refs = self.repo('git')('refs')('tags').get()
142
        sha = -1
143
144
        tags = []
145
        for ref in refs:
146
            ref_name = 'refs/tags/{tag}'.format(tag=tag_name)
147
            if 'object' in ref and ref['ref'] == ref_name:
148
                sha = ref['object']['sha']
149
            tags.append(ref['ref'].split('/')[-1])
150
151
        if sha == -1:
152
            print("LOGHUB: You didn't pass a valid tag name!")
153
            print('LOGHUB: The available tags are: {0}\n'.format(tags))
154
            sys.exit(1)
155
156
        return self.repo('git')('tags')(sha).get()
157
158
    def labels(self):
159
        """Return labels for the repo."""
160
        self._check_rate()
161
        return self.repo.labels.get()
162
163
    def set_labels(self, labels):
164
        """Return labels for the repo."""
165
        self._check_rate()
166
        for label in labels:
167
            new_name = label['new_name']
168
            old_name = label['old_name']
169
            color = label['color']
170
            try:
171
                self.repo.labels(old_name).patch(name=new_name, color=color)
172
                print('Updated label: "{0}" -> "{1}" (#{2})'.format(
173
                    old_name, new_name, color))
174
            except ApiError:
175
                try:
176
                    self.repo.labels.post(name=new_name, color=color)
177
                    print('Created label: "{0}" (#{1})'.format(new_name,
178
                                                               color))
179
                except ApiError:
180
                    print('\nLabel "{0}" already exists!'.format(new_name))
181
182
    def milestones(self):
183
        """Return all milestones."""
184
        self._check_rate()
185
        return self.repo.milestones.get(state='all')
186
187
    def milestone(self, milestone_title):
188
        """Return milestone with given title."""
189
        self._check_rate()
190
        milestones = self.milestones()
191
        milestone_number = -1
192
193
        milestone_titles = [milestone['title'] for milestone in milestones]
194
        for milestone in milestones:
195
            if milestone['title'] == milestone_title:
196
                milestone_number = milestone['number']
197
                break
198
199
        if milestone_number == -1:
200
            print("LOGHUB: You didn't pass a valid milestone name!")
201
            print('LOGHUB: The available milestones are: {0}\n'
202
                  ''.format(milestone_titles))
203
            sys.exit(1)
204
205
        return milestone
206
207
    def pr(self, pr_number):
208
        """Get PR information."""
209
        self._check_rate()
210
        return self.repo('pulls')(str(pr_number)).get()
211
212
    def issues(self,
213
               milestone=None,
214
               state=None,
215
               assignee=None,
216
               creator=None,
217
               mentioned=None,
218
               labels=None,
219
               sort=None,
220
               direction=None,
221
               since=None,
222
               until=None,
223
               branch=None):
224
        """Return Issues and Pull Requests."""
225
        self._check_rate()
226
        page = 1
227
        issues = []
228
        while True:
229
            result = self.repo.issues.get(page=page,
230
                                          per_page=100,
231
                                          milestone=milestone,
232
                                          state=state,
233
                                          assignee=assignee,
234
                                          creator=creator,
235
                                          mentioned=mentioned,
236
                                          labels=labels,
237
                                          sort=sort,
238
                                          direction=direction,
239
                                          since=since)
240
            if len(result) > 0:
241
                issues += result
242
                page = page + 1
243
            else:
244
                break
245
246
        # If since was provided, filter the issue
247
        issues = self._filter_since(issues, since)
248
249
        # If until was provided, filter the issue
250
        issues = self._filter_until(issues, until)
251
252
        # If it is a pr check if it is merged or closed, removed closed ones
253
        issues = self._filer_closed_prs(issues, branch)
254
255
        return issues
256
257
    def is_merged(self, pr):
258
        """
259
        Return wether a PR was merged, or if it was closed and discarded.
260
261
        https://developer.github.com/v3/pulls/#get-if-a-pull-request-has-been-merged
262
        """
263
        self._check_rate()
264
        merged = True
265
        try:
266
            self.repo('pulls')(str(pr))('merge').get()
267
        except Exception:
268
            merged = False
269
        return merged
270
271
    @staticmethod
272
    def str_to_date(string):
273
        """Convert ISO date string to datetime object."""
274
        parts = string.split('T')
275
        date_parts = parts[0]
276
        time_parts = parts[1][:-1]
277
        year, month, day = [int(i) for i in date_parts.split('-')]
278
        hour, minutes, seconds = [int(i) for i in time_parts.split(':')]
279
        return datetime.datetime(year, month, day, hour, minutes, seconds)
280