Use prepared SQL queries.
This commit is contained in:
parent
20000088de
commit
e67da51662
@ -78,6 +78,19 @@ class WebHistory(sql.SqlTable):
|
|||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
super().__init__("History", ['url', 'title', 'atime', 'redirect'],
|
super().__init__("History", ['url', 'title', 'atime', 'redirect'],
|
||||||
parent=parent)
|
parent=parent)
|
||||||
|
self._between_query = sql.Query('SELECT * FROM History '
|
||||||
|
'where not redirect '
|
||||||
|
'and not url like "qute://%" '
|
||||||
|
'and atime > ? '
|
||||||
|
'and atime <= ? '
|
||||||
|
'ORDER BY atime desc')
|
||||||
|
|
||||||
|
self._before_query = sql.Query('SELECT * FROM History '
|
||||||
|
'where not redirect '
|
||||||
|
'and not url like "qute://%" '
|
||||||
|
'and atime <= ? '
|
||||||
|
'ORDER BY atime desc '
|
||||||
|
'limit ? offset ?')
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return utils.get_repr(self, length=len(self))
|
return utils.get_repr(self, length=len(self))
|
||||||
@ -101,16 +114,8 @@ class WebHistory(sql.SqlTable):
|
|||||||
earliest: Omit timestamps earlier than this.
|
earliest: Omit timestamps earlier than this.
|
||||||
latest: Omit timestamps later than this.
|
latest: Omit timestamps later than this.
|
||||||
"""
|
"""
|
||||||
result = sql.run_query('SELECT * FROM History '
|
self._between_query.run([earliest, latest])
|
||||||
'where not redirect '
|
return iter(self._between_query)
|
||||||
'and not url like "qute://%" '
|
|
||||||
'and atime > {} '
|
|
||||||
'and atime <= {} '
|
|
||||||
'ORDER BY atime desc'
|
|
||||||
.format(earliest, latest))
|
|
||||||
while result.next():
|
|
||||||
rec = result.record()
|
|
||||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
|
||||||
|
|
||||||
def entries_before(self, latest, limit, offset):
|
def entries_before(self, latest, limit, offset):
|
||||||
"""Iterate non-redirect, non-qute entries occurring before a timestamp.
|
"""Iterate non-redirect, non-qute entries occurring before a timestamp.
|
||||||
@ -120,16 +125,8 @@ class WebHistory(sql.SqlTable):
|
|||||||
limit: Max number of entries to include.
|
limit: Max number of entries to include.
|
||||||
offset: Number of entries to skip.
|
offset: Number of entries to skip.
|
||||||
"""
|
"""
|
||||||
result = sql.run_query('SELECT * FROM History '
|
self._before_query.run([latest, limit, offset])
|
||||||
'where not redirect '
|
return iter(self._before_query)
|
||||||
'and not url like "qute://%" '
|
|
||||||
'and atime <= {} '
|
|
||||||
'ORDER BY atime desc '
|
|
||||||
'limit {} offset {}'
|
|
||||||
.format(latest, limit, offset))
|
|
||||||
while result.next():
|
|
||||||
rec = result.record()
|
|
||||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
|
||||||
|
|
||||||
@cmdutils.register(name='history-clear', instance='web-history')
|
@cmdutils.register(name='history-clear', instance='web-history')
|
||||||
def clear(self, force=False):
|
def clear(self, force=False):
|
||||||
|
@ -29,12 +29,13 @@ from qutebrowser.misc import sql
|
|||||||
class SqlCategory(QSqlQueryModel):
|
class SqlCategory(QSqlQueryModel):
|
||||||
"""Wraps a SqlQuery for use as a completion category."""
|
"""Wraps a SqlQuery for use as a completion category."""
|
||||||
|
|
||||||
def __init__(self, name, *, sort_by=None, sort_order=None, select='*',
|
def __init__(self, name, *, filter_fields, sort_by=None, sort_order=None,
|
||||||
where=None, group_by=None, parent=None):
|
select='*', where=None, group_by=None, parent=None):
|
||||||
"""Create a new completion category backed by a sql table.
|
"""Create a new completion category backed by a sql table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of category, and the table in the database.
|
name: Name of category, and the table in the database.
|
||||||
|
filter_fields: Names of fields to apply filter pattern to.
|
||||||
select: A custom result column expression for the select statement.
|
select: A custom result column expression for the select statement.
|
||||||
where: An optional clause to filter out some rows.
|
where: An optional clause to filter out some rows.
|
||||||
sort_by: The name of the field to sort by, or None for no sorting.
|
sort_by: The name of the field to sort by, or None for no sorting.
|
||||||
@ -42,11 +43,24 @@ class SqlCategory(QSqlQueryModel):
|
|||||||
"""
|
"""
|
||||||
super().__init__(parent=parent)
|
super().__init__(parent=parent)
|
||||||
self.name = name
|
self.name = name
|
||||||
self._sort_by = sort_by
|
|
||||||
self._sort_order = sort_order
|
querystr = 'select {} from {} where ('.format(select, name)
|
||||||
self._select = select
|
# the incoming pattern will have literal % and _ escaped with '\'
|
||||||
self._where = where
|
# we need to tell sql to treat '\' as an escape character
|
||||||
self._group_by = group_by
|
querystr += ' or '.join("{} like ? escape '\\'".format(f)
|
||||||
|
for f in filter_fields)
|
||||||
|
querystr += ')'
|
||||||
|
|
||||||
|
if where:
|
||||||
|
querystr += ' and ' + where
|
||||||
|
if group_by:
|
||||||
|
querystr += ' group by {}'.format(group_by)
|
||||||
|
if sort_by:
|
||||||
|
assert sort_order in ['asc', 'desc']
|
||||||
|
querystr += ' order by {} {}'.format(sort_by, sort_order)
|
||||||
|
|
||||||
|
self._query = sql.Query(querystr)
|
||||||
|
self._param_count=len(filter_fields)
|
||||||
self.set_pattern('', [0])
|
self.set_pattern('', [0])
|
||||||
|
|
||||||
def set_pattern(self, pattern, columns_to_filter):
|
def set_pattern(self, pattern, columns_to_filter):
|
||||||
@ -56,31 +70,13 @@ class SqlCategory(QSqlQueryModel):
|
|||||||
pattern: string pattern to filter by.
|
pattern: string pattern to filter by.
|
||||||
columns_to_filter: indices of columns to apply pattern to.
|
columns_to_filter: indices of columns to apply pattern to.
|
||||||
"""
|
"""
|
||||||
query = sql.run_query('select * from {} limit 1'.format(self.name))
|
# TODO: eliminate columns_to_filter
|
||||||
fields = [query.record().fieldName(i) for i in columns_to_filter]
|
#assert len(columns_to_filter) == self._param_count
|
||||||
|
|
||||||
querystr = 'select {} from {} where ('.format(self._select, self.name)
|
|
||||||
# the incoming pattern will have literal % and _ escaped with '\'
|
|
||||||
# we need to tell sql to treat '\' as an escape character
|
|
||||||
querystr += ' or '.join("{} like ? escape '\\'".format(f)
|
|
||||||
for f in fields)
|
|
||||||
querystr += ')'
|
|
||||||
if self._where:
|
|
||||||
querystr += ' and ' + self._where
|
|
||||||
|
|
||||||
if self._group_by:
|
|
||||||
querystr += ' group by {}'.format(self._group_by)
|
|
||||||
|
|
||||||
if self._sort_by:
|
|
||||||
assert self._sort_order in ['asc', 'desc']
|
|
||||||
querystr += ' order by {} {}'.format(self._sort_by,
|
|
||||||
self._sort_order)
|
|
||||||
|
|
||||||
# escape to treat a user input % or _ as a literal, not a wildcard
|
# escape to treat a user input % or _ as a literal, not a wildcard
|
||||||
pattern = pattern.replace('%', '\\%')
|
pattern = pattern.replace('%', '\\%')
|
||||||
pattern = pattern.replace('_', '\\_')
|
pattern = pattern.replace('_', '\\_')
|
||||||
# treat spaces as wildcards to match any of the typed words
|
# treat spaces as wildcards to match any of the typed words
|
||||||
pattern = re.sub(r' +', '%', pattern)
|
pattern = re.sub(r' +', '%', pattern)
|
||||||
pattern = '%{}%'.format(pattern)
|
pattern = '%{}%'.format(pattern)
|
||||||
query = sql.run_query(querystr, [pattern for _ in fields])
|
self._query.run([pattern] * self._param_count)
|
||||||
self.setQuery(query)
|
self.setQuery(self._query)
|
||||||
|
@ -77,6 +77,7 @@ def url():
|
|||||||
select_time = "strftime('{}', max(atime), 'unixepoch')".format(timefmt)
|
select_time = "strftime('{}', max(atime), 'unixepoch')".format(timefmt)
|
||||||
hist_cat = sqlcategory.SqlCategory(
|
hist_cat = sqlcategory.SqlCategory(
|
||||||
'History', sort_order='desc', sort_by='atime', group_by='url',
|
'History', sort_order='desc', sort_by='atime', group_by='url',
|
||||||
|
filter_fields=['url', 'title'],
|
||||||
select='url, title, {}'.format(select_time), where='not redirect')
|
select='url, title, {}'.format(select_time), where='not redirect')
|
||||||
model.add_category(hist_cat)
|
model.add_category(hist_cat)
|
||||||
return model
|
return model
|
||||||
|
@ -50,56 +50,45 @@ def close():
|
|||||||
|
|
||||||
def version():
|
def version():
|
||||||
"""Return the sqlite version string."""
|
"""Return the sqlite version string."""
|
||||||
result = run_query("select sqlite_version()")
|
q = Query("select sqlite_version()")
|
||||||
result.next()
|
q.run()
|
||||||
return result.record().value(0)
|
return q.value()
|
||||||
|
|
||||||
|
|
||||||
def _prepare_query(querystr):
|
class Query(QSqlQuery):
|
||||||
log.sql.debug('Preparing SQL query: "{}"'.format(querystr))
|
|
||||||
database = QSqlDatabase.database()
|
|
||||||
query = QSqlQuery(database)
|
|
||||||
query.prepare(querystr)
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
"""A prepared SQL Query."""
|
||||||
|
|
||||||
def run_query(querystr, values=None):
|
def __init__(self, querystr):
|
||||||
"""Run the given SQL query string on the database.
|
super().__init__(QSqlDatabase.database())
|
||||||
|
log.sql.debug('Preparing SQL query: "{}"'.format(querystr))
|
||||||
|
self.prepare(querystr)
|
||||||
|
|
||||||
Args:
|
def __iter__(self):
|
||||||
values: A list of positional parameter bindings.
|
assert self.isActive(), "Cannot iterate inactive query"
|
||||||
"""
|
rec = self.record()
|
||||||
query = _prepare_query(querystr)
|
fields = [rec.fieldName(i) for i in range(rec.count())]
|
||||||
for val in values or []:
|
rowtype = collections.namedtuple('ResultRow', fields)
|
||||||
query.addBindValue(val)
|
|
||||||
log.sql.debug('Query bindings: {}'.format(query.boundValues()))
|
|
||||||
if not query.exec_():
|
|
||||||
raise SqlException('Failed to exec query "{}": "{}"'.format(
|
|
||||||
querystr, query.lastError().text()))
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
while self.next():
|
||||||
|
rec = self.record()
|
||||||
|
yield rowtype(*[rec.value(i) for i in range(rec.count())])
|
||||||
|
|
||||||
def run_batch(querystr, values):
|
def run(self, values=None):
|
||||||
"""Run the given SQL query string on the database in batch mode.
|
"""Execute the prepared query."""
|
||||||
|
log.sql.debug('Running SQL query: "{}"'.format(self.lastQuery()))
|
||||||
|
for val in values or []:
|
||||||
|
self.addBindValue(val)
|
||||||
|
log.sql.debug('self bindings: {}'.format(self.boundValues()))
|
||||||
|
if not self.exec_():
|
||||||
|
raise SqlException('Failed to exec self "{}": "{}"'.format(
|
||||||
|
self.lastself(), self.lastError().text()))
|
||||||
|
|
||||||
Args:
|
def value(self):
|
||||||
values: A list of lists, where each inner list contains positional
|
"""Return the result of a single-value query (e.g. an EXISTS)."""
|
||||||
bindings for one run of the batch.
|
ok = self.next()
|
||||||
"""
|
assert ok, "No result for single-result query"
|
||||||
query = _prepare_query(querystr)
|
return self.record().value(0)
|
||||||
transposed = [list(row) for row in zip(*values)]
|
|
||||||
for val in transposed:
|
|
||||||
query.addBindValue(val)
|
|
||||||
log.sql.debug('Batch Query bindings: {}'.format(query.boundValues()))
|
|
||||||
|
|
||||||
db = QSqlDatabase.database()
|
|
||||||
db.transaction()
|
|
||||||
if not query.execBatch():
|
|
||||||
raise SqlException('Failed to exec query "{}": "{}"'.format(
|
|
||||||
querystr, query.lastError().text()))
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
class SqlTable(QObject):
|
class SqlTable(QObject):
|
||||||
@ -127,17 +116,17 @@ class SqlTable(QObject):
|
|||||||
"""
|
"""
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self._name = name
|
self._name = name
|
||||||
run_query("CREATE TABLE IF NOT EXISTS {} ({})"
|
q = Query("CREATE TABLE IF NOT EXISTS {} ({})"
|
||||||
.format(name, ','.join(fields)))
|
.format(name, ','.join(fields)))
|
||||||
|
q.run()
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
self.Entry = collections.namedtuple(name + '_Entry', fields)
|
self.Entry = collections.namedtuple(name + '_Entry', fields)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Iterate rows in the table."""
|
"""Iterate rows in the table."""
|
||||||
result = run_query("SELECT * FROM {}".format(self._name))
|
q = Query("SELECT * FROM {}".format(self._name))
|
||||||
while result.next():
|
q.run()
|
||||||
rec = result.record()
|
return iter(q)
|
||||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
|
||||||
|
|
||||||
def contains(self, field, value):
|
def contains(self, field, value):
|
||||||
"""Return whether the table contains the matching item.
|
"""Return whether the table contains the matching item.
|
||||||
@ -146,16 +135,16 @@ class SqlTable(QObject):
|
|||||||
field: Field to match.
|
field: Field to match.
|
||||||
value: Value to check for the given field.
|
value: Value to check for the given field.
|
||||||
"""
|
"""
|
||||||
query = run_query("Select EXISTS(SELECT * FROM {} where {} = ?)"
|
q = Query("Select EXISTS(SELECT * FROM {} where {} = ?)"
|
||||||
.format(self._name, field), [value])
|
.format(self._name, field))
|
||||||
query.next()
|
q.run([value])
|
||||||
return query.value(0)
|
return q.value()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Return the count of rows in the table."""
|
"""Return the count of rows in the table."""
|
||||||
result = run_query("SELECT count(*) FROM {}".format(self._name))
|
q = Query("SELECT count(*) FROM {}".format(self._name))
|
||||||
result.next()
|
q.run()
|
||||||
return result.value(0)
|
return q.value()
|
||||||
|
|
||||||
def delete(self, value, field):
|
def delete(self, value, field):
|
||||||
"""Remove all rows for which `field` equals `value`.
|
"""Remove all rows for which `field` equals `value`.
|
||||||
@ -167,9 +156,9 @@ class SqlTable(QObject):
|
|||||||
Return:
|
Return:
|
||||||
The number of rows deleted.
|
The number of rows deleted.
|
||||||
"""
|
"""
|
||||||
query = run_query("DELETE FROM {} where {} = ?".format(
|
q = Query("DELETE FROM {} where {} = ?".format(self._name, field))
|
||||||
self._name, field), [value])
|
q.run([value])
|
||||||
if not query.numRowsAffected():
|
if not q.numRowsAffected():
|
||||||
raise KeyError('No row with {} = "{}"'.format(field, value))
|
raise KeyError('No row with {} = "{}"'.format(field, value))
|
||||||
self.changed.emit()
|
self.changed.emit()
|
||||||
|
|
||||||
@ -180,8 +169,8 @@ class SqlTable(QObject):
|
|||||||
values: A list of values to insert.
|
values: A list of values to insert.
|
||||||
"""
|
"""
|
||||||
paramstr = ','.join(['?'] * len(values))
|
paramstr = ','.join(['?'] * len(values))
|
||||||
run_query("INSERT INTO {} values({})".format(self._name, paramstr),
|
q = Query("INSERT INTO {} values({})".format(self._name, paramstr))
|
||||||
values)
|
q.run(values)
|
||||||
self.changed.emit()
|
self.changed.emit()
|
||||||
|
|
||||||
def insert_batch(self, rows):
|
def insert_batch(self, rows):
|
||||||
@ -191,13 +180,23 @@ class SqlTable(QObject):
|
|||||||
rows: A list of lists, where each sub-list is a row.
|
rows: A list of lists, where each sub-list is a row.
|
||||||
"""
|
"""
|
||||||
paramstr = ','.join(['?'] * len(rows[0]))
|
paramstr = ','.join(['?'] * len(rows[0]))
|
||||||
run_batch("INSERT INTO {} values({})".format(self._name, paramstr),
|
q = Query("INSERT INTO {} values({})".format(self._name, paramstr))
|
||||||
rows)
|
|
||||||
|
transposed = [list(row) for row in zip(*rows)]
|
||||||
|
for val in transposed:
|
||||||
|
q.addBindValue(val)
|
||||||
|
|
||||||
|
db = QSqlDatabase.database()
|
||||||
|
db.transaction()
|
||||||
|
if not q.execBatch():
|
||||||
|
raise SqlException('Failed to exec query "{}": "{}"'.format(
|
||||||
|
q.lastQuery(), q.lastError().text()))
|
||||||
|
db.commit()
|
||||||
self.changed.emit()
|
self.changed.emit()
|
||||||
|
|
||||||
def delete_all(self):
|
def delete_all(self):
|
||||||
"""Remove all row from the table."""
|
"""Remove all row from the table."""
|
||||||
run_query("DELETE FROM {}".format(self._name))
|
Query("DELETE FROM {}".format(self._name)).run()
|
||||||
self.changed.emit()
|
self.changed.emit()
|
||||||
|
|
||||||
def select(self, sort_by, sort_order, limit=-1):
|
def select(self, sort_by, sort_order, limit=-1):
|
||||||
@ -208,8 +207,7 @@ class SqlTable(QObject):
|
|||||||
sort_order: 'asc' or 'desc'.
|
sort_order: 'asc' or 'desc'.
|
||||||
limit: max number of rows in result, defaults to -1 (unlimited).
|
limit: max number of rows in result, defaults to -1 (unlimited).
|
||||||
"""
|
"""
|
||||||
result = run_query('SELECT * FROM {} ORDER BY {} {} LIMIT {}'
|
q = Query('SELECT * FROM {} ORDER BY {} {} LIMIT ?'
|
||||||
.format(self._name, sort_by, sort_order, limit))
|
.format(self._name, sort_by, sort_order))
|
||||||
while result.next():
|
q.run([limit])
|
||||||
rec = result.record()
|
return q
|
||||||
yield self.Entry(*[rec.value(i) for i in range(rec.count())])
|
|
||||||
|
@ -74,7 +74,7 @@ def test_sorting(sort_by, sort_order, data, expected):
|
|||||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||||
for row in data:
|
for row in data:
|
||||||
table.insert(row)
|
table.insert(row)
|
||||||
cat = sqlcategory.SqlCategory('Foo', sort_by=sort_by,
|
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], sort_by=sort_by,
|
||||||
sort_order=sort_order)
|
sort_order=sort_order)
|
||||||
_validate(cat, expected)
|
_validate(cat, expected)
|
||||||
|
|
||||||
@ -129,7 +129,8 @@ def test_set_pattern(pattern, filter_cols, before, after):
|
|||||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||||
for row in before:
|
for row in before:
|
||||||
table.insert(row)
|
table.insert(row)
|
||||||
cat = sqlcategory.SqlCategory('Foo')
|
filter_fields = [['a', 'b', 'c'][i] for i in filter_cols]
|
||||||
|
cat = sqlcategory.SqlCategory('Foo', filter_fields=filter_fields)
|
||||||
cat.set_pattern(pattern, filter_cols)
|
cat.set_pattern(pattern, filter_cols)
|
||||||
_validate(cat, after)
|
_validate(cat, after)
|
||||||
|
|
||||||
@ -137,7 +138,7 @@ def test_set_pattern(pattern, filter_cols, before, after):
|
|||||||
def test_select():
|
def test_select():
|
||||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||||
table.insert(['foo', 'bar', 'baz'])
|
table.insert(['foo', 'bar', 'baz'])
|
||||||
cat = sqlcategory.SqlCategory('Foo', select='b, c, a')
|
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], select='b, c, a')
|
||||||
_validate(cat, [('bar', 'baz', 'foo')])
|
_validate(cat, [('bar', 'baz', 'foo')])
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ def test_where():
|
|||||||
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
table = sql.SqlTable('Foo', ['a', 'b', 'c'])
|
||||||
table.insert(['foo', 'bar', False])
|
table.insert(['foo', 'bar', False])
|
||||||
table.insert(['baz', 'biz', True])
|
table.insert(['baz', 'biz', True])
|
||||||
cat = sqlcategory.SqlCategory('Foo', where='not c')
|
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'], where='not c')
|
||||||
_validate(cat, [('foo', 'bar', False)])
|
_validate(cat, [('foo', 'bar', False)])
|
||||||
|
|
||||||
|
|
||||||
@ -155,7 +156,8 @@ def test_group():
|
|||||||
table.insert(['bar', 3])
|
table.insert(['bar', 3])
|
||||||
table.insert(['foo', 2])
|
table.insert(['foo', 2])
|
||||||
table.insert(['bar', 0])
|
table.insert(['bar', 0])
|
||||||
cat = sqlcategory.SqlCategory('Foo', select='a, max(b)', group_by='a')
|
cat = sqlcategory.SqlCategory('Foo', filter_fields=['a'],
|
||||||
|
select='a, max(b)', group_by='a')
|
||||||
_validate(cat, [('bar', 3), ('foo', 2)])
|
_validate(cat, [('bar', 3), ('foo', 2)])
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user