Use QSqlQueryModel instead of QSqlTableModel.

This allows setting the query as a QSqlQuery instead of a string, which allows:

- Escaping quotes
- Using LIMIT (needed for history-max-items)
- Using ORDER BY (needed for sorting history)
- SELECTing columns (needed for quickmark completion)
- Creating a custom select (needed for history timestamp formatting)
This commit is contained in:
Ryan Roden-Corrent 2017-02-14 08:46:31 -05:00
parent 52d7d1df0c
commit b70d5ba901
3 changed files with 62 additions and 29 deletions

View File

@ -22,9 +22,39 @@
import re import re
from PyQt5.QtCore import Qt, QModelIndex, QAbstractItemModel from PyQt5.QtCore import Qt, QModelIndex, QAbstractItemModel
from PyQt5.QtSql import QSqlTableModel, QSqlDatabase from PyQt5.QtSql import QSqlQuery, QSqlQueryModel, QSqlDatabase
from qutebrowser.utils import log from qutebrowser.utils import log
from qutebrowser.misc import sql
class SqlCompletionCategory(QSqlQueryModel):
def __init__(self, name, sort_by, sort_order, limit, columns_to_filter,
parent=None):
super().__init__(parent=parent)
self.tablename = name
query = sql.run_query('select * from {} limit 1'.format(name))
self._fields = [query.record().fieldName(i) for i in columns_to_filter]
querystr = 'select * from {} where '.format(self.tablename)
querystr += ' or '.join('{} like ?'.format(f) for f in self._fields)
querystr += " escape '\\'"
if sort_by:
sortstr = 'asc' if sort_order == Qt.AscendingOrder else 'desc'
querystr += ' order by {} {}'.format(sort_by, sortstr)
if limit:
querystr += ' limit {}'.format(limit)
self._querystr = querystr
self.set_pattern('%')
def set_pattern(self, pattern):
# TODO: kill star-args for run_query
query = sql.run_query(self._querystr, *[pattern for _ in self._fields])
self.setQuery(query)
class SqlCompletionModel(QAbstractItemModel): class SqlCompletionModel(QAbstractItemModel):
@ -58,7 +88,7 @@ class SqlCompletionModel(QAbstractItemModel):
self.srcmodel = self # TODO: dummy for compat with old API self.srcmodel = self # TODO: dummy for compat with old API
self.pattern = '' self.pattern = ''
def new_category(self, name, sort_by=None, sort_order=Qt.AscendingOrder): def new_category(self, name, sort_by=None, sort_order=None, limit=None):
"""Create a new completion category and add it to this model. """Create a new completion category and add it to this model.
Args: Args:
@ -68,14 +98,10 @@ class SqlCompletionModel(QAbstractItemModel):
Return: A new CompletionCategory. Return: A new CompletionCategory.
""" """
database = QSqlDatabase.database() cat = SqlCompletionCategory(name, parent=self, sort_by=sort_by,
cat = QSqlTableModel(parent=self, db=database) sort_order=sort_order, limit=limit,
cat.setTable(name) columns_to_filter=self.columns_to_filter)
if sort_by:
cat.setSort(cat.fieldIndex(sort_by), sort_order)
cat.select()
self._categories.append(cat) self._categories.append(cat)
return cat
def delete_cur_item(self, completion): def delete_cur_item(self, completion):
"""Delete the selected item.""" """Delete the selected item."""
@ -95,7 +121,7 @@ class SqlCompletionModel(QAbstractItemModel):
return return
if not index.parent().isValid(): if not index.parent().isValid():
if index.column() == 0: if index.column() == 0:
return self._categories[index.row()].tableName() return self._categories[index.row()].tablename
else: else:
table = self._categories[index.parent().row()] table = self._categories[index.parent().row()]
idx = table.index(index.row(), index.column()) idx = table.index(index.row(), index.column())
@ -177,6 +203,7 @@ class SqlCompletionModel(QAbstractItemModel):
Args: Args:
pattern: The filter pattern to set. pattern: The filter pattern to set.
""" """
log.completion.debug("Setting completion pattern '{}'".format(pattern))
# TODO: should pattern be saved in the view layer instead? # TODO: should pattern be saved in the view layer instead?
self.pattern = pattern self.pattern = pattern
# escape to treat a user input % or _ as a literal, not a wildcard # escape to treat a user input % or _ as a literal, not a wildcard
@ -184,14 +211,9 @@ class SqlCompletionModel(QAbstractItemModel):
pattern = pattern.replace('_', '\\_') pattern = pattern.replace('_', '\\_')
# treat spaces as wildcards to match any of the typed words # treat spaces as wildcards to match any of the typed words
pattern = re.sub(r' +', '%', pattern) pattern = re.sub(r' +', '%', pattern)
for t in self._categories: pattern = '%{}%'.format(pattern)
fields = (t.record().fieldName(i) for i in self.columns_to_filter) for cat in self._categories:
query = ' or '.join("{} like '%{}%' escape '\\'" cat.set_pattern(pattern)
.format(field, pattern)
for field in fields)
log.completion.debug("Setting filter = '{}' for table '{}'"
.format(query, t.tableName()))
t.setFilter(query)
def first_item(self): def first_item(self):
"""Return the index of the first child (non-category) in the model.""" """Return the index of the first child (non-category) in the model."""

View File

