Passed
Push — master ( f5a0bf...c46b9b )
by Matt
01:38
created

Table.check_signature()   B

Complexity

Conditions 7

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 7
c 1
b 0
f 0
dl 0
loc 20
rs 7.3333
1
# MIT License
2
#
3
# Copyright (c) 2017 Matt Boyer
4
#
5
# Permission is hereby granted, free of charge, to any person obtaining a copy
6
# of this software and associated documentation files (the "Software"), to deal
7
# in the Software without restriction, including without limitation the rights
8
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
# copies of the Software, and to permit persons to whom the Software is
10
# furnished to do so, subject to the following conditions:
11
#
12
# The above copyright notice and this permission notice shall be included in
13
# all copies or substantial portions of the Software.
14
#
15
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
# SOFTWARE.
22
23
from . import constants
24
from . import PROJECT_NAME, PROJECT_DESCRIPTION, USER_JSON_PATH, BUILTIN_JSON
25
26
import argparse
27
import base64
28
import collections
29
import csv
30
import json
31
import logging
32
import os
33
import os.path
34
import pdb
35
import pkg_resources
36
import re
37
import shutil
38
import sqlite3
39
import stat
40
import struct
41
import tempfile
42
43
44
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
45
_LOGGER = logging.getLogger('SQLite recovery')
46
_LOGGER.setLevel(logging.INFO)
47
48
49
SQLite_header = collections.namedtuple('SQLite_header', (
50
    'magic',
51
    'page_size',
52
    'write_format',
53
    'read_format',
54
    'reserved_length',
55
    'max_payload_fraction',
56
    'min_payload_fraction',
57
    'leaf_payload_fraction',
58
    'file_change_counter',
59
    'size_in_pages',
60
    'first_freelist_trunk',
61
    'freelist_pages',
62
    'schema_cookie',
63
    'schema_format',
64
    'default_page_cache_size',
65
    'largest_btree_page',
66
    'text_encoding',
67
    'user_version',
68
    'incremental_vacuum',
69
    'application_id',
70
    'version_valid',
71
    'sqlite_version',
72
))
73
74
75
SQLite_btree_page_header = collections.namedtuple('SQLite_btree_page_header', (
76
    'page_type',
77
    'first_freeblock_offset',
78
    'num_cells',
79
    'cell_content_offset',
80
    'num_fragmented_free_bytes',
81
    'right_most_page_idx',
82
))
83
84
85
SQLite_ptrmap_info = collections.namedtuple('SQLite_ptrmap_info', (
86
    'page_idx',
87
    'page_type',
88
    'page_ptr',
89
))
90
91
92
SQLite_record_field = collections.namedtuple('SQLite_record_field', (
93
    'col_type',
94
    'col_type_descr',
95
    'field_length',
96
    'field_bytes',
97
))
98
99
100
SQLite_master_record = collections.namedtuple('SQLite_master_record', (
101
    'type',
102
    'name',
103
    'tbl_name',
104
    'rootpage',
105
    'sql',
106
))
107
108
109
type_specs = {
110
    'INTEGER': int,
111
    'TEXT': str,
112
    'VARCHAR': str,
113
    'LONGVARCHAR': str,
114
    'REAL': float,
115
    'FLOAT': float,
116
    'LONG': int,
117
    'BLOB': bytes,
118
}
119
120
121
heuristics = {}
122
signatures = {}
123
124
125
def heuristic_factory(magic, offset):
126
    assert(isinstance(magic, bytes))
127
    assert(isinstance(offset, int))
128
    assert(offset >= 0)
129
130
    # We only need to compile the regex once
131
    magic_re = re.compile(magic)
132
133
    def generic_heuristic(freeblock_bytes):
134
        all_matches = [match for match in magic_re.finditer(freeblock_bytes)]
135
        for magic_match in all_matches[::-1]:
136
            header_start = magic_match.start()-offset
137
            if header_start < 0:
138
                _LOGGER.debug("Header start outside of freeblock!")
139
                break
140
            yield header_start
141
    return generic_heuristic
142
143
144
def load_heuristics():
145
146
    def _load_from_json(raw_json):
147
        if isinstance(raw_json, bytes):
148
            raw_json = raw_json.decode('utf-8')
149
        for table_name, table_props in json.loads(raw_json).items():
150
            magic = base64.standard_b64decode(
151
                table_props['magic']
152
            )
153
            heuristics[table_name] = heuristic_factory(
154
                magic, table_props['offset']
155
            )
156
            _LOGGER.debug("Loaded heuristics for \"%s\"", table_name)
157
158
    with pkg_resources.resource_stream(PROJECT_NAME, BUILTIN_JSON) as builtin:
159
        _load_from_json(builtin.read())
160
161
    if not os.path.exists(USER_JSON_PATH):
162
        return
163
    with open(USER_JSON_PATH, 'r') as user_json:
164
        _load_from_json(user_json.read())
165
166
167
class IndexDict(dict):
168
    def __iter__(self):
169
        for k in sorted(self.keys()):
170
            yield k
171
172
173
class SQLite_DB(object):
174
    def __init__(self, path):
175
        self._path = path
176
        self._page_types = {}
177
        self._header = self.parse_header()
178
179
        self._page_cache = None
180
        # Actual page objects go here
181
        self._pages = {}
182
        self.build_page_cache()
183
184
        self._ptrmap = {}
185
186
        # TODO Do we need all of these?
187
        self._table_roots = {}
188
        self._page_tables = {}
189
        self._tables = {}
190
        self._table_columns = {}
191
        self._freelist_leaves = []
192
        self._freelist_btree_pages = []
193
194
    @property
195
    def ptrmap(self):
196
        return self._ptrmap
197
198
    @property
199
    def header(self):
200
        return self._header
201
202
    @property
203
    def pages(self):
204
        return self._pages
205
206
    @property
207
    def tables(self):
208
        return self._tables
209
210
    @property
211
    def freelist_leaves(self):
212
        return self._freelist_leaves
213
214
    @property
215
    def table_columns(self):
216
        return self._table_columns
217
218
    def page_bytes(self, page_idx):
219
        try:
220
            return self._page_cache[page_idx]
221
        except KeyError:
222
            raise ValueError("No cache for page %d", page_idx)
223
224
    def map_table_page(self, page_idx, table):
225
        assert isinstance(page_idx, int)
226
        assert isinstance(table, Table)
227
        self._page_tables[page_idx] = table
228
229
    def get_page_table(self, page_idx):
230
        assert isinstance(page_idx, int)
231
        try:
232
            return self._page_tables[page_idx]
233
        except KeyError:
234
            return None
235
236
    def __repr__(self):
237
        return '<SQLite DB, page count: {} | page size: {}>'.format(
238
            self.header.size_in_pages,
239
            self.header.page_size
240
        )
241
242
    def parse_header(self):
243
        header_bytes = None
244
        file_size = None
245
        with open(self._path, 'br') as sqlite:
246
            header_bytes = sqlite.read(100)
247
            file_size = os.fstat(sqlite.fileno())[stat.ST_SIZE]
248
249
        if not header_bytes:
250
            raise ValueError("Couldn't read SQLite header")
251
        assert isinstance(header_bytes, bytes)
252
        # This DB header is always big-endian
253
        fields = SQLite_header(*struct.unpack(
254
            r'>16sHBBBBBBIIIIIIIIIIII20xII',
255
            header_bytes[:100]
256
        ))
257
        assert fields.page_size in constants.VALID_PAGE_SIZES
258
        db_size = fields.page_size * fields.size_in_pages
259
        assert db_size <= file_size
260
        assert (fields.page_size > 0) and \
261
            (fields.file_change_counter == fields.version_valid)
262
263
        if file_size < 1073741824:
