Completed
Pull Request — master (#81)
by
unknown
01:15
created

compute_pvals()   A

Complexity

Conditions 2

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 2
c 2
b 0
f 0
dl 0
loc 14
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
4
"""
5
python %prog study.file population.file gene-association.file
6
7
This program returns P-values for functional enrichment in a cluster of
8
study genes using Fisher's exact test, and corrected for multiple testing
9
(including Bonferroni, Holm, Sidak, and false discovery rate)
10
"""
11
12
from __future__ import absolute_import
13
14
__copyright__ = "Copyright (C) 2010-2017, H Tang et al., All rights reserved."
15
__author__ = "various"
16
17
import sys
18
import collections as cx
19
import datetime
20
from functools import partial
21
import multiprocessing
22
23
24
from goatools.multiple_testing import Methods, Bonferroni, Sidak, HolmBonferroni, FDR, calc_qval
25
from goatools.ratio import get_terms, count_terms, is_ratio_different
26
import goatools.wr_tbl as RPT
27
from goatools.pvalcalc import FisherFactory
28
29
30
def compute_pvals(allterms, calc_pvalue, go2studyitems, go2popitems,
31
                  study_n, pop_n):
32
33
    results = {}
34
    for term in allterms:
35
36
        study_items = go2studyitems.get(term, set())
37
        study_count = len(study_items)
38
        pop_items = go2popitems.get(term, set())
39
        pop_count = len(pop_items)
40
41
        results[term] = calc_pvalue(study_count, study_n, pop_count, pop_n)
42
43
    return results
44
45
46
class GOEnrichmentRecord(object):
47
    """Represents one result (from a single GOTerm) in the GOEnrichmentStudy
48
    """
49
    namespace2NS = cx.OrderedDict([
50
        ('biological_process', 'BP'),
51
        ('molecular_function', 'MF'),
52
        ('cellular_component', 'CC')])
53
54
    # Fields seen in every enrichment result
55
    _fldsdefprt = [
56
        "GO",
57
        "NS",
58
        "enrichment",
59
        "name",
60
        "ratio_in_study",
61
        "ratio_in_pop",
62
        "p_uncorrected",
63
        "depth",
64
        "study_count",
65
        "study_items"]
66
    _fldsdeffmt = ["%s"]*3 + ["%-30s"] + ["%d/%d"] * 2 + ["%.3g"] + ["%d"] * 2 + ["%15s"]
67
68
    _flds = set(_fldsdefprt).intersection(
69
        set(['study_items', 'study_count', 'study_n', 'pop_items', 'pop_count', 'pop_n']))
70
71
    def __init__(self, **kwargs):
72
        # Methods seen in current enrichment result
73
        self._methods = []
74
        for k, v in kwargs.items():
75
            setattr(self, k, v)
76
            if k == 'ratio_in_study':
77
                setattr(self, 'study_count', v[0])
78
                setattr(self, 'study_n', v[1])
79
            if k == 'ratio_in_pop':
80
                setattr(self, 'pop_count', v[0])
81
                setattr(self, 'pop_n', v[1])
82
        self._init_enrichment()
83
        self.goterm = None  # the reference to the GOTerm
84
85
    def get_method_name(self):
86
        """Return name of first method in the _methods list."""
87
        return self._methods[0].fieldname
88
89
    def get_pvalue(self):
90
        """Returns pval for 1st method, if it exists. Else returns uncorrected pval."""
91
        if self._methods:
92
            return getattr(self, "p_{m}".format(m=self.get_method_name()))
93
        return getattr(self, "p_uncorrected")
94
95
    def set_corrected_pval(self, nt_method, pvalue):
96
        """Add object attribute based on method name."""
97
        self._methods.append(nt_method)
98
        fieldname = "".join(["p_", nt_method.fieldname])
99
        setattr(self, fieldname, pvalue)
100
101
    def __str__(self, indent=False):
102
        field_data = [getattr(self, f, "n.a.") for f in self._fldsdefprt[:-1]] + \
103
                     [getattr(self, "p_{}".format(m.fieldname)) for m in self._methods] + \
104
                     [", ".join(sorted(getattr(self, self._fldsdefprt[-1], set())))]
105
        fldsdeffmt = self._fldsdeffmt
106
        field_formatter = fldsdeffmt[:-1] + ["%.3g"]*len(self._methods) + [fldsdeffmt[-1]]
107
        self._chk_fields(field_data, field_formatter)
108
109
        # default formatting only works for non-"n.a" data
110
        for i, f in enumerate(field_data):
111
            if f == "n.a.":
112
                field_formatter[i] = "%s"
113
114
        # print dots to show the level of the term
115
        dots = self.get_indent_dots() if indent else ""
116
        prtdata = "\t".join(a % b for (a, b) in zip(field_formatter, field_data))
117
        return "".join([dots, prtdata])
118
119
    def get_indent_dots(self):
120
        """Get a string of dots ("....") representing the level of the GO term."""
121
        return "." * self.goterm.level if self.goterm is not None else ""
122
123
    @staticmethod
124
    def _chk_fields(field_data, field_formatter):
125
        """Check that expected fields are present."""
126
        if len(field_data) == len(field_formatter):
127
            return
128
        len_dat = len(field_data)
129
        len_fmt = len(field_formatter)
130
        msg = [
131
            "FIELD DATA({d}) != FORMATTER({f})".format(d=len_dat, f=len_fmt),
132
            "DAT({N}): {D}".format(N=len_dat, D=field_data),
133
            "FMT({N}): {F}".format(N=len_fmt, F=field_formatter)]
134
        raise Exception("\n".join(msg))
135
136
    def __repr__(self):
137
        return "GOEnrichmentRecord({GO})".format(GO=self.GO)
138
139
    def set_goterm(self, goid):
140
        """Set goterm and copy GOTerm's name and namespace."""
141
        self.goterm = goid.get(self.GO, None)
142
        present = self.goterm is not None
143
        self.name = self.goterm.name if present else "n.a."
144
        self.NS = self.namespace2NS[self.goterm.namespace] if present else "XX"
145
146
    def _init_enrichment(self):
147
        """Mark as 'enriched' or 'purified'."""
148
        self.enrichment = 'e' if ((1.0 * self.study_count / self.study_n) >
149
                                  (1.0 * self.pop_count / self.pop_n)) else 'p'
150
151
    def update_remaining_fldsdefprt(self, min_ratio=None):
152
        """Finish updating self (GOEnrichmentRecord) field, is_ratio_different."""
153
        self.is_ratio_different = is_ratio_different(min_ratio, self.study_count,
154
                                                     self.study_n, self.pop_count, self.pop_n)
155
156
157
    # -------------------------------------------------------------------------------------
158
    # Methods for getting flat namedtuple values from GOEnrichmentRecord object
159
    def get_prtflds_default(self):
160
        """Get default fields."""
161
        return self._fldsdefprt[:-1] + \
162
               ["p_{M}".format(M=m.fieldname) for m in self._methods] + \
163
               [self._fldsdefprt[-1]]
164
165
    def get_prtflds_all(self):
166
        """When converting to a namedtuple, get all possible fields in their original order."""
167
        flds = []
168
        dont_add = set(['_parents', '_methods'])
169
        # Fields: GO NS enrichment name ratio_in_study ratio_in_pop p_uncorrected
170
        #         depth study_count p_sm_bonferroni p_fdr_bh study_items
171
        self._flds_append(flds, self.get_prtflds_default(), dont_add)
172
        # Fields: GO NS goterm
173
        #         ratio_in_pop pop_n pop_count pop_items name
174
        #         ratio_in_study study_n study_count study_items
175
        #         _methods enrichment p_uncorrected p_sm_bonferroni p_fdr_bh
176
        self._flds_append(flds, vars(self).keys(), dont_add)
177
        # Fields: name level is_obsolete namespace id depth parents children _parents alt_ids
178
        self._flds_append(flds, vars(self.goterm).keys(), dont_add)
179
        return flds
180
181
    @staticmethod
182
    def _flds_append(flds, addthese, dont_add):
183
        """Retain order of fields as we add them once to the list."""
184
        for fld in addthese:
185
            if fld not in flds and fld not in dont_add:
186
                flds.append(fld)
187
188
    def get_field_values(self, fldnames, rpt_fmt=True):
189
        """Get flat namedtuple fields for one GOEnrichmentRecord."""
190
        row = []
191
        # Loop through each user field desired
192
        for fld in fldnames:
193
            # 1. Check the GOEnrichmentRecord's attributes
194
            val = getattr(self, fld, None)
195
            if val is not None:
196
                if rpt_fmt:
197
                    val = self._get_rpt_fmt(fld, val)
198
                row.append(val)
199
            else:
200
                # 2. Check the GO object for the field
201
                val = getattr(self.goterm, fld, None)
202
                if rpt_fmt:
203
                    val = self._get_rpt_fmt(fld, val)
204
                if val is not None:
205
                    row.append(val)
206
                else:
207
                    # 3. Field not found, raise Exception
208
                    self._err_fld(fld, fldnames, row)
209
            if rpt_fmt:
210
                assert not isinstance(val, list), \
211
                   "UNEXPECTED LIST: FIELD({F}) VALUE({V}) FMT({P})".format(
212
                       P=rpt_fmt, F=fld, V=val)
213
        return row
214
215
    @staticmethod
216
    def _get_rpt_fmt(fld, val):
217
        """Return values in a format amenable to printing in a table."""
218
        if fld.startswith("ratio_"):
219
            return "{N}/{TOT}".format(N=val[0], TOT=val[1])
220
        elif fld in set(['study_items', 'pop_items', 'alt_ids']):
221
            return ", ".join([str(v) for v in sorted(val)])
222
        return val
223
224
    def _err_fld(self, fld, fldnames):
225
        """Unrecognized field. Print detailed Failure message."""
226
        msg = ['ERROR. UNRECOGNIZED FIELD({F})'.format(F=fld)]
227
        actual_flds = set(self.get_prtflds_default() + self.goterm.__dict__.keys())
228
        bad_flds = set(fldnames).difference(set(actual_flds))
229
        if bad_flds:
230
            msg.append("\nGOEA RESULT FIELDS: {}".format(" ".join(self._fldsdefprt)))
231
            msg.append("GO FIELDS: {}".format(" ".join(self.goterm.__dict__.keys())))
232
            msg.append("\nFATAL: {N} UNEXPECTED FIELDS({F})\n".format(
233
                N=len(bad_flds), F=" ".join(bad_flds)))
234
            msg.append("  {N} User-provided fields:".format(N=len(fldnames)))
235
            for idx, fld in enumerate(fldnames, 1):
236
                mrk = "ERROR -->" if fld in bad_flds else ""
237
                msg.append("  {M:>9} {I:>2}) {F}".format(M=mrk, I=idx, F=fld))
238
        raise Exception("\n".join(msg))
239
240
241
class GOEnrichmentStudy(object):
242
    """Runs Fisher's exact test, as well as multiple corrections
243
    """
244
    # Default Excel table column widths for GOEA results
245
    default_fld2col_widths = {
246
        'NS'        :  3,
247
        'GO'        : 12,
248
        'level'     :  3,
249
        'enrichment':  1,
250
        'name'      : 60,
251
        'ratio_in_study':  8,
252
        'ratio_in_pop'  : 12,
253
        'study_items'   : 15,
254
    }
255
256
    def __init__(self, pop, assoc, obo_dag, propagate_counts=True, alpha=.05, methods=None, **kws):
257
        self.log = kws['log'] if 'log' in kws else sys.stdout
258
        self._run_multitest = {
259
            'local':lambda iargs: self._run_multitest_local(iargs),
260
            'statsmodels':lambda iargs: self._run_multitest_statsmodels(iargs)}
261
        self.pop = pop
262
        self.pop_n = len(pop)
263
        self.assoc = assoc
264
        self.obo_dag = obo_dag
265
        self.alpha = alpha
266
        if methods is None:
267
            methods = ["bonferroni", "sidak", "holm"]
268
        self.methods = Methods(methods)
269
        self.pval_obj = FisherFactory(**kws).pval_obj
270
271
        if propagate_counts:
272
            sys.stderr.write("Propagating term counts to parents ..\n")
273
            obo_dag.update_association(assoc)
274
        self.go2popitems = get_terms("population", pop, assoc, obo_dag, self.log)
275
276
    def run_study(self, study, **kws):
277
        """Run Gene Ontology Enrichment Study (GOEA) on study ids."""
278
        # Key-word arguments:
279
        methods = Methods(kws['methods']) if 'methods' in kws else self.methods
280
        alpha = kws['alpha'] if 'alpha' in kws else self.alpha
281
        log = kws['log'] if 'log' in kws else self.log
282
        # Calculate uncorrected pvalues
283
        results = self._get_pval_uncorr(study, log)
284
        if not results:
285
            return []
286
287
        # Do multipletest corrections on uncorrected pvalues and update results
288
        self._run_multitest_corr(results, methods, alpha, study, log)
289
290
        for rec in results:
291
            # get go term for name and level
292
            rec.set_goterm(self.obo_dag)
293
294
        # 'keep_if' can be used to keep only significant GO terms. Example:
295
        #     >>> keep_if = lambda nt: nt.p_fdr_bh < 0.05 # if results are significant
296
        #     >>> goea_results = goeaobj.run_study(geneids_study, keep_if=keep_if)
297
        if 'keep_if' in kws:
298
            keep_if = kws['keep_if']
299
            results = [r for r in results if keep_if(r)]
300
301
        # Default sort order: First, sort by BP, MF, CC. Second, sort by pval
302
        results.sort(key=lambda r: [r.NS, r.p_uncorrected])
303
304
        if log is not None:
305
            log.write("  {MSG}\n".format(MSG="\n  ".join(self.get_results_msg(results, study))))
306
307
        return results # list of GOEnrichmentRecord objects
308
309
    def run_study_nts(self, study, **kws):
310
        """Run GOEA on study ids. Return results as a list of namedtuples."""
311
        goea_results = self.run_study(study, **kws)
312
        return get_goea_nts_all(goea_results)
313
314
    def get_results_msg(self, results, study):
315
        """Return summary for GOEA results."""
316
        # To convert msg list to string: "\n".join(msg)
317
        msg = []
318
        if results:
319
            stu_items, num_gos_stu = self.get_item_cnt(results, "study_items")
320
            pop_items, num_gos_pop = self.get_item_cnt(results, "pop_items")
321
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} study items".format(
322
                N=len(stu_items), NT=len(set(study)), M=num_gos_stu))
