Completed
Pull Request — master (#83)
by
unknown
01:28
created

compute_pvals()   A

Complexity

Conditions 2

Size

Total Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 0
loc 17
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
4
"""
5
python %prog study.file population.file gene-association.file
6
7
This program returns P-values for functional enrichment in a cluster of
8
study genes using Fisher's exact test, and corrected for multiple testing
9
(including Bonferroni, Holm, Sidak, and false discovery rate)
10
"""
11
12
from __future__ import absolute_import
13
14
__copyright__ = "Copyright (C) 2010-2017, H Tang et al., All rights reserved."
15
__author__ = "various"
16
17
import sys
18
import collections as cx
19
import datetime
20
from functools import partial
21
import multiprocessing
22
23
24
from goatools.multiple_testing import Methods, Bonferroni, Sidak, HolmBonferroni, FDR, calc_qval
25
from goatools.ratio import get_terms, count_terms, is_ratio_different
26
import goatools.wr_tbl as RPT
27
from goatools.pvalcalc import FisherFactory
28
from .multiprocessing_tools import p_map
29
30
31
32
33
def compute_pvals(allterms, calc_pvalue, go2studyitems, go2popitems,
34
                  study_n, pop_n):
35
36
    import os
37
    print("Pid=", os.getpid())
38
39
    results = {}
40
    for term in allterms:
41
42
        study_items = go2studyitems.get(term, set())
43
        study_count = len(study_items)
44
        pop_items = go2popitems.get(term, set())
45
        pop_count = len(pop_items)
46
47
        results[term] = calc_pvalue(study_count, study_n, pop_count, pop_n)
48
49
    return results
50
51
52
class GOEnrichmentRecord(object):
53
    """Represents one result (from a single GOTerm) in the GOEnrichmentStudy
54
    """
55
    namespace2NS = cx.OrderedDict([
56
        ('biological_process', 'BP'),
57
        ('molecular_function', 'MF'),
58
        ('cellular_component', 'CC')])
59
60
    # Fields seen in every enrichment result
61
    _fldsdefprt = [
62
        "GO",
63
        "NS",
64
        "enrichment",
65
        "name",
66
        "ratio_in_study",
67
        "ratio_in_pop",
68
        "p_uncorrected",
69
        "depth",
70
        "study_count",
71
        "study_items"]
72
    _fldsdeffmt = ["%s"]*3 + ["%-30s"] + ["%d/%d"] * 2 + ["%.3g"] + ["%d"] * 2 + ["%15s"]
73
74
    _flds = set(_fldsdefprt).intersection(
75
        set(['study_items', 'study_count', 'study_n', 'pop_items', 'pop_count', 'pop_n']))
76
77
    def __init__(self, **kwargs):
78
        # Methods seen in current enrichment result
79
        self._methods = []
80
        for k, v in kwargs.items():
81
            setattr(self, k, v)
82
            if k == 'ratio_in_study':
83
                setattr(self, 'study_count', v[0])
84
                setattr(self, 'study_n', v[1])
85
            if k == 'ratio_in_pop':
86
                setattr(self, 'pop_count', v[0])
87
                setattr(self, 'pop_n', v[1])
88
        self._init_enrichment()
89
        self.goterm = None  # the reference to the GOTerm
90
91
    def get_method_name(self):
92
        """Return name of first method in the _methods list."""
93
        return self._methods[0].fieldname
94
95
    def get_pvalue(self):
96
        """Returns pval for 1st method, if it exists. Else returns uncorrected pval."""
97
        if self._methods:
98
            return getattr(self, "p_{m}".format(m=self.get_method_name()))
99
        return getattr(self, "p_uncorrected")
100
101
    def set_corrected_pval(self, nt_method, pvalue):
102
        """Add object attribute based on method name."""
103
        self._methods.append(nt_method)
104
        fieldname = "".join(["p_", nt_method.fieldname])
105
        setattr(self, fieldname, pvalue)
106
107
    def __str__(self, indent=False):
108
        field_data = [getattr(self, f, "n.a.") for f in self._fldsdefprt[:-1]] + \
109
                     [getattr(self, "p_{}".format(m.fieldname)) for m in self._methods] + \
110
                     [", ".join(sorted(getattr(self, self._fldsdefprt[-1], set())))]
111
        fldsdeffmt = self._fldsdeffmt
112
        field_formatter = fldsdeffmt[:-1] + ["%.3g"]*len(self._methods) + [fldsdeffmt[-1]]
113
        self._chk_fields(field_data, field_formatter)
114
115
        # default formatting only works for non-"n.a" data
116
        for i, f in enumerate(field_data):
117
            if f == "n.a.":
118
                field_formatter[i] = "%s"
119
120
        # print dots to show the level of the term
121
        dots = self.get_indent_dots() if indent else ""
122
        prtdata = "\t".join(a % b for (a, b) in zip(field_formatter, field_data))
123
        return "".join([dots, prtdata])
124
125
    def get_indent_dots(self):
126
        """Get a string of dots ("....") representing the level of the GO term."""
127
        return "." * self.goterm.level if self.goterm is not None else ""
128
129
    @staticmethod
130
    def _chk_fields(field_data, field_formatter):
131
        """Check that expected fields are present."""
132
        if len(field_data) == len(field_formatter):
133
            return
134
        len_dat = len(field_data)
135
        len_fmt = len(field_formatter)
136
        msg = [
137
            "FIELD DATA({d}) != FORMATTER({f})".format(d=len_dat, f=len_fmt),
138
            "DAT({N}): {D}".format(N=len_dat, D=field_data),
139
            "FMT({N}): {F}".format(N=len_fmt, F=field_formatter)]
140
        raise Exception("\n".join(msg))
141
142
    def __repr__(self):
143
        return "GOEnrichmentRecord({GO})".format(GO=self.GO)
144
145
    def set_goterm(self, goid):
146
        """Set goterm and copy GOTerm's name and namespace."""
147
        self.goterm = goid.get(self.GO, None)
148
        present = self.goterm is not None
149
        self.name = self.goterm.name if present else "n.a."
150
        self.NS = self.namespace2NS[self.goterm.namespace] if present else "XX"
151
152
    def _init_enrichment(self):
153
        """Mark as 'enriched' or 'purified'."""
154
        self.enrichment = 'e' if ((1.0 * self.study_count / self.study_n) >
155
                                  (1.0 * self.pop_count / self.pop_n)) else 'p'
156
157
    def update_remaining_fldsdefprt(self, min_ratio=None):
158
        """Finish updating self (GOEnrichmentRecord) field, is_ratio_different."""
159
        self.is_ratio_different = is_ratio_different(min_ratio, self.study_count,
160
                                                     self.study_n, self.pop_count, self.pop_n)
161
162
163
    # -------------------------------------------------------------------------------------
164
    # Methods for getting flat namedtuple values from GOEnrichmentRecord object
165
    def get_prtflds_default(self):
166
        """Get default fields."""
167
        return self._fldsdefprt[:-1] + \
168
               ["p_{M}".format(M=m.fieldname) for m in self._methods] + \
169
               [self._fldsdefprt[-1]]
170
171
    def get_prtflds_all(self):
172
        """When converting to a namedtuple, get all possible fields in their original order."""
173
        flds = []
174
        dont_add = set(['_parents', '_methods'])
175
        # Fields: GO NS enrichment name ratio_in_study ratio_in_pop p_uncorrected
176
        #         depth study_count p_sm_bonferroni p_fdr_bh study_items
177
        self._flds_append(flds, self.get_prtflds_default(), dont_add)
178
        # Fields: GO NS goterm
179
        #         ratio_in_pop pop_n pop_count pop_items name
180
        #         ratio_in_study study_n study_count study_items
181
        #         _methods enrichment p_uncorrected p_sm_bonferroni p_fdr_bh
182
        self._flds_append(flds, vars(self).keys(), dont_add)
183
        # Fields: name level is_obsolete namespace id depth parents children _parents alt_ids
184
        self._flds_append(flds, vars(self.goterm).keys(), dont_add)
185
        return flds
186
187
    @staticmethod
188
    def _flds_append(flds, addthese, dont_add):
189
        """Retain order of fields as we add them once to the list."""
190
        for fld in addthese:
191
            if fld not in flds and fld not in dont_add:
192
                flds.append(fld)
193
194
    def get_field_values(self, fldnames, rpt_fmt=True):
195
        """Get flat namedtuple fields for one GOEnrichmentRecord."""
196
        row = []
197
        # Loop through each user field desired
198
        for fld in fldnames:
199
            # 1. Check the GOEnrichmentRecord's attributes
200
            val = getattr(self, fld, None)
201
            if val is not None:
202
                if rpt_fmt:
203
                    val = self._get_rpt_fmt(fld, val)
204
                row.append(val)
205
            else:
206
                # 2. Check the GO object for the field
207
                val = getattr(self.goterm, fld, None)
208
                if rpt_fmt:
209
                    val = self._get_rpt_fmt(fld, val)
210
                if val is not None:
211
                    row.append(val)
212
                else:
213
                    # 3. Field not found, raise Exception
214
                    self._err_fld(fld, fldnames, row)
215
            if rpt_fmt:
216
                assert not isinstance(val, list), \
217
                   "UNEXPECTED LIST: FIELD({F}) VALUE({V}) FMT({P})".format(
218
                       P=rpt_fmt, F=fld, V=val)
219
        return row
220
221
    @staticmethod
222
    def _get_rpt_fmt(fld, val):
223
        """Return values in a format amenable to printing in a table."""
224
        if fld.startswith("ratio_"):
225
            return "{N}/{TOT}".format(N=val[0], TOT=val[1])
226
        elif fld in set(['study_items', 'pop_items', 'alt_ids']):
227
            return ", ".join([str(v) for v in sorted(val)])
228
        return val
229
230
    def _err_fld(self, fld, fldnames):
231
        """Unrecognized field. Print detailed Failure message."""
232
        msg = ['ERROR. UNRECOGNIZED FIELD({F})'.format(F=fld)]
233
        actual_flds = set(self.get_prtflds_default() + self.goterm.__dict__.keys())
234
        bad_flds = set(fldnames).difference(set(actual_flds))
235
        if bad_flds:
236
            msg.append("\nGOEA RESULT FIELDS: {}".format(" ".join(self._fldsdefprt)))
237
            msg.append("GO FIELDS: {}".format(" ".join(self.goterm.__dict__.keys())))
238
            msg.append("\nFATAL: {N} UNEXPECTED FIELDS({F})\n".format(
239
                N=len(bad_flds), F=" ".join(bad_flds)))
240
            msg.append("  {N} User-provided fields:".format(N=len(fldnames)))
241
            for idx, fld in enumerate(fldnames, 1):
242
                mrk = "ERROR -->" if fld in bad_flds else ""
243
                msg.append("  {M:>9} {I:>2}) {F}".format(M=mrk, I=idx, F=fld))
244
        raise Exception("\n".join(msg))
245
246
247
class GOEnrichmentStudy(object):
248
    """Runs Fisher's exact test, as well as multiple corrections
249
    """
250
    # Default Excel table column widths for GOEA results
251
    default_fld2col_widths = {
252
        'NS'        :  3,
253
        'GO'        : 12,
254
        'level'     :  3,
255
        'enrichment':  1,
256
        'name'      : 60,
257
        'ratio_in_study':  8,
258
        'ratio_in_pop'  : 12,
259
        'study_items'   : 15,
260
    }
261
262
    def __init__(self, pop, assoc, obo_dag, propagate_counts=True, alpha=.05, methods=None, **kws):
263
        self.log = kws['log'] if 'log' in kws else sys.stdout
264
        self._run_multitest = {
265
            'local':lambda iargs: self._run_multitest_local(iargs),
266
            'statsmodels':lambda iargs: self._run_multitest_statsmodels(iargs)}
267
        self.pop = pop
268
        self.pop_n = len(pop)
269
        self.assoc = assoc
270
        self.obo_dag = obo_dag
271
        self.alpha = alpha
272
        if methods is None:
273
            methods = ["bonferroni", "sidak", "holm"]
274
        self.methods = Methods(methods)
275
        self.pval_obj = FisherFactory(**kws).pval_obj
276
277
        if propagate_counts:
278
            sys.stderr.write("Propagating term counts to parents ..\n")
279
            obo_dag.update_association(assoc)
280
        self.go2popitems = get_terms("population", pop, assoc, obo_dag, self.log)
281
282
    def run_study(self, study, **kws):
283
        """Run Gene Ontology Enrichment Study (GOEA) on study ids."""
284
        # Key-word arguments:
285
        methods = Methods(kws['methods']) if 'methods' in kws else self.methods
286
        alpha = kws['alpha'] if 'alpha' in kws else self.alpha
287
        log = kws['log'] if 'log' in kws else self.log
288
        # Calculate uncorrected pvalues
289
        results = self._get_pval_uncorr(study, log)
290
        if not results:
291
            return []
292
293
        # Do multipletest corrections on uncorrected pvalues and update results
294
        self._run_multitest_corr(results, methods, alpha, study, log)
295
296
        for rec in results:
297
            # get go term for name and level
298
            rec.set_goterm(self.obo_dag)
299
300
        # 'keep_if' can be used to keep only significant GO terms. Example:
301
        #     >>> keep_if = lambda nt: nt.p_fdr_bh < 0.05 # if results are significant
302
        #     >>> goea_results = goeaobj.run_study(geneids_study, keep_if=keep_if)
303
        if 'keep_if' in kws:
304
            keep_if = kws['keep_if']
305
            results = [r for r in results if keep_if(r)]
306
307
        # Default sort order: First, sort by BP, MF, CC. Second, sort by pval
308
        results.sort(key=lambda r: [r.NS, r.p_uncorrected])
309
310
        if log is not None:
311
            log.write("  {MSG}\n".format(MSG="\n  ".join(self.get_results_msg(results, study))))
312
313
        return results # list of GOEnrichmentRecord objects
314
315
    def run_study_nts(self, study, **kws):
316
        """Run GOEA on study ids. Return results as a list of namedtuples."""
317
        goea_results = self.run_study(study, **kws)
318
        return get_goea_nts_all(goea_results)
319
320
    def get_results_msg(self, results, study):
321
        """Return summary for GOEA results."""
322
        # To convert msg list to string: "\n".join(msg)
323
        msg = []
324
        if results:
325
            stu_items, num_gos_stu = self.get_item_cnt(results, "study_items")
326
            pop_items, num_gos_pop = self.get_item_cnt(results, "pop_items")
327
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} study items".format(
328
                N=len(stu_items), NT=len(set(study)), M=num_gos_stu))
