MDRepository.keys()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 2
rs 10
cc 1
1
"""
2
3
This is the implementation of the active repository of SAML metadata. The 'local' and 'remote' pipes operate on this.
4
5
"""
6
from StringIO import StringIO
7
from datetime import datetime
8
import hashlib
9
import urllib
10
from UserDict import DictMixin, UserDict
11
from lxml import etree
12
from lxml.builder import ElementMaker
13
from lxml.etree import DocumentInvalid
14
import os
15
import re
16
from copy import deepcopy
17
from pyff import merge_strategies
18
import pyff.index
19
from pyff.logs import log
20
from pyff.utils import schema, URLFetch, filter_lang, root, duration2timedelta, template
21
import xmlsec
22
from pyff.constants import NS, NF_URI, DIGESTS, EVENT_DROP_ENTITY, EVENT_IMPORTED_METADATA, EVENT_IMPORT_FAIL
23
import traceback
24
import threading
25
from Queue import Queue
26
27
28
__author__ = 'leifj'
29
30
31
def _is_self_signed_err(ebuf):
32
    for e in ebuf:
33
        if e['func'] == 'xmlSecOpenSSLX509StoreVerify' and re.match('err=18', e['message']):
34
            return True
35
    return False
36
37
38
etree.set_default_parser(etree.XMLParser(resolve_entities=False))
39
40
41
def _e(error_log, m=None):
42
    def _f(x):
43
        if ":WARNING:" in x:
44
            return False
45
        if m is not None and not m in x:
46
            return False
47
        return True
48
49
    return "\n".join(filter(_f, ["%s" % e for e in error_log]))
50
51
52
class MetadataException(Exception):
53
    pass
54
55
56
class Event(UserDict):
57
    pass
58
59
60
class Observable(object):
61
    def __init__(self):
62
        self.callbacks = []
63
64
    def subscribe(self, callback):
65
        self.callbacks.append(callback)
66
67
    def fire(self, **attrs):
68
        e = Event(attrs)
69
        e['time'] = datetime.now()
70
        for fn in self.callbacks:
71
            fn(e)
72
73
74
class MDRepository(DictMixin, Observable):
75
    """A class representing a set of SAML Metadata. Instances present as dict-like objects where
76
    the keys are URIs and values are EntitiesDescriptor elements containing sets of metadata.
77
    """
78
79
    def __init__(self, index=pyff.index.MemoryIndex(), metadata_cache_enabled=False, min_cache_ttl="PT5M"):
80
        self.md = {}
81
        self.index = index
82
        self.metadata_cache_enabled = metadata_cache_enabled
83
        self.min_cache_ttl = min_cache_ttl
84
        self.respect_cache_duration = True
85
        self.default_cache_duration = "PT10M"
86
        self.retry_limit = 5
87
88
        super(MDRepository, self).__init__()
89
90
    def is_idp(self, entity):
91
        """Returns True if the supplied EntityDescriptor has an IDPSSODescriptor Role
92
93
:param entity: An EntityDescriptor element
94
        """
95
        return bool(entity.find(".//{%s}IDPSSODescriptor" % NS['md']) is not None)
96
97
    def is_sp(self, entity):
98
        """Returns True if the supplied EntityDescriptor has an SPSSODescriptor Role
99
100
:param entity: An EntityDescriptor element
101
        """
102
        return bool(entity.find(".//{%s}SPSSODescriptor" % NS['md']) is not None)
103
104
    def display(self, entity):
105
        """Utility-method for computing a displayable string for a given entity.
106
107
:param entity: An EntityDescriptor element
108
        """
109
        for displayName in filter_lang(entity.findall(".//{%s}DisplayName" % NS['mdui'])):
110
            return displayName.text
111
112
        for serviceName in filter_lang(entity.findall(".//{%s}ServiceName" % NS['md'])):
113
            return serviceName.text
114
115
        for organizationDisplayName in filter_lang(entity.findall(".//{%s}OrganizationDisplayName" % NS['md'])):
116
            return organizationDisplayName.text
117
118
        for organizationName in filter_lang(entity.findall(".//{%s}OrganizationName" % NS['md'])):
119
            return organizationName.text
120
121
        return entity.get('entityID')
122
123
    def __iter__(self):
124
        for t in [self.md[url] for url in self.md.keys()]:
125
            for entity in t.findall(".//{%s}EntityDescriptor" % NS['md']):
126
                yield entity
127
128
    def sha1_id(self, e):
129
        return pyff.index.hash_id(e, 'sha1')
130
131
    def search(self, query, path=None, page=None, page_limit=10, entity_filter=None):
132
        """
133
:param query: A string to search for.
134
:param path: The repository collection (@Name) to search in - None for search in all collections
135
:param page:  When using paged search, the page index
136
:param page_limit: When using paged search, the maximum entry per page
137
:param entity_filter: A lookup expression used to filter the entries before search is done.
138
139
Returns a list of dict's for each EntityDescriptor present in the metadata store such
140
that any of the DisplayName, ServiceName, OrganizationName or OrganizationDisplayName
141
elements match the query (as in contains the query as a substring).
142
143
The dict in the list contains three items:
144
145
:param label: A displayable string, useful as a UI label
146
:param value: The entityID of the EntityDescriptor
147
:param id: A sha1-ID of the entityID - on the form {sha1}<sha1-hash-of-entityID>
148
        """
149
150
        def _strings(e):
151
            lst = [e.get('entityID')]
152
            for attr in ['.//{%s}DisplayName' % NS['mdui'],
153
                         './/{%s}ServiceName' % NS['md'],
154
                         './/{%s}OrganizationDisplayName' % NS['md'],
155
                         './/{%s}OrganizationName' % NS['md']]:
156
                lst.extend([x.text.lower() for x in e.findall(attr)])
157
            return filter(lambda s: s is not None, lst)
158
159
        def _match(query, e):
160
            #log.debug("looking for %s in %s" % (query,",".join(_strings(e))))
161
            for qstr in _strings(e):
162
                if query in qstr:
163
                    return True
164
            return False
165
166
        f = []
167
        if path is not None:
168
            f.append(path)
169
        if entity_filter is not None:
170
            f.append(entity_filter)
171
        mexpr = None
172
        if f:
173
            mexpr = "+".join(f)
174
175
        log.debug("mexpr: %s" % mexpr)
176
177
        res = [{'label': self.display(e),
178
                'value': e.get('entityID'),
179
                'id': pyff.index.hash_id(e, 'sha1')}
180
               for e in pyff.index.EntitySet(filter(lambda ent: _match(query, ent), self.lookup(mexpr)))]
181
182
        res.sort(key=lambda i: i['label'])
183
184
        log.debug(res)
185
186
        if page is not None:
187
            total = len(res)
188
            begin = (page - 1) * page_limit
189
            end = begin + page_limit
190
            more = (end < total)
191
            return res[begin:end], more, total
192
        else:
193
            return res
194
195
    def sane(self):
196
        """A very basic test for sanity. An empty metadata set is probably not a sane output of any process.
197
198
:return: True iff there is at least one EntityDescriptor in the active set.
199
        """
200
        return len(self.md) > 0
201
202
    def extensions(self, e):
203
        """Return a list of the Extensions elements in the EntityDescriptor
204
205
:param e: an EntityDescriptor
206
:return: a list
207
        """
208
        ext = e.find(".//{%s}Extensions" % NS['md'])
209
        if ext is None:
210
            ext = etree.Element("{%s}Extensions" % NS['md'])
211
            e.insert(0, ext)
212
        return ext
213
214
    def annotate(self, e, category, title, message, source=None):
215
        """Add an ATOM annotation to an EntityDescriptor or an EntitiesDescriptor. This is a simple way to
216
        add non-normative text annotations to metadata, eg for the purpuse of generating reports.
217
218
:param e: An EntityDescriptor or an EntitiesDescriptor element
219
:param category: The ATOM category
220
:param title: The ATOM title
221
:param message: The ATOM content
222
:param source: An optional source URL. It is added as a <link> element with @rel='saml-metadata-source'
223
        """
224
        if e.tag != "{%s}EntityDescriptor" % NS['md'] and e.tag != "{%s}EntitiesDescriptor" % NS['md']:
225
            raise MetadataException(
226
                "I can only annotate EntityDescriptor or EntitiesDescriptor elements")
227
        subject = e.get('Name', e.get('entityID', None))
228
        atom = ElementMaker(nsmap={
229
                            'atom': 'http://www.w3.org/2005/Atom'}, namespace='http://www.w3.org/2005/Atom')