323
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} population items".format(
324
                N=len(pop_items), NT=self.pop_n, M=num_gos_pop))
325
        return msg
326
327
    def _get_pval_uncorr(self, study, log=sys.stdout):
328
        """Calculate the uncorrected pvalues for study items."""
329
        if log is not None:
330
            log.write("Calculating uncorrected p-values using {PFNC}\n".format(PFNC=self.pval_obj.name))
331
        go2studyitems = get_terms("study", study, self.assoc, self.obo_dag, log)
332
        pop_n, study_n = self.pop_n, len(study)
333
        allterms = set(go2studyitems.keys()).union(
334
            set(self.go2popitems.keys()))
335
336
        # if self.pval_obj.log is a file handle, which we can not serialize, so we could
337
        # not transfer self.pval_obj.calc_pvalue to another python process with multiprocessing.
338
        # there fore we "path" the object which will later be restored again.
339
        old = self.pval_obj.log
340
        self.pval_obj.log = None
341
        calc_pvalue = self.pval_obj.calc_pvalue
342
343
        # -1 avoids freezing of the machine:
344
        n_procs = multiprocessing.cpu_count() - 1
345
346
        p = multiprocessing.Pool(n_procs)
347
        n = len(allterms)
348
349
        allterms = list(allterms)
350
        fragments = [allterms[i::n_procs] for i in range(n_procs)]