329
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} population items".format(
330
                N=len(pop_items), NT=self.pop_n, M=num_gos_pop))
331
        return msg
332
333
    def _get_pval_uncorr(self, study, log=sys.stdout):
334
        """Calculate the uncorrected pvalues for study items."""
335
        if log is not None:
336
            log.write("Calculating uncorrected p-values using {PFNC}\n".format(PFNC=self.pval_obj.name))
337
        go2studyitems = get_terms("study", study, self.assoc, self.obo_dag, log)
338
        pop_n, study_n = self.pop_n, len(study)
339
        allterms = set(go2studyitems.keys()).union(set(self.go2popitems.keys()))
340
341
        # -1 avoids freezing of the machine:
342
        n_procs = multiprocessing.cpu_count() - 1
343
344
        allterms = list(allterms)
345
        fragments = [allterms[i::n_procs] for i in range(n_procs)]
346
347
        # bind arguments, so that remote_func only depends on fragment of terms to process:
348
        calc_pvalue = self.pval_obj.calc_pvalue
349
        remote_func = partial(compute_pvals, calc_pvalue=calc_pvalue, go2studyitems=go2studyitems,
350
                              go2popitems=self.go2popitems, study_n=study_n, pop_n=pop_n)
351
352
        # if self.pval_obj.log is a file handle, which we can not serialize, we could not transfer
