Completed
Push — develop ( 9ba97b...51e53a )
by Jace
14s queued 11s
created

gitman.models.source.Source._on_post_load()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 2
CRAP Score 1

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 3
ccs 2
cts 2
cp 1
crap 1
rs 10
c 0
b 0
f 0
1 1
import os
2 1
from dataclasses import dataclass, field
3 1
from typing import List, Optional
4
5 1
import log
6 1
7
from .. import common, exceptions, git, shell
8 1
9
10
@dataclass
11 1
class Source:
12
    """A dictionary of `git` and `ln` arguments."""
13
14 1
    name: Optional[str]
15 1
    type: str
16 1
    repo: str
17 1
    sparse_paths: List[str] = field(default_factory=list)
18 1
    rev: str = 'master'
19 1
    link: Optional[str] = None
20
    scripts: List[str] = field(default_factory=list)
21
22 1
    DIRTY = '<dirty>'
23 1
    UNKNOWN = '<unknown>'
24
25 1
    def __post_init__(self):
26 1
        if self.name is None:
27 1
            self.name = self._infer_name(self.repo)
28 1
29 1
        # TODO: Remove this?
30 1
        for name in ['name', 'repo', 'rev']:
31 1
            if not getattr(self, name):
32
                msg = "'{}' required for {}".format(name, repr(self))
33 1
                raise exceptions.InvalidConfig(msg)
34 1
35 1
    def _on_post_load(self):
36 1
        # TODO: Remove this?
37
        self.type = self.type or 'git'
38 1
39 1
    def __repr__(self):
40
        return "<source {}>".format(self)
41 1
42 1
    def __str__(self):
43 1
        pattern = "['{t}'] '{r}' @ '{v}' in '{d}'"
44 1
        if self.link:
45 1
            pattern += " <- '{s}'"
46
        return pattern.format(
47 1
            t=self.type, r=self.repo, v=self.rev, d=self.name, s=self.link
48 1
        )
49
50 1
    def __eq__(self, other):
51 1
        return self.name == other.name
52
53 1
    def __ne__(self, other):
54 1
        return self.name != other.name
55
56 1
    def __lt__(self, other):
57
        return self.name < other.name
58 1
59
    def update_files(
60
        self,
61 1
        force=False,
62 1
        force_interactive=False,
63
        fetch=False,
64
        clean=True,
65 1
        skip_changes=False,
66 1
    ):
67 1
        """Ensure the source matches the specified revision."""
68
        log.info("Updating source files...")
69
70 1
        # Clone the repository if needed
71 1
        assert self.name
72 1
        if not os.path.exists(self.name):
73
            git.clone(
74
                self.type,
75
                self.repo,
76
                self.name,
77 1
                sparse_paths=self.sparse_paths,
78
                rev=self.rev,
79
            )
80 1
81
        # Enter the working tree
82
        shell.cd(self.name)
83 1
        if not git.valid():
84
            if force:
85 1
                git.rebuild(self.type, self.repo)
86
                fetch = True
87 1
            else:
88 1
                raise self._invalid_repository
89
90 1
        # Check for uncommitted changes
91
        if not force:
92 1
            log.debug("Confirming there are no uncommitted changes...")
93
            if skip_changes:
94
                if git.changes(
95
                    self.type, include_untracked=clean, display_status=False
96 1
                ):
97 1
                    common.show(
98
                        f'Skipped update due to uncommitted changes in {os.getcwd()}',
99 1
                        color='git_changes',
100 1
                    )
101 1
                    return
102 1
            elif force_interactive:
103 1
                if git.changes(
104
                    self.type, include_untracked=clean, display_status=False
105 1
                ):
106 1
                    common.show(
107
                        f'Uncommitted changes found in {os.getcwd()}',
108 1
                        color='git_changes',
109
                    )
110 1
111 1
                    while True:
112
                        yn_input = str(
113
                            input("Do you want to overwrite? (Y/N)[Y]: ")
114 1
                        ).rstrip('\r\n')
115 1
116
                        if yn_input.lower() == "y" or not yn_input:
117
                            break
118
119 1
                        if yn_input.lower() == "n":
120 1
                            common.show(
121 1
                                f'Skipped update in {os.getcwd()}', color='git_changes'
122 1
                            )
123
                            return
124
125 1
            else:
126 1
                if git.changes(self.type, include_untracked=clean):
127 1
                    raise exceptions.UncommittedChanges(
128 1
                        f'Uncommitted changes in {os.getcwd()}'
129 1
                    )
130 1
131 1
        # Fetch the desired revision
132 1
        if fetch or git.is_fetch_required(self.type, self.rev):
133
            git.fetch(self.type, self.repo, self.name, rev=self.rev)