264
            _LOGGER.debug("No lock-byte page in this file!")
265
266
        if fields.first_freelist_trunk > 0:
267
            self._page_types[fields.first_freelist_trunk] = \
268
                constants.FREELIST_TRUNK_PAGE
269
        _LOGGER.debug(fields)
270
        return fields
271
272
    def build_page_cache(self):
273
        # The SQLite docs use a numbering convention for pages where the
274
        # first page (the one that has the header) is page 1, with the next
275
        # ptrmap page being page 2, etc.
276
        page_cache = [None, ]
277
        with open(self._path, 'br') as sqlite:
278
            for page_idx in range(self._header.size_in_pages):
279
                page_offset = page_idx * self._header.page_size
280
                sqlite.seek(page_offset, os.SEEK_SET)
281
                page_cache.append(sqlite.read(self._header.page_size))
282
        self._page_cache = page_cache
283
        for page_idx in range(1, len(self._page_cache)):
284
            # We want these to be temporary objects, to be replaced with
285
            # more specialised objects as parsing progresses
286
            self._pages[page_idx] = Page(page_idx, self)
287
288
    def populate_freelist_pages(self):
289
        if 0 == self._header.first_freelist_trunk:
290
            _LOGGER.debug("This database has no freelist trunk page")
291
            return
292
293
        _LOGGER.info("Parsing freelist pages")
294
        parsed_trunks = 0
295
        parsed_leaves = 0
296
        freelist_trunk_idx = self._header.first_freelist_trunk
297
298
        while freelist_trunk_idx != 0:
299
            _LOGGER.debug(
300
                "Parsing freelist trunk page %d",
301
                freelist_trunk_idx
302
            )
303
            trunk_bytes = bytes(self.pages[freelist_trunk_idx])
304
305
            next_freelist_trunk_page_idx, num_leaf_pages = struct.unpack(
306
                r'>II',
307
                trunk_bytes[:8]
308
            )
309
310
            # Now that we know how long the array of freelist page pointers is,
311
            # let's read it again
312
            trunk_array = struct.unpack(
313
                r'>{count}I'.format(count=2+num_leaf_pages),
314
                trunk_bytes[:(4*(2+num_leaf_pages))]
315
            )
316
317
            # We're skipping the first entries as they are realy the next trunk
318
            # index and the leaf count
319
            # TODO Fix that
320
            leaves_in_trunk = []
321
            for page_idx in trunk_array[2:]:
322
                # Let's prepare a specialised object for this freelist leaf
323
                # page
324
                leaf_page = FreelistLeafPage(
325
                    page_idx, self, freelist_trunk_idx
326
                )
327
                leaves_in_trunk.append(leaf_page)
328
                self._freelist_leaves.append(page_idx)
329
                self._pages[page_idx] = leaf_page
330
331
                self._page_types[page_idx] = constants.FREELIST_LEAF_PAGE
332
333
            trunk_page = FreelistTrunkPage(
334
                freelist_trunk_idx,
335
                self,
336
                leaves_in_trunk
337
            )
338
            self._pages[freelist_trunk_idx] = trunk_page
339
            # We've parsed this trunk page
340
            parsed_trunks += 1
341
            # ...And every leaf in it
342
            parsed_leaves += num_leaf_pages
343
344
            freelist_trunk_idx = next_freelist_trunk_page_idx
345
346
        assert (parsed_trunks + parsed_leaves) == self._header.freelist_pages
347
        _LOGGER.info(
348
            "Freelist summary: %d trunk pages, %d leaf pages",
349
            parsed_trunks,
350
            parsed_leaves
351
        )
352
353
    def populate_overflow_pages(self):
354
        # Knowledge of the overflow pages can come from the pointer map (easy),
355
        # or the parsing of individual cells in table leaf pages (hard)
356
        #
357
        # For now, assume we already have a page type dict populated from the
358
        # ptrmap
359
        _LOGGER.info("Parsing overflow pages")
360
        overflow_count = 0
361
        for page_idx in sorted(self._page_types):
362
            page_type = self._page_types[page_idx]
363
            if page_type not in constants.OVERFLOW_PAGE_TYPES:
364
                continue
365
            overflow_page = OverflowPage(page_idx, self)
366
            self.pages[page_idx] = overflow_page
367
            overflow_count += 1
368
369
        _LOGGER.info("Overflow summary: %d pages", overflow_count)
370
371
    def populate_ptrmap_pages(self):
372
        if self._header.largest_btree_page == 0:
373
            # We don't have ptrmap pages in this DB. That sucks.
374
            _LOGGER.warning("%r does not have ptrmap pages!", self)
375
            for page_idx in range(1, self._header.size_in_pages):
376
                self._page_types[page_idx] = constants.UNKNOWN_PAGE
377
            return
378
379
        _LOGGER.info("Parsing ptrmap pages")
380
381
        ptrmap_page_idx = 2
382
        usable_size = self._header.page_size - self._header.reserved_length
383
        num_ptrmap_entries_in_page = usable_size // 5
384
        ptrmap_page_indices = []
385
386
        ptrmap_page_idx = 2
387
        while ptrmap_page_idx <= self._header.size_in_pages:
388
            page_bytes = self._page_cache[ptrmap_page_idx]
389
            ptrmap_page_indices.append(ptrmap_page_idx)
390
            self._page_types[ptrmap_page_idx] = constants.PTRMAP_PAGE
391
            page_ptrmap_entries = {}
392
393
            ptrmap_bytes = page_bytes[:5 * num_ptrmap_entries_in_page]
394
            for entry_idx in range(num_ptrmap_entries_in_page):
395
                ptr_page_idx = ptrmap_page_idx + entry_idx + 1
396
                page_type, page_ptr = struct.unpack(
397
                    r'>BI',
398
                    ptrmap_bytes[5*entry_idx:5*(entry_idx+1)]
399
                )
400
                if page_type == 0:
401
                    break
402
403
                ptrmap_entry = SQLite_ptrmap_info(
404
                    ptr_page_idx, page_type, page_ptr
405
                )
406
                assert ptrmap_entry.page_type in constants.PTRMAP_PAGE_TYPES
407
                if page_type == constants.BTREE_ROOT_PAGE:
408
                    assert page_ptr == 0
409
                    self._page_types[ptr_page_idx] = page_type
410
411
                elif page_type == constants.FREELIST_PAGE:
412
                    # Freelist pages are assumed to be known already
413
                    assert self._page_types[ptr_page_idx] in \
414
                        constants.FREELIST_PAGE_TYPES
415
                    assert page_ptr == 0
416
417
                elif page_type == constants.FIRST_OFLOW_PAGE:
418
                    assert page_ptr != 0
419
                    self._page_types[ptr_page_idx] = page_type
420
421
                elif page_type == constants.NON_FIRST_OFLOW_PAGE:
422
                    assert page_ptr != 0
423
                    self._page_types[ptr_page_idx] = page_type
424
425
                elif page_type == constants.BTREE_NONROOT_PAGE:
426
                    assert page_ptr != 0
427
                    self._page_types[ptr_page_idx] = page_type
428
429
                # _LOGGER.debug("%r", ptrmap_entry)
430
                self._ptrmap[ptr_page_idx] = ptrmap_entry
431
                page_ptrmap_entries[ptr_page_idx] = ptrmap_entry
432
433
            page = PtrmapPage(ptrmap_page_idx, self, page_ptrmap_entries)
434
            self._pages[ptrmap_page_idx] = page
435
            _LOGGER.debug("%r", page)
436
            ptrmap_page_idx += num_ptrmap_entries_in_page + 1
437
438
        _LOGGER.info(
439
            "Ptrmap summary: %d pages, %r",
440
            len(ptrmap_page_indices), ptrmap_page_indices
441
        )