353
        # self.pval_obj.calc_pvalue to another python process with multiprocessing.  therefore we
354
        # "patch" the object which will later be restored again.
355
        old = self.pval_obj.log
356
        self.pval_obj.log = None
357
        p = multiprocessing.Pool(n_procs)
358
        try:
359
            all_p_values = p_map(p, remote_func, fragments)
360
        finally:
361
            # restore patched file handle
362
            self.pval_obj.log = old
363
364
        results = []
365
366
        for p_values_map in all_p_values:
367
368
            for term, p_value in p_values_map.items():
369
370
                study_items = go2studyitems.get(term, set())
371
                study_count = len(study_items)
372
                pop_items = self.go2popitems.get(term, set())
373
                pop_count = len(pop_items)
374
375
                one_record = GOEnrichmentRecord(
376
                    GO=term,
377
                    p_uncorrected=p_value,
378
                    study_items=study_items,
379
                    pop_items=pop_items,
380
                    ratio_in_study=(study_count, study_n),
381
                    ratio_in_pop=(pop_count, pop_n))
382
383
                results.append(one_record)
384
385
        return results
386
387
    def _run_multitest_corr(self, results, usr_methods, alpha, study, log):
388
        """Do multiple-test corrections on uncorrected pvalues."""
389
        assert 0 < alpha < 1, "Test-wise alpha must fall between (0, 1)"