134 1
135 1
        # Update the working tree to the desired revision
136
        git.update(
137
            self.type, self.repo, self.name, fetch=fetch, clean=clean, rev=self.rev
138 1
        )
139
140 1
    def create_link(self, root, force=False):
141
        """Create a link from the target name to the current directory."""
142 1
        if not self.link:
143
            return
144 1
145 1
        log.info("Creating a symbolic link...")
146 1
147
        target = os.path.join(root, self.link)
148 1
        source = os.path.relpath(os.getcwd(), os.path.dirname(target))
149 1
150 1
        if os.path.islink(target):
151 1
            os.remove(target)
152 1
        elif os.path.exists(target):
153 1
            if force:
154
                shell.rm(target)
155
            else:
156
                msg = "Preexisting link location at {}".format(target)
157
                raise exceptions.UncommittedChanges(msg)
158
159 1
        shell.ln(source, target)
160 1
161 1
    def run_scripts(self, force=False):
162 1
        log.info("Running install scripts...")
163
164 1
        # Enter the working tree
165
        shell.cd(self.name)
166 1
        if not git.valid():
167
            raise self._invalid_repository
168
169
        # Check for scripts
170 1
        if not self.scripts or not self.scripts[0]:
171
            common.show("(no scripts to run)", color='shell_info')
172 1
            common.newline()
173
            return
174 1
175 1
        # Run all scripts
176 1
        for script in self.scripts:
177
            try:
178 1
                lines = shell.call(script, _shell=True)
179
            except exceptions.ShellError as exc:
180 1
                common.show(*exc.output, color='shell_error')
181
                cmd = exc.program
182 1
                if force:
183 1
                    log.debug("Ignored error from call to '%s'", cmd)
184 1
                else:
185
                    msg = "Command '{}' failed in {}".format(cmd, os.getcwd())
186 1
                    raise exceptions.ScriptFailure(msg)
187
            else:
188 1
                common.show(*lines, color='shell_output')
189 1
        common.newline()
190 1
191
    def identify(self, allow_dirty=True, allow_missing=True, skip_changes=False):
192
        """Get the path and current repository URL and hash."""
193
        assert self.name
194
        if os.path.isdir(self.name):
195
196
            shell.cd(self.name)
197
            if not git.valid():
198
                raise self._invalid_repository
199
200
            path = os.getcwd()
201
            url = git.get_url(self.type)
202
            if git.changes(
203
                self.type,
204
                display_status=not allow_dirty and not skip_changes,
205
                _show=not skip_changes,
206
            ):
207
208
                if allow_dirty:
209
                    common.show(self.DIRTY, color='git_dirty', log=False)
210
                    common.newline()
211
                    return path, url, self.DIRTY
212
213
                if skip_changes:
214
                    msg = ("Skipped lock due to uncommitted changes " "in {}").format(
215
                        os.getcwd()
216
                    )
217
                    common.show(msg, color='git_changes')
218
                    common.newline()
219
                    return path, url, self.DIRTY
220
221
                msg = "Uncommitted changes in {}".format(os.getcwd())
222
                raise exceptions.UncommittedChanges(msg)
223
224
            rev = git.get_hash(self.type, _show=True)
225
            common.show(rev, color='git_rev', log=False)
226
            common.newline()
227
            return path, url, rev
228
229
        if allow_missing:
230
            return os.getcwd(), '<missing>', self.UNKNOWN
231
232
        raise self._invalid_repository
233
234
    def lock(self, rev=None, allow_dirty=False, skip_changes=False):
235
        """Create a locked source object.
236
237
        Return a locked version of the current source if not dirty
238
        otherwise None.
239
        """
240
241
        if rev is None:
242
            _, _, rev = self.identify(
243
                allow_dirty=allow_dirty, allow_missing=False, skip_changes=skip_changes
244
            )
245
246
        if rev == self.DIRTY:
247
            return None
248
249
        source = self.__class__(
250
            type=self.type,
251
            repo=self.repo,
252
            name=self.name,
253
            rev=rev,
254
            link=self.link,
255
            scripts=self.scripts,
256
            sparse_paths=self.sparse_paths,
257
        )
258
        return source
259
260
    @property
261
    def _invalid_repository(self):
262
        assert self.name
263
        path = os.path.join(os.getcwd(), self.name)
264
        msg = """
265
266
            Not a valid repository: {}
267
            During install you can rebuild a repo with a missing .git directory using the --force option
268
            """.format(
269
            path
270
        )
271
        return exceptions.InvalidRepository(msg)
272
273
    @staticmethod
274
    def _infer_name(repo):
275
        filename = repo.split('/')[-1]
276
        name = filename.split('.')[0]
277
        return name
278