351
352
        remote_func = partial(compute_pvals, calc_pvalue=calc_pvalue, go2studyitems=go2studyitems,
353
                              go2popitems=self.go2popitems, study_n=study_n, pop_n=pop_n)
354
355
        all_p_values = p.map(remote_func, fragments)
356
357
        # restore patched file handle
358
        self.pval_obj.log = old
359
360
        results = []
361
362
        for p_values in all_p_values:
363
364
            for term, p_value in p_values.items():
365
366
                study_items = go2studyitems.get(term, set())
367
                study_count = len(study_items)
368
                pop_items = self.go2popitems.get(term, set())
369
                pop_count = len(pop_items)
370
371
                one_record = GOEnrichmentRecord(
372
                    GO=term,
373
                    p_uncorrected=p_value,
374
                    study_items=study_items,
375
                    pop_items=pop_items,
376
                    ratio_in_study=(study_count, study_n),
377
                    ratio_in_pop=(pop_count, pop_n))
378
379
                results.append(one_record)
380
381
        return results
382
383
    def _run_multitest_corr(self, results, usr_methods, alpha, study, log):
384
        """Do multiple-test corrections on uncorrected pvalues."""
385
        assert 0 < alpha < 1, "Test-wise alpha must fall between (0, 1)"
386
        pvals = [r.p_uncorrected for r in results]