390
        pvals = [r.p_uncorrected for r in results]
391
        NtMt = cx.namedtuple("NtMt", "results pvals alpha nt_method study")
392
393
        for nt_method in usr_methods:
394
            ntmt = NtMt(results, pvals, alpha, nt_method, study)
395
            if log is not None:
396
                log.write("Running multitest correction: {MSRC} {METHOD}\n".format(
397
                    MSRC=ntmt.nt_method.source, METHOD=ntmt.nt_method.method))
398
            self._run_multitest[nt_method.source](ntmt)
399
400
    def _run_multitest_statsmodels(self, ntmt):
401
        """Use multitest mthods that have been implemented in statsmodels."""
402
        # Only load statsmodels if it is used
403
        multipletests = self.methods.get_statsmodels_multipletests()
404
        results = multipletests(ntmt.pvals, ntmt.alpha, ntmt.nt_method.method)
405
        pvals_corrected = results[1] # reject_lst, pvals_corrected, alphacSidak, alphacBonf
406
        self._update_pvalcorr(ntmt, pvals_corrected)
407
408
    def _run_multitest_local(self, ntmt):
409
        """Use multitest mthods that have been implemented locally."""
410
        corrected_pvals = None
411
        method = ntmt.nt_method.method
412
        if method == "bonferroni":