@ -39,7 +39,7 @@ def close():
QSqlDatabase.removeDatabase(QSqlDatabase.database().connectionName()) QSqlDatabase.removeDatabase(QSqlDatabase.database().connectionName())
def _run_query(querystr, *values): def run_query(querystr, *values):
"""Run the given SQL query string on the database. """Run the given SQL query string on the database.
Args: Args:
@ -85,12 +85,12 @@ class SqlTable(QObject):
super().__init__(parent) super().__init__(parent)
self._name = name self._name = name
self._primary_key = primary_key self._primary_key = primary_key
_run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))" run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))"
.format(name, ','.join(fields), primary_key)) .format(name, ','.join(fields), primary_key))
def __iter__(self): def __iter__(self):
"""Iterate rows in the table.""" """Iterate rows in the table."""
result = _run_query("SELECT * FROM {}".format(self._name)) result = run_query("SELECT * FROM {}".format(self._name))
while result.next(): while result.next():
rec = result.record() rec = result.record()
yield tuple(rec.value(i) for i in range(rec.count())) yield tuple(rec.value(i) for i in range(rec.count()))
@ -101,13 +101,13 @@ class SqlTable(QObject):
Args: Args:
key: Primary key value to search for. key: Primary key value to search for.
""" """
query = _run_query("SELECT * FROM {} where {} = ?" query = run_query("SELECT * FROM {} where {} = ?"
.format(self._name, self._primary_key), key) .format(self._name, self._primary_key), key)
return query.next() return query.next()
def __len__(self): def __len__(self):
"""Return the count of rows in the table.""" """Return the count of rows in the table."""
result = _run_query("SELECT count(*) FROM {}".format(self._name)) result = run_query("SELECT count(*) FROM {}".format(self._name))
result.next() result.next()
return result.value(0) return result.value(0)
@ -117,7 +117,7 @@ class SqlTable(QObject):
Args: Args:
key: Primary key value to fetch. key: Primary key value to fetch.
""" """
result = _run_query("SELECT * FROM {} where {} = ?" result = run_query("SELECT * FROM {} where {} = ?"
.format(self._name, self._primary_key), key) .format(self._name, self._primary_key), key)
result.next() result.next()
rec = result.record() rec = result.record()
@ -134,7 +134,7 @@ class SqlTable(QObject):
The number of rows deleted. The number of rows deleted.
""" """
field = field or self._primary_key field = field or self._primary_key
query = _run_query("DELETE FROM {} where {} = ?" query = run_query("DELETE FROM {} where {} = ?"
.format(self._name, field), value) .format(self._name, field), value)
if not query.numRowsAffected(): if not query.numRowsAffected():
raise KeyError('No row with {} = "{}"'.format(field, value)) raise KeyError('No row with {} = "{}"'.format(field, value))
@ -149,13 +149,13 @@ class SqlTable(QObject):
""" """
cmd = "REPLACE" if replace else "INSERT" cmd = "REPLACE" if replace else "INSERT"
paramstr = ','.join(['?'] * len(values)) paramstr = ','.join(['?'] * len(values))
_run_query("{} INTO {} values({})" run_query("{} INTO {} values({})"
.format(cmd, self._name, paramstr), *values) .format(cmd, self._name, paramstr), *values)
self.changed.emit() self.changed.emit()
def delete_all(self): def delete_all(self):
"""Remove all row from the table.""" """Remove all row from the table."""
_run_query("DELETE FROM {}".format(self._name)) run_query("DELETE FROM {}".format(self._name))
self.changed.emit() self.changed.emit()

View File

@ -155,16 +155,19 @@ def test_sorting(sort_by, sort_order, data, expected):
('_', [0], ('_', [0],
[('A', [('a_b', '', ''), ('__a', '', ''), ('abc', '', '')])], [('A', [('a_b', '', ''), ('__a', '', ''), ('abc', '', '')])],
[('A', [('a_b', '', ''), ('__a', '', '')])]), [('A', [('a_b', '', ''), ('__a', '', '')])]),
("can't", [0],
[('A', [("can't touch this", '', ''), ('a', '', '')])],
[('A', [("can't touch this", '', '')])]),
]) ])
def test_set_pattern(pattern, filter_cols, before, after): def test_set_pattern(pattern, filter_cols, before, after):
"""Validate the filtering and sorting results of set_pattern.""" """Validate the filtering and sorting results of set_pattern."""
model = sqlmodel.SqlCompletionModel() model = sqlmodel.SqlCompletionModel(columns_to_filter=filter_cols)
for name, rows in before: for name, rows in before:
table = sql.SqlTable(name, ['a', 'b', 'c'], primary_key='a') table = sql.SqlTable(name, ['a', 'b', 'c'], primary_key='a')
for row in rows: for row in rows:
table.insert(*row) table.insert(*row)
model.new_category(name) model.new_category(name)
model.columns_to_filter = filter_cols
model.set_pattern(pattern) model.set_pattern(pattern)
_check_model(model, after) _check_model(model, after)
@ -198,3 +201,11 @@ def test_first_last_item(data, first, last):
model.new_category(name) model.new_category(name)
assert model.data(model.first_item()) == first assert model.data(model.first_item()) == first
assert model.data(model.last_item()) == last assert model.data(model.last_item()) == last
def test_limit():
table = sql.SqlTable('test_limit', ['a'], primary_key='a')
for i in range(5):
table.insert([i])
model = sqlmodel.SqlCompletionModel()
model.new_category('test_limit', limit=3)
assert model.count() == 3