diff --git a/qutebrowser/browser/history.py b/qutebrowser/browser/history.py index e19463e75..d8b1f81d5 100644 --- a/qutebrowser/browser/history.py +++ b/qutebrowser/browser/history.py @@ -78,6 +78,19 @@ class WebHistory(sql.SqlTable): def __init__(self, parent=None): super().__init__("History", ['url', 'title', 'atime', 'redirect'], parent=parent) + self._between_query = sql.Query('SELECT * FROM History ' + 'where not redirect ' + 'and not url like "qute://%" ' + 'and atime > ? ' + 'and atime <= ? ' + 'ORDER BY atime desc') + + self._before_query = sql.Query('SELECT * FROM History ' + 'where not redirect ' + 'and not url like "qute://%" ' + 'and atime <= ? ' + 'ORDER BY atime desc ' + 'limit ? offset ?') def __repr__(self): return utils.get_repr(self, length=len(self)) @@ -101,16 +114,8 @@ class WebHistory(sql.SqlTable): earliest: Omit timestamps earlier than this. latest: Omit timestamps later than this. """ - result = sql.run_query('SELECT * FROM History ' - 'where not redirect ' - 'and not url like "qute://%" ' - 'and atime > {} ' - 'and atime <= {} ' - 'ORDER BY atime desc' - .format(earliest, latest)) - while result.next(): - rec = result.record() - yield self.Entry(*[rec.value(i) for i in range(rec.count())]) + self._between_query.run([earliest, latest]) + return iter(self._between_query) def entries_before(self, latest, limit, offset): """Iterate non-redirect, non-qute entries occurring before a timestamp. @@ -120,16 +125,8 @@ class WebHistory(sql.SqlTable): limit: Max number of entries to include. offset: Number of entries to skip. """ - result = sql.run_query('SELECT * FROM History ' - 'where not redirect ' - 'and not url like "qute://%" ' - 'and atime <= {} ' - 'ORDER BY atime desc ' - 'limit {} offset {}' - .format(latest, limit, offset)) - while result.next(): - rec = result.record() - yield self.Entry(*[rec.value(i) for i in range(rec.count())]) + self._before_query.run([latest, limit, offset]) + return iter(self._before_query) @cmdutils.register(name='history-clear', instance='web-history') def clear(self, force=False): diff --git a/qutebrowser/completion/models/sqlcategory.py b/qutebrowser/completion/models/sqlcategory.py index 9f78163ca..543ea2ade 100644 --- a/qutebrowser/completion/models/sqlcategory.py +++ b/qutebrowser/completion/models/sqlcategory.py @@ -29,12 +29,13 @@ from qutebrowser.misc import sql class SqlCategory(QSqlQueryModel): """Wraps a SqlQuery for use as a completion category.""" - def __init__(self, name, *, sort_by=None, sort_order=None, select='*', - where=None, group_by=None, parent=None): + def __init__(self, name, *, filter_fields, sort_by=None, sort_order=None, + select='*', where=None, group_by=None, parent=None): """Create a new completion category backed by a sql table. Args: name: Name of category, and the table in the database. + filter_fields: Names of fields to apply filter pattern to. select: A custom result column expression for the select statement. where: An optional clause to filter out some rows. sort_by: The name of the field to sort by, or None for no sorting. @@ -42,11 +43,24 @@ class SqlCategory(QSqlQueryModel): """ super().__init__(parent=parent) self.name = name - self._sort_by = sort_by - self._sort_order = sort_order - self._select = select - self._where = where - self._group_by = group_by + + querystr = 'select {} from {} where ('.format(select, name) + # the incoming pattern will have literal % and _ escaped with '\' + # we need to tell sql to treat '\' as an escape character + querystr += ' or '.join("{} like ? escape '\\'".format(f) + for f in filter_fields) + querystr += ')' + + if where: + querystr += ' and ' + where + if group_by: + querystr += ' group by {}'.format(group_by) + if sort_by: + assert sort_order in ['asc', 'desc'] + querystr += ' order by {} {}'.format(sort_by, sort_order) + + self._query = sql.Query(querystr) + self._param_count=len(filter_fields) self.set_pattern('', [0]) def set_pattern(self, pattern, columns_to_filter): @@ -56,31 +70,13 @@ class SqlCategory(QSqlQueryModel): pattern: string pattern to filter by. columns_to_filter: indices of columns to apply pattern to. """ - query = sql.run_query('select * from {} limit 1'.format(self.name)) - fields = [query.record().fieldName(i) for i in columns_to_filter] - - querystr = 'select {} from {} where ('.format(self._select, self.name) - # the incoming pattern will have literal % and _ escaped with '\' - # we need to tell sql to treat '\' as an escape character - querystr += ' or '.join("{} like ? escape '\\'".format(f) - for f in fields) - querystr += ')' - if self._where: - querystr += ' and ' + self._where - - if self._group_by: - querystr += ' group by {}'.format(self._group_by) - - if self._sort_by: - assert self._sort_order in ['asc', 'desc'] - querystr += ' order by {} {}'.format(self._sort_by, - self._sort_order) - + # TODO: eliminate columns_to_filter + #assert len(columns_to_filter) == self._param_count # escape to treat a user input % or _ as a literal, not a wildcard pattern = pattern.replace('%', '\\%') pattern = pattern.replace('_', '\\_') # treat spaces as wildcards to match any of the typed words pattern = re.sub(r' +', '%', pattern) pattern = '%{}%'.format(pattern) - query = sql.run_query(querystr, [pattern for _ in fields]) - self.setQuery(query) + self._query.run([pattern] * self._param_count) + self.setQuery(self._query) diff --git a/qutebrowser/completion/models/urlmodel.py b/qutebrowser/completion/models/urlmodel.py index 9e601cd39..29ecdca9f 100644 --- a/qutebrowser/completion/models/urlmodel.py +++ b/qutebrowser/completion/models/urlmodel.py @@ -77,6 +77,7 @@ def url(): select_time = "strftime('{}', max(atime), 'unixepoch')".format(timefmt) hist_cat = sqlcategory.SqlCategory( 'History', sort_order='desc', sort_by='atime', group_by='url', + filter_fields=['url', 'title'], select='url, title, {}'.format(select_time), where='not redirect') model.add_category(hist_cat) return model diff --git a/qutebrowser/misc/sql.py b/qutebrowser/misc/sql.py index 5c49767e8..5acc620fb 100644 --- a/qutebrowser/misc/sql.py +++ b/qutebrowser/misc/sql.py @@ -50,56 +50,45 @@ def close(): def version(): """Return the sqlite version string.""" - result = run_query("select sqlite_version()") - result.next() - return result.record().value(0) + q = Query("select sqlite_version()") + q.run() + return q.value() -def _prepare_query(querystr): - log.sql.debug('Preparing SQL query: "{}"'.format(querystr)) - database = QSqlDatabase.database() - query = QSqlQuery(database) - query.prepare(querystr) - return query +class Query(QSqlQuery): + """A prepared SQL Query.""" -def run_query(querystr, values=None): - """Run the given SQL query string on the database. + def __init__(self, querystr): + super().__init__(QSqlDatabase.database()) + log.sql.debug('Preparing SQL query: "{}"'.format(querystr)) + self.prepare(querystr) - Args: - values: A list of positional parameter bindings. - """ - query = _prepare_query(querystr) - for val in values or []: - query.addBindValue(val) - log.sql.debug('Query bindings: {}'.format(query.boundValues())) - if not query.exec_(): - raise SqlException('Failed to exec query "{}": "{}"'.format( - querystr, query.lastError().text())) - return query + def __iter__(self): + assert self.isActive(), "Cannot iterate inactive query" + rec = self.record() + fields = [rec.fieldName(i) for i in range(rec.count())] + rowtype = collections.namedtuple('ResultRow', fields) + while self.next(): + rec = self.record() + yield rowtype(*[rec.value(i) for i in range(rec.count())]) -def run_batch(querystr, values): - """Run the given SQL query string on the database in batch mode. + def run(self, values=None): + """Execute the prepared query.""" + log.sql.debug('Running SQL query: "{}"'.format(self.lastQuery())) + for val in values or []: + self.addBindValue(val) + log.sql.debug('self bindings: {}'.format(self.boundValues())) + if not self.exec_(): + raise SqlException('Failed to exec self "{}": "{}"'.format( + self.lastself(), self.lastError().text())) - Args: - values: A list of lists, where each inner list contains positional - bindings for one run of the batch. - """ - query = _prepare_query(querystr) - transposed = [list(row) for row in zip(*values)] - for val in transposed: - query.addBindValue(val) - log.sql.debug('Batch Query bindings: {}'.format(query.boundValues())) - - db = QSqlDatabase.database() - db.transaction() - if not query.execBatch(): - raise SqlException('Failed to exec query "{}": "{}"'.format( - querystr, query.lastError().text())) - db.commit() - - return query + def value(self): + """Return the result of a single-value query (e.g. an EXISTS).""" + ok = self.next() + assert ok, "No result for single-result query" + return self.record().value(0) class SqlTable(QObject): @@ -127,17 +116,17 @@ class SqlTable(QObject): """ super().__init__(parent) self._name = name - run_query("CREATE TABLE IF NOT EXISTS {} ({})" + q = Query("CREATE TABLE IF NOT EXISTS {} ({})" .format(name, ','.join(fields))) + q.run() # pylint: disable=invalid-name self.Entry = collections.namedtuple(name + '_Entry', fields) def __iter__(self): """Iterate rows in the table.""" - result = run_query("SELECT * FROM {}".format(self._name)) - while result.next(): - rec = result.record() - yield self.Entry(*[rec.value(i) for i in range(rec.count())]) + q = Query("SELECT * FROM {}".format(self._name)) + q.run() + return iter(q) def contains(self, field, value): """Return whether the table contains the matching item. @@ -146,16 +135,16 @@ class SqlTable(QObject): field: Field to match. value: Value to check for the given field. """ - query = run_query("Select EXISTS(SELECT * FROM {} where {} = ?)" - .format(self._name, field), [value]) - query.next() - return query.value(0) + q = Query("Select EXISTS(SELECT * FROM {} where {} = ?)" + .format(self._name, field)) + q.run([value]) + return q.value() def __len__(self): """Return the count of rows in the table.""" - result = run_query("SELECT count(*) FROM {}".format(self._name)) - result.next() - return result.value(0) + q = Query("SELECT count(*) FROM {}".format(self._name)) + q.run() + return q.value() def delete(self, value, field): """Remove all rows for which `field` equals `value`. @@ -167,9 +156,9 @@ class SqlTable(QObject): Return: The number of rows deleted. """ - query = run_query("DELETE FROM {} where {} = ?".format( - self._name, field), [value]) - if not query.numRowsAffected(): + q = Query("DELETE FROM {} where {} = ?".format(self._name, field)) + q.run([value]) + if not q.numRowsAffected(): raise KeyError('No row with {} = "{}"'.format(field, value)) self.changed.emit() @@ -180,8 +169,8 @@ class SqlTable(QObject): values: A list of values to insert. """ paramstr = ','.join(['?'] * len(values)) - run_query("INSERT INTO {} values({})".format(self._name, paramstr), - values) + q = Query("INSERT INTO {} values({})".format(self._name, paramstr)) + q.run(values) self.changed.emit() def insert_batch(self, rows): @@ -191,13 +180,23 @@ class SqlTable(QObject): rows: A list of lists, where each sub-list is a row. """ paramstr = ','.join(['?'] * len(rows[0])) - run_batch("INSERT INTO {} values({})".format(self._name, paramstr), - rows) + q = Query("INSERT INTO {} values({})".format(self._name, paramstr)) + + transposed = [list(row) for row in zip(*rows)] + for val in transposed: + q.addBindValue(val) + + db = QSqlDatabase.database() + db.transaction() + if not q.execBatch(): + raise SqlException('Failed to exec query "{}": "{}"'.format( + q.lastQuery(), q.lastError().text())) + db.commit() self.changed.emit() def delete_all(self): """Remove all row from the table.""" - run_query("DELETE FROM {}".format(self._name)) + Query("DELETE FROM {}".format(self._name)).run() self.changed.emit() def select(self, sort_by, sort_order, limit=-1): @@ -208,8 +207,7 @@ class SqlTable(QObject): sort_order: 'asc' or 'desc'. limit: max number of rows in result, defaults to -1 (unlimited). """ - result = run_query('SELECT * FROM {} ORDER BY {} {} LIMIT {}' - .format(self._name, sort_by, sort_order, limit)) - while result.next(): - rec = result.record() - yield self.Entry(*[rec.value(i) for i in range(rec.count())]) + q = Query('SELECT * FROM {} ORDER BY {} {} LIMIT ?' + .format(self._name, sort_by, sort_order)) + q.run([limit]) + return q diff --git a/tests/unit/completion/test_sqlcategory.py b/tests/unit/completion/test_sqlcategory.py index f29589717..3d0b07d07 100644 --- a/tests/unit/completion/test_sqlcategory.py +++ b/tests/unit/completion/test_sqlcategory.py @@ -74,7 +74,7 @@ def test_sorting(sort_by, sort_order, data, expected): table = sql.SqlTable('Foo', ['a', 'b', 'c']) for row in data: table.insert(row) - cat = sqlcategory.SqlCategory('Foo', sort_by=sort_by, + cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], sort_by=sort_by, sort_order=sort_order) _validate(cat, expected) @@ -129,7 +129,8 @@ def test_set_pattern(pattern, filter_cols, before, after): table = sql.SqlTable('Foo', ['a', 'b', 'c']) for row in before: table.insert(row) - cat = sqlcategory.SqlCategory('Foo') + filter_fields = [['a', 'b', 'c'][i] for i in filter_cols] + cat = sqlcategory.SqlCategory('Foo', filter_fields=filter_fields) cat.set_pattern(pattern, filter_cols) _validate(cat, after) @@ -137,7 +138,7 @@ def test_set_pattern(pattern, filter_cols, before, after): def test_select(): table = sql.SqlTable('Foo', ['a', 'b', 'c']) table.insert(['foo', 'bar', 'baz']) - cat = sqlcategory.SqlCategory('Foo', select='b, c, a') + cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], select='b, c, a') _validate(cat, [('bar', 'baz', 'foo')]) @@ -145,7 +146,7 @@ def test_where(): table = sql.SqlTable('Foo', ['a', 'b', 'c']) table.insert(['foo', 'bar', False]) table.insert(['baz', 'biz', True]) - cat = sqlcategory.SqlCategory('Foo', where='not c') + cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], where='not c') _validate(cat, [('foo', 'bar', False)]) @@ -155,7 +156,8 @@ def test_group(): table.insert(['bar', 3]) table.insert(['foo', 2]) table.insert(['bar', 0]) - cat = sqlcategory.SqlCategory('Foo', select='a, max(b)', group_by='a') + cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], + select='a, max(b)', group_by='a') _validate(cat, [('bar', 3), ('foo', 2)])