Use prepared SQL queries.

This commit is contained in:
Ryan Roden-Corrent 2017-05-23 09:01:20 -04:00
parent 20000088de
commit e67da51662
5 changed files with 115 additions and 121 deletions

View File

@ -78,6 +78,19 @@ class WebHistory(sql.SqlTable):
def __init__(self, parent=None):
super().__init__("History", ['url', 'title', 'atime', 'redirect'],
parent=parent)
self._between_query = sql.Query('SELECT * FROM History '
'where not redirect '
'and not url like "qute://%" '
'and atime > ? '
'and atime <= ? '
'ORDER BY atime desc')
self._before_query = sql.Query('SELECT * FROM History '
'where not redirect '
'and not url like "qute://%" '
'and atime <= ? '
'ORDER BY atime desc '
'limit ? offset ?')
def __repr__(self):
return utils.get_repr(self, length=len(self))
@ -101,16 +114,8 @@ class WebHistory(sql.SqlTable):
earliest: Omit timestamps earlier than this.
latest: Omit timestamps later than this.
"""
result = sql.run_query('SELECT * FROM History '
'where not redirect '
'and not url like "qute://%" '
'and atime > {} '
'and atime <= {} '
'ORDER BY atime desc'
.format(earliest, latest))
while result.next():
rec = result.record()
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
self._between_query.run([earliest, latest])
return iter(self._between_query)
def entries_before(self, latest, limit, offset):
"""Iterate non-redirect, non-qute entries occurring before a timestamp.
@ -120,16 +125,8 @@ class WebHistory(sql.SqlTable):
limit: Max number of entries to include.
offset: Number of entries to skip.
"""
result = sql.run_query('SELECT * FROM History '
'where not redirect '
'and not url like "qute://%" '
'and atime <= {} '
'ORDER BY atime desc '
'limit {} offset {}'
.format(latest, limit, offset))
while result.next():
rec = result.record()
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
self._before_query.run([latest, limit, offset])
return iter(self._before_query)
@cmdutils.register(name='history-clear', instance='web-history')
def clear(self, force=False):

View File

@ -29,12 +29,13 @@ from qutebrowser.misc import sql
class SqlCategory(QSqlQueryModel):
"""Wraps a SqlQuery for use as a completion category."""
def __init__(self, name, *, sort_by=None, sort_order=None, select='*',
where=None, group_by=None, parent=None):
def __init__(self, name, *, filter_fields, sort_by=None, sort_order=None,
select='*', where=None, group_by=None, parent=None):
"""Create a new completion category backed by a sql table.
Args:
name: Name of category, and the table in the database.
filter_fields: Names of fields to apply filter pattern to.
select: A custom result column expression for the select statement.
where: An optional clause to filter out some rows.
sort_by: The name of the field to sort by, or None for no sorting.
@ -42,11 +43,24 @@ class SqlCategory(QSqlQueryModel):
"""
super().__init__(parent=parent)
self.name = name
self._sort_by = sort_by
self._sort_order = sort_order
self._select = select
self._where = where
self._group_by = group_by
querystr = 'select {} from {} where ('.format(select, name)
# the incoming pattern will have literal % and _ escaped with '\'
# we need to tell sql to treat '\' as an escape character
querystr += ' or '.join("{} like ? escape '\\'".format(f)
for f in filter_fields)
querystr += ')'
if where:
querystr += ' and ' + where
if group_by:
querystr += ' group by {}'.format(group_by)
if sort_by:
assert sort_order in ['asc', 'desc']
querystr += ' order by {} {}'.format(sort_by, sort_order)
self._query = sql.Query(querystr)
self._param_count=len(filter_fields)
self.set_pattern('', [0])
def set_pattern(self, pattern, columns_to_filter):
@ -56,31 +70,13 @@ class SqlCategory(QSqlQueryModel):
pattern: string pattern to filter by.
columns_to_filter: indices of columns to apply pattern to.
"""
query = sql.run_query('select * from {} limit 1'.format(self.name))
fields = [query.record().fieldName(i) for i in columns_to_filter]
querystr = 'select {} from {} where ('.format(self._select, self.name)
# the incoming pattern will have literal % and _ escaped with '\'
# we need to tell sql to treat '\' as an escape character
querystr += ' or '.join("{} like ? escape '\\'".format(f)
for f in fields)
querystr += ')'
if self._where:
querystr += ' and ' + self._where
if self._group_by:
querystr += ' group by {}'.format(self._group_by)
if self._sort_by:
assert self._sort_order in ['asc', 'desc']
querystr += ' order by {} {}'.format(self._sort_by,
self._sort_order)
# TODO: eliminate columns_to_filter
#assert len(columns_to_filter) == self._param_count
# escape to treat a user input % or _ as a literal, not a wildcard
pattern = pattern.replace('%', '\\%')
pattern = pattern.replace('_', '\\_')
# treat spaces as wildcards to match any of the typed words
pattern = re.sub(r' +', '%', pattern)
pattern = '%{}%'.format(pattern)
query = sql.run_query(querystr, [pattern for _ in fields])
self.setQuery(query)
self._query.run([pattern] * self._param_count)
self.setQuery(self._query)