387
        NtMt = cx.namedtuple("NtMt", "results pvals alpha nt_method study")
388
389
        for nt_method in usr_methods:
390
            ntmt = NtMt(results, pvals, alpha, nt_method, study)
391
            if log is not None:
392
                log.write("Running multitest correction: {MSRC} {METHOD}\n".format(
393
                    MSRC=ntmt.nt_method.source, METHOD=ntmt.nt_method.method))
394
            self._run_multitest[nt_method.source](ntmt)
395
396
    def _run_multitest_statsmodels(self, ntmt):
397
        """Use multitest mthods that have been implemented in statsmodels."""
398
        # Only load statsmodels if it is used
399
        multipletests = self.methods.get_statsmodels_multipletests()
400
        results = multipletests(ntmt.pvals, ntmt.alpha, ntmt.nt_method.method)
401
        pvals_corrected = results[1] # reject_lst, pvals_corrected, alphacSidak, alphacBonf
402
        self._update_pvalcorr(ntmt, pvals_corrected)
403
404
    def _run_multitest_local(self, ntmt):
405
        """Use multitest mthods that have been implemented locally."""
406
        corrected_pvals = None
407
        method = ntmt.nt_method.method
408
        if method == "bonferroni":
409
            corrected_pvals = Bonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
410
        elif method == "sidak":
