Completed
Pull Request — master (#9)
by Matt
01:37
created

Table.check_signature()   B

Complexity

Conditions 7

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 7
c 5
b 0
f 0
dl 0
loc 20
rs 7.3333
1
# MIT License
2
#
3
# Copyright (c) 2017 Matt Boyer
4
#
5
# Permission is hereby granted, free of charge, to any person obtaining a copy
6
# of this software and associated documentation files (the "Software"), to deal
7
# in the Software without restriction, including without limitation the rights
8
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
# copies of the Software, and to permit persons to whom the Software is
10
# furnished to do so, subject to the following conditions:
11
#
12
# The above copyright notice and this permission notice shall be included in
13
# all copies or substantial portions of the Software.
14
#
15
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
# SOFTWARE.
22
23
from . import constants
24
from . import PROJECT_NAME, PROJECT_DESCRIPTION, USER_JSON_PATH, BUILTIN_JSON
25
26
import argparse
27
import base64
28
import collections
29
import csv
30
import json
31
import logging
32
import os
33
import os.path
34
import pdb
35
import pkg_resources
36
import re
37
import shutil
38
import sqlite3
39
import stat
40
import struct
41
import tempfile
42
43
44
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
45
_LOGGER = logging.getLogger('SQLite recovery')
46
_LOGGER.setLevel(logging.INFO)
47
48
49
SQLite_header = collections.namedtuple('SQLite_header', (
50
    'magic',
51
    'page_size',
52
    'write_format',
53
    'read_format',
54
    'reserved_length',
55
    'max_payload_fraction',
56
    'min_payload_fraction',
57
    'leaf_payload_fraction',
58
    'file_change_counter',
59
    'size_in_pages',
60
    'first_freelist_trunk',
61
    'freelist_pages',
62
    'schema_cookie',
63
    'schema_format',
64
    'default_page_cache_size',
65
    'largest_btree_page',
66
    'text_encoding',
67
    'user_version',
68
    'incremental_vacuum',
69
    'application_id',
70
    'version_valid',
71
    'sqlite_version',
72
))
73
74
75
SQLite_btree_page_header = collections.namedtuple('SQLite_btree_page_header', (
76
    'page_type',
77
    'first_freeblock_offset',
78
    'num_cells',
79
    'cell_content_offset',
80
    'num_fragmented_free_bytes',
81
    'right_most_page_idx',
82
))
83
84
85
SQLite_ptrmap_info = collections.namedtuple('SQLite_ptrmap_info', (
86
    'page_idx',
87
    'page_type',
88
    'page_ptr',
89
))
90
91
92
SQLite_record_field = collections.namedtuple('SQLite_record_field', (
93
    'col_type',
94
    'col_type_descr',
95
    'field_length',
96
    'field_bytes',
97
))
98
99
100
SQLite_master_record = collections.namedtuple('SQLite_master_record', (
101
    'type',
102
    'name',
103
    'tbl_name',
104
    'rootpage',
105
    'sql',
106
))
107
108
109
type_specs = {
110
    'INTEGER': int,
111
    'TEXT': str,
112
    'VARCHAR': str,
113
    'LONGVARCHAR': str,
114
    'REAL': float,
115
    'FLOAT': float,
116
    'LONG': int,
117
    'BLOB': bytes,
118
}
119
120
121
heuristics = {}
122
signatures = {}
123
124
125
def heuristic_factory(magic, offset):
126
    assert(isinstance(magic, bytes))
127
    assert(isinstance(offset, int))
128
    assert(offset >= 0)
129
130
    # We only need to compile the regex once
131
    magic_re = re.compile(magic)
132
133
    def generic_heuristic(freeblock_bytes):
134
        all_matches = [match for match in magic_re.finditer(freeblock_bytes)]
135
        for magic_match in all_matches[::-1]:
136
            header_start = magic_match.start()-offset
137
            if header_start < 0:
138
                _LOGGER.debug("Header start outside of freeblock!")
139
                break
140
            yield header_start
141
    return generic_heuristic
142
143
144
def load_heuristics():
145
146
    def _load_from_json(raw_json):
147
        if isinstance(raw_json, bytes):
148
            raw_json = raw_json.decode('utf-8')
149
        for table_name, table_props in json.loads(raw_json).items():
150
            magic = base64.standard_b64decode(
151
                table_props['magic']
152
            )
153
            heuristics[table_name] = heuristic_factory(
154
                magic, table_props['offset']
155
            )
156
            _LOGGER.debug("Loaded heuristics for \"%s\"", table_name)
157
158
    with pkg_resources.resource_stream(PROJECT_NAME, BUILTIN_JSON) as builtin:
159
        _load_from_json(builtin.read())
160
161
    if not os.path.exists(USER_JSON_PATH):
162
        return
163
    with open(USER_JSON_PATH, 'r') as user_json:
164
        _load_from_json(user_json.read())
165
166
167
class IndexDict(dict):
168
    def __iter__(self):
169
        for k in sorted(self.keys()):
170
            yield k
171
172
173
class SQLite_DB(object):
174
    def __init__(self, path):
175
        self._path = path
176
        self._page_types = {}
177
        self._header = self.parse_header()
178
179
        self._page_cache = None
180
        # Actual page objects go here
181
        self._pages = {}
182
        self.build_page_cache()
183
184
        self._ptrmap = {}
185
186
        # TODO Do we need all of these?
187
        self._table_roots = {}
188
        self._page_tables = {}
189
        self._tables = {}
190
        self._table_columns = {}
191
        self._freelist_leaves = []
192
        self._freelist_btree_pages = []
193
194
    @property
195
    def ptrmap(self):
196
        return self._ptrmap
197
198
    @property
199
    def header(self):
200
        return self._header
201
202
    @property
203
    def pages(self):
204
        return self._pages
205
206
    @property
207
    def tables(self):
208
        return self._tables
209
210
    @property
211
    def freelist_leaves(self):
212
        return self._freelist_leaves
213
214
    @property
215
    def table_columns(self):
216
        return self._table_columns
217
218
    def page_bytes(self, page_idx):
219
        try:
220
            return self._page_cache[page_idx]
221
        except KeyError:
222
            raise ValueError("No cache for page %d", page_idx)
223
224
    def map_table_page(self, page_idx, table):
225
        assert isinstance(page_idx, int)
226
        assert isinstance(table, Table)
227
        self._page_tables[page_idx] = table
228
229
    def get_page_table(self, page_idx):
230
        assert isinstance(page_idx, int)
231
        try:
232
            return self._page_tables[page_idx]
233
        except KeyError:
234
            return None
235
236
    def __repr__(self):
237
        return '<SQLite DB, page count: {} | page size: {}>'.format(
238
            self.header.size_in_pages,
239
            self.header.page_size
240
        )
241
242
    def parse_header(self):
243
        header_bytes = None
244
        file_size = None
245
        with open(self._path, 'br') as sqlite:
246
            header_bytes = sqlite.read(100)
247
            file_size = os.fstat(sqlite.fileno())[stat.ST_SIZE]
248
249
        if not header_bytes:
250
            raise ValueError("Couldn't read SQLite header")
251
        assert isinstance(header_bytes, bytes)
252
        # This DB header is always big-endian
253
        fields = SQLite_header(*struct.unpack(
254
            r'>16sHBBBBBBIIIIIIIIIIII20xII',
255
            header_bytes[:100]
256
        ))
257
        assert fields.page_size in constants.VALID_PAGE_SIZES
258
        db_size = fields.page_size * fields.size_in_pages
259
        assert db_size <= file_size
260
        assert (fields.page_size > 0) and \
261
            (fields.file_change_counter == fields.version_valid)
262
263
        if file_size < 1073741824:
264
            _LOGGER.debug("No lock-byte page in this file!")
265
266
        if fields.first_freelist_trunk > 0:
267
            self._page_types[fields.first_freelist_trunk] = \
268
                constants.FREELIST_TRUNK_PAGE
269
        _LOGGER.debug(fields)
270
        return fields
271
272
    def build_page_cache(self):
273
        # The SQLite docs use a numbering convention for pages where the
274
        # first page (the one that has the header) is page 1, with the next
275
        # ptrmap page being page 2, etc.
276
        page_cache = [None, ]
277
        with open(self._path, 'br') as sqlite:
278
            for page_idx in range(self._header.size_in_pages):
279
                page_offset = page_idx * self._header.page_size
280
                sqlite.seek(page_offset, os.SEEK_SET)
281
                page_cache.append(sqlite.read(self._header.page_size))
282
        self._page_cache = page_cache
283
        for page_idx in range(1, len(self._page_cache)):
284
            # We want these to be temporary objects, to be replaced with
285
            # more specialised objects as parsing progresses
286
            self._pages[page_idx] = Page(page_idx, self)
287
288
    def populate_freelist_pages(self):
289
        if 0 == self._header.first_freelist_trunk:
290
            _LOGGER.debug("This database has no freelist trunk page")
291
            return
292
293
        _LOGGER.info("Parsing freelist pages")
294
        parsed_trunks = 0
295
        parsed_leaves = 0
296
        freelist_trunk_idx = self._header.first_freelist_trunk
297
298
        while freelist_trunk_idx != 0:
299
            _LOGGER.debug(
300
                "Parsing freelist trunk page %d",
301
                freelist_trunk_idx
302
            )
303
            trunk_bytes = bytes(self.pages[freelist_trunk_idx])
304
305
            next_freelist_trunk_page_idx, num_leaf_pages = struct.unpack(
306
                r'>II',
307
                trunk_bytes[:8]
308
            )
309
310
            # Now that we know how long the array of freelist page pointers is,
311
            # let's read it again
312
            trunk_array = struct.unpack(
313
                r'>{count}I'.format(count=2+num_leaf_pages),
314
                trunk_bytes[:(4*(2+num_leaf_pages))]
315
            )
316
317
            # We're skipping the first entries as they are realy the next trunk
318
            # index and the leaf count
319
            # TODO Fix that
320
            leaves_in_trunk = []
321
            for page_idx in trunk_array[2:]:
322
                # Let's prepare a specialised object for this freelist leaf
323
                # page
324
                leaf_page = FreelistLeafPage(
325
                    page_idx, self, freelist_trunk_idx
326
                )
327
                leaves_in_trunk.append(leaf_page)
328
                self._freelist_leaves.append(page_idx)
329
                self._pages[page_idx] = leaf_page
330
331
                self._page_types[page_idx] = constants.FREELIST_LEAF_PAGE
332
333
            trunk_page = FreelistTrunkPage(
334
                freelist_trunk_idx,
335
                self,
336
                leaves_in_trunk
337
            )
338
            self._pages[freelist_trunk_idx] = trunk_page
339
            # We've parsed this trunk page
340
            parsed_trunks += 1
341
            # ...And every leaf in it
342
            parsed_leaves += num_leaf_pages
343
344
            freelist_trunk_idx = next_freelist_trunk_page_idx
345
346
        assert (parsed_trunks + parsed_leaves) == self._header.freelist_pages
347
        _LOGGER.info(
348
            "Freelist summary: %d trunk pages, %d leaf pages",
349
            parsed_trunks,
350
            parsed_leaves
351
        )
352
353
    def populate_overflow_pages(self):
354
        # Knowledge of the overflow pages can come from the pointer map (easy),
355
        # or the parsing of individual cells in table leaf pages (hard)
356
        #
357
        # For now, assume we already have a page type dict populated from the
358
        # ptrmap
359
        _LOGGER.info("Parsing overflow pages")
360
        overflow_count = 0
361
        for page_idx in sorted(self._page_types):
362
            page_type = self._page_types[page_idx]
363
            if page_type not in constants.OVERFLOW_PAGE_TYPES:
364
                continue
365
            overflow_page = OverflowPage(page_idx, self)
366
            self.pages[page_idx] = overflow_page
367
            overflow_count += 1
368
369
        _LOGGER.info("Overflow summary: %d pages", overflow_count)
370
371
    def populate_ptrmap_pages(self):
372
        if self._header.largest_btree_page == 0:
373
            # We don't have ptrmap pages in this DB. That sucks.
374
            _LOGGER.warning("%r does not have ptrmap pages!", self)
375
            for page_idx in range(1, self._header.size_in_pages):
376
                self._page_types[page_idx] = constants.UNKNOWN_PAGE
377
            return
378
379
        _LOGGER.info("Parsing ptrmap pages")
380
381
        ptrmap_page_idx = 2
382
        usable_size = self._header.page_size - self._header.reserved_length
383
        num_ptrmap_entries_in_page = usable_size // 5
384
        ptrmap_page_indices = []
385
386
        ptrmap_page_idx = 2
387
        while ptrmap_page_idx <= self._header.size_in_pages:
388
            page_bytes = self._page_cache[ptrmap_page_idx]
389
            ptrmap_page_indices.append(ptrmap_page_idx)
390
            self._page_types[ptrmap_page_idx] = constants.PTRMAP_PAGE
391
            page_ptrmap_entries = {}
392
393
            ptrmap_bytes = page_bytes[:5 * num_ptrmap_entries_in_page]
394
            for entry_idx in range(num_ptrmap_entries_in_page):
395
                ptr_page_idx = ptrmap_page_idx + entry_idx + 1
396
                page_type, page_ptr = struct.unpack(
397
                    r'>BI',
398
                    ptrmap_bytes[5*entry_idx:5*(entry_idx+1)]
399
                )
400
                if page_type == 0:
401
                    break
402
403
                ptrmap_entry = SQLite_ptrmap_info(
404
                    ptr_page_idx, page_type, page_ptr
405
                )
406
                assert ptrmap_entry.page_type in constants.PTRMAP_PAGE_TYPES
407
                if page_type == constants.BTREE_ROOT_PAGE:
408
                    assert page_ptr == 0
409
                    self._page_types[ptr_page_idx] = page_type
410
411
                elif page_type == constants.FREELIST_PAGE:
412
                    # Freelist pages are assumed to be known already
413
                    assert self._page_types[ptr_page_idx] in \
414
                        constants.FREELIST_PAGE_TYPES
415
                    assert page_ptr == 0
416
417
                elif page_type == constants.FIRST_OFLOW_PAGE:
418
                    assert page_ptr != 0
419
                    self._page_types[ptr_page_idx] = page_type
420
421
                elif page_type == constants.NON_FIRST_OFLOW_PAGE:
422
                    assert page_ptr != 0
423
                    self._page_types[ptr_page_idx] = page_type
424
425
                elif page_type == constants.BTREE_NONROOT_PAGE:
426
                    assert page_ptr != 0
427
                    self._page_types[ptr_page_idx] = page_type
428
429
                # _LOGGER.debug("%r", ptrmap_entry)
430
                self._ptrmap[ptr_page_idx] = ptrmap_entry
431
                page_ptrmap_entries[ptr_page_idx] = ptrmap_entry
432
433
            page = PtrmapPage(ptrmap_page_idx, self, page_ptrmap_entries)
434
            self._pages[ptrmap_page_idx] = page
435
            _LOGGER.debug("%r", page)
436
            ptrmap_page_idx += num_ptrmap_entries_in_page + 1
437
438
        _LOGGER.info(
439
            "Ptrmap summary: %d pages, %r",
440
            len(ptrmap_page_indices), ptrmap_page_indices
441
        )
442
443
    def populate_btree_pages(self):
444
        # TODO Should this use table information instead of scanning all pages?
445
        page_idx = 1
446
        while page_idx <= self._header.size_in_pages:
447
            try:
448
                if self._page_types[page_idx] in \
449
                        constants.NON_BTREE_PAGE_TYPES:
450
                    page_idx += 1
451
                    continue
452
            except KeyError:
453
                pass
454
455
            try:
456
                page_obj = BTreePage(page_idx, self)
457
            except ValueError:
458
                # This page isn't a valid btree page. This can happen if we
459
                # don't have a ptrmap to guide us
460
                _LOGGER.warning(
461
                    "Page %d (%s) is not a btree page",
462
                    page_idx,
463
                    self._page_types[page_idx]
464
                )
465
                page_idx += 1
466
                continue
467
468
            page_obj.parse_cells()
469
            self._page_types[page_idx] = page_obj.page_type
470
            self._pages[page_idx] = page_obj
471
            page_idx += 1
472
473
    def _parse_master_leaf_page(self, page):
474
        for cell_idx in page.cells:
475
            _, master_record = page.cells[cell_idx]
476
            assert isinstance(master_record, Record)
477
            fields = [
478
                master_record.fields[idx].value for idx in master_record.fields
479
            ]
480
            master_record = SQLite_master_record(*fields)
481
            if 'table' != master_record.type:
482
                continue
483
484
            self._table_roots[master_record.name] = \
485
                self.pages[master_record.rootpage]
486
487
            # This record describes a table in the schema, which means it
488
            # includes a SQL statement that defines the table's columns
489
            # We need to parse the field names out of that statement
490
            assert master_record.sql.startswith('CREATE TABLE')
491
            columns_re = re.compile(r'^CREATE TABLE (\S+) \((.*)\)$')
492
            match = columns_re.match(master_record.sql)
493
            if match:
494
                assert match.group(1) == master_record.name
495
                column_list = match.group(2)
496
                csl_between_parens_re = re.compile(r'\([^)]+\)')
497
                expunged = csl_between_parens_re.sub('', column_list)
498
499
                cols = [
500
                    statement.strip() for statement in expunged.split(',')
501
                ]
502
                cols = [
503
                    statement for statement in cols if not (
504
                        statement.startswith('PRIMARY') or
505
                        statement.startswith('UNIQUE')
506
                    )
507
                ]
508
                columns = [col.split()[0] for col in cols]
509
                signature = []
510
511
                # Some column definitions lack a type
512
                for col_def in cols:
513
                    def_tokens = col_def.split()
514
                    try:
515
                        col_type = def_tokens[1]
516
                    except IndexError:
517
                        signature.append(object)
518
                        continue
519
520
                    _LOGGER.debug(
521
                        "Column \"%s\" is defined as \"%s\"",
522
                        def_tokens[0], col_type
523
                    )
524
                    try:
525
                        signature.append(type_specs[col_type])
526
                    except KeyError:
527
                        _LOGGER.warning("No native type for \"%s\"", col_def)
528
                        signature.append(object)
529
                _LOGGER.info(
530
                    "Signature for table \"%s\": %r",
531
                    master_record.name, signature
532
                )
533
                signatures[master_record.name] = signature
534
535
                _LOGGER.info(
536
                    "Columns for table \"%s\": %r",
537
                    master_record.name, columns
538
                )
539
                self._table_columns[master_record.name] = columns
540
541
    def map_tables(self):
542
        first_page = self.pages[1]
543
        assert isinstance(first_page, BTreePage)
544
545
        master_table = Table('sqlite_master', self, first_page)
546
        self._table_columns.update(constants.SQLITE_TABLE_COLUMNS)
547
548
        for master_leaf in master_table.leaves:
549
            self._parse_master_leaf_page(master_leaf)
550
551
        assert all(
552
            isinstance(root, BTreePage) for root in self._table_roots.values()
553
        )
554
        assert all(
555
            root.parent is None for root in self._table_roots.values()
556
        )
557
558
        self.map_table_page(1, master_table)
559
        self._table_roots['sqlite_master'] = self.pages[1]
560
561
        for table_name, rootpage in self._table_roots.items():
562
            try:
563
                table_obj = Table(table_name, self, rootpage)
564
            except Exception as ex:  # pylint:disable=W0703
565
                pdb.set_trace()
566
                _LOGGER.warning(
567
                    "Caught %r while instantiating table object for \"%s\"",
568
                    ex, table_name
569
                )
570
            else:
571
                self._tables[table_name] = table_obj
572
573
    def reparent_orphaned_table_leaf_pages(self):
574
        reparented_pages = []
575
        for page in self.pages.values():
576
            if not isinstance(page, BTreePage):
577
                continue
578
            if page.page_type != "Table Leaf":
579
                continue
580
581
            table = page.table
582
            if not table:
583
                parent = page
584
                root_table = None
585
                while parent:
586
                    root_table = parent.table
587
                    parent = parent.parent
588
                if root_table is None:
589
                    self._freelist_btree_pages.append(page)
590
591
                if root_table is None:
592
                    # So that's our main problem. We have a valid B-Tree Leaf
593
                    # page, but no idea what table it belongs to. The only
594
                    # thing we have to determine its table is the make-up of
595
                    # its records.
596
                    #
597
                    # Basically, we need to extend the logic currently in use
598
                    # to associate regexps and offsets to table and add some
599
                    # sort of signature mechanism that we can use to determine
600
                    # whether a given freeleaf record matches the invariant
601
                    # fields of a given known table. Integers, NULLs and
602
                    # fixed-length strings (GUIDs) would be used as part of
603
                    # that signature mechanism
604
                    if not page.cells:
605
                        continue
606
607
                    first_record = page.cells[0][1]
608
                    matches = []
609
                    for table_name in signatures:
610
                        # All records within a given page are for the same
611
                        # table
612
                        if self.tables[table_name].check_signature(
613
                                first_record):
614
                            matches.append(self.tables[table_name])
615
                    if not matches:
616
                        _LOGGER.error(
617
                            "Couldn't find a matching table for %r",
618
                            page
619
                        )
620
                        continue
621
                    if len(matches) > 1:
622
                        _LOGGER.error(
623
                            "Multiple matching tables for %r: %r",
624
                            page, matches
625
                        )
626
                        continue
627
                    elif len(matches) == 1:
628
                        root_table = matches[0]
629
630
                _LOGGER.debug(
631
                    "Reparenting %r to table \"%s\"",
632
                    page, root_table.name
633
                )
634
                root_table.add_leaf(page)
635
                self.map_table_page(page.idx, root_table)
636
                reparented_pages.append(page)
637
638
        if reparented_pages:
639
            _LOGGER.info(
640
                "Reparented %d pages: %r",
641
                len(reparented_pages), [p.idx for p in reparented_pages]
642
            )
643
644
    def grep(self, needle):
645
        match_found = False
646
        page_idx = 1
647
        needle_re = re.compile(needle.encode('utf-8'))
648
        while (page_idx <= self.header.size_in_pages):
649
            page = self.pages[page_idx]
650
            page_offsets = []
651
            for match in needle_re.finditer(bytes(page)):
652
                needle_offset = match.start()
653
                page_offsets.append(needle_offset)
654
            if page_offsets:
655
                _LOGGER.info(
656
                    "Found search term in page %r @ offset(s) %s",
657
                    page, ', '.join(str(offset) for offset in page_offsets)
658
                )
659
            page_idx += 1
660
        if not match_found:
661
            _LOGGER.warning(
662
                "Search term not found",
663
            )
664
665
666
class Table(object):
667
    def __init__(self, name, db, rootpage):
668
        self._name = name
669
        self._db = db
670
        assert(isinstance(rootpage, BTreePage))
671
        self._root = rootpage
672
        self._leaves = []
673
        try:
674
            self._columns = self._db.table_columns[self.name]
675
        except KeyError:
676
            self._columns = None
677
678
        # We want this to be a list of leaf-type pages, sorted in the order of
679
        # their smallest rowid
680
        self._populate_pages()
681
682
    @property
683
    def name(self):
684
        return self._name
685
686
    def add_leaf(self, leaf_page):
687
        self._leaves.append(leaf_page)
688
689
    @property
690
    def columns(self):
691
        return self._columns
692
693
    def __repr__(self):
694
        return "<SQLite table \"{}\", root: {}, leaves: {}>".format(
695
            self.name, self._root.idx, len(self._leaves)
696
        )
697
698
    def _populate_pages(self):
699
        _LOGGER.info("Page %d is root for %s", self._root.idx, self.name)
700
        table_pages = [self._root]
701
702
        if self._root.btree_header.right_most_page_idx is not None:
703
            rightmost_idx = self._root.btree_header.right_most_page_idx
704
            rightmost_page = self._db.pages[rightmost_idx]
705
            if rightmost_page is not self._root:
706
                _LOGGER.info(
707
                    "Page %d is rightmost for %s",
708
                    rightmost_idx, self.name
709
                )
710
                table_pages.append(rightmost_page)
711
712
        page_queue = list(table_pages)
713
        while page_queue:
714
            table_page = page_queue.pop(0)
715
            # table_pages is initialised with the table's rootpage, which
716
            # may be a leaf page for a very small table
717
            if table_page.page_type != 'Table Interior':
718
                self._leaves.append(table_page)
719
                continue
720
721
            for cell_idx in table_page.cells:
722
                page_ptr, max_row_in_page = table_page.cells[cell_idx]
723
724
                page = self._db.pages[page_ptr]
725
                _LOGGER.debug("B-Tree cell: (%r, %d)", page, max_row_in_page)
726
                table_pages.append(page)
727
                if page.page_type == 'Table Interior':
728
                    page_queue.append(page)
729
                elif page.page_type == 'Table Leaf':
730
                    self._leaves.append(page)
731
732
        assert(all(p.page_type == 'Table Leaf' for p in self._leaves))
733
        for page in table_pages:
734
            self._db.map_table_page(page.idx, self)
735
736
    @property
737
    def leaves(self):
738
        for leaf_page in self._leaves:
739
            yield leaf_page
740
741
    def recover_records(self):
742
        for page in self.leaves:
743
            assert isinstance(page, BTreePage)
744
            if not page.freeblocks:
745
                continue
746
747
            _LOGGER.info("%r", page)
748
            page.recover_freeblock_records()
749
            page.print_recovered_records()
750
751
    def csv_dump(self, out_dir):
752
        csv_path = os.path.join(out_dir, self.name + '.csv')
753
        if os.path.exists(csv_path):
754
            raise ValueError("Output file {} exists!".format(csv_path))
755
756
        _LOGGER.info("Dumping table \"%s\" to CSV", self.name)
757
        with tempfile.TemporaryFile('w+', newline='') as csv_temp:
758
            writer = csv.DictWriter(csv_temp, fieldnames=self._columns)
759
            writer.writeheader()
760
761
            for leaf_page in self.leaves:
762
                for cell_idx in leaf_page.cells:
763
                    rowid, record = leaf_page.cells[cell_idx]
764
                    # assert(self.check_signature(record))
765
766
                    _LOGGER.debug('Record %d: %r', rowid, record.header)
767
                    fields_iter = (
768
                        repr(record.fields[idx]) for idx in record.fields
769
                    )
770
                    _LOGGER.debug(', '.join(fields_iter))
771
772
                    values_iter = (
773
                        record.fields[idx].value for idx in record.fields
774
                    )
775
                    writer.writerow(dict(zip(self._columns, values_iter)))
776
777
                if not leaf_page.recovered_records:
778
                    continue
779
780
                # Recovered records are in an unordered set because their rowid
781
                # has been lost, making sorting impossible
782
                for record in leaf_page.recovered_records:
783
                    values_iter = (
784
                        record.fields[idx].value for idx in record.fields
785
                    )
786
                    writer.writerow(dict(zip(self._columns, values_iter)))
787
788
            if csv_temp.tell() > 0:
789
                csv_temp.seek(0)
790
                with open(csv_path, 'w') as csv_file:
791
                    csv_file.write(csv_temp.read())
792
793
    def build_insert_SQL(self, record):
794
        column_placeholders = (
795
            ':' + col_name for col_name in self._columns
796
        )
797
        insert_statement = 'INSERT INTO {} VALUES ({})'.format(
798
            self.name,
799
            ', '.join(c for c in column_placeholders),
800
        )
801
        value_kwargs = {}
802
        for col_idx, col_name in enumerate(self._columns):
803
            try:
804
                if record.fields[col_idx].value is None:
805
                    value_kwargs[col_name] = None
806
                else:
807
                    value_kwargs[col_name] = record.fields[col_idx].value
808
            except KeyError:
809
                value_kwargs[col_name] = None
810
811
        return insert_statement, value_kwargs
812
813
    def check_signature(self, record):
814
        assert isinstance(record, Record)
815
        try:
816
            sig = signatures[self.name]
817
        except KeyError:
818
            # The sqlite schema tables don't have a signature (or need one)
819
            return True
820
        if len(record.fields) > len(self.columns):
821
            return False
822
823
        # It's OK for a record to have fewer fields than there are columns in
824
        # this table, this is seen when NULLable or default-valued columns are
825
        # added in an ALTER TABLE statement.
826
        for field_idx, field in record.fields.items():
827
            # NULL can be a value for any column type
828
            if field.value is None:
829
                continue
830
            if not isinstance(field.value, sig[field_idx]):
831
                return False
832
        return True
833
834
835
class Page(object):
836
    def __init__(self, page_idx, db):
837
        self._page_idx = page_idx
838
        self._db = db
839
        self._bytes = db.page_bytes(self.idx)
840
841
    @property
842
    def idx(self):
843
        return self._page_idx
844
845
    @property
846
    def usable_size(self):
847
        return self._db.header.page_size - self._db.header.reserved_length
848
849
    def __bytes__(self):
850
        return self._bytes
851
852
    @property
853
    def parent(self):
854
        try:
855
            parent_idx = self._db.ptrmap[self.idx].page_ptr
856
        except KeyError:
857
            return None
858
859
        if 0 == parent_idx:
860
            return None
861
        else:
862
            return self._db.pages[parent_idx]
863
864
    def __repr__(self):
865
        return "<SQLite Page {0}>".format(self.idx)
866
867
868
class FreelistTrunkPage(Page):
869
    # XXX Maybe it would make sense to expect a Page instance as constructor
870
    # argument?
871
    def __init__(self, page_idx, db, leaves):
872
        super().__init__(page_idx, db)
873
        self._leaves = leaves
874
875
    def __repr__(self):
876
        return "<SQLite Freelist Trunk Page {0}: {1} leaves>".format(
877
            self.idx, len(self._leaves)
878
        )
879
880
881
class FreelistLeafPage(Page):
882
    # XXX Maybe it would make sense to expect a Page instance as constructor
883
    # argument?
884
    def __init__(self, page_idx, db, trunk_idx):
885
        super().__init__(page_idx, db)
886
        self._trunk = self._db.pages[trunk_idx]
887
888
    def __repr__(self):
889
        return "<SQLite Freelist Leaf Page {0}. Trunk: {1}>".format(
890
            self.idx, self._trunk.idx
891
        )
892
893
894
class PtrmapPage(Page):
895
    # XXX Maybe it would make sense to expect a Page instance as constructor
896
    # argument?
897
    def __init__(self, page_idx, db, ptr_array):
898
        super().__init__(page_idx, db)
899
        self._pointers = ptr_array
900
901
    @property
902
    def pointers(self):
903
        return self._pointers
904
905
    def __repr__(self):
906
        return "<SQLite Ptrmap Page {0}. {1} pointers>".format(
907
            self.idx, len(self.pointers)
908
        )
909
910
911
class OverflowPage(Page):
912
    # XXX Maybe it would make sense to expect a Page instance as constructor
913
    # argument?
914
    def __init__(self, page_idx, db):
915
        super().__init__(page_idx, db)
916
        self._parse()
917
918
    def _parse(self):
919
        # TODO We should have parsing here for the next page index in the
920
        # overflow chain
921
        pass
922
923
    def __repr__(self):
924
        return "<SQLite Overflow Page {0}. Continuation of {1}>".format(
925
            self.idx, self.parent.idx
926
        )
927
928
929
class BTreePage(Page):
930
    btree_page_types = {
931
        0x02:   "Index Interior",
932
        0x05:   "Table Interior",
933
        0x0A:   "Index Leaf",
934
        0x0D:   "Table Leaf",
935
    }
936
937
    def __init__(self, page_idx, db):
938
        # XXX We don't know a page's type until we've had a look at the header.
939
        # Or do we?
940
        super().__init__(page_idx, db)
941
        self._header_size = 8
942
        page_header_bytes = self._get_btree_page_header()
943
        self._btree_header = SQLite_btree_page_header(
944
            # Set the right-most page index to None in the 1st pass
945
            *struct.unpack(r'>BHHHB', page_header_bytes), None
946
        )
947
        self._cell_ptr_array = []
948
        self._freeblocks = IndexDict()
949
        self._cells = IndexDict()
950
        self._recovered_records = set()
951
        self._overflow_threshold = self.usable_size - 35
952
953
        if self._btree_header.page_type not in BTreePage.btree_page_types:
954
            # pdb.set_trace()
955
            raise ValueError
956
957
        # We have a twelve-byte header, need to read it again
958
        if self._btree_header.page_type in (0x02, 0x05):
959
            self._header_size = 12
960
            page_header_bytes = self._get_btree_page_header()
961
            self._btree_header = SQLite_btree_page_header(*struct.unpack(
962
                r'>BHHHBI', page_header_bytes
963
            ))
964
965
        # Page 1 (and page 2, but that's the 1st ptrmap page) does not have a
966
        # ptrmap entry.
967
        # The first ptrmap page will contain back pointer information for pages
968
        # 3 through J+2, inclusive.
969
        if self._db.ptrmap:
970
            if self.idx >= 3 and self.idx not in self._db.ptrmap:
971
                _LOGGER.warning(
972
                    "BTree page %d doesn't have ptrmap entry!", self.idx
973
                )
974
975
        if self._btree_header.num_cells > 0:
976
            cell_ptr_bytes = self._get_btree_ptr_array(
977
                self._btree_header.num_cells
978
            )
979
            self._cell_ptr_array = struct.unpack(
980
                r'>{count}H'.format(count=self._btree_header.num_cells),
981
                cell_ptr_bytes
982
            )
983
            smallest_cell_offset = min(self._cell_ptr_array)
984
            if self._btree_header.cell_content_offset != smallest_cell_offset:
985
                _LOGGER.warning(
986
                    (
987
                        "Inconsistent cell ptr array in page %d! Cell content "
988
                        "starts at offset %d, but min cell pointer is %d"
989
                    ),
990
                    self.idx,
991
                    self._btree_header.cell_content_offset,
992
                    smallest_cell_offset
993
                )
994
995
    @property
996
    def btree_header(self):
997
        return self._btree_header
998
999
    @property
1000
    def page_type(self):
1001
        try:
1002
            return self.btree_page_types[self._btree_header.page_type]
1003
        except KeyError:
1004
            pdb.set_trace()
1005
            _LOGGER.warning(
1006
                "Unknown B-Tree page type: %d", self._btree_header.page_type
1007
            )
1008
            raise
1009
1010
    @property
1011
    def freeblocks(self):
1012
        return self._freeblocks
1013
1014
    @property
1015
    def cells(self):
1016
        return self._cells
1017
1018
    def __repr__(self):
1019
        # TODO Include table in repr, where available
1020
        return "<SQLite B-Tree Page {0} ({1}) {2} cells>".format(
1021
            self.idx, self.page_type, len(self._cell_ptr_array)
1022
        )
1023
1024
    @property
1025
    def table(self):
1026
        return self._db.get_page_table(self.idx)
1027
1028
    def _get_btree_page_header(self):
1029
        header_offset = 0
1030
        if self.idx == 1:
1031
            header_offset += 100
1032
        return bytes(self)[header_offset:self._header_size + header_offset]
1033
1034
    def _get_btree_ptr_array(self, num_cells):
1035
        array_offset = self._header_size
1036
        if self.idx == 1:
1037
            array_offset += 100
1038
        return bytes(self)[array_offset:2 * num_cells + array_offset]
1039
1040
    def parse_cells(self):
1041
        if self.btree_header.page_type == 0x05:
1042
            self.parse_table_interior_cells()
1043
        elif self.btree_header.page_type == 0x0D:
1044
            self.parse_table_leaf_cells()
1045
        self.parse_freeblocks()
1046
1047
    def parse_table_interior_cells(self):
1048
        if self.btree_header.page_type != 0x05:
1049
            assert False
1050
1051
        _LOGGER.debug("Parsing cells in table interior cell %d", self.idx)
1052
        for cell_idx, offset in enumerate(self._cell_ptr_array):
1053
            _LOGGER.debug("Parsing cell %d @ offset %d", cell_idx, offset)
1054
            left_ptr_bytes = bytes(self)[offset:offset + 4]
1055
            left_ptr, = struct.unpack(r'>I', left_ptr_bytes)
1056
1057
            offset += 4
1058
            integer_key = Varint(bytes(self)[offset:offset+9])
1059
            self._cells[cell_idx] = (left_ptr, int(integer_key))
1060
1061
    def parse_table_leaf_cells(self):
1062
        if self.btree_header.page_type != 0x0d:
1063
            assert False
1064
1065
        _LOGGER.debug("Parsing cells in table leaf cell %d", self.idx)
1066
        for cell_idx, cell_offset in enumerate(self._cell_ptr_array):
1067
            _LOGGER.debug("Parsing cell %d @ offset %d", cell_idx, cell_offset)
1068
1069
            # This is the total size of the payload, which may include overflow
1070
            offset = cell_offset
1071
            payload_length_varint = Varint(bytes(self)[offset:offset+9])
1072
            total_payload_size = int(payload_length_varint)
1073
1074
            overflow = False
1075
            # Let X be U-35. If the payload size P is less than or equal to X
1076
            # then the entire payload is stored on the b-tree leaf page. Let M
1077
            # be ((U-12)*32/255)-23 and let K be M+((P-M)%(U-4)). If P is
1078
            # greater than X then the number of bytes stored on the table
1079
            # b-tree leaf page is K if K is less or equal to X or M otherwise.
1080
            # The number of bytes stored on the leaf page is never less than M.
1081
            cell_payload_size = 0
1082
            if total_payload_size > self._overflow_threshold:
1083
                m = int(((self.usable_size - 12) * 32/255)-23)
1084
                k = m + ((total_payload_size - m) % (self.usable_size - 4))
1085
                if k <= self._overflow_threshold:
1086
                    cell_payload_size = k
1087
                else:
1088
                    cell_payload_size = m
1089
                overflow = True
1090
            else:
1091
                cell_payload_size = total_payload_size
1092
1093
            offset += len(payload_length_varint)
1094
1095
            integer_key = Varint(bytes(self)[offset:offset+9])
1096
            offset += len(integer_key)
1097
1098
            overflow_bytes = bytes()
1099
            if overflow:
1100
                first_oflow_page_bytes = bytes(self)[
1101
                    offset + cell_payload_size:offset + cell_payload_size + 4
1102
                ]
1103
                if not first_oflow_page_bytes:
1104
                    continue
1105
1106
                first_oflow_idx, = struct.unpack(
1107
                    r'>I', first_oflow_page_bytes
1108
                )
1109
                next_oflow_idx = first_oflow_idx
1110
                while next_oflow_idx != 0:
1111
                    oflow_page_bytes = self._db.page_bytes(next_oflow_idx)
1112
1113
                    len_overflow = min(
1114
                        len(oflow_page_bytes) - 4,
1115
                        (
1116
                            total_payload_size - cell_payload_size +
1117
                            len(overflow_bytes)
1118
                        )
1119
                    )
1120
                    overflow_bytes += oflow_page_bytes[4:4 + len_overflow]
1121
1122
                    first_four_bytes = oflow_page_bytes[:4]
1123
                    next_oflow_idx, = struct.unpack(
1124
                        r'>I', first_four_bytes
1125
                    )
1126
1127
            try:
1128
                cell_data = bytes(self)[offset:offset + cell_payload_size]
1129
                if overflow_bytes:
1130
                    cell_data += overflow_bytes
1131
1132
                # All payload bytes should be accounted for
1133
                assert len(cell_data) == total_payload_size
1134
1135
                record_obj = Record(cell_data)
1136
                _LOGGER.debug("Created record: %r", record_obj)
1137
1138
            except TypeError as ex:
1139
                _LOGGER.warning(
1140
                    "Caught %r while instantiating record %d",
1141
                    ex, int(integer_key)
1142
                )
1143
                pdb.set_trace()
1144
                raise
1145
1146
            self._cells[cell_idx] = (int(integer_key), record_obj)
1147
1148
    def parse_freeblocks(self):
1149
        # The first 2 bytes of a freeblock are a big-endian integer which is
1150
        # the offset in the b-tree page of the next freeblock in the chain, or
1151
        # zero if the freeblock is the last on the chain. The third and fourth
1152
        # bytes of each freeblock form a big-endian integer which is the size
1153
        # of the freeblock in bytes, including the 4-byte header. Freeblocks
1154
        # are always connected in order of increasing offset. The second field
1155
        # of the b-tree page header is the offset of the first freeblock, or
1156
        # zero if there are no freeblocks on the page. In a well-formed b-tree
1157
        # page, there will always be at least one cell before the first
1158
        # freeblock.
1159
        #
1160
        # TODO But what about deleted records that exceeded the overflow
1161
        # threshold in the past?
1162
        block_offset = self.btree_header.first_freeblock_offset
1163
        while block_offset != 0:
1164
            freeblock_header = bytes(self)[block_offset:block_offset + 4]
1165
            # Freeblock_size includes the 4-byte header
1166
            next_freeblock_offset, freeblock_size = struct.unpack(
1167
                r'>HH',
1168
                freeblock_header
1169
            )
1170
            freeblock_bytes = bytes(self)[
1171
                block_offset + 4:block_offset + freeblock_size - 4
1172
            ]
1173
            self._freeblocks[block_offset] = freeblock_bytes
1174
            block_offset = next_freeblock_offset
1175
1176
    def print_cells(self):
1177
        for cell_idx in self.cells.keys():
1178
            rowid, record = self.cells[cell_idx]
1179
            _LOGGER.info(
1180
                "Cell %d, rowid: %d, record: %r",
1181
                cell_idx, rowid, record
1182
            )
1183
            record.print_fields(table=self.table)
1184
1185
    def recover_freeblock_records(self):
1186
        # If we're lucky (i.e. if no overwriting has taken place), we should be
1187
        # able to find whole record headers in freeblocks.
1188
        # We need to start from the end of the freeblock and work our way back
1189
        # to the start. That means we don't know where a cell header will
1190
        # start, but I suppose we can take a guess
1191
        table = self.table
1192
        if not table or table.name not in heuristics:
1193
            return
1194
1195
        _LOGGER.info("Attempting to recover records from freeblocks")
1196
        for freeblock_idx, freeblock_offset in enumerate(self._freeblocks):
1197
            freeblock_bytes = self._freeblocks[freeblock_offset]
1198
            if 0 == len(freeblock_bytes):
1199
                continue
1200
            _LOGGER.debug(
1201
                "Freeblock %d/%d in page, offset %d, %d bytes",
1202
                1 + freeblock_idx,
1203
                len(self._freeblocks),
1204
                freeblock_offset,
1205
                len(freeblock_bytes)
1206
            )
1207
1208
            recovered_bytes = 0
1209
            recovered_in_freeblock = 0
1210
1211
            # TODO Maybe we need to guess the record header lengths rather than
1212
            # try and read them from the freeblocks
1213
            for header_start in heuristics[table.name](freeblock_bytes):
1214
                _LOGGER.debug(
1215
                    (
1216
                        "Trying potential record header start at "
1217
                        "freeblock offset %d/%d"
1218
                    ),
1219
                    header_start, len(freeblock_bytes)
1220
                )
1221
                _LOGGER.debug("%r", freeblock_bytes)
1222
                try:
1223
                    # We don't know how to handle overflow in deleted records,
1224
                    # so we'll have to truncate the bytes object used to
1225
                    # instantiate the Record object
1226
                    record_bytes = freeblock_bytes[
1227
                        header_start:header_start+self._overflow_threshold
1228
                    ]
1229
                    record_obj = Record(record_bytes)
1230
                except MalformedRecord:
1231
                    # This isn't a well-formed record, let's move to the next
1232
                    # candidate
1233
                    continue
1234
1235
                field_lengths = sum(
1236
                    len(field_obj) for field_obj in record_obj.fields.values()
1237
                )
1238
                record_obj.truncate(field_lengths + len(record_obj.header))
1239
                self._recovered_records.add(record_obj)
1240
1241
                recovered_bytes += len(bytes(record_obj))
1242
                recovered_in_freeblock += 1
1243
1244
            _LOGGER.info(
1245
                (
1246
                    "Recovered %d record(s): %d bytes out of %d "
1247
                    "freeblock bytes @ offset %d"
1248
                ),
1249
                recovered_in_freeblock,
1250
                recovered_bytes,
1251
                len(freeblock_bytes),
1252
                freeblock_offset,
1253
            )
1254
1255
    @property
1256
    def recovered_records(self):
1257
        return self._recovered_records
1258
1259
    def print_recovered_records(self):
1260
        if not self._recovered_records:
1261
            return
1262
1263
        for record_obj in self._recovered_records:
1264
            _LOGGER.info("Recovered record: %r", record_obj)
1265
            _LOGGER.info("Recovered record header: %s", record_obj.header)
1266
            record_obj.print_fields(table=self.table)
1267
1268
1269
class Record(object):
1270
1271
    column_types = {
1272
        0: (0, "NULL"),
1273
        1: (1, "8-bit twos-complement integer"),
1274
        2: (2, "big-endian 16-bit twos-complement integer"),
1275
        3: (3, "big-endian 24-bit twos-complement integer"),
1276
        4: (4, "big-endian 32-bit twos-complement integer"),
1277
        5: (6, "big-endian 48-bit twos-complement integer"),
1278
        6: (8, "big-endian 64-bit twos-complement integer"),
1279
        7: (8, "Floating point"),
1280
        8: (0, "Integer 0"),
1281
        9: (0, "Integer 1"),
1282
    }
1283
1284
    def __init__(self, record_bytes):
1285
        self._bytes = record_bytes
1286
        self._header_bytes = None
1287
        self._fields = IndexDict()
1288
        self._parse()
1289
1290
    def __bytes__(self):
1291
        return self._bytes
1292
1293
    @property
1294
    def header(self):
1295
        return self._header_bytes
1296
1297
    @property
1298
    def fields(self):
1299
        return self._fields
1300
1301
    def truncate(self, new_length):
1302
        self._bytes = self._bytes[:new_length]
1303
        self._parse()
1304
1305
    def _parse(self):
1306
        header_offset = 0
1307
1308
        header_length_varint = Varint(
1309
            # A varint is encoded on *at most* 9 bytes
1310
            bytes(self)[header_offset:9 + header_offset]
1311
        )
1312
1313
        # Let's keep track of how many bytes of the Record header (including
1314
        # the header length itself) we've succesfully parsed
1315
        parsed_header_bytes = len(header_length_varint)
1316
1317
        if len(bytes(self)) < int(header_length_varint):
1318
            raise MalformedRecord(
1319
                "Not enough bytes to fully read the record header!"
1320
            )
1321
1322
        header_offset += len(header_length_varint)
1323
        self._header_bytes = bytes(self)[:int(header_length_varint)]
1324
1325
        col_idx = 0
1326
        field_offset = int(header_length_varint)
1327
        while header_offset < int(header_length_varint):
1328
            serial_type_varint = Varint(
1329
                bytes(self)[header_offset:9 + header_offset]
1330
            )
1331
            serial_type = int(serial_type_varint)
1332
            col_length = None
1333
1334
            try:
1335
                col_length, _ = self.column_types[serial_type]
1336
            except KeyError:
1337
                if serial_type >= 13 and (1 == serial_type % 2):
1338
                    col_length = (serial_type - 13) // 2
1339
                elif serial_type >= 12 and (0 == serial_type % 2):
1340
                    col_length = (serial_type - 12) // 2
1341
                else:
1342
                    raise ValueError(
1343
                        "Unknown serial type {}".format(serial_type)
1344
                    )
1345
1346
            try:
1347
                field_obj = Field(
1348
                    col_idx,
1349
                    serial_type,
1350
                    bytes(self)[field_offset:field_offset + col_length]
1351
                )
1352
            except MalformedField as ex:
1353
                _LOGGER.warning(
1354
                    "Caught %r while instantiating field %d (%d)",
1355
                    ex, col_idx, serial_type
1356
                )
1357
                raise MalformedRecord
1358
            except Exception as ex:
1359
                _LOGGER.warning(
1360
                    "Caught %r while instantiating field %d (%d)",
1361
                    ex, col_idx, serial_type
1362
                )
1363
                pdb.set_trace()
1364
                raise
1365
1366
            self._fields[col_idx] = field_obj
1367
            col_idx += 1
1368
            field_offset += col_length
1369
1370
            parsed_header_bytes += len(serial_type_varint)
1371
            header_offset += len(serial_type_varint)
1372
1373
            if field_offset > len(bytes(self)):
1374
                raise MalformedRecord
1375
1376
        # assert(parsed_header_bytes == int(header_length_varint))
1377
1378
    def print_fields(self, table=None):
1379
        for field_idx in self._fields:
1380
            field_obj = self._fields[field_idx]
1381
            if not table or table.columns is None:
1382
                _LOGGER.info(
1383
                    "\tField %d (%d bytes), type %d: %s",
1384
                    field_obj.index,
1385
                    len(field_obj),
1386
                    field_obj.serial_type,
1387
                    field_obj.value
1388
                )
1389
            else:
1390
                _LOGGER.info(
1391
                    "\t%s: %s",
1392
                    table.columns[field_obj.index],
1393
                    field_obj.value
1394
                )
1395
1396
    def __repr__(self):
1397
        return '<Record {} fields, {} bytes, header: {} bytes>'.format(
1398
            len(self._fields), len(bytes(self)), len(self.header)
1399
        )
1400
1401
1402
class MalformedField(Exception):
1403
    pass
1404
1405
1406
class MalformedRecord(Exception):
1407
    pass
1408
1409
1410
class Field(object):
1411
    def __init__(self, idx, serial_type, serial_bytes):
1412
        self._index = idx
1413
        self._type = serial_type
1414
        self._bytes = serial_bytes
1415
        self._value = None
1416
        self._parse()
1417
1418
    def _check_length(self, expected_length):
1419
        if len(self) != expected_length:
1420
            raise MalformedField
1421
1422
    # TODO Raise a specific exception when bad bytes are encountered for the
1423
    # fields and then use this to weed out bad freeblock records
1424
    def _parse(self):
1425
        if self._type == 0:
1426
            self._value = None
1427
        # Integer types
1428
        elif self._type == 1:
1429
            self._check_length(1)
1430
            self._value = decode_twos_complement(bytes(self)[0:1], 8)
1431
        elif self._type == 2:
1432
            self._check_length(2)
1433
            self._value = decode_twos_complement(bytes(self)[0:2], 16)
1434
        elif self._type == 3:
1435
            self._check_length(3)
1436
            self._value = decode_twos_complement(bytes(self)[0:3], 24)
1437
        elif self._type == 4:
1438
            self._check_length(4)
1439
            self._value = decode_twos_complement(bytes(self)[0:4], 32)
1440
        elif self._type == 5:
1441
            self._check_length(6)
1442
            self._value = decode_twos_complement(bytes(self)[0:6], 48)
1443
        elif self._type == 6:
1444
            self._check_length(8)
1445
            self._value = decode_twos_complement(bytes(self)[0:8], 64)
1446
1447
        elif self._type == 7:
1448
            self._value = struct.unpack(r'>d', bytes(self)[0:8])[0]
1449
        elif self._type == 8:
1450
            self._value = 0
1451
        elif self._type == 9:
1452
            self._value = 1
1453
        elif self._type >= 13 and (1 == self._type % 2):
1454
            try:
1455
                self._value = bytes(self).decode('utf-8')
1456
            except UnicodeDecodeError:
1457
                raise MalformedField
1458
1459
        elif self._type >= 12 and (0 == self._type % 2):
1460
            self._value = bytes(self)
1461
1462
    def __bytes__(self):
1463
        return self._bytes
1464
1465
    def __repr__(self):
1466
        return "<Field {}: {} ({} bytes)>".format(
1467
            self._index, self._value, len(bytes(self))
1468
        )
1469
1470
    def __len__(self):
1471
        return len(bytes(self))
1472
1473
    @property
1474
    def index(self):
1475
        return self._index
1476
1477
    @property
1478
    def value(self):
1479
        return self._value
1480
1481
    @property
1482
    def serial_type(self):
1483
        return self._type
1484
1485
1486
class Varint(object):
1487
    def __init__(self, varint_bytes):
1488
        self._bytes = varint_bytes
1489
        self._len = 0
1490
        self._value = 0
1491
1492
        varint_bits = []
1493
        for b in self._bytes:
1494
            self._len += 1
1495
            if b & 0x80:
1496
                varint_bits.append(b & 0x7F)
1497
            else:
1498
                varint_bits.append(b)
1499
                break
1500
1501
        varint_twos_complement = 0
1502
        for position, b in enumerate(varint_bits[::-1]):
1503
            varint_twos_complement += b * (1 << (7*position))
1504
1505
        self._value = decode_twos_complement(
1506
            int.to_bytes(varint_twos_complement, 4, byteorder='big'), 64
1507
        )
1508
1509
    def __int__(self):
1510
        return self._value
1511
1512
    def __len__(self):
1513
        return self._len
1514
1515
    def __repr__(self):
1516
        return "<Varint {} ({} bytes)>".format(int(self), len(self))
1517
1518
1519
def decode_twos_complement(encoded, bit_length):
1520
    assert(0 == bit_length % 8)
1521
    encoded_int = int.from_bytes(encoded, byteorder='big')
1522
    mask = 2**(bit_length - 1)
1523
    value = -(encoded_int & mask) + (encoded_int & ~mask)
1524
    return value
1525
1526
1527
def gen_output_dir(db_path):
1528
    db_abspath = os.path.abspath(db_path)
1529
    db_dir, db_name = os.path.split(db_abspath)
1530
1531
    munged_name = db_name.replace('.', '_')
1532
    out_dir = os.path.join(db_dir, munged_name)
1533
    if not os.path.exists(out_dir):
1534
        return out_dir
1535
    suffix = 1
1536
    while suffix <= 10:
1537
        out_dir = os.path.join(db_dir, "{}_{}".format(munged_name, suffix))
1538
        if not os.path.exists(out_dir):
1539
            return out_dir
1540
        suffix += 1
1541
    raise SystemError(
1542
        "Unreasonable number of output directories for {}".format(db_path)
1543
    )
1544
1545
1546
def _load_db(sqlite_path):
1547
    _LOGGER.info("Processing %s", sqlite_path)
1548
1549
    load_heuristics()
1550
1551
    db = SQLite_DB(sqlite_path)
1552
    _LOGGER.info("Database: %r", db)
1553
1554
    db.populate_freelist_pages()
1555
    db.populate_ptrmap_pages()
1556
    db.populate_overflow_pages()
1557
1558
    # Should we aim to instantiate specialised b-tree objects here, or is the
1559
    # use of generic btree page objects acceptable?
1560
    db.populate_btree_pages()
1561
1562
    db.map_tables()
1563
1564
    # We need a first pass to process tables that are disconnected
1565
    # from their table's root page
1566
    db.reparent_orphaned_table_leaf_pages()
1567
1568
    # All pages should now be represented by specialised objects
1569
    assert(all(isinstance(p, Page) for p in db.pages.values()))
1570
    assert(not any(type(p) is Page for p in db.pages.values()))
1571
    return db
1572
1573
1574
def dump_to_csv(args):
1575
    out_dir = args.output_dir or gen_output_dir(args.sqlite_path)
1576
    db = _load_db(args.sqlite_path)
1577
1578
    if os.path.exists(out_dir):
1579
        raise ValueError("Output directory {} exists!".format(out_dir))
1580
    os.mkdir(out_dir)
1581
1582
    for table_name in sorted(db.tables):
1583
        table = db.tables[table_name]
1584
        _LOGGER.info("Table \"%s\"", table)
1585
        table.recover_records()
1586
        table.csv_dump(out_dir)
1587
1588
1589
def undelete(args):
1590
    db_abspath = os.path.abspath(args.sqlite_path)
1591
    db = _load_db(db_abspath)
1592
1593
    output_path = os.path.abspath(args.output_path)
1594
    if os.path.exists(output_path):
1595
        raise ValueError("Output file {} exists!".format(output_path))
1596
1597
    shutil.copyfile(db_abspath, output_path)
1598
    with sqlite3.connect(output_path) as output_db_connection:
1599
        cursor = output_db_connection.cursor()
1600
        for table_name in sorted(db.tables):
1601
            table = db.tables[table_name]
1602
            _LOGGER.info("Table \"%s\"", table)
1603
            table.recover_records()
1604
1605
            failed_inserts = 0
1606
            constraint_violations = 0
1607
            successful_inserts = 0
1608
            for leaf_page in table.leaves:
1609
                if not leaf_page.recovered_records:
1610
                    continue
1611
1612
                for record in leaf_page.recovered_records:
1613
                    insert_statement, values = table.build_insert_SQL(record)
1614
1615
                    try:
1616
                        cursor.execute(insert_statement, values)
1617
                    except sqlite3.IntegrityError:
1618
                        # We gotta soldier on, there's not much we can do if a
1619
                        # constraint is violated by this insert
1620
                        constraint_violations += 1
1621
                    except (
1622
                                sqlite3.ProgrammingError,
1623
                                sqlite3.OperationalError,
1624
                                sqlite3.InterfaceError
1625
                            ) as insert_ex:
1626
                        _LOGGER.warning(
1627
                            (
1628
                                "Caught %r while executing INSERT statement "
1629
                                "in \"%s\""
1630
                            ),
1631
                            insert_ex,
1632
                            table
1633
                        )
1634
                        failed_inserts += 1
1635
                        # pdb.set_trace()
1636
                    else:
1637
                        successful_inserts += 1
1638
            if failed_inserts > 0:
1639
                _LOGGER.warning(
1640
                    "%d failed INSERT statements in \"%s\"",
1641
                    failed_inserts, table
1642
                )
1643
            if constraint_violations > 0:
1644
                _LOGGER.warning(
1645
                    "%d constraint violations statements in \"%s\"",
1646
                    constraint_violations, table
1647
                )
1648
            _LOGGER.info(
1649
                "%d successful INSERT statements in \"%s\"",
1650
                successful_inserts, table
1651
            )
1652
1653
1654
def find_in_db(args):
1655
    db = _load_db(args.sqlite_path)
1656
    db.grep(args.needle)
1657
1658
1659
subcmd_actions = {
1660
    'csv':  dump_to_csv,
1661
    'grep': find_in_db,
1662
    'undelete': undelete,
1663
}
1664
1665
1666
def subcmd_dispatcher(arg_ns):
1667
    return subcmd_actions[arg_ns.subcmd](arg_ns)
1668
1669
1670
def main():
1671
1672
    verbose_parser = argparse.ArgumentParser(add_help=False)
1673
    verbose_parser.add_argument(
1674
        '-v', '--verbose',
1675
        action='count',
1676
        help='Give *A LOT* more output.',
1677
    )
1678
1679
    cli_parser = argparse.ArgumentParser(
1680
        description=PROJECT_DESCRIPTION,
1681
        parents=[verbose_parser],
1682
    )
1683
1684
    subcmd_parsers = cli_parser.add_subparsers(
1685
        title='Subcommands',
1686
        description='%(prog)s implements the following subcommands:',
1687
        dest='subcmd',
1688
    )
1689
1690
    csv_parser = subcmd_parsers.add_parser(
1691
        'csv',
1692
        parents=[verbose_parser],
1693
        help='Dumps visible and recovered records to CSV files',
1694
        description=(
1695
            'Recovers as many records as possible from the database passed as '
1696
            'argument and outputs all visible and recovered records to CSV '
1697
            'files in output_dir'
1698
        ),
1699
    )
1700
    csv_parser.add_argument(
1701
        'sqlite_path',
1702
        help='sqlite3 file path'
1703
    )
1704
    csv_parser.add_argument(
1705
        'output_dir',
1706
        nargs='?',
1707
        default=None,
1708
        help='Output directory'
1709
    )
1710
1711
    grep_parser = subcmd_parsers.add_parser(
1712
        'grep',
1713
        parents=[verbose_parser],
1714
        help='Matches a string in one or more pages of the database',
1715
        description='Bar',
1716
    )
1717
    grep_parser.add_argument(
1718
        'sqlite_path',
1719
        help='sqlite3 file path'
1720
    )
1721
    grep_parser.add_argument(
1722
        'needle',
1723
        help='String to match in the database'
1724
    )
1725
1726
    undelete_parser = subcmd_parsers.add_parser(
1727
        'undelete',
1728
        parents=[verbose_parser],
1729
        help='Inserts recovered records into a copy of the database',
1730
        description=(
1731
            'Recovers as many records as possible from the database passed as '
1732
            'argument and inserts all recovered records into a copy of'
1733
            'the database.'
1734
        ),
1735
    )
1736
    undelete_parser.add_argument(
1737
        'sqlite_path',
1738
        help='sqlite3 file path'
1739
    )
1740
    undelete_parser.add_argument(
1741
        'output_path',
1742
        help='Output database path'
1743
    )
1744
1745
    cli_args = cli_parser.parse_args()
1746
    if cli_args.verbose:
1747
        _LOGGER.setLevel(logging.DEBUG)
1748
1749
    if cli_args.subcmd:
1750
        subcmd_dispatcher(cli_args)
1751
    else:
1752
        # No subcommand specified, print the usage and bail
1753
        cli_parser.print_help()
1754