View File

@ -77,6 +77,7 @@ def url():
select_time = "strftime('{}', max(atime), 'unixepoch')".format(timefmt)
hist_cat = sqlcategory.SqlCategory(
'History', sort_order='desc', sort_by='atime', group_by='url',
filter_fields=['url', 'title'],
select='url, title, {}'.format(select_time), where='not redirect')
model.add_category(hist_cat)
return model

View File

@ -50,56 +50,45 @@ def close():
def version():
"""Return the sqlite version string."""
result = run_query("select sqlite_version()")
result.next()
return result.record().value(0)
q = Query("select sqlite_version()")
q.run()
return q.value()
def _prepare_query(querystr):
log.sql.debug('Preparing SQL query: "{}"'.format(querystr))
database = QSqlDatabase.database()
query = QSqlQuery(database)
query.prepare(querystr)
return query
class Query(QSqlQuery):
"""A prepared SQL Query."""
def run_query(querystr, values=None):
"""Run the given SQL query string on the database.
def __init__(self, querystr):
super().__init__(QSqlDatabase.database())
log.sql.debug('Preparing SQL query: "{}"'.format(querystr))
self.prepare(querystr)
Args:
values: A list of positional parameter bindings.
"""
query = _prepare_query(querystr)
for val in values or []:
query.addBindValue(val)
log.sql.debug('Query bindings: {}'.format(query.boundValues()))
if not query.exec_():
raise SqlException('Failed to exec query "{}": "{}"'.format(
querystr, query.lastError().text()))
return query
def __iter__(self):
assert self.isActive(), "Cannot iterate inactive query"
rec = self.record()
fields = [rec.fieldName(i) for i in range(rec.count())]
rowtype = collections.namedtuple('ResultRow', fields)
while self.next():
rec = self.record()
yield rowtype(*[rec.value(i) for i in range(rec.count())])
def run_batch(querystr, values):
"""Run the given SQL query string on the database in batch mode.
def run(self, values=None):
"""Execute the prepared query."""
log.sql.debug('Running SQL query: "{}"'.format(self.lastQuery()))
for val in values or []:
self.addBindValue(val)
log.sql.debug('self bindings: {}'.format(self.boundValues()))
if not self.exec_():
raise SqlException('Failed to exec self "{}": "{}"'.format(
self.lastself(), self.lastError().text()))
Args:
values: A list of lists, where each inner list contains positional
bindings for one run of the batch.
"""
query = _prepare_query(querystr)
transposed = [list(row) for row in zip(*values)]
for val in transposed:
query.addBindValue(val)
log.sql.debug('Batch Query bindings: {}'.format(query.boundValues()))
db = QSqlDatabase.database()
db.transaction()
if not query.execBatch():
raise SqlException('Failed to exec query "{}": "{}"'.format(
querystr, query.lastError().text()))
db.commit()
return query
def value(self):
"""Return the result of a single-value query (e.g. an EXISTS)."""
ok = self.next()
assert ok, "No result for single-result query"
return self.record().value(0)
class SqlTable(QObject):
@ -127,17 +116,17 @@ class SqlTable(QObject):
"""
super().__init__(parent)
self._name = name
run_query("CREATE TABLE IF NOT EXISTS {} ({})"
q = Query("CREATE TABLE IF NOT EXISTS {} ({})"
.format(name, ','.join(fields)))
q.run()
# pylint: disable=invalid-name
self.Entry = collections.namedtuple(name + '_Entry', fields)
def __iter__(self):
"""Iterate rows in the table."""
result = run_query("SELECT * FROM {}".format(self._name))
while result.next():
rec = result.record()
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
q = Query("SELECT * FROM {}".format(self._name))
q.run()
return iter(q)
def contains(self, field, value):
"""Return whether the table contains the matching item.
@ -146,16 +135,16 @@ class SqlTable(QObject):
field: Field to match.
value: Value to check for the given field.
"""
query = run_query("Select EXISTS(SELECT * FROM {} where {} = ?)"
.format(self._name, field), [value])
query.next()
return query.value(0)
q = Query("Select EXISTS(SELECT * FROM {} where {} = ?)"
.format(self._name, field))
q.run([value])
return q.value()
def __len__(self):
"""Return the count of rows in the table."""
result = run_query("SELECT count(*) FROM {}".format(self._name))
result.next()
return result.value(0)
q = Query("SELECT count(*) FROM {}".format(self._name))
q.run()
return q.value()
def delete(self, value, field):
"""Remove all rows for which `field` equals `value`.
@ -167,9 +156,9 @@ class SqlTable(QObject):
Return:
The number of rows deleted.
"""
query = run_query("DELETE FROM {} where {} = ?".format(
self._name, field), [value])
if not query.numRowsAffected():
q = Query("DELETE FROM {} where {} = ?".format(self._name, field))
q.run([value])
if not q.numRowsAffected():
raise KeyError('No row with {} = "{}"'.format(field, value))
self.changed.emit()
@ -180,8 +169,8 @@ class SqlTable(QObject):
values: A list of values to insert.
"""
paramstr = ','.join(['?'] * len(values))
run_query("INSERT INTO {} values({})".format(self._name, paramstr),
values)
q = Query("INSERT INTO {} values({})".format(self._name, paramstr))
q.run(values)
self.changed.emit()
def insert_batch(self, rows):
@ -191,13 +180,23 @@ class SqlTable(QObject):
rows: A list of lists, where each sub-list is a row.
"""
paramstr = ','.join(['?'] * len(rows[0]))
run_batch("INSERT INTO {} values({})".format(self._name, paramstr),
rows)
q = Query("INSERT INTO {} values({})".format(self._name, paramstr))
transposed = [list(row) for row in zip(*rows)]
for val in transposed:
q.addBindValue(val)
db = QSqlDatabase.database()
db.transaction()
if not q.execBatch():
raise SqlException('Failed to exec query "{}": "{}"'.format(
q.lastQuery(), q.lastError().text()))
db.commit()
self.changed.emit()
def delete_all(self):
"""Remove all row from the table."""
run_query("DELETE FROM {}".format(self._name))
Query("DELETE FROM {}".format(self._name)).run()
self.changed.emit()
def select(self, sort_by, sort_order, limit=-1):
@ -208,8 +207,7 @@ class SqlTable(QObject):
sort_order: 'asc' or 'desc'.
limit: max number of rows in result, defaults to -1 (unlimited).
"""
result = run_query('SELECT * FROM {} ORDER BY {} {} LIMIT {}'
.format(self._name, sort_by, sort_order, limit))
while result.next():
rec = result.record()
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
q = Query('SELECT * FROM {} ORDER BY {} {} LIMIT ?'
.format(self._name, sort_by, sort_order))
q.run([limit])
return q