442
443
    def populate_btree_pages(self):
444
        # TODO Should this use table information instead of scanning all pages?
445
        page_idx = 1
446
        while page_idx <= self._header.size_in_pages:
447
            try:
448
                if self._page_types[page_idx] in \
449
                        constants.NON_BTREE_PAGE_TYPES:
450
                    page_idx += 1
451
                    continue
452
            except KeyError:
453
                pass
454
455
            try:
456
                page_obj = BTreePage(page_idx, self)
457
            except ValueError:
458
                # This page isn't a valid btree page. This can happen if we
459
                # don't have a ptrmap to guide us
460
                _LOGGER.warning(
461
                    "Page %d (%s) is not a btree page",
462
                    page_idx,
463
                    self._page_types[page_idx]
464
                )
465
                page_idx += 1
466
                continue
467
468
            page_obj.parse_cells()
469
            self._page_types[page_idx] = page_obj.page_type
470
            self._pages[page_idx] = page_obj
471
            page_idx += 1
472
473
    def _parse_master_leaf_page(self, page):
474
        for cell_idx in page.cells:
475
            _, master_record = page.cells[cell_idx]
476
            assert isinstance(master_record, Record)
477
            fields = [
478
                master_record.fields[idx].value for idx in master_record.fields
479
            ]
480
            master_record = SQLite_master_record(*fields)
481
            if 'table' != master_record.type:
482
                continue
483
484
            self._table_roots[master_record.name] = \
485
                self.pages[master_record.rootpage]
486
487
            # This record describes a table in the schema, which means it
488
            # includes a SQL statement that defines the table's columns
489
            # We need to parse the field names out of that statement
490
            assert master_record.sql.startswith('CREATE TABLE')
491
            columns_re = re.compile(r'^CREATE TABLE (\S+) \((.*)\)$')
492
            match = columns_re.match(master_record.sql)
493
            if match:
494
                assert match.group(1) == master_record.name
495
                column_list = match.group(2)
496
                csl_between_parens_re = re.compile(r'\([^)]+\)')
497
                expunged = csl_between_parens_re.sub('', column_list)
498
499
                cols = [
500
                    statement.strip() for statement in expunged.split(',')
501
                ]
502
                cols = [
503
                    statement for statement in cols if not (
504
                        statement.startswith('PRIMARY') or
505
                        statement.startswith('UNIQUE')
506
                    )
507
                ]
508
                columns = [col.split()[0] for col in cols]
509
                signature = []
510
511
                # Some column definitions lack a type
512
                for col_def in cols:
513
                    def_tokens = col_def.split()
514
                    try:
515
                        col_type = def_tokens[1]
516
                    except IndexError:
517
                        signature.append(object)
518
                        continue
519
520
                    _LOGGER.debug(
521
                        "Column \"%s\" is defined as \"%s\"",
522
                        def_tokens[0], col_type
523
                    )
524
                    try:
525
                        signature.append(type_specs[col_type])
526
                    except KeyError:
527
                        _LOGGER.warning("No native type for \"%s\"", col_def)
528
                        signature.append(object)
529
                _LOGGER.info(
530
                    "Signature for table \"%s\": %r",
531
                    master_record.name, signature
532
                )
533
                signatures[master_record.name] = signature
534
535
                _LOGGER.info(
536
                    "Columns for table \"%s\": %r",
537
                    master_record.name, columns
538
                )
539
                self._table_columns[master_record.name] = columns
540
541
    def map_tables(self):
542
        first_page = self.pages[1]
543
        assert isinstance(first_page, BTreePage)
544
545
        master_table = Table('sqlite_master', self, first_page)
546
        self._table_columns.update(constants.SQLITE_TABLE_COLUMNS)
547
548
        for master_leaf in master_table.leaves:
549
            self._parse_master_leaf_page(master_leaf)
550
551
        assert all(
552
            isinstance(root, BTreePage) for root in self._table_roots.values()
553
        )
554
        assert all(
555
            root.parent is None for root in self._table_roots.values()
556
        )
557
558
        self.map_table_page(1, master_table)
559
        self._table_roots['sqlite_master'] = self.pages[1]
560
561
        for table_name, rootpage in self._table_roots.items():
562
            try:
563
                table_obj = Table(table_name, self, rootpage)
564
            except Exception as ex:  # pylint:disable=W0703
565
                pdb.set_trace()
566
                _LOGGER.warning(
567
                    "Caught %r while instantiating table object for \"%s\"",
568
                    ex, table_name
569
                )
570
            else:
571
                self._tables[table_name] = table_obj
572
573
    def reparent_orphaned_table_leaf_pages(self):
574
        reparented_pages = []
575
        for page in self.pages.values():
576
            if not isinstance(page, BTreePage):
577
                continue
578
            if page.page_type != "Table Leaf":
579
                continue
580
581
            table = page.table
582
            if not table:
583
                parent = page
584
                root_table = None
585
                while parent:
586
                    root_table = parent.table
587
                    parent = parent.parent
588
                if root_table is None:
589
                    self._freelist_btree_pages.append(page)
590
591
                if root_table is None:
592
                    if not page.cells:
593
                        continue
594
595
                    first_record = page.cells[0][1]
596
                    matches = []
597
                    for table_name in signatures:
598
                        # All records within a given page are for the same
599
                        # table
600
                        if self.tables[table_name].check_signature(
601
                                first_record):
602
                            matches.append(self.tables[table_name])
603
                    if not matches:
604
                        _LOGGER.error(
605
                            "Couldn't find a matching table for %r",
606
                            page
607
                        )
608
                        continue
609
                    if len(matches) > 1:
610
                        _LOGGER.error(
611
                            "Multiple matching tables for %r: %r",
612
                            page, matches
613
                        )
614
                        continue
615
                    elif len(matches) == 1:
616
                        root_table = matches[0]
617
618
                _LOGGER.debug(
619
                    "Reparenting %r to table \"%s\"",
620
                    page, root_table.name
621
                )
622
                root_table.add_leaf(page)
623
                self.map_table_page(page.idx, root_table)
624
                reparented_pages.append(page)
625
626
        if reparented_pages:
627
            _LOGGER.info(
628
                "Reparented %d pages: %r",
629
                len(reparented_pages), [p.idx for p in reparented_pages]
630
            )
631
632
    def grep(self, needle):
633
        match_found = False
634
        page_idx = 1
635
        needle_re = re.compile(needle.encode('utf-8'))
636
        while (page_idx <= self.header.size_in_pages):
637
            page = self.pages[page_idx]
638
            page_offsets = []
639
            for match in needle_re.finditer(bytes(page)):
640
                needle_offset = match.start()
641
                page_offsets.append(needle_offset)
642
            if page_offsets:
643
                _LOGGER.info(
644
                    "Found search term in page %r @ offset(s) %s",
645
                    page, ', '.join(str(offset) for offset in page_offsets)
646
                )
647
            page_idx += 1
648
        if not match_found:
649
            _LOGGER.warning(
650
                "Search term not found",
651
            )
652
653
654
class Table(object):
655
    def __init__(self, name, db, rootpage):
656
        self._name = name
657
        self._db = db
658
        assert(isinstance(rootpage, BTreePage))
659
        self._root = rootpage
660
        self._leaves = []
661
        try:
662
            self._columns = self._db.table_columns[self.name]
663
        except KeyError:
664
            self._columns = None
665
666
        # We want this to be a list of leaf-type pages, sorted in the order of
667
        # their smallest rowid
668
        self._populate_pages()
669
670
    @property
671
    def name(self):
672
        return self._name
673
674
    def add_leaf(self, leaf_page):
675
        self._leaves.append(leaf_page)
676
677
    @property
678
    def columns(self):
679
        return self._columns
680
681
    def __repr__(self):