413
            corrected_pvals = Bonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
414
        elif method == "sidak":
415
            corrected_pvals = Sidak(ntmt.pvals, ntmt.alpha).corrected_pvals
416
        elif method == "holm":
417
            corrected_pvals = HolmBonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
418
        elif method == "fdr":
419
            # get the empirical p-value distributions for FDR
420
            term_pop = getattr(self, 'term_pop', None)
421
            if term_pop is None:
422
                term_pop = count_terms(self.pop, self.assoc, self.obo_dag)
423
            p_val_distribution = calc_qval(len(ntmt.study),
424
                                           self.pop_n,
425
                                           self.pop, self.assoc,
426
                                           term_pop, self.obo_dag)
427
            corrected_pvals = FDR(p_val_distribution,
428
                                  ntmt.results, ntmt.alpha).corrected_pvals
429
430
        self._update_pvalcorr(ntmt, corrected_pvals)
431
432
    @staticmethod
433
    def _update_pvalcorr(ntmt, corrected_pvals):
434
        """Add data members to store multiple test corrections."""
435
        if corrected_pvals is None:
436
            return
437
        for rec, val in zip(ntmt.results, corrected_pvals):
438
            rec.set_corrected_pval(ntmt.nt_method, val)
439
440
    # Methods for writing results into tables: text, tab-separated, Excel spreadsheets
441
    def wr_txt(self, fout_txt, goea_results, prtfmt=None, **kws):
442
        """Print GOEA results to text file."""
443
        if not goea_results:
444
            sys.stdout.write("      0 GOEA results. NOT WRITING {FOUT}\n".format(FOUT=fout_txt))
445
            return
446
        with open(fout_txt, 'w') as prt:
447
            data_nts = self.prt_txt(prt, goea_results, prtfmt, **kws)
448
            log = self.log if self.log is not None else sys.stdout
449
            log.write("  {N:>5} GOEA results for {CUR:5} study items. WROTE: {F}\n".format(
450
                N=len(data_nts),
451
                CUR=len(get_study_items(goea_results)),
452
                F=fout_txt))