View File

@ -74,7 +74,7 @@ def test_sorting(sort_by, sort_order, data, expected):
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
for row in data:
table.insert(row)
cat = sqlcategory.SqlCategory('Foo', sort_by=sort_by,
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], sort_by=sort_by,
sort_order=sort_order)
_validate(cat, expected)
@ -129,7 +129,8 @@ def test_set_pattern(pattern, filter_cols, before, after):
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
for row in before:
table.insert(row)
cat = sqlcategory.SqlCategory('Foo')
filter_fields = [['a', 'b', 'c'][i] for i in filter_cols]
cat = sqlcategory.SqlCategory('Foo', filter_fields=filter_fields)
cat.set_pattern(pattern, filter_cols)
_validate(cat, after)
@ -137,7 +138,7 @@ def test_set_pattern(pattern, filter_cols, before, after):
def test_select():
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
table.insert(['foo', 'bar', 'baz'])
cat = sqlcategory.SqlCategory('Foo', select='b, c, a')
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], select='b, c, a')
_validate(cat, [('bar', 'baz', 'foo')])
@ -145,7 +146,7 @@ def test_where():
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
table.insert(['foo', 'bar', False])
table.insert(['baz', 'biz', True])
cat = sqlcategory.SqlCategory('Foo', where='not c')
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], where='not c')
_validate(cat, [('foo', 'bar', False)])
@ -155,7 +156,8 @@ def test_group():
table.insert(['bar', 3])
table.insert(['foo', 2])
table.insert(['bar', 0])
cat = sqlcategory.SqlCategory('Foo', select='a, max(b)', group_by='a')
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'],
select='a, max(b)', group_by='a')
_validate(cat, [('bar', 3), ('foo', 2)])