682
        return "<SQLite table \"{}\", root: {}, leaves: {}>".format(
683
            self.name, self._root.idx, len(self._leaves)
684
        )
685
686
    def _populate_pages(self):
687
        _LOGGER.info("Page %d is root for %s", self._root.idx, self.name)
688
        table_pages = [self._root]
689
690
        if self._root.btree_header.right_most_page_idx is not None:
691
            rightmost_idx = self._root.btree_header.right_most_page_idx
692
            rightmost_page = self._db.pages[rightmost_idx]
693
            if rightmost_page is not self._root:
694
                _LOGGER.info(
695
                    "Page %d is rightmost for %s",
696
                    rightmost_idx, self.name
697
                )
698
                table_pages.append(rightmost_page)
699
700
        page_queue = list(table_pages)
701
        while page_queue:
702
            table_page = page_queue.pop(0)
703
            # table_pages is initialised with the table's rootpage, which
704
            # may be a leaf page for a very small table
705
            if table_page.page_type != 'Table Interior':
706
                self._leaves.append(table_page)
707
                continue
708
709
            for cell_idx in table_page.cells:
710
                page_ptr, max_row_in_page = table_page.cells[cell_idx]
711
712
                page = self._db.pages[page_ptr]
713
                _LOGGER.debug("B-Tree cell: (%r, %d)", page, max_row_in_page)
714
                table_pages.append(page)
715
                if page.page_type == 'Table Interior':
716
                    page_queue.append(page)
717
                elif page.page_type == 'Table Leaf':
718
                    self._leaves.append(page)
719
720
        assert(all(p.page_type == 'Table Leaf' for p in self._leaves))
721
        for page in table_pages:
722
            self._db.map_table_page(page.idx, self)
723
724
    @property
725
    def leaves(self):
726
        for leaf_page in self._leaves:
727
            yield leaf_page
728
729
    def recover_records(self):
730
        for page in self.leaves:
731
            assert isinstance(page, BTreePage)
732
            if not page.freeblocks:
733
                continue
734
735
            _LOGGER.info("%r", page)
736
            page.recover_freeblock_records()
737
            page.print_recovered_records()
738
739
    def csv_dump(self, out_dir):
740
        csv_path = os.path.join(out_dir, self.name + '.csv')
741
        if os.path.exists(csv_path):
742
            raise ValueError("Output file {} exists!".format(csv_path))
743
744
        _LOGGER.info("Dumping table \"%s\" to CSV", self.name)
745
        with tempfile.TemporaryFile('w+', newline='') as csv_temp:
746
            writer = csv.DictWriter(csv_temp, fieldnames=self._columns)
747
            writer.writeheader()
748
749
            for leaf_page in self.leaves:
750
                for cell_idx in leaf_page.cells:
751
                    rowid, record = leaf_page.cells[cell_idx]
752
                    # assert(self.check_signature(record))
753
754
                    _LOGGER.debug('Record %d: %r', rowid, record.header)
755
                    fields_iter = (
756
                        repr(record.fields[idx]) for idx in record.fields
757
                    )
758
                    _LOGGER.debug(', '.join(fields_iter))
759
760
                    values_iter = (
761
                        record.fields[idx].value for idx in record.fields
762
                    )
763
                    writer.writerow(dict(zip(self._columns, values_iter)))
764
765
                if not leaf_page.recovered_records:
766
                    continue
767
768
                # Recovered records are in an unordered set because their rowid
769
                # has been lost, making sorting impossible
770
                for record in leaf_page.recovered_records:
771
                    values_iter = (
772
                        record.fields[idx].value for idx in record.fields
773
                    )
774
                    writer.writerow(dict(zip(self._columns, values_iter)))
775
776
            if csv_temp.tell() > 0:
777
                csv_temp.seek(0)
778
                with open(csv_path, 'w') as csv_file:
779
                    csv_file.write(csv_temp.read())
780
781
    def build_insert_SQL(self, record):
782
        column_placeholders = (
783
            ':' + col_name for col_name in self._columns
784
        )
785
        insert_statement = 'INSERT INTO {} VALUES ({})'.format(
786
            self.name,
787
            ', '.join(c for c in column_placeholders),
788
        )
789
        value_kwargs = {}
790
        for col_idx, col_name in enumerate(self._columns):
791
            try:
792
                if record.fields[col_idx].value is None:
793
                    value_kwargs[col_name] = None
794
                else:
795
                    value_kwargs[col_name] = record.fields[col_idx].value
796
            except KeyError:
797
                value_kwargs[col_name] = None
798
799
        return insert_statement, value_kwargs
800
801
    def check_signature(self, record):
802
        assert isinstance(record, Record)
803
        try:
804
            sig = signatures[self.name]
805
        except KeyError:
806
            # The sqlite schema tables don't have a signature (or need one)
807
            return True
808
        if len(record.fields) > len(self.columns):
809
            return False
810
811
        # It's OK for a record to have fewer fields than there are columns in
812
        # this table, this is seen when NULLable or default-valued columns are
813
        # added in an ALTER TABLE statement.
814
        for field_idx, field in record.fields.items():
815
            # NULL can be a value for any column type
816
            if field.value is None:
817
                continue
818
            if not isinstance(field.value, sig[field_idx]):
819
                return False
820
        return True
821
822
823
class Page(object):
824
    def __init__(self, page_idx, db):
825
        self._page_idx = page_idx
826
        self._db = db
827
        self._bytes = db.page_bytes(self.idx)
828
829
    @property
830
    def idx(self):
831
        return self._page_idx
832
833
    @property
834
    def usable_size(self):
835
        return self._db.header.page_size - self._db.header.reserved_length
836
837
    def __bytes__(self):
838
        return self._bytes
839
840
    @property
841
    def parent(self):
842
        try:
843
            parent_idx = self._db.ptrmap[self.idx].page_ptr
844
        except KeyError:
845
            return None
846
847
        if 0 == parent_idx:
848
            return None
849
        else:
850
            return self._db.pages[parent_idx]
851
852
    def __repr__(self):
853
        return "<SQLite Page {0}>".format(self.idx)
854
855
856
class FreelistTrunkPage(Page):
857
    # XXX Maybe it would make sense to expect a Page instance as constructor
858
    # argument?
859
    def __init__(self, page_idx, db, leaves):
860
        super().__init__(page_idx, db)
861
        self._leaves = leaves
862
863
    def __repr__(self):
864
        return "<SQLite Freelist Trunk Page {0}: {1} leaves>".format(
865
            self.idx, len(self._leaves)
866
        )
867
868
869
class FreelistLeafPage(Page):
870
    # XXX Maybe it would make sense to expect a Page instance as constructor
871
    # argument?
872
    def __init__(self, page_idx, db, trunk_idx):
873
        super().__init__(page_idx, db)
874
        self._trunk = self._db.pages[trunk_idx]
875
876
    def __repr__(self):
877
        return "<SQLite Freelist Leaf Page {0}. Trunk: {1}>".format(
878
            self.idx, self._trunk.idx
879
        )
880
881
882
class PtrmapPage(Page):
883
    # XXX Maybe it would make sense to expect a Page instance as constructor
884
    # argument?
885
    def __init__(self, page_idx, db, ptr_array):
886
        super().__init__(page_idx, db)
887
        self._pointers = ptr_array
888
889
    @property
890
    def pointers(self):
891
        return self._pointers
892
893
    def __repr__(self):
894
        return "<SQLite Ptrmap Page {0}. {1} pointers>".format(
895
            self.idx, len(self.pointers)
896
        )
897
898
899
class OverflowPage(Page):
900
    # XXX Maybe it would make sense to expect a Page instance as constructor
901
    # argument?
902
    def __init__(self, page_idx, db):
