Use prepared SQL queries.
This commit is contained in:
parent
20000088de
commit
e67da51662
@ -78,6 +78,19 @@ class WebHistory(sql.SqlTable):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__("History", ['url', 'title', 'atime', 'redirect'],
|
||||
parent=parent)
|
||||
self._between_query = sql.Query('SELECT * FROM History '
|
||||
'where not redirect '
|
||||
'and not url like "qute://%" '
|
||||
'and atime > ? '
|
||||
'and atime <= ? '
|
||||
'ORDER BY atime desc')
|
||||
|
||||
self._before_query = sql.Query('SELECT * FROM History '
|
||||
'where not redirect '
|
||||
'and not url like "qute://%" '
|
||||
'and atime <= ? '
|
||||
'ORDER BY atime desc '
|
||||
'limit ? offset ?')
|
||||
|
||||
def __repr__(self):
|
||||
return utils.get_repr(self, length=len(self))
|
||||
@ -101,16 +114,8 @@ class WebHistory(sql.SqlTable):
|
||||
earliest: Omit timestamps earlier than this.
|
||||
latest: Omit timestamps later than this.
|
||||
"""
|
||||
result = sql.run_query('SELECT * FROM History '
|
||||
'where not redirect '
|
||||
'and not url like "qute://%" '
|
||||
'and atime > {} '
|
||||
'and atime <= {} '
|
||||
'ORDER BY atime desc'
|
||||
.format(earliest, latest))
|
||||
while result.next():
|
||||
rec = result.record()
|
||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
||||
self._between_query.run([earliest, latest])
|
||||
return iter(self._between_query)
|
||||
|
||||
def entries_before(self, latest, limit, offset):
|
||||
"""Iterate non-redirect, non-qute entries occurring before a timestamp.
|
||||
@ -120,16 +125,8 @@ class WebHistory(sql.SqlTable):
|
||||
limit: Max number of entries to include.
|
||||
offset: Number of entries to skip.
|
||||
"""
|
||||
result = sql.run_query('SELECT * FROM History '
|
||||
'where not redirect '
|
||||
'and not url like "qute://%" '
|
||||
'and atime <= {} '
|
||||
'ORDER BY atime desc '
|
||||
'limit {} offset {}'
|
||||
.format(latest, limit, offset))
|
||||
while result.next():
|
||||
rec = result.record()
|
||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
||||
self._before_query.run([latest, limit, offset])
|
||||
return iter(self._before_query)
|
||||
|
||||
@cmdutils.register(name='history-clear', instance='web-history')
|
||||
def clear(self, force=False):
|
||||
|
@ -29,12 +29,13 @@ from qutebrowser.misc import sql
|
||||
class SqlCategory(QSqlQueryModel):
|
||||
"""Wraps a SqlQuery for use as a completion category."""
|
||||
|
||||
def __init__(self, name, *, sort_by=None, sort_order=None, select='*',
|
||||
where=None, group_by=None, parent=None):
|
||||
def __init__(self, name, *, filter_fields, sort_by=None, sort_order=None,
|
||||
select='*', where=None, group_by=None, parent=None):
|
||||
"""Create a new completion category backed by a sql table.
|
||||
|
||||
Args:
|
||||
name: Name of category, and the table in the database.
|
||||
filter_fields: Names of fields to apply filter pattern to.
|
||||
select: A custom result column expression for the select statement.
|
||||
where: An optional clause to filter out some rows.
|
||||
sort_by: The name of the field to sort by, or None for no sorting.
|
||||
@ -42,11 +43,24 @@ class SqlCategory(QSqlQueryModel):
|
||||
"""
|
||||
super().__init__(parent=parent)
|
||||
self.name = name
|
||||
self._sort_by = sort_by
|
||||
self._sort_order = sort_order
|
||||
self._select = select
|
||||
self._where = where
|
||||
self._group_by = group_by
|
||||
|
||||
querystr = 'select {} from {} where ('.format(select, name)
|
||||
# the incoming pattern will have literal % and _ escaped with '\'
|
||||
# we need to tell sql to treat '\' as an escape character
|
||||
querystr += ' or '.join("{} like ? escape '\\'".format(f)
|
||||
for f in filter_fields)
|
||||
querystr += ')'
|
||||
|
||||
if where:
|
||||
querystr += ' and ' + where
|
||||
if group_by:
|
||||
querystr += ' group by {}'.format(group_by)
|
||||
if sort_by:
|
||||
assert sort_order in ['asc', 'desc']
|
||||
querystr += ' order by {} {}'.format(sort_by, sort_order)
|
||||
|
||||
self._query = sql.Query(querystr)
|
||||
self._param_count=len(filter_fields)
|
||||
self.set_pattern('', [0])
|
||||
|
||||
def set_pattern(self, pattern, columns_to_filter):
|
||||
@ -56,31 +70,13 @@ class SqlCategory(QSqlQueryModel):
|
||||
pattern: string pattern to filter by.
|
||||
columns_to_filter: indices of columns to apply pattern to.
|
||||
"""
|
||||
query = sql.run_query('select * from {} limit 1'.format(self.name))
|
||||
fields = [query.record().fieldName(i) for i in columns_to_filter]
|
||||
|
||||
querystr = 'select {} from {} where ('.format(self._select, self.name)
|
||||
# the incoming pattern will have literal % and _ escaped with '\'
|
||||
# we need to tell sql to treat '\' as an escape character
|
||||
querystr += ' or '.join("{} like ? escape '\\'".format(f)
|
||||
for f in fields)
|
||||
querystr += ')'
|
||||
if self._where:
|
||||
querystr += ' and ' + self._where
|
||||
|
||||
if self._group_by:
|
||||
querystr += ' group by {}'.format(self._group_by)
|
||||
|
||||
if self._sort_by:
|
||||
assert self._sort_order in ['asc', 'desc']
|
||||
querystr += ' order by {} {}'.format(self._sort_by,
|
||||
self._sort_order)
|
||||
|
||||
# TODO: eliminate columns_to_filter
|
||||
#assert len(columns_to_filter) == self._param_count
|
||||
# escape to treat a user input % or _ as a literal, not a wildcard
|
||||
pattern = pattern.replace('%', '\\%')
|
||||
pattern = pattern.replace('_', '\\_')
|
||||
# treat spaces as wildcards to match any of the typed words
|
||||
pattern = re.sub(r' +', '%', pattern)
|
||||
pattern = '%{}%'.format(pattern)
|
||||
query = sql.run_query(querystr, [pattern for _ in fields])
|
||||
self.setQuery(query)
|
||||
self._query.run([pattern] * self._param_count)
|
||||
self.setQuery(self._query)
|
||||
|
@ -77,6 +77,7 @@ def url():
|
||||
select_time = "strftime('{}', max(atime), 'unixepoch')".format(timefmt)
|
||||
hist_cat = sqlcategory.SqlCategory(
|
||||
'History', sort_order='desc', sort_by='atime', group_by='url',
|
||||
filter_fields=['url', 'title'],
|
||||
select='url, title, {}'.format(select_time), where='not redirect')
|
||||
model.add_category(hist_cat)
|
||||
return model
|
||||
|
@ -50,56 +50,45 @@ def close():
|
||||
|
||||
def version():
|
||||
"""Return the sqlite version string."""
|
||||
result = run_query("select sqlite_version()")
|
||||
result.next()
|
||||
return result.record().value(0)
|
||||
q = Query("select sqlite_version()")
|
||||
q.run()
|
||||
return q.value()
|
||||
|
||||
|
||||
def _prepare_query(querystr):
|
||||
log.sql.debug('Preparing SQL query: "{}"'.format(querystr))
|
||||
database = QSqlDatabase.database()
|
||||
query = QSqlQuery(database)
|
||||
query.prepare(querystr)
|
||||
return query
|
||||
class Query(QSqlQuery):
|
||||
|
||||
"""A prepared SQL Query."""
|
||||
|
||||
def run_query(querystr, values=None):
|
||||
"""Run the given SQL query string on the database.
|
||||
def __init__(self, querystr):
|
||||
super().__init__(QSqlDatabase.database())
|
||||
log.sql.debug('Preparing SQL query: "{}"'.format(querystr))
|
||||
self.prepare(querystr)
|
||||
|
||||
Args:
|
||||
values: A list of positional parameter bindings.
|
||||
"""
|
||||
query = _prepare_query(querystr)
|
||||
for val in values or []:
|
||||
query.addBindValue(val)
|
||||
log.sql.debug('Query bindings: {}'.format(query.boundValues()))
|
||||
if not query.exec_():
|
||||
raise SqlException('Failed to exec query "{}": "{}"'.format(
|
||||
querystr, query.lastError().text()))
|
||||
return query
|
||||
def __iter__(self):
|
||||
assert self.isActive(), "Cannot iterate inactive query"
|
||||
rec = self.record()
|
||||
fields = [rec.fieldName(i) for i in range(rec.count())]
|
||||
rowtype = collections.namedtuple('ResultRow', fields)
|
||||
|
||||
while self.next():
|
||||
rec = self.record()
|
||||
yield rowtype(*[rec.value(i) for i in range(rec.count())])
|
||||
|
||||
def run_batch(querystr, values):
|
||||
"""Run the given SQL query string on the database in batch mode.
|
||||
def run(self, values=None):
|
||||
"""Execute the prepared query."""
|
||||
log.sql.debug('Running SQL query: "{}"'.format(self.lastQuery()))
|
||||
for val in values or []:
|
||||
self.addBindValue(val)
|
||||
log.sql.debug('self bindings: {}'.format(self.boundValues()))
|
||||
if not self.exec_():
|
||||
raise SqlException('Failed to exec self "{}": "{}"'.format(
|
||||
self.lastself(), self.lastError().text()))
|
||||
|
||||
Args:
|
||||
values: A list of lists, where each inner list contains positional
|
||||
bindings for one run of the batch.
|
||||
"""
|
||||
query = _prepare_query(querystr)
|
||||
transposed = [list(row) for row in zip(*values)]
|
||||
for val in transposed:
|
||||
query.addBindValue(val)
|
||||
log.sql.debug('Batch Query bindings: {}'.format(query.boundValues()))
|
||||
|
||||
db = QSqlDatabase.database()
|
||||
db.transaction()
|
||||
if not query.execBatch():
|
||||
raise SqlException('Failed to exec query "{}": "{}"'.format(
|
||||
querystr, query.lastError().text()))
|
||||
db.commit()
|
||||
|
||||
return query
|
||||
def value(self):
|
||||
"""Return the result of a single-value query (e.g. an EXISTS)."""
|
||||
ok = self.next()
|
||||
assert ok, "No result for single-result query"
|
||||
return self.record().value(0)
|
||||
|
||||
|
||||
class SqlTable(QObject):
|
||||
@ -127,17 +116,17 @@ class SqlTable(QObject):
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self._name = name
|
||||
run_query("CREATE TABLE IF NOT EXISTS {} ({})"
|
||||
q = Query("CREATE TABLE IF NOT EXISTS {} ({})"
|
||||
.format(name, ','.join(fields)))
|
||||
q.run()
|
||||
# pylint: disable=invalid-name
|
||||
self.Entry = collections.namedtuple(name + '_Entry', fields)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate rows in the table."""
|
||||
result = run_query("SELECT * FROM {}".format(self._name))
|
||||
while result.next():
|
||||
rec = result.record()
|
||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
||||
q = Query("SELECT * FROM {}".format(self._name))
|
||||
q.run()
|
||||
return iter(q)
|
||||
|
||||
def contains(self, field, value):
|
||||
"""Return whether the table contains the matching item.
|
||||
@ -146,16 +135,16 @@ class SqlTable(QObject):
|
||||
field: Field to match.
|
||||
value: Value to check for the given field.
|
||||
"""
|
||||
query = run_query("Select EXISTS(SELECT * FROM {} where {} = ?)"
|
||||
.format(self._name, field), [value])
|
||||
query.next()
|
||||
return query.value(0)
|
||||
q = Query("Select EXISTS(SELECT * FROM {} where {} = ?)"
|
||||
.format(self._name, field))
|
||||
q.run([value])
|
||||
return q.value()
|
||||
|
||||
def __len__(self):
|
||||
"""Return the count of rows in the table."""
|
||||
result = run_query("SELECT count(*) FROM {}".format(self._name))
|
||||
result.next()
|
||||
return result.value(0)
|
||||
q = Query("SELECT count(*) FROM {}".format(self._name))
|
||||
q.run()
|
||||
return q.value()
|
||||
|
||||
def delete(self, value, field):
|
||||
"""Remove all rows for which `field` equals `value`.
|
||||
@ -167,9 +156,9 @@ class SqlTable(QObject):
|
||||
Return:
|
||||
The number of rows deleted.
|
||||
"""
|
||||
query = run_query("DELETE FROM {} where {} = ?".format(
|
||||
self._name, field), [value])
|
||||
if not query.numRowsAffected():
|
||||
q = Query("DELETE FROM {} where {} = ?".format(self._name, field))
|
||||
q.run([value])
|
||||
if not q.numRowsAffected():
|
||||
raise KeyError('No row with {} = "{}"'.format(field, value))
|
||||
self.changed.emit()
|
||||
|
||||
@ -180,8 +169,8 @@ class SqlTable(QObject):
|
||||
values: A list of values to insert.
|
||||
"""
|
||||
paramstr = ','.join(['?'] * len(values))
|
||||
run_query("INSERT INTO {} values({})".format(self._name, paramstr),
|
||||
values)
|
||||
q = Query("INSERT INTO {} values({})".format(self._name, paramstr))
|
||||
q.run(values)
|
||||
self.changed.emit()
|
||||
|
||||
def insert_batch(self, rows):
|
||||
@ -191,13 +180,23 @@ class SqlTable(QObject):
|
||||
rows: A list of lists, where each sub-list is a row.
|
||||
"""
|
||||
paramstr = ','.join(['?'] * len(rows[0]))
|
||||
run_batch("INSERT INTO {} values({})".format(self._name, paramstr),
|
||||
rows)
|
||||
q = Query("INSERT INTO {} values({})".format(self._name, paramstr))
|
||||
|
||||
transposed = [list(row) for row in zip(*rows)]
|
||||
for val in transposed:
|
||||
q.addBindValue(val)
|
||||
|
||||
db = QSqlDatabase.database()
|
||||
db.transaction()
|
||||
if not q.execBatch():
|
||||
raise SqlException('Failed to exec query "{}": "{}"'.format(
|
||||
q.lastQuery(), q.lastError().text()))
|
||||
db.commit()
|
||||
self.changed.emit()
|
||||
|
||||
def delete_all(self):
|
||||
"""Remove all row from the table."""
|
||||
run_query("DELETE FROM {}".format(self._name))
|
||||
Query("DELETE FROM {}".format(self._name)).run()
|
||||
self.changed.emit()
|
||||
|
||||
def select(self, sort_by, sort_order, limit=-1):
|
||||
@ -208,8 +207,7 @@ class SqlTable(QObject):
|
||||
sort_order: 'asc' or 'desc'.
|
||||
limit: max number of rows in result, defaults to -1 (unlimited).
|
||||
"""
|
||||
result = run_query('SELECT * FROM {} ORDER BY {} {} LIMIT {}'
|
||||
.format(self._name, sort_by, sort_order, limit))
|
||||
while result.next():
|
||||
rec = result.record()
|
||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
||||
q = Query('SELECT * FROM {} ORDER BY {} {} LIMIT ?'
|
||||
.format(self._name, sort_by, sort_order))
|
||||
q.run([limit])
|
||||
return q
|
||||
|
@ -74,7 +74,7 @@ def test_sorting(sort_by, sort_order, data, expected):
|
||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||
for row in data:
|
||||
table.insert(row)
|
||||
cat = sqlcategory.SqlCategory('Foo', sort_by=sort_by,
|
||||
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], sort_by=sort_by,
|
||||
sort_order=sort_order)
|
||||
_validate(cat, expected)
|
||||
|
||||
@ -129,7 +129,8 @@ def test_set_pattern(pattern, filter_cols, before, after):
|
||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||
for row in before:
|
||||
table.insert(row)
|
||||
cat = sqlcategory.SqlCategory('Foo')
|
||||
filter_fields = [['a', 'b', 'c'][i] for i in filter_cols]
|
||||
cat = sqlcategory.SqlCategory('Foo', filter_fields=filter_fields)
|
||||
cat.set_pattern(pattern, filter_cols)
|
||||
_validate(cat, after)
|
||||
|
||||
@ -137,7 +138,7 @@ def test_set_pattern(pattern, filter_cols, before, after):
|
||||
def test_select():
|
||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||
table.insert(['foo', 'bar', 'baz'])
|
||||
cat = sqlcategory.SqlCategory('Foo', select='b, c, a')
|
||||
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], select='b, c, a')
|
||||
_validate(cat, [('bar', 'baz', 'foo')])
|
||||
|
||||
|
||||
@ -145,7 +146,7 @@ def test_where():
|
||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||
table.insert(['foo', 'bar', False])
|
||||
table.insert(['baz', 'biz', True])
|
||||
cat = sqlcategory.SqlCategory('Foo', where='not c')
|
||||
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], where='not c')
|
||||
_validate(cat, [('foo', 'bar', False)])
|
||||
|
||||
|
||||
@ -155,7 +156,8 @@ def test_group():
|
||||
table.insert(['bar', 3])
|
||||
table.insert(['foo', 2])
|
||||
table.insert(['bar', 0])
|
||||
cat = sqlcategory.SqlCategory('Foo', select='a, max(b)', group_by='a')
|
||||
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'],
|
||||
select='a, max(b)', group_by='a')
|
||||
_validate(cat, [('bar', 3), ('foo', 2)])
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user