230
        args = [atom.published("%s" % datetime.now().isoformat()),
231
                atom.link(href=subject, rel="saml-metadata-subject")]
232
        if source is not None:
233
            args.append(atom.link(href=source, rel="saml-metadata-source"))
234
        args.extend([atom.title(title),
235
                     atom.category(term=category),
236
                     atom.content(message, type="text/plain")])
237
        self.extensions(e).append(atom.entry(*args))
238
239
    def _entity_attributes(self, e):
240
        ext = self.extensions(e)
241
        # log.debug(ext)
242
        ea = ext.find(".//{%s}EntityAttributes" % NS['mdattr'])
243
        if ea is None:
244
            ea = etree.Element("{%s}EntityAttributes" % NS['mdattr'])
245
            ext.append(ea)
246
        return ea
247
248
    def _eattribute(self, e, attr, nf):
249
        ea = self._entity_attributes(e)
250
        # log.debug(ea)
251
        a = ea.xpath(
252
            ".//saml:Attribute[@NameFormat='%s' and @Name='%s']" % (nf, attr), namespaces=NS)
253
        if a is None or len(a) == 0:
254
            a = etree.Element("{%s}Attribute" % NS['saml'])
255
            a.set('NameFormat', nf)
256
            a.set('Name', attr)
257
            ea.append(a)
258
        else:
259
            a = a[0]
260
            # log.debug(etree.tostring(self.extensions(e)))
261
        return a
262
263
    def set_entity_attributes(self, e, d, nf=NF_URI):
264
        """Set an entity attribute on an EntityDescriptor
265
266
:param e: The EntityDescriptor element
267
:param d: A dict of attribute-value pairs that should be added as entity attributes
268
:param nf: The nameFormat (by default "urn:oasis:names:tc:SAML:2.0:attrname-format:uri") to use.
269
:raise: MetadataException unless e is an EntityDescriptor element
270
        """
271
        if e.tag != "{%s}EntityDescriptor" % NS['md']:
272
            raise MetadataException(
273
                "I can only add EntityAttribute(s) to EntityDescriptor elements")
274
275
        #log.debug("set %s" % d)
276
        for attr, value in d.iteritems():
277
            #log.debug("set %s to %s" % (attr,value))
278
            a = self._eattribute(e, attr, nf)
279
            # log.debug(etree.tostring(a))
280
            velt = etree.Element("{%s}AttributeValue" % NS['saml'])
281
            velt.text = value
282
            a.append(velt)
283
            # log.debug(etree.tostring(a))
284
285
    def fetch_metadata(self, resources, qsize=5, timeout=120, stats=None, xrd=None):
286
        """Fetch a series of metadata URLs and optionally verify signatures.
287
288
:param resources: A list of triples (url,cert-or-fingerprint,id)
289
:param qsize: The number of parallell downloads to run
290
:param timeout: The number of seconds to wait (120 by default) for each download
291
:param stats: A dictionary used for storing statistics. Useful for cherrypy cpstats
292
293
The list of triples is processed by first downloading the URL. If a cert-or-fingerprint
294
is supplied it is used to validate the signature on the received XML. Two forms of XML
295
is supported: SAML Metadata and XRD.
296
297
SAML metadata is (if valid and contains a valid signature) stored under the 'id'
298
identifier (which defaults to the URL unless provided in the triple.
299
300
XRD elements are processed thus: for all <Link> elements that contain a ds;KeyInfo
301
elements with a X509Certificate and where the <Rel> element contains the string
302
'urn:oasis:names:tc:SAML:2.0:metadata', the corresponding <URL> element is download
303
and verified.
304
        """
305
        if stats is None:
306
            stats = {}
307
308
        def producer(q, resources, cache=self.metadata_cache_enabled):
309
            print resources
310
            for url, verify, id, tries in resources:
311
                log.debug("starting fetcher for '%s'" % url)
312
                thread = URLFetch(
313
                    url, verify, id, enable_cache=cache, tries=tries)
314
                thread.start()
315
                q.put(thread, True)
316
317
        def consumer(q, njobs, stats, next_jobs=None, resolved=None):
318
            if next_jobs is None:
319
                next_jobs = []
320
            if resolved is None:
321
                resolved = set()
322
            nfinished = 0
323
324
            while nfinished < njobs:
325
                info = None
326
                try:
327
                    log.debug("waiting for next thread to finish...")