903
        super().__init__(page_idx, db)
904
        self._parse()
905
906
    def _parse(self):
907
        # TODO We should have parsing here for the next page index in the
908
        # overflow chain
909
        pass
910
911
    def __repr__(self):
912
        return "<SQLite Overflow Page {0}. Continuation of {1}>".format(
913
            self.idx, self.parent.idx
914
        )
915
916
917
class BTreePage(Page):
918
    btree_page_types = {
919
        0x02:   "Index Interior",
920
        0x05:   "Table Interior",
921
        0x0A:   "Index Leaf",
922
        0x0D:   "Table Leaf",
923
    }
924
925
    def __init__(self, page_idx, db):
926
        # XXX We don't know a page's type until we've had a look at the header.
927
        # Or do we?
928
        super().__init__(page_idx, db)
929
        self._header_size = 8
930
        page_header_bytes = self._get_btree_page_header()
931
        self._btree_header = SQLite_btree_page_header(
932
            # Set the right-most page index to None in the 1st pass
933
            *struct.unpack(r'>BHHHB', page_header_bytes), None
934
        )
935
        self._cell_ptr_array = []
936
        self._freeblocks = IndexDict()
937
        self._cells = IndexDict()
938
        self._recovered_records = set()
939
        self._overflow_threshold = self.usable_size - 35
940
941
        if self._btree_header.page_type not in BTreePage.btree_page_types:
942
            # pdb.set_trace()
943
            raise ValueError
944
945
        # We have a twelve-byte header, need to read it again
946
        if self._btree_header.page_type in (0x02, 0x05):
947
            self._header_size = 12
948
            page_header_bytes = self._get_btree_page_header()
949
            self._btree_header = SQLite_btree_page_header(*struct.unpack(
950
                r'>BHHHBI', page_header_bytes
951
            ))
952
953
        # Page 1 (and page 2, but that's the 1st ptrmap page) does not have a
954
        # ptrmap entry.
955
        # The first ptrmap page will contain back pointer information for pages
956
        # 3 through J+2, inclusive.
957
        if self._db.ptrmap:
958
            if self.idx >= 3 and self.idx not in self._db.ptrmap:
959
                _LOGGER.warning(
960
                    "BTree page %d doesn't have ptrmap entry!", self.idx
961
                )
962
963
        if self._btree_header.num_cells > 0:
964
            cell_ptr_bytes = self._get_btree_ptr_array(
965
                self._btree_header.num_cells
966
            )
967
            self._cell_ptr_array = struct.unpack(
968
                r'>{count}H'.format(count=self._btree_header.num_cells),
969
                cell_ptr_bytes
970
            )
971
            smallest_cell_offset = min(self._cell_ptr_array)
972
            if self._btree_header.cell_content_offset != smallest_cell_offset:
973
                _LOGGER.warning(
974
                    (
975
                        "Inconsistent cell ptr array in page %d! Cell content "
976
                        "starts at offset %d, but min cell pointer is %d"
977
                    ),
978
                    self.idx,
979
                    self._btree_header.cell_content_offset,
980
                    smallest_cell_offset
981
                )
982
983
    @property
984
    def btree_header(self):
985
        return self._btree_header
986
987
    @property
988
    def page_type(self):
989
        try:
990
            return self.btree_page_types[self._btree_header.page_type]
991
        except KeyError:
992
            pdb.set_trace()
993
            _LOGGER.warning(
994
                "Unknown B-Tree page type: %d", self._btree_header.page_type
995
            )
996
            raise
997
998
    @property
999
    def freeblocks(self):
1000
        return self._freeblocks
1001
1002
    @property
1003
    def cells(self):
1004
        return self._cells
1005
1006
    def __repr__(self):
1007
        # TODO Include table in repr, where available
1008
        return "<SQLite B-Tree Page {0} ({1}) {2} cells>".format(
1009
            self.idx, self.page_type, len(self._cell_ptr_array)
1010
        )
1011
1012
    @property
1013
    def table(self):
1014
        return self._db.get_page_table(self.idx)
1015
1016
    def _get_btree_page_header(self):
1017
        header_offset = 0
1018
        if self.idx == 1:
1019
            header_offset += 100
1020
        return bytes(self)[header_offset:self._header_size + header_offset]
1021
1022
    def _get_btree_ptr_array(self, num_cells):
1023
        array_offset = self._header_size
1024
        if self.idx == 1:
1025
            array_offset += 100
1026
        return bytes(self)[array_offset:2 * num_cells + array_offset]
1027
1028
    def parse_cells(self):
1029
        if self.btree_header.page_type == 0x05:
1030
            self.parse_table_interior_cells()
1031
        elif self.btree_header.page_type == 0x0D:
1032
            self.parse_table_leaf_cells()
1033
        self.parse_freeblocks()
1034
1035
    def parse_table_interior_cells(self):
1036
        if self.btree_header.page_type != 0x05:
1037
            assert False
1038
1039
        _LOGGER.debug("Parsing cells in table interior cell %d", self.idx)
1040
        for cell_idx, offset in enumerate(self._cell_ptr_array):
1041
            _LOGGER.debug("Parsing cell %d @ offset %d", cell_idx, offset)
1042
            left_ptr_bytes = bytes(self)[offset:offset + 4]
1043
            left_ptr, = struct.unpack(r'>I', left_ptr_bytes)
1044
1045
            offset += 4
1046
            integer_key = Varint(bytes(self)[offset:offset+9])
1047
            self._cells[cell_idx] = (left_ptr, int(integer_key))
1048
1049
    def parse_table_leaf_cells(self):
1050
        if self.btree_header.page_type != 0x0d:
1051
            assert False
1052
1053
        _LOGGER.debug("Parsing cells in table leaf cell %d", self.idx)
1054
        for cell_idx, cell_offset in enumerate(self._cell_ptr_array):
1055
            _LOGGER.debug("Parsing cell %d @ offset %d", cell_idx, cell_offset)
1056
1057
            # This is the total size of the payload, which may include overflow
1058
            offset = cell_offset
1059
            payload_length_varint = Varint(bytes(self)[offset:offset+9])
1060
            total_payload_size = int(payload_length_varint)
1061
1062
            overflow = False
1063
            # Let X be U-35. If the payload size P is less than or equal to X
1064
            # then the entire payload is stored on the b-tree leaf page. Let M
1065
            # be ((U-12)*32/255)-23 and let K be M+((P-M)%(U-4)). If P is
1066
            # greater than X then the number of bytes stored on the table
1067
            # b-tree leaf page is K if K is less or equal to X or M otherwise.
1068
            # The number of bytes stored on the leaf page is never less than M.
1069
            cell_payload_size = 0
1070
            if total_payload_size > self._overflow_threshold:
1071
                m = int(((self.usable_size - 12) * 32/255)-23)
1072
                k = m + ((total_payload_size - m) % (self.usable_size - 4))
1073
                if k <= self._overflow_threshold:
1074
                    cell_payload_size = k
1075
                else:
1076
                    cell_payload_size = m
1077
                overflow = True
1078
            else:
1079
                cell_payload_size = total_payload_size
1080
1081
            offset += len(payload_length_varint)
1082
1083
            integer_key = Varint(bytes(self)[offset:offset+9])
1084
            offset += len(integer_key)
1085
1086
            overflow_bytes = bytes()
1087
            if overflow:
1088
                first_oflow_page_bytes = bytes(self)[
1089
                    offset + cell_payload_size:offset + cell_payload_size + 4
1090
                ]
1091
                if not first_oflow_page_bytes:
1092
                    continue
1093
1094
                first_oflow_idx, = struct.unpack(
1095
                    r'>I', first_oflow_page_bytes
1096
                )
1097
                next_oflow_idx = first_oflow_idx