411
            corrected_pvals = Sidak(ntmt.pvals, ntmt.alpha).corrected_pvals
412
        elif method == "holm":
413
            corrected_pvals = HolmBonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
414
        elif method == "fdr":
415
            # get the empirical p-value distributions for FDR
416
            term_pop = getattr(self, 'term_pop', None)
417
            if term_pop is None:
418
                term_pop = count_terms(self.pop, self.assoc, self.obo_dag)
419
            p_val_distribution = calc_qval(len(ntmt.study),
420
                                           self.pop_n,
421
                                           self.pop, self.assoc,
422
                                           term_pop, self.obo_dag)
423
            corrected_pvals = FDR(p_val_distribution,
424
                                  ntmt.results, ntmt.alpha).corrected_pvals
425
426
        self._update_pvalcorr(ntmt, corrected_pvals)
427
428
    @staticmethod
429
    def _update_pvalcorr(ntmt, corrected_pvals):
430
        """Add data members to store multiple test corrections."""
431
        if corrected_pvals is None:
432
            return
433
        for rec, val in zip(ntmt.results, corrected_pvals):
434
            rec.set_corrected_pval(ntmt.nt_method, val)
435
436
    # Methods for writing results into tables: text, tab-separated, Excel spreadsheets
437
    def wr_txt(self, fout_txt, goea_results, prtfmt=None, **kws):
438
        """Print GOEA results to text file."""
439
        if not goea_results:
440
            sys.stdout.write("      0 GOEA results. NOT WRITING {FOUT}\n".format(FOUT=fout_txt))
441
            return
442
        with open(fout_txt, 'w') as prt:
443
            data_nts = self.prt_txt(prt, goea_results, prtfmt, **kws)
444
            log = self.log if self.log is not None else sys.stdout
445
            log.write("  {N:>5} GOEA results for {CUR:5} study items. WROTE: {F}\n".format(
446
                N=len(data_nts),
447
                CUR=len(get_study_items(goea_results)),
448
                F=fout_txt))