453
454
    def prt_txt(self, prt, goea_results, prtfmt=None, **kws):
455
        """Print GOEA results in text format."""
456
        if prtfmt is None:
457
            prtfmt = "{GO} {NS} {p_uncorrected:5.2e} {study_count:>5} {name}\n"
458
        prtfmt = self.adjust_prtfmt(prtfmt)
459
        prt_flds = RPT.get_fmtflds(prtfmt)
460
        data_nts = get_goea_nts_prt(goea_results, prt_flds, **kws)
461
        RPT.prt_txt(prt, data_nts, prtfmt, prt_flds, **kws)
462
        return data_nts
463
464
    def wr_xlsx(self, fout_xlsx, goea_results, **kws):
465
        """Write a xlsx file."""
466
        # kws: prt_if indent
467
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
468
        xlsx_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
469
        if 'fld2col_widths' not in kws:
470
            kws['fld2col_widths'] = {f:self.default_fld2col_widths.get(f, 8) for f in prt_flds}
471
        RPT.wr_xlsx(fout_xlsx, xlsx_data, **kws)
472
473
    def wr_tsv(self, fout_tsv, goea_results, **kws):
474
        """Write tab-separated table data to file"""
475
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
476
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
477
        RPT.wr_tsv(fout_tsv, tsv_data, **kws)
478
479
    def prt_tsv(self, prt, goea_results, **kws):
480
        """Write tab-separated table data"""
481
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
482
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
483
        RPT.prt_tsv(prt, tsv_data, prt_flds, **kws)
484
485
    @staticmethod
486
    def adjust_prtfmt(prtfmt):
487
        """Adjust format_strings for legal values."""
488
        prtfmt = prtfmt.replace("{p_holm-sidak", "{p_holm_sidak")
489
        prtfmt = prtfmt.replace("{p_simes-hochberg", "{p_simes_hochberg")
490
        return prtfmt
491
492
    @staticmethod
493
    def get_NS2nts(results, fldnames=None, **kws):
494
        """Get namedtuples of GOEA results, split into BP, MF, CC."""
495
        NS2nts = cx.defaultdict(list)
496
        nts = get_goea_nts_all(results, fldnames, **kws)
497
        for nt in nts:
498
            NS2nts[nt.NS].append(nt)
499
        return NS2nts
500
501
    @staticmethod
502
    def get_item_cnt(results, attrname="study_items"):
503
        """Get all study or population items (e.g., geneids)."""
504
        items = set()
505
        go_cnt = 0
506
        for rec in results:
507
            if hasattr(rec, attrname):
508
                items_cur = getattr(rec, attrname)
509
                # Only count GO term if there are items in the set.
510
                if len(items_cur) != 0:
511
                    items |= items_cur
512
                    go_cnt += 1
513
        return items, go_cnt
514
515
    @staticmethod
516
    def get_prtflds_default(results):
517
        """Get default fields names. Used in printing GOEA results.
518
519
           Researchers can control which fields they want to print in the GOEA results
520
           or they can use the default fields.
521
        """
522
        if results:
523
            return results[0].get_prtflds_default()
524
        return []
525
526
    @staticmethod
527
    def print_summary(results, min_ratio=None, indent=False, pval=0.05):
528
        """Print summary."""
529
        from .version import __version__ as version
530
531
        # Header contains provenance and parameters
532
        print("# Generated by GOATOOLS v{0} ({1})".format(version, datetime.date.today()))
533
        print("# min_ratio={0} pval={1}".format(min_ratio, pval))
534
535
        # field names for output
536
        if results:
537
            print("\t".join(GOEnrichmentStudy.get_prtflds_default(results)))
538
539
        for rec in results:
540
            # calculate some additional statistics
541
            # (over_under, is_ratio_different)
542
            rec.update_remaining_fldsdefprt(min_ratio=min_ratio)
543
544
            if pval is not None and rec.p_uncorrected >= pval:
545
                continue
546
547
            if rec.is_ratio_different:
548
                print(rec.__str__(indent=indent))
549
550
    def wr_py_goea_results(self, fout_py, goea_results, **kws):