1098
                while next_oflow_idx != 0:
1099
                    oflow_page_bytes = self._db.page_bytes(next_oflow_idx)
1100
1101
                    len_overflow = min(
1102
                        len(oflow_page_bytes) - 4,
1103
                        (
1104
                            total_payload_size - cell_payload_size +
1105
                            len(overflow_bytes)
1106
                        )
1107
                    )
1108
                    overflow_bytes += oflow_page_bytes[4:4 + len_overflow]
1109
1110
                    first_four_bytes = oflow_page_bytes[:4]
1111
                    next_oflow_idx, = struct.unpack(
1112
                        r'>I', first_four_bytes
1113
                    )
1114
1115
            try:
1116
                cell_data = bytes(self)[offset:offset + cell_payload_size]
1117
                if overflow_bytes:
1118
                    cell_data += overflow_bytes
1119
1120
                # All payload bytes should be accounted for
1121
                assert len(cell_data) == total_payload_size
1122
1123
                record_obj = Record(cell_data)
1124
                _LOGGER.debug("Created record: %r", record_obj)
1125
1126
            except TypeError as ex:
1127
                _LOGGER.warning(
1128
                    "Caught %r while instantiating record %d",
1129
                    ex, int(integer_key)
1130
                )
1131
                pdb.set_trace()
1132
                raise
1133
1134
            self._cells[cell_idx] = (int(integer_key), record_obj)
1135
1136
    def parse_freeblocks(self):
1137
        # The first 2 bytes of a freeblock are a big-endian integer which is
1138
        # the offset in the b-tree page of the next freeblock in the chain, or
1139
        # zero if the freeblock is the last on the chain. The third and fourth
1140
        # bytes of each freeblock form a big-endian integer which is the size
1141
        # of the freeblock in bytes, including the 4-byte header. Freeblocks
1142
        # are always connected in order of increasing offset. The second field
1143
        # of the b-tree page header is the offset of the first freeblock, or
1144
        # zero if there are no freeblocks on the page. In a well-formed b-tree
1145
        # page, there will always be at least one cell before the first
1146
        # freeblock.
1147
        #
1148
        # TODO But what about deleted records that exceeded the overflow
1149
        # threshold in the past?
1150
        block_offset = self.btree_header.first_freeblock_offset
1151
        while block_offset != 0:
1152
            freeblock_header = bytes(self)[block_offset:block_offset + 4]
1153
            # Freeblock_size includes the 4-byte header
1154
            next_freeblock_offset, freeblock_size = struct.unpack(
1155
                r'>HH',
1156
                freeblock_header
1157
            )
1158
            freeblock_bytes = bytes(self)[
1159
                block_offset + 4:block_offset + freeblock_size - 4
1160
            ]
1161
            self._freeblocks[block_offset] = freeblock_bytes
1162
            block_offset = next_freeblock_offset
1163
1164
    def print_cells(self):
1165
        for cell_idx in self.cells.keys():
1166
            rowid, record = self.cells[cell_idx]
1167
            _LOGGER.info(
1168
                "Cell %d, rowid: %d, record: %r",
1169
                cell_idx, rowid, record
1170
            )
1171
            record.print_fields(table=self.table)
1172
1173
    def recover_freeblock_records(self):
1174
        # If we're lucky (i.e. if no overwriting has taken place), we should be
1175
        # able to find whole record headers in freeblocks.
1176
        # We need to start from the end of the freeblock and work our way back
1177
        # to the start. That means we don't know where a cell header will
1178
        # start, but I suppose we can take a guess
1179
        table = self.table
1180
        if not table or table.name not in heuristics:
1181
            return
1182
1183
        _LOGGER.info("Attempting to recover records from freeblocks")
1184
        for freeblock_idx, freeblock_offset in enumerate(self._freeblocks):
1185
            freeblock_bytes = self._freeblocks[freeblock_offset]
1186
            if 0 == len(freeblock_bytes):
1187
                continue
1188
            _LOGGER.debug(
1189
                "Freeblock %d/%d in page, offset %d, %d bytes",
1190
                1 + freeblock_idx,
1191
                len(self._freeblocks),
1192
                freeblock_offset,
1193
                len(freeblock_bytes)
1194
            )
1195
1196
            recovered_bytes = 0
1197
            recovered_in_freeblock = 0
1198
1199
            # TODO Maybe we need to guess the record header lengths rather than
1200
            # try and read them from the freeblocks
1201
            for header_start in heuristics[table.name](freeblock_bytes):
1202
                _LOGGER.debug(
1203
                    (
1204
                        "Trying potential record header start at "
1205
                        "freeblock offset %d/%d"
1206
                    ),
1207
                    header_start, len(freeblock_bytes)
1208
                )
1209
                _LOGGER.debug("%r", freeblock_bytes)
1210
                try:
1211
                    # We don't know how to handle overflow in deleted records,
1212
                    # so we'll have to truncate the bytes object used to
1213
                    # instantiate the Record object
1214
                    record_bytes = freeblock_bytes[
1215
                        header_start:header_start+self._overflow_threshold
1216
                    ]
1217
                    record_obj = Record(record_bytes)
1218
                except MalformedRecord:
1219
                    # This isn't a well-formed record, let's move to the next
1220
                    # candidate
1221
                    continue
1222
1223
                field_lengths = sum(
1224
                    len(field_obj) for field_obj in record_obj.fields.values()
1225
                )
1226
                record_obj.truncate(field_lengths + len(record_obj.header))
1227
                self._recovered_records.add(record_obj)
1228
1229
                recovered_bytes += len(bytes(record_obj))
1230
                recovered_in_freeblock += 1
1231
1232
            _LOGGER.info(
1233
                (
1234
                    "Recovered %d record(s): %d bytes out of %d "
1235
                    "freeblock bytes @ offset %d"
1236
                ),
1237
                recovered_in_freeblock,
1238
                recovered_bytes,
1239
                len(freeblock_bytes),
1240
                freeblock_offset,
1241
            )
1242
1243
    @property
1244
    def recovered_records(self):
1245
        return self._recovered_records
1246
1247
    def print_recovered_records(self):
1248
        if not self._recovered_records:
1249
            return
1250
1251
        for record_obj in self._recovered_records:
1252
            _LOGGER.info("Recovered record: %r", record_obj)
1253
            _LOGGER.info("Recovered record header: %s", record_obj.header)
1254
            record_obj.print_fields(table=self.table)
1255
1256
1257
class Record(object):
1258
1259
    column_types = {
1260
        0: (0, "NULL"),
1261
        1: (1, "8-bit twos-complement integer"),
1262
        2: (2, "big-endian 16-bit twos-complement integer"),
1263
        3: (3, "big-endian 24-bit twos-complement integer"),
1264
        4: (4, "big-endian 32-bit twos-complement integer"),
1265
        5: (6, "big-endian 48-bit twos-complement integer"),
1266
        6: (8, "big-endian 64-bit twos-complement integer"),
1267
        7: (8, "Floating point"),
1268
        8: (0, "Integer 0"),
1269
        9: (0, "Integer 1"),
1270
    }
1271
1272
    def __init__(self, record_bytes):
1273
        self._bytes = record_bytes
1274
        self._header_bytes = None
1275
        self._fields = IndexDict()
1276
        self._parse()
1277
1278
    def __bytes__(self):
1279
        return self._bytes
1280
1281
    @property
1282
    def header(self):
1283
        return self._header_bytes
1284
1285
    @property
1286
    def fields(self):
1287
        return self._fields
1288
1289
    def truncate(self, new_length):
1290
        self._bytes = self._bytes[:new_length]
1291
        self._parse()
1292
1293
    def _parse(self):
1294
        header_offset = 0