328
                    thread = q.get(True)
329
                    thread.join(timeout)
330
331
                    if thread.isAlive():
332
                        raise MetadataException(
333
                            "thread timeout fetching '%s'" % thread.url)
334
335
                    info = {
336
                        'Time Spent': thread.time()
337
                    }
338
339
                    if thread.ex is not None:
340
                        raise thread.ex
341
                    else:
342
                        if thread.result is not None:
343
                            info['Bytes'] = len(thread.result)
344
                        else:
345
                            raise MetadataException(
346
                                "empty response fetching '%s'" % thread.url)
347
                        info['Cached'] = thread.cached
348
                        info['Date'] = str(thread.date)
349
                        info['Last-Modified'] = str(thread.last_modified)
350
                        info['Tries'] = thread.tries
351
352
                    xml = thread.result.strip()
353
354
                    if thread.status is not None:
355
                        info['Status'] = thread.resp.status_code
356
357
                    t = self.parse_metadata(
358
                        StringIO(xml), key=thread.verify, base_url=thread.url)
359
                    if t is None:
360
                        self.fire(type=EVENT_IMPORT_FAIL, url=thread.url)
361
                        raise MetadataException(
362
                            "no valid metadata found at '%s'" % thread.url)
363
364
                    relt = root(t)
365
                    if relt.tag in ('{%s}XRD' % NS['xrd'], '{%s}XRDS' % NS['xrd']):
366
                        log.debug("%s looks like an xrd document" % thread.url)
367
                        for xrd in t.xpath("//xrd:XRD", namespaces=NS):
368
                            log.debug("xrd: %s" % xrd)
369
                            for link in xrd.findall(".//{%s}Link[@rel='%s']" % (NS['xrd'], NS['md'])):
370
                                url = link.get("href")
371
                                certs = xmlsec.CertDict(link)
372
                                fingerprints = certs.keys()
373
                                fp = None
374
                                if len(fingerprints) > 0:
375
                                    fp = fingerprints[0]
376
                                log.debug("fingerprint: %s" % fp)
377
                                next_jobs.append((url, fp, url, 0))
378
379
                    elif relt.tag in ('{%s}EntityDescriptor' % NS['md'], '{%s}EntitiesDescriptor' % NS['md']):
380
                        cacheDuration = self.default_cache_duration
381
                        if self.respect_cache_duration:
382
                            cacheDuration = root(t).get(
383
                                'cacheDuration', self.default_cache_duration)
384
                        offset = duration2timedelta(cacheDuration)
385
386
                        if thread.cached:
387
                            if thread.last_modified + offset < datetime.now() - duration2timedelta(self.min_cache_ttl):
388
                                raise MetadataException(
389
                                    "cached metadata expired")
390
                            else:
391
                                log.debug("found cached metadata for '%s' (last-modified: %s)" %
392
                                          (thread.url, thread.last_modified))
393
                                ne = self.import_metadata(t, url=thread.id)
394
                                info['Number of Entities'] = ne
395
                        else:
396
                            log.debug("got fresh metadata for '%s' (date: %s)" % (
397
                                thread.url, thread.date))
398
                            ne = self.import_metadata(t, url=thread.id)
399
                            info['Number of Entities'] = ne
400
                        info['Cache Expiration Time'] = str(
401
                            thread.last_modified + offset)
402
                        certs = xmlsec.CertDict(relt)
403
                        cert = None
404
                        if certs.values():
405
                            cert = certs.values()[0].strip()
406
                        resolved.add((thread.url, cert))
407
                    else:
408
                        raise MetadataException(
409
                            "unknown metadata type for '%s' (%s)" % (thread.url, relt.tag))
410
                except Exception, ex:
411
                    # traceback.print_exc(ex)
412
                    log.warn("problem fetching '%s' (will retry): %s" %
413
                             (thread.url, ex))
414
                    if info is not None:
415
                        info['Exception'] = ex
416
                    if thread.tries < self.retry_limit:
417
                        next_jobs.append(
418
                            (thread.url, thread.verify, thread.id, thread.tries + 1))
419
                    else:
420
                        # traceback.print_exc(ex)
421
                        log.error(
422
                            "retry limit exceeded for %s (last error was: %s)" % (thread.url, ex))
423
                finally:
424
                    nfinished += 1
425
                    if info is not None:
426
                        stats[thread.url] = info
427
428
        resources = [(url, verify, rid, 0) for url, verify, rid in resources]
429
        resolved = set()
430
        cache = True
431
        while len(resources) > 0:
432
            log.debug("fetching %d resources (%s)" %
433
                      (len(resources), repr(resources)))
434
            next_jobs = []
435
            q = Queue(qsize)
436
            prod_thread = threading.Thread(
437
                target=producer, args=(q, resources, cache))
438
            cons_thread = threading.Thread(target=consumer, args=(
439
                q, len(resources), stats, next_jobs, resolved))
440
            prod_thread.start()
441
            cons_thread.start()
442
            prod_thread.join()
443
            cons_thread.join()
444
            log.debug("after fetch: %d jobs to retry" % len(next_jobs))
445
            if len(next_jobs) > 0:
446
                resources = next_jobs
447
                cache = False
448
            else:
449
                resources = []
450
451
        if xrd is not None:
452
            with open(xrd, "w") as fd:
453
                fd.write(template("trust.xrd").render(links=resolved))
454
455
    def parse_metadata(self, fn, key=None, base_url=None, fail_on_error=False, filter_invalid=True):
456
        """Parse a piece of XML and split it up into EntityDescriptor elements. Each such element
457
        is stored in the MDRepository instance.
458
459
:param fn: a file-like object containing SAML metadata
460
:param key: a certificate (file) or a SHA1 fingerprint to use for signature verification
461
:param base_url: use this base url to resolve relative URLs for XInclude processing
462
        """
463
        try:
464
            t = etree.parse(fn, base_url=base_url,
465
                            parser=etree.XMLParser(resolve_entities=False))
466
            t.xinclude()
467
            if filter_invalid:
468
                for e in t.findall('{%s}EntityDescriptor' % NS['md']):
469
                    if not schema().validate(e):
470
                        error = _e(schema().error_log, m=base_url)
471
                        log.debug("removing '%s': schema validation failed (%s)" % (
472
                            e.get('entityID'), error))
473
                        e.getparent().remove(e)
474
                        self.fire(type=EVENT_DROP_ENTITY, url=base_url,
475
                                  entityID=e.get('entityID'), error=error)
476
            else:
477
                # Having removed the invalid entities this should now never
478
                # happen...
479
                schema().assertValid(t)
480
        except DocumentInvalid, ex:
481
            traceback.print_exc()
482
            log.debug("schema validation failed on '%s': %s" % (
483
                base_url, _e(ex.error_log, m=base_url)))
484
            raise MetadataException("schema validation failed")
485
        except Exception, ex:
486
            # log.debug(_e(schema().error_log))
487
            log.error(ex)
488
            if fail_on_error:
489
                raise ex
490
            return None
491
        if key is not None:
492
            try:
493
                log.debug("verifying signature using %s" % key)
494
                refs = xmlsec.verified(t, key)
495
                if len(refs) != 1:
496
                    raise MetadataException(
497
                        "XML metadata contains %d signatures - exactly 1 is required" % len(refs))
498
                t = refs[0]  # prevent wrapping attacks
499
            except Exception, ex:
500
                tb = traceback.format_exc()
501
                print tb
502
                log.error(ex)
503
                return None
504
505
        return t
506
507
    def _index_entity(self, e):
508
        #log.debug("adding %s to index" % e.get('entityID'))
509
        if 'ID' in e.attrib:
510
            del e.attrib['ID']
511
        self.index.add(e)
512
513
    def import_metadata(self, t, url=None):
514
        """
515
:param t: An EntitiesDescriptor element
516
:param url: An optional URL to used to identify the EntitiesDescriptor in the MDRepository
517
518
Import an EntitiesDescriptor element using the @Name attribute (or the supplied url parameter). All
519
EntityDescriptor elements are stripped of any @ID attribute and are then indexed before the collection
520
is stored in the MDRepository object.
521
        """
522
        if url is None:
523
            top = t.xpath("//md:EntitiesDescriptor", namespaces=NS)
524
            if top is not None and len(top) == 1:
525
                url = top[0].get("Name", None)
526
        if url is None:
527
            raise MetadataException("No collection name found")
528
        self[url] = t
529
        # we always clean incoming ID
530
        # add to the index
531
        ne = 0
532
533
        if t is not None:
534
            if root(t).tag == "{%s}EntityDescriptor" % NS['md']:
535
                self._index_entity(root(t))
536
                ne += 1
537
            else:
538
                for e in t.findall(".//{%s}EntityDescriptor" % NS['md']):
539
                    self._index_entity(e)
540
                    ne += 1
541
542
        self.fire(type=EVENT_IMPORTED_METADATA, size=ne, url=url)
543
        return ne
544
545
    def entities(self, t=None):
546
        """
547
:param t: An EntitiesDescriptor element
548
549
Returns the list of contained EntityDescriptor elements
550
        """
551
        if t is None:
552
            return []
553
        elif root(t).tag == "{%s}EntityDescriptor" % NS['md']:
554
            return [root(t)]
555
        else:
556
            return t.findall(".//{%s}EntityDescriptor" % NS['md'])
557
558
    def load_dir(self, directory, ext=".xml", url=None):
559
        """
560
:param directory: A directory to walk.
561
:param ext: Include files with this extension (default .xml)
562
563
Traverse a directory tree looking for metadata. Files ending in the specified extension are included. Directories
564
starting with '.' are excluded.
565
        """
566
        if url is None:
567
            url = directory
568
        log.debug("walking %s" % directory)
569
        if not directory in self.md:
570
            entities = []
571
            for top, dirs, files in os.walk(directory):
572
                for dn in dirs:
573
                    if dn.startswith("."):
574
                        dirs.remove(dn)
575
                for nm in files:
576
                    log.debug("found file %s" % nm)
577
                    if nm.endswith(ext):
578
                        fn = os.path.join(top, nm)
579
                        try:
580
                            t = self.parse_metadata(fn, fail_on_error=True)
581
                            # local metadata is assumed to be ok
582
                            entities.extend(self.entities(t))
583
                        except Exception, ex:
584
                            log.error(ex)
585
            self.import_metadata(self.entity_set(entities, url))
586
        return self.md[url]
587
588
    def _lookup(self, member, xp=None):
589
        """
590
:param member: Either an entity, URL or a filter expression.
591
592
Find a (set of) EntityDescriptor element(s) based on the specified 'member' expression.
593
        """
594
595
        def _hash(hn, strv):
596
            if hn == 'null':
597
                return strv
598
            if not hasattr(hashlib, hn):
599
                raise MetadataException("Unknown digest mechanism: '%s'" % hn)
600
            hash_m = getattr(hashlib, hn)
601
            h = hash_m()
602
            h.update(strv)
603
            return h.hexdigest()
604
605
        if xp is None:
606
            xp = "//md:EntityDescriptor"
607
        if member is None:
608
            lst = []
609
            for m in self.keys():
610
                log.debug("resolving %s filtered by %s" % (m, xp))
611
                lst.extend(self._lookup(m, xp))
612
            return lst
613
        elif hasattr(member, 'xpath'):
614
            log.debug("xpath filter %s <- %s" % (xp, member))
615
            return member.xpath(xp, namespaces=NS)
616
        elif type(member) is str or type(member) is unicode:
617
            log.debug("string lookup %s" % member)
618
619
            if '+' in member:
620
                member = member.strip('+')
621
                log.debug("lookup intersection of '%s'" %
622
                          ' and '.join(member.split('+')))
623
                hits = None
624
                for f in member.split("+"):
625
                    f = f.strip()
626
                    if hits is None:
627
                        hits = set(self._lookup(f, xp))
628
                    else:
629
                        other = self._lookup(f, xp)
630
                        hits.intersection_update(other)
631
632
                    if not hits:
633
                        log.debug("empty intersection")
634
                        return []
635
636
                if hits is not None and hits:
637
                    return list(hits)
638
                else:
639
                    return []
640
641
            if "!" in member:
642
                (src, xp) = member.split("!")
643
                if len(src) == 0:
644
                    src = None
645
                    log.debug("filtering using %s" % xp)
646
                else:
647
                    log.debug("selecting %s filtered by %s" % (src, xp))
648
                return self._lookup(src, xp)
649
650
            m = re.match("^\{(.+)\}(.+)$", member)
651
            if m is not None:
652
                log.debug("attribute-value match: %s='%s'" %
653
                          (m.group(1), m.group(2)))
654
                return self.index.get(m.group(1), m.group(2).rstrip("/"))
655
656
            m = re.match("^(.+)=(.+)$", member)