449
450
    def prt_txt(self, prt, goea_results, prtfmt=None, **kws):
451
        """Print GOEA results in text format."""
452
        if prtfmt is None:
453
            prtfmt = "{GO} {NS} {p_uncorrected:5.2e} {study_count:>5} {name}\n"
454
        prtfmt = self.adjust_prtfmt(prtfmt)
455
        prt_flds = RPT.get_fmtflds(prtfmt)
456
        data_nts = get_goea_nts_prt(goea_results, prt_flds, **kws)
457
        RPT.prt_txt(prt, data_nts, prtfmt, prt_flds, **kws)
458
        return data_nts
459
460
    def wr_xlsx(self, fout_xlsx, goea_results, **kws):
461
        """Write a xlsx file."""
462
        # kws: prt_if indent
463
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
464
        xlsx_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
465
        if 'fld2col_widths' not in kws:
466
            kws['fld2col_widths'] = {f:self.default_fld2col_widths.get(f, 8) for f in prt_flds}
467
        RPT.wr_xlsx(fout_xlsx, xlsx_data, **kws)
468
469
    def wr_tsv(self, fout_tsv, goea_results, **kws):
470
        """Write tab-separated table data to file"""
471
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
472
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
473
        RPT.wr_tsv(fout_tsv, tsv_data, **kws)
474
475
    def prt_tsv(self, prt, goea_results, **kws):
476
        """Write tab-separated table data"""
477
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
478
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
479
        RPT.prt_tsv(prt, tsv_data, prt_flds, **kws)
480
481
    @staticmethod
482
    def adjust_prtfmt(prtfmt):
483
        """Adjust format_strings for legal values."""
484
        prtfmt = prtfmt.replace("{p_holm-sidak", "{p_holm_sidak")
485
        prtfmt = prtfmt.replace("{p_simes-hochberg", "{p_simes_hochberg")
486
        return prtfmt
487
488
    @staticmethod
489
    def get_NS2nts(results, fldnames=None, **kws):
490
        """Get namedtuples of GOEA results, split into BP, MF, CC."""
491
        NS2nts = cx.defaultdict(list)
492
        nts = get_goea_nts_all(results, fldnames, **kws)
493
        for nt in nts:
494
            NS2nts[nt.NS].append(nt)
495
        return NS2nts
496
497
    @staticmethod
498
    def get_item_cnt(results, attrname="study_items"):
499
        """Get all study or population items (e.g., geneids)."""
500
        items = set()
501
        go_cnt = 0
502
        for rec in results:
503
            if hasattr(rec, attrname):
504
                items_cur = getattr(rec, attrname)
505
                # Only count GO term if there are items in the set.
506
                if len(items_cur) != 0:
507
                    items |= items_cur
508
                    go_cnt += 1
509
        return items, go_cnt
510
511
    @staticmethod
512
    def get_prtflds_default(results):
513
        """Get default fields names. Used in printing GOEA results.
514
515
           Researchers can control which fields they want to print in the GOEA results
516
           or they can use the default fields.
517
        """
518
        if results:
519
            return results[0].get_prtflds_default()
520
        return []
521
522
    @staticmethod
523
    def print_summary(results, min_ratio=None, indent=False, pval=0.05):
524
        """Print summary."""
525
        from .version import __version__ as version
526
527
        # Header contains provenance and parameters
528
        print("# Generated by GOATOOLS v{0} ({1})".format(version, datetime.date.today()))
529
        print("# min_ratio={0} pval={1}".format(min_ratio, pval))
530
531
        # field names for output
532
        if results:
533
            print("\t".join(GOEnrichmentStudy.get_prtflds_default(results)))
534
535
        for rec in results:
536
            # calculate some additional statistics
537
            # (over_under, is_ratio_different)
538
            rec.update_remaining_fldsdefprt(min_ratio=min_ratio)
539
540
            if pval is not None and rec.p_uncorrected >= pval:
541
                continue
542
543
            if rec.is_ratio_different:
544
                print(rec.__str__(indent=indent))
545
546
    def wr_py_goea_results(self, fout_py, goea_results, **kws):
