Completed
Push — master ( fb87d8...76ab95 )
by Tinghui
01:10
created

load_and_test()   A

Complexity

Conditions 1

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 14
rs 9.4285
1
import os
2
import pickle
3
import logging
4
import argparse
5
from datetime import datetime
6
from pyActLearn.learning.decision_tree import DecisionTree
7
from pyActLearn.CASAS.data import CASASData
8
from pyActLearn.CASAS.fuel import CASASFuel
9
from pyActLearn.performance.record import LearningResult
10
from pyActLearn.performance import get_confusion_matrix
11
12
logger = logging.getLogger(__file__)
13
14
15 View Code Duplication
def training_and_test(token, train_data, test_data, num_classes, result):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
16
    """Train and test
17
18
    Args:
19
        token (:obj:`str`): token representing this run
20
        train_data (:obj:`tuple` of :obj:`numpy.array`): Tuple of training feature and label
21
        test_data (:obj:`tuple` of :obj:`numpy.array`): Tuple of testing feature and label
22
        num_classes (:obj:`int`): Number of classes
23
        result (:obj:`pyActLearn.performance.record.LearningResult`): LearningResult object to hold learning result
24
    """
25
    decision_tree = DecisionTree(train_data[0].shape[1], num_classes, log_level=logging.WARNING)
26
    decision_tree.build(train_data[0], train_data[1].flatten())
27
    # Test
28
    predicted_y = decision_tree.classify(test_data[0])
29
    # Evaluate the Test and Store Result
30
    confusion_matrix = get_confusion_matrix(num_classes=num_classes,
31
                                            label=test_data[1].flatten(), predicted=predicted_y)
32
    result.add_record(decision_tree.export_to_dict(), key=token, confusion_matrix=confusion_matrix)
33
    return predicted_y
34
35
36
def load_and_test(token, test_data, num_classes, result):
37
    """Load and test
38
39
    Args:
40
        token (:obj:`str`): token representing this run
41
        test_data (:obj:`tuple` of :obj:`numpy.array`): Tuple of testing feature and label
42
        num_classes (:obj:`int`): Number of classes
43
        result (:obj:`pyActLearn.performance.record.LearningResult`): LearningResult object to hold learning result
44
    """
45
    decision_tree = DecisionTree(test_data[0].shape[1], num_classes, log_level=logging.WARNING)
46
    decision_tree.load_from_dict(result.get_record_by_key(token)['model'])
47
    # Test
48
    predicted_y = decision_tree.classify(test_data[0])
49
    return predicted_y
50
51 View Code Duplication
if __name__ == '__main__':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
52
    args_ok = False
53
    parser = argparse.ArgumentParser(description='Run Decision Tree on single resident CASAS datasets.')
54
    parser.add_argument('-d', '--dataset', help='Directory to original datasets')
55
    parser.add_argument('-o', '--output', help='Output folder')
56
    parser.add_argument('--h5py', help='HDF5 dataset folder')
57
    args = parser.parse_args()
58
    # Default parameters
59
    log_filename = os.path.basename(__file__).split('.')[0] + \
60
                   '-%s.log' % datetime.now().strftime('%y%m%d_%H:%M:%S')
61
    # Setup output directory
62
    output_dir = args.output
63
    if output_dir is not None:
64
        output_dir = os.path.abspath(os.path.expanduser(output_dir))
65
        if os.path.exists(output_dir):
66
            # Found output_dir, check if it is a directory
67
            if not os.path.isdir(output_dir):
68
                exit('Output directory %s is found, but not a directory. Abort.' % output_dir)
69
        else:
70
            # Create directory
71
            os.mkdir(output_dir)
72
    else:
73
        output_dir = '.'
74
    log_filename = os.path.join(output_dir, log_filename)
75
    # Setup Logging as early as possible
76
    logging.basicConfig(level=logging.DEBUG,
77
                        format='[%(asctime)s] %(name)s:%(levelname)s:%(message)s',
78
                        handlers=[logging.FileHandler(log_filename),
79
                                  logging.StreamHandler()])
80
    # If dataset is specified, update h5py
81
    casas_data_dir = args.dataset
82
    if casas_data_dir is not None:
83
        casas_data_dir = os.path.abspath(os.path.expanduser(casas_data_dir))
84
        if not os.path.isdir(casas_data_dir):
85
            exit('CASAS dataset at %s does not exist. Abort.' % casas_data_dir)
86
    # Find h5py dataset first
87
    h5py_dir = args.h5py
88
    if h5py_dir is not None:
89
        h5py_dir = os.path.abspath(os.path.expanduser(h5py_dir))
90
    else:
91
        # Default location
92
        h5py_dir = os.path.join(output_dir, 'h5py')
93
    if os.path.exists(h5py_dir):
94
        if not os.path.isdir(h5py_dir):
95
            exit('h5py dataset location %s is not a directory. Abort.' % h5py_dir)
96
    if not CASASFuel.files_exist(h5py_dir):
97
        # Finish check and creating all directory needed - now load datasets
98
        if casas_data_dir is not None:
99
            casas_data = CASASData(path=casas_data_dir)
100
            casas_data.summary()
101
            # SVM needs to use statistical feature with per-sensor and normalization
102
            casas_data.populate_feature(method='stat', normalized=False, per_sensor=False)
103
            casas_data.export_hdf5(h5py_dir)
104
    casas_fuel = CASASFuel(dir_name=h5py_dir)
105
    # Prepare learning result
106
    result_pkl_file = os.path.join(output_dir, 'result.pkl')
107
    result = None
108
    if os.path.isfile(result_pkl_file):
109
        f = open(result_pkl_file, 'rb')
110
        result = pickle.load(f)
111
        f.close()
112
        if result.data != h5py_dir:
113
            logger.error('Result pickle file found for different dataset %s' % result.data)
114
            exit('Cannot save learning result at %s' % result_pkl_file)
115
    else:
116
        result = LearningResult(name='DecisionTree', data=h5py_dir, mode='by_week')
117
    num_classes = casas_fuel.get_output_dims()
118
    # Open Fuel and get all splits
119
    split_list = casas_fuel.get_set_list()
120
    train_name = split_list[0]
121
    train_set = casas_fuel.get_dataset((train_name,), load_in_memory=True)
122
    (train_set_data) = train_set.data_sources
123
    # Prepare Back Annotation
124
    fp_back_annotated = open(os.path.join(output_dir, 'back_annotated.txt'), 'w')
125
    for i in range(1, len(split_list)):
126
        test_name = split_list[i]
127
        test_set = casas_fuel.get_dataset((test_name,), load_in_memory=True)
128
        (test_set_data) = test_set.data_sources
129
        # run svm
130
        logger.info('Training on %s, Testing on %s' % (train_name, test_name))
131
        if result.get_record_by_key(test_name) is None:
132
            prediction = training_and_test(test_name, train_set_data, test_set_data, num_classes, result)
133
        else:
134
            prediction = load_and_test(test_name, test_set_data, num_classes, result)
135
        casas_fuel.back_annotate(fp_back_annotated, prediction=prediction, split_id=i)
136
        train_name = test_name
137
        train_set_data = test_set_data
138
    f = open(result_pkl_file, 'wb')
139
    pickle.dump(obj=result, file=f, protocol=pickle.HIGHEST_PROTOCOL)
140
    f.close()
141
    result.export_to_xlsx(os.path.join(output_dir, 'result.xlsx'))
142
143