657
            if m is not None:
658
                log.debug("attribute-value match: %s='%s'" %
659
                          (m.group(1), m.group(2)))
660
                return self.index.get(m.group(1), m.group(2).rstrip("/"))
661
662
            log.debug("basic lookup %s" % member)
663
            for idx in DIGESTS:
664
                e = self.index.get(idx, member)
665
                if e:
666
                    log.debug("found %s in %s index" % (e, idx))
667
                    return e
668
669
            e = self.get(member, None)
670
            if e is not None:
671
                return self._lookup(e, xp)
672
673
            # hackish but helps save people from their misstakes
674
            e = self.get("%s.xml" % member, None)
675
            if e is not None:
676
                if not "://" in member:  # not an absolute URL
677
                    log.warn(
678
                        "Found %s.xml as an alias - AVOID extensions in 'select as' statements" % member)
679
                return self._lookup(e, xp)
680
681
            if "://" in member:  # looks like a URL and wasn't an entity or collection - recurse away!
682
                log.debug("recursively fetching members from '%s'" % member)
683
                # note that this supports remote lists which may be more rope
684
                # than is healthy
685
                return [self._lookup(line, xp) for line in urllib.urlopen(member).iterlines()]
686
687
            return []
688
        elif hasattr(member, '__iter__') and type(member) is not dict:
689
            if not len(member):
690
                member = self.keys()
691
            return [self._lookup(m, xp) for m in member]
692
        else:
693
            raise MetadataException("What about %s ??" % member)
694
695
    def lookup(self, member, xp=None):
696
        """
697
Lookup elements in the working metadata repository
698
699
:param member: A selector (cf below)
700
:type member: basestring
701
:param xp: An optional xpath filter
702
:type xp: basestring
703
:return: An interable of EntityDescriptor elements
704
:rtype: etree.Element
705
706
**Selector Syntax**
707
708
    - selector "+" selector
709
    - [sourceID] "!" xpath
710
    - attribute=value or {attribute}value
711
    - entityID
712
    - sourceID (@Name)
713
    - <URL containing one selector per line>
714
715
The first form results in the intersection of the results of doing a lookup on the selectors. The second form
716
results in the EntityDescriptor elements from the source (defaults to all EntityDescriptors) that match the
717
xpath expression. The attribute-value forms resuls in the EntityDescriptors that contain the specified entity
718
attribute pair. If non of these forms apply, the lookup is done using either source ID (normally @Name from
719
the EntitiesDescriptor) or the entityID of single EntityDescriptors. If member is a URI but isn't part of
720
the metadata repository then it is fetched an treated as a list of (one per line) of selectors. If all else
721
fails an empty list is returned.
722
723
        """
724
        l = self._lookup(member, xp)
725
        return list(set(filter(lambda x: x is not None, l)))
726
727
    def entity_set(self, entities, name, cacheDuration=None, validUntil=None, validate=True):
728
        """
729
:param entities: a set of entities specifiers (lookup is used to find entities from this set)
730
:param name: the @Name attribute
731
:param cacheDuration: an XML timedelta expression, eg PT1H for 1hr
732
:param validUntil: a relative time eg 2w 4d 1h for 2 weeks, 4 days and 1hour from now.
733
734
Produce an EntityDescriptors set from a list of entities. Optional Name, cacheDuration and validUntil are affixed.
735
        """
736
        attrs = dict(Name=name, nsmap=NS)
737
        if cacheDuration is not None:
738
            attrs['cacheDuration'] = cacheDuration
739
        if validUntil is not None:
740
            attrs['validUntil'] = validUntil
741
        t = etree.Element("{%s}EntitiesDescriptor" % NS['md'], **attrs)
742
        nent = 0
743
        seen = {}  # TODO make better de-duplication
744
        for member in entities:
745
            for ent in self.lookup(member):
746
                entityID = ent.get('entityID', None)
747
                if (ent is not None) and (entityID is not None) and (not seen.get(entityID, False)):
748
                    t.append(deepcopy(ent))
749
                    seen[entityID] = True
750
                    nent += 1
751
752
        log.debug("selecting %d entities from %d entity set(s) before validation" % (
753
            nent, len(entities)))
754
755
        if not nent:
756
            return None
757
758
        if validate:
759
            try:
760
                schema().assertValid(t)
761
            except DocumentInvalid, ex:
762
                log.debug(_e(ex.error_log))
763
                #raise MetadataException(
764
                #    "XML schema validation failed: %s" % name)
765
        return t
766
767
    def error_set(self, url, title, ex):
768
        """
769
Creates an "error" EntitiesDescriptor - empty but for an annotation about the error that occured
770
        """
771
        t = etree.Element("{%s}EntitiesDescriptor" %
772
                          NS['md'], Name=url, nsmap=NS)
773
        self.annotate(t, "error", title, ex, source=url)
774
775
    def keys(self):
776
        return self.md.keys()
777
778
    def __getitem__(self, item):
779
        return self.md[item]
780
781
    def __setitem__(self, key, value):
782
        self.md[key] = value
783
784
    def __delitem__(self, key):
785
        del self.md[key]
786
787
    def summary(self, uri):
788
        """
789
:param uri: An EntitiesDescriptor URI present in the MDRepository
790
:return: an information dict
791
792
Returns a dict object with basic information about the EntitiesDescriptor
793
        """
794
        seen = dict()
795
        info = dict()
796
        t = root(self[uri])
797
        info['Name'] = t.get('Name', uri)
798
        info['cacheDuration'] = t.get('cacheDuration', None)
799
        info['validUntil'] = t.get('validUntil', None)
800
        info['Duplicates'] = []
801
        info['Size'] = 0
802
        for e in self.entities(self[uri]):
803
            entityID = e.get('entityID')
804
            if seen.get(entityID, False):
805
                info['Duplicates'].append(entityID)
806
            else:
807
                seen[entityID] = True
808
            info['Size'] += 1
809
810
        return info
811
812
    def merge(self, t, nt, strategy=pyff.merge_strategies.replace_existing, strategy_name=None):
813
        """
814
:param t: The EntitiesDescriptor element to merge *into*
815
:param nt:  The EntitiesDescriptor element to merge *from*
816
:param strategy: A callable implementing the merge strategy pattern
817
:param strategy_name: The name of a strategy to import. Overrides the callable if present.
818
:return:
819
820
Two EntitiesDescriptor elements are merged - the second into the first. For each element
821
in the second collection that is present (using the @entityID attribute as key) in the
822
first the strategy callable is called with the old and new EntityDescriptor elements
823
as parameters. The strategy callable thus must implement the following pattern:
824
825
:param old_e: The EntityDescriptor from t
826
:param e: The EntityDescriptor from nt
827
:return: A merged EntityDescriptor element
828
829
Before each call to strategy old_e is removed from the MDRepository index and after
830
merge the resultant EntityDescriptor is added to the index before it is used to
831
replace old_e in t.
832
        """
833
        if strategy_name is not None:
834
            if not '.' in strategy_name:
835
                strategy_name = "pyff.merge_strategies.%s" % strategy_name
836
            (mn, sep, fn) = strategy_name.rpartition('.')
837
            #log.debug("import %s from %s" % (fn,mn))
838
            module = None
839
            if '.' in mn:
840
                (pn, sep, modn) = mn.rpartition('.')
841
                module = getattr(__import__(
842
                    pn, globals(), locals(), [modn], -1), modn)
843
            else:
844
                module = __import__(mn, globals(), locals(), [], -1)
845
            # we might aswell let this fail early if the strategy is wrongly
846
            # named
847
            strategy = getattr(module, fn)
848
849
        if strategy is None:
850
            raise MetadataException("No merge strategy - refusing to merge")
851
852
        for e in nt.findall(".//{%s}EntityDescriptor" % NS['md']):
853
            entityID = e.get("entityID")
854
            # we assume ddup:ed tree
855
            old_e = t.find(
856
                ".//{%s}EntityDescriptor[@entityID='%s']" % (NS['md'], entityID))
857
            #log.debug("merging %s into %s" % (e,old_e))
858
            # update index!
859
860
            try:
861
                self.index.remove(old_e)
862
                #log.debug("removed old entity from index")
863
                strategy(old_e, e)
864
                new_e = t.find(
865
                    ".//{%s}EntityDescriptor[@entityID='%s']" % (NS['md'], entityID))
866
                if new_e:
867
                    # we don't know which strategy was employed
868
                    self.index.add(new_e)
869
            except Exception, ex:
870
                traceback.print_exc()
871
                self.index.add(old_e)
872
                raise ex
873