551
        """Save GOEA results into Python package containing list of namedtuples."""
552
        var_name = kws.get("var_name", "goea_results")
553
        docstring = kws.get("docstring", "")
554
        sortby = kws.get("sortby", None)
555
        if goea_results:
556
            from goatools.nt_utils import wr_py_nts
557
            nts_goea = goea_results
558
            # If list has GOEnrichmentRecords or verbose namedtuples, exclude some fields.
559
            if hasattr(goea_results[0], "_fldsdefprt") or hasattr(goea_results[0], 'goterm'):
560
                # Exclude some attributes from the namedtuple when saving results
561
                # to a Python file because the information is redundant or verbose.
562
                nts_goea = get_goea_nts_prt(goea_results)
563
            docstring = "\n".join([docstring, "# {VER}\n\n".format(VER=self.obo_dag.version)])
564
            assert hasattr(nts_goea[0], '_fields')
565
            if sortby is None:
566
                sortby = lambda nt: getattr(nt, 'p_uncorrected')
567
            nts_goea = sorted(nts_goea, key=sortby)
568
            wr_py_nts(fout_py, nts_goea, docstring, var_name)
569
570
def get_study_items(goea_results):
571
    """Get all study items (e.g., geneids)."""
572
    study_items = set()
573
    for rec in goea_results:
574
        study_items |= rec.study_items
575
    return study_items
576
577
def get_goea_nts_prt(goea_results, fldnames=None, **usr_kws):
578
    """Return list of namedtuples removing fields which are redundant or verbose."""
579
    kws = usr_kws.copy()
580
    if 'not_fldnames' not in kws:
581
        kws['not_fldnames'] = ['goterm', 'parents', 'children', 'id']
582
    if 'rpt_fmt' not in kws:
583
        kws['rpt_fmt'] = True
584
    return get_goea_nts_all(goea_results, fldnames, **kws)
585
586
def get_goea_nts_all(goea_results, fldnames=None, **kws):
587
    """Get namedtuples containing user-specified (or default) data from GOEA results.
588
589
        Reformats data from GOEnrichmentRecord objects into lists of
590
        namedtuples so the generic table writers may be used.
591
    """
592
    data_nts = [] # A list of namedtuples containing GOEA results
593
    if not goea_results:
594
        return data_nts
595
    keep_if = kws.get('keep_if', None)
596
    rpt_fmt = kws.get('rpt_fmt', False)
597
    indent = kws.get('indent', False)
598
    # I. FIELD (column) NAMES
599
    not_fldnames = kws.get('not_fldnames', None)
600
    if fldnames is None:
601
        fldnames = get_fieldnames(goea_results[0])
602
    # Ia. Explicitly exclude specific fields from named tuple
603
    if not_fldnames is not None:
604
        fldnames = [f for f in fldnames if f not in not_fldnames]
605
    nttyp = cx.namedtuple("NtGoeaResults", " ".join(fldnames))
606
    goid_idx = fldnames.index("GO") if 'GO' in fldnames else None
607
    # II. Loop through GOEA results stored in a GOEnrichmentRecord object
608
    for goerec in goea_results:
609
        vals = get_field_values(goerec, fldnames, rpt_fmt)
610
        if indent:
611
            vals[goid_idx] = "".join([goerec.get_indent_dots(), vals[goid_idx]])
612
        ntobj = nttyp._make(vals)
613
        if keep_if is None or keep_if(ntobj):
614
            data_nts.append(ntobj)
615
    return data_nts
616
617
def get_field_values(item, fldnames, rpt_fmt=None):
618
    """Return fieldnames and values of either a namedtuple or GOEnrichmentRecord."""
619
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
620
        return item.get_field_values(fldnames, rpt_fmt)
621
    if hasattr(item, "_fields"): # Is a namedtuple
622
        return [getattr(item, f) for f in fldnames]
623
624
def get_fieldnames(item):
625
    """Return fieldnames of either a namedtuple or GOEnrichmentRecord."""
626
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
627
        return item.get_prtflds_all()
628
    if hasattr(item, "_fields"): # Is a namedtuple
629
        return item._fields
630
631
# Copyright (C) 2010-2017, H Tang et al., All rights reserved.
632