1295
1296
        header_length_varint = Varint(
1297
            # A varint is encoded on *at most* 9 bytes
1298
            bytes(self)[header_offset:9 + header_offset]
1299
        )
1300
1301
        # Let's keep track of how many bytes of the Record header (including
1302
        # the header length itself) we've succesfully parsed
1303
        parsed_header_bytes = len(header_length_varint)
1304
1305
        if len(bytes(self)) < int(header_length_varint):
1306
            raise MalformedRecord(
1307
                "Not enough bytes to fully read the record header!"
1308
            )
1309
1310
        header_offset += len(header_length_varint)
1311
        self._header_bytes = bytes(self)[:int(header_length_varint)]
1312
1313
        col_idx = 0
1314
        field_offset = int(header_length_varint)
1315
        while header_offset < int(header_length_varint):
1316
            serial_type_varint = Varint(
1317
                bytes(self)[header_offset:9 + header_offset]
1318
            )
1319
            serial_type = int(serial_type_varint)
1320
            col_length = None
1321
1322
            try:
1323
                col_length, _ = self.column_types[serial_type]
1324
            except KeyError:
1325
                if serial_type >= 13 and (1 == serial_type % 2):
1326
                    col_length = (serial_type - 13) // 2
1327
                elif serial_type >= 12 and (0 == serial_type % 2):
1328
                    col_length = (serial_type - 12) // 2
1329
                else:
1330
                    raise ValueError(
1331
                        "Unknown serial type {}".format(serial_type)
1332
                    )
1333
1334
            try:
1335
                field_obj = Field(
1336
                    col_idx,
1337
                    serial_type,
1338
                    bytes(self)[field_offset:field_offset + col_length]
1339
                )
1340
            except MalformedField as ex:
1341
                _LOGGER.warning(
1342
                    "Caught %r while instantiating field %d (%d)",
1343
                    ex, col_idx, serial_type
1344
                )
1345
                raise MalformedRecord
1346
            except Exception as ex:
1347
                _LOGGER.warning(
1348
                    "Caught %r while instantiating field %d (%d)",
1349
                    ex, col_idx, serial_type
1350
                )
1351
                pdb.set_trace()
1352
                raise
1353
1354
            self._fields[col_idx] = field_obj
1355
            col_idx += 1
1356
            field_offset += col_length
1357
1358
            parsed_header_bytes += len(serial_type_varint)
1359
            header_offset += len(serial_type_varint)
1360
1361
            if field_offset > len(bytes(self)):
1362
                raise MalformedRecord
1363
1364
        # assert(parsed_header_bytes == int(header_length_varint))
1365
1366
    def print_fields(self, table=None):
1367
        for field_idx in self._fields:
1368
            field_obj = self._fields[field_idx]
1369
            if not table or table.columns is None:
1370
                _LOGGER.info(
1371
                    "\tField %d (%d bytes), type %d: %s",
1372
                    field_obj.index,
1373
                    len(field_obj),
1374
                    field_obj.serial_type,
1375
                    field_obj.value
1376
                )
1377
            else:
1378
                _LOGGER.info(
1379
                    "\t%s: %s",
1380
                    table.columns[field_obj.index],
1381
                    field_obj.value
1382
                )
1383
1384
    def __repr__(self):
1385
        return '<Record {} fields, {} bytes, header: {} bytes>'.format(
1386
            len(self._fields), len(bytes(self)), len(self.header)
1387
        )
1388
1389
1390
class MalformedField(Exception):
1391
    pass
1392
1393
1394
class MalformedRecord(Exception):
1395
    pass
1396
1397
1398
class Field(object):
1399
    def __init__(self, idx, serial_type, serial_bytes):
1400
        self._index = idx
1401
        self._type = serial_type
1402
        self._bytes = serial_bytes
1403
        self._value = None
1404
        self._parse()
1405
1406
    def _check_length(self, expected_length):
1407
        if len(self) != expected_length:
1408
            raise MalformedField
1409
1410
    # TODO Raise a specific exception when bad bytes are encountered for the
1411
    # fields and then use this to weed out bad freeblock records
1412
    def _parse(self):
1413
        if self._type == 0:
1414
            self._value = None
1415
        # Integer types
1416
        elif self._type == 1:
1417
            self._check_length(1)
1418
            self._value = decode_twos_complement(bytes(self)[0:1], 8)
1419
        elif self._type == 2:
1420
            self._check_length(2)
1421
            self._value = decode_twos_complement(bytes(self)[0:2], 16)
1422
        elif self._type == 3:
1423
            self._check_length(3)
1424
            self._value = decode_twos_complement(bytes(self)[0:3], 24)
1425
        elif self._type == 4:
1426
            self._check_length(4)
1427
            self._value = decode_twos_complement(bytes(self)[0:4], 32)
1428
        elif self._type == 5:
1429
            self._check_length(6)
1430
            self._value = decode_twos_complement(bytes(self)[0:6], 48)
1431
        elif self._type == 6:
1432
            self._check_length(8)
1433
            self._value = decode_twos_complement(bytes(self)[0:8], 64)
1434
1435
        elif self._type == 7:
1436
            self._value = struct.unpack(r'>d', bytes(self)[0:8])[0]
1437
        elif self._type == 8:
1438
            self._value = 0
1439
        elif self._type == 9:
1440
            self._value = 1
1441
        elif self._type >= 13 and (1 == self._type % 2):
1442
            try:
1443
                self._value = bytes(self).decode('utf-8')
1444
            except UnicodeDecodeError:
1445
                raise MalformedField
1446
1447
        elif self._type >= 12 and (0 == self._type % 2):
1448
            self._value = bytes(self)
1449
1450
    def __bytes__(self):
1451
        return self._bytes
1452
1453
    def __repr__(self):
1454
        return "<Field {}: {} ({} bytes)>".format(
1455
            self._index, self._value, len(bytes(self))
1456
        )
1457
1458
    def __len__(self):
1459
        return len(bytes(self))
1460
1461
    @property
1462
    def index(self):
1463
        return self._index
1464
1465
    @property
1466
    def value(self):
1467
        return self._value
1468
1469
    @property
1470
    def serial_type(self):
1471
        return self._type
1472
1473
1474
class Varint(object):
1475
    def __init__(self, varint_bytes):
1476
        self._bytes = varint_bytes
1477
        self._len = 0
1478
        self._value = 0
1479
1480
        varint_bits = []
1481
        for b in self._bytes:
1482
            self._len += 1
1483
            if b & 0x80:
1484
                varint_bits.append(b & 0x7F)
1485
            else:
1486
                varint_bits.append(b)
1487
                break
1488
1489
        varint_twos_complement = 0
1490
        for position, b in enumerate(varint_bits[::-1]):
1491
            varint_twos_complement += b * (1 << (7*position))
1492
1493
        self._value = decode_twos_complement(
1494
            int.to_bytes(varint_twos_complement, 4, byteorder='big'), 64
1495
        )
1496
1497
    def __int__(self):
1498
        return self._value
1499
1500
    def __len__(self):
1501
        return self._len
1502
1503
    def __repr__(self):
1504
        return "<Varint {} ({} bytes)>".format(int(self), len(self))
1505
1506
1507
def decode_twos_complement(encoded, bit_length):
1508
    assert(0 == bit_length % 8)
1509
    encoded_int = int.from_bytes(encoded, byteorder='big')
1510
    mask = 2**(bit_length - 1)
1511
    value = -(encoded_int & mask) + (encoded_int & ~mask)
1512
    return value
1513
1514
1515
def gen_output_dir(db_path):
1516
    db_abspath = os.path.abspath(db_path)
1517
    db_dir, db_name = os.path.split(db_abspath)
1518
1519
    munged_name = db_name.replace('.', '_')
1520
    out_dir = os.path.join(db_dir, munged_name)
