Completed
Pull Request — master (#173)
by
unknown
27s
created

TestSlurm._test_param()   B

Complexity

Conditions 6

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 6
c 3
b 0
f 0
dl 0
loc 19
rs 8
1
from glob import glob
2
import os
3
import time
4
import unittest
5
from subprocess import Popen, PIPE
6
7
from smartdispatch.utils import get_slurm_cluster_name
8
9
pbs_string = """\
10
#!/usr/bin/env /bin/bash
11
12
#PBS -N arrayJob
13
#PBS -o arrayJob_%A_%a.out
14
#PBS -l walltime=01:00:00
15
{}
16
17
######################
18
# Begin work section #
19
######################
20
21
echo "My SLURM_ARRAY_JOB_ID:" $SLURM_ARRAY_JOB_ID
22
echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
23
nvidia-smi
24
"""
25
26
# Checking which cluster is running the tests first
27
cluster = get_slurm_cluster_name()
28
to_skip = cluster in ['graham', 'cedar']
29
message = "Test does not run on cluster {}".format(cluster)
30
31
class TestSlurm(unittest.TestCase):
32
33
    def tearDown(self):
34
        for file_name in (glob('*.out') + ["test.pbs"]):
35
            os.remove(file_name)
36
37
    def _test_param(self, param_array, command_template, flag, string=pbs_string, output_array=None):
38
        output_array = output_array or param_array
39
        for param, output in zip(param_array, output_array):
40
            param_command = pbs_string.format(
41
                string.format(command_template.format(param))
42
            )
43
            with open("test.pbs", "w") as text_file:
44
                text_file.write(param_command)
45
            process = Popen("sbatch test.pbs", stdout=PIPE, stderr=PIPE, shell=True)
46
            stdout, _ = process.communicate()
47
            stdout = stdout.decode()
48
            self.assertIn("Submitted batch job", stdout)
49
            job_id = stdout.split(" ")[-1].strip()
50
51
            time.sleep(0.25)
52
            process = Popen("squeue -u $USER -j {} -O {}".format(job_id, flag), stdout=PIPE, stderr=PIPE, shell=True)
53
            stdout, _ = process.communicate()
54
            job_params = [c.strip() for c in stdout.decode().split("\n")[1:] if c != '']
55
            self.assertSequenceEqual(job_params, [output for _ in range(len(job_params))])
56
57
    @unittest.skipIf(to_skip, message)
58
    def test_priority(self):
59
        self._test_param(
60
            ['high', 'low'],
61
            "#SBATCH --qos={}",
62
            "qos",
63
            pbs_string
64
        )
65
66
    def test_gres(self):
67
        self._test_param(
68
            ["1", "2"],
69
            "#PBS -l naccelerators={}",
70
            "gres",
71
            pbs_string,
72
            ["gpu:1", "gpu:2"]
73
        )
74
75
    def test_memory(self):
76
        self._test_param(
77
            ["2G", "4G"],
78
            "#PBS -l mem={}",
79
            "minmemory",
80
            pbs_string
81
        )
82
83
    def test_nb_cpus(self):
84
        self._test_param(
85
            ["2", "3"],
86
            "#PBS -l ncpus={}",
87
            "mincpus",
88
            pbs_string
89
        )
90
91
    @unittest.skipIf(to_skip, message)
92
    def test_constraint(self):
93
        self._test_param(
94
            ["gpu6gb", "gpu8gb"],
95
            "#PBS -l proc={}",
96
            "feature",
97
            pbs_string
98
        )
99
100
if __name__ == '__main__':
101
    unittest.main()
102