547
        """Save GOEA results into Python package containing list of namedtuples."""
548
        var_name = kws.get("var_name", "goea_results")
549
        docstring = kws.get("docstring", "")
550
        sortby = kws.get("sortby", None)
551
        if goea_results:
552
            from goatools.nt_utils import wr_py_nts
553
            nts_goea = goea_results
554
            # If list has GOEnrichmentRecords or verbose namedtuples, exclude some fields.
555
            if hasattr(goea_results[0], "_fldsdefprt") or hasattr(goea_results[0], 'goterm'):
556
                # Exclude some attributes from the namedtuple when saving results
557
                # to a Python file because the information is redundant or verbose.
558
                nts_goea = get_goea_nts_prt(goea_results)
559
            docstring = "\n".join([docstring, "# {VER}\n\n".format(VER=self.obo_dag.version)])
560
            assert hasattr(nts_goea[0], '_fields')
561
            if sortby is None:
562
                sortby = lambda nt: getattr(nt, 'p_uncorrected')
563
            nts_goea = sorted(nts_goea, key=sortby)
564
            wr_py_nts(fout_py, nts_goea, docstring, var_name)
565
566
def get_study_items(goea_results):
567
    """Get all study items (e.g., geneids)."""
568
    study_items = set()
569
    for rec in goea_results:
570
        study_items |= rec.study_items
571
    return study_items
572
573
def get_goea_nts_prt(goea_results, fldnames=None, **usr_kws):
574
    """Return list of namedtuples removing fields which are redundant or verbose."""
575
    kws = usr_kws.copy()
576
    if 'not_fldnames' not in kws:
577
        kws['not_fldnames'] = ['goterm', 'parents', 'children', 'id']
578
    if 'rpt_fmt' not in kws:
579
        kws['rpt_fmt'] = True
580
    return get_goea_nts_all(goea_results, fldnames, **kws)
581
582
def get_goea_nts_all(goea_results, fldnames=None, **kws):
583
    """Get namedtuples containing user-specified (or default) data from GOEA results.
584
585
        Reformats data from GOEnrichmentRecord objects into lists of
586
        namedtuples so the generic table writers may be used.
587
    """
588
    data_nts = [] # A list of namedtuples containing GOEA results
589
    if not goea_results:
590
        return data_nts
591
    keep_if = kws.get('keep_if', None)
592
    rpt_fmt = kws.get('rpt_fmt', False)
593
    indent = kws.get('indent', False)
594
    # I. FIELD (column) NAMES
595
    not_fldnames = kws.get('not_fldnames', None)
596
    if fldnames is None:
597
        fldnames = get_fieldnames(goea_results[0])
598
    # Ia. Explicitly exclude specific fields from named tuple
599
    if not_fldnames is not None:
600
        fldnames = [f for f in fldnames if f not in not_fldnames]
601
    nttyp = cx.namedtuple("NtGoeaResults", " ".join(fldnames))
602
    goid_idx = fldnames.index("GO") if 'GO' in fldnames else None
603
    # II. Loop through GOEA results stored in a GOEnrichmentRecord object
604
    for goerec in goea_results:
605
        vals = get_field_values(goerec, fldnames, rpt_fmt)
606
        if indent:
607
            vals[goid_idx] = "".join([goerec.get_indent_dots(), vals[goid_idx]])
608
        ntobj = nttyp._make(vals)
609
        if keep_if is None or keep_if(ntobj):
610
            data_nts.append(ntobj)
611
    return data_nts
612
613
def get_field_values(item, fldnames, rpt_fmt=None):
614
    """Return fieldnames and values of either a namedtuple or GOEnrichmentRecord."""
615
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
616
        return item.get_field_values(fldnames, rpt_fmt)
617
    if hasattr(item, "_fields"): # Is a namedtuple
618
        return [getattr(item, f) for f in fldnames]
619
620
def get_fieldnames(item):
621
    """Return fieldnames of either a namedtuple or GOEnrichmentRecord."""
622
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
623
        return item.get_prtflds_all()
624
    if hasattr(item, "_fields"): # Is a namedtuple
625
        return item._fields
626
627
# Copyright (C) 2010-2017, H Tang et al., All rights reserved.
628