1521
    if not os.path.exists(out_dir):
1522
        return out_dir
1523
    suffix = 1
1524
    while suffix <= 10:
1525
        out_dir = os.path.join(db_dir, "{}_{}".format(munged_name, suffix))
1526
        if not os.path.exists(out_dir):
1527
            return out_dir
1528
        suffix += 1
1529
    raise SystemError(
1530
        "Unreasonable number of output directories for {}".format(db_path)
1531
    )
1532
1533
1534
def _load_db(sqlite_path):
1535
    _LOGGER.info("Processing %s", sqlite_path)
1536
1537
    load_heuristics()
1538
1539
    db = SQLite_DB(sqlite_path)
1540
    _LOGGER.info("Database: %r", db)
1541
1542
    db.populate_freelist_pages()
1543
    db.populate_ptrmap_pages()
1544
    db.populate_overflow_pages()
1545
1546
    # Should we aim to instantiate specialised b-tree objects here, or is the
1547
    # use of generic btree page objects acceptable?
1548
    db.populate_btree_pages()
1549
1550
    db.map_tables()
1551
1552
    # We need a first pass to process tables that are disconnected
1553
    # from their table's root page
1554
    db.reparent_orphaned_table_leaf_pages()
1555
1556
    # All pages should now be represented by specialised objects
1557
    assert(all(isinstance(p, Page) for p in db.pages.values()))
1558
    assert(not any(type(p) is Page for p in db.pages.values()))
1559
    return db
1560
1561
1562
def dump_to_csv(args):
1563
    out_dir = args.output_dir or gen_output_dir(args.sqlite_path)
1564
    db = _load_db(args.sqlite_path)
1565
1566
    if os.path.exists(out_dir):
1567
        raise ValueError("Output directory {} exists!".format(out_dir))
1568
    os.mkdir(out_dir)
1569
1570
    for table_name in sorted(db.tables):
1571
        table = db.tables[table_name]
1572
        _LOGGER.info("Table \"%s\"", table)
1573
        table.recover_records()
1574
        table.csv_dump(out_dir)
1575
1576
1577
def undelete(args):
1578
    db_abspath = os.path.abspath(args.sqlite_path)
1579
    db = _load_db(db_abspath)
1580
1581
    output_path = os.path.abspath(args.output_path)
1582
    if os.path.exists(output_path):
1583
        raise ValueError("Output file {} exists!".format(output_path))
1584
1585
    shutil.copyfile(db_abspath, output_path)
1586
    with sqlite3.connect(output_path) as output_db_connection:
1587
        cursor = output_db_connection.cursor()
1588
        for table_name in sorted(db.tables):
1589
            table = db.tables[table_name]
1590
            _LOGGER.info("Table \"%s\"", table)
1591
            table.recover_records()
1592
1593
            failed_inserts = 0
1594
            constraint_violations = 0
1595
            successful_inserts = 0
1596
            for leaf_page in table.leaves:
1597
                if not leaf_page.recovered_records:
1598
                    continue
1599
1600
                for record in leaf_page.recovered_records:
1601
                    insert_statement, values = table.build_insert_SQL(record)
1602
1603
                    try:
1604
                        cursor.execute(insert_statement, values)
1605
                    except sqlite3.IntegrityError:
1606
                        # We gotta soldier on, there's not much we can do if a
1607
                        # constraint is violated by this insert
1608
                        constraint_violations += 1
1609
                    except (
1610
                                sqlite3.ProgrammingError,
1611
                                sqlite3.OperationalError,
1612
                                sqlite3.InterfaceError
1613
                            ) as insert_ex:
1614
                        _LOGGER.warning(
1615
                            (
1616
                                "Caught %r while executing INSERT statement "
1617
                                "in \"%s\""
1618
                            ),
1619
                            insert_ex,
1620
                            table
1621
                        )
1622
                        failed_inserts += 1
1623
                        # pdb.set_trace()
1624
                    else:
1625
                        successful_inserts += 1
1626
            if failed_inserts > 0:
1627
                _LOGGER.warning(
1628
                    "%d failed INSERT statements in \"%s\"",
1629
                    failed_inserts, table
1630
                )
1631
            if constraint_violations > 0:
1632
                _LOGGER.warning(
1633
                    "%d constraint violations statements in \"%s\"",
1634
                    constraint_violations, table
1635
                )
1636
            _LOGGER.info(
1637
                "%d successful INSERT statements in \"%s\"",
1638
                successful_inserts, table
1639
            )
1640
1641
1642
def find_in_db(args):
1643
    db = _load_db(args.sqlite_path)
1644
    db.grep(args.needle)
1645
1646
1647
subcmd_actions = {
1648
    'csv':  dump_to_csv,
1649
    'grep': find_in_db,
1650
    'undelete': undelete,
1651
}
1652
1653
1654
def subcmd_dispatcher(arg_ns):
1655
    return subcmd_actions[arg_ns.subcmd](arg_ns)
1656
1657
1658
def main():
1659
1660
    verbose_parser = argparse.ArgumentParser(add_help=False)
1661
    verbose_parser.add_argument(
1662
        '-v', '--verbose',
1663
        action='count',
1664
        help='Give *A LOT* more output.',
1665
    )
1666
1667
    cli_parser = argparse.ArgumentParser(
1668
        description=PROJECT_DESCRIPTION,
1669
        parents=[verbose_parser],
1670
    )
1671
1672
    subcmd_parsers = cli_parser.add_subparsers(
1673
        title='Subcommands',
1674
        description='%(prog)s implements the following subcommands:',
1675
        dest='subcmd',
1676
    )
1677
1678
    csv_parser = subcmd_parsers.add_parser(
1679
        'csv',
1680
        parents=[verbose_parser],
1681
        help='Dumps visible and recovered records to CSV files',
1682
        description=(
1683
            'Recovers as many records as possible from the database passed as '
1684
            'argument and outputs all visible and recovered records to CSV '
1685
            'files in output_dir'
1686
        ),
1687
    )
1688
    csv_parser.add_argument(
1689
        'sqlite_path',
1690
        help='sqlite3 file path'
1691
    )
1692
    csv_parser.add_argument(
1693
        'output_dir',
1694
        nargs='?',
1695
        default=None,
1696
        help='Output directory'
1697
    )
1698
1699
    grep_parser = subcmd_parsers.add_parser(
1700
        'grep',
1701
        parents=[verbose_parser],
1702
        help='Matches a string in one or more pages of the database',
1703
        description='Bar',
1704
    )
1705
    grep_parser.add_argument(
1706
        'sqlite_path',
1707
        help='sqlite3 file path'
1708
    )
1709
    grep_parser.add_argument(
1710
        'needle',
1711
        help='String to match in the database'
1712
    )
1713
1714
    undelete_parser = subcmd_parsers.add_parser(
1715
        'undelete',
1716
        parents=[verbose_parser],
1717
        help='Inserts recovered records into a copy of the database',
1718
        description=(
1719
            'Recovers as many records as possible from the database passed as '
1720
            'argument and inserts all recovered records into a copy of'
1721
            'the database.'
1722
        ),
1723
    )
1724
    undelete_parser.add_argument(
1725
        'sqlite_path',
1726
        help='sqlite3 file path'
1727
    )
1728
    undelete_parser.add_argument(
1729
        'output_path',
1730
        help='Output database path'
1731
    )
1732
1733
    cli_args = cli_parser.parse_args()
1734
    if cli_args.verbose:
1735
        _LOGGER.setLevel(logging.DEBUG)
1736
1737
    if cli_args.subcmd:
1738
        subcmd_dispatcher(cli_args)
1739
    else:
1740
        # No subcommand specified, print the usage and bail
1741
        cli_parser.print_help()
1742