Completed
Pull Request — master (#166)
by
unknown
47s
created

TestSlurm._test_param()   B

Complexity

Conditions 6

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 20
rs 8
1
from glob import glob
2
import os
3
import time
4
import unittest
5
6
from subprocess import Popen, PIPE
7
8
pbs_string = """\
9
#!/usr/bin/env /bin/bash
10
11
#PBS -N arrayJob
12
#PBS -o arrayJob_%A_%a.out
13
#PBS -l walltime=01:00:00
14
{}
15
16
######################
17
# Begin work section #
18
######################
19
20
echo "My SLURM_ARRAY_JOB_ID:" $SLURM_ARRAY_JOB_ID
21
echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
22
nvidia-smi
23
"""
24
25
sbatch_string = """\
26
#!/usr/bin/env -i /bin/zsh
27
28
#SBATCH --job-name=arrayJob
29
#SBATCH --output=arrayJob_%A_%a.out
30
#SBATCH --time=01:00:00
31
#SBATCH --gres=gpu
32
#SBATCH --constraint=gpu6gb
33
{}
34
35
######################
36
# Begin work section #
37
######################
38
39
echo "My SLURM_ARRAY_JOB_ID:" $SLURM_ARRAY_JOB_ID
40
echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
41
nvidia-smi
42
"""
43
44
45
class TestSlurm(unittest.TestCase):
46
47
    def tearDown(self):
48
        for file_name in (glob('*.out') + ["test.pbs"]):
49
            os.remove(file_name)
50
51
    def _test_param(self, param_array, com, flag, string=pbs_string):
52
        for param in param_array:
53
            command = pbs_string.format(
54
                string.format(com.format(param))
55
            )
56
            with open("test.pbs", "w") as text_file:
57
                text_file.write(command)
58
            process = Popen("sbatch test.pbs", stdout=PIPE, stderr=PIPE, shell=True)
59
            stdout, _ = process.communicate()
60
            stdout = stdout.decode()
61
            print(stdout)
62
            self.assertIn("Submitted batch job", stdout)
63
            job_id = stdout.split(" ")[-1].strip()
64
65
            time.sleep(0.25)
66
            process = Popen("squeue -u $USER -j {} -O {}".format(job_id, flag), stdout=PIPE, stderr=PIPE, shell=True)
67
            stdout, _ = process.communicate()
68
            job_params = [c.strip() for c in stdout.decode().split("\n")[1:] if c != '']
69
            # import ipdb; ipdb.set_trace()
70
            self.assertSequenceEqual(job_params, [param for _ in range(len(job_params))])
71
72
    def test_priority(self):
73
        self._test_param(
74
            ['high', 'low'],
75
            "#SBATCH --qos={}",
76
            "qos",
77
            pbs_string
78
        )
79
80
    def test_gres(self):
81
        self._test_param(
82
            ['k80'],
83
            "#PBS -l naccelerators={}",
84
            "gres",
85
            pbs_string
86
        )
87
88
    def test_memory(self):
89
        self._test_param(
90
            ['2G', '4G'],
91
            "#PBS -l mem={}",
92
            "minmemory",
93
            pbs_string
94
        )
95
96
    def test_nb_cpus(self):
97
        self._test_param(
98
            ["2", "3"],
99
            "#PBS -l mppdepth={}",
100
            # "#SBATCH --cpus-per-task={}",
101
            "numcpus",
102
            pbs_string
103
        )
104
105
    def test_constraint(self):
106
        self._test_param(
107
            ["gpu6gb", "gpu8gb"],
108
            "#PBS -l proc={}",
109
            "feature",
110
            pbs_string
111
        )
112
113
if __name__ == '__main__':
114
    unittest.main()
115