Use QSqlQueryModel instead of QSqlTableModel.
This allows setting the query as a QSqlQuery instead of a string, which allows: - Escaping quotes - Using LIMIT (needed for history-max-items) - Using ORDER BY (needed for sorting history) - SELECTing columns (needed for quickmark completion) - Creating a custom select (needed for history timestamp formatting)
This commit is contained in:
parent
52d7d1df0c
commit
b70d5ba901
@ -22,9 +22,39 @@
|
||||
import re
|
||||
|
||||
from PyQt5.QtCore import Qt, QModelIndex, QAbstractItemModel
|
||||
from PyQt5.QtSql import QSqlTableModel, QSqlDatabase
|
||||
from PyQt5.QtSql import QSqlQuery, QSqlQueryModel, QSqlDatabase
|
||||
|
||||
from qutebrowser.utils import log
|
||||
from qutebrowser.misc import sql
|
||||
|
||||
|
||||
class SqlCompletionCategory(QSqlQueryModel):
|
||||
def __init__(self, name, sort_by, sort_order, limit, columns_to_filter,
|
||||
parent=None):
|
||||
super().__init__(parent=parent)
|
||||
self.tablename = name
|
||||
|
||||
query = sql.run_query('select * from {} limit 1'.format(name))
|
||||
self._fields = [query.record().fieldName(i) for i in columns_to_filter]
|
||||
|
||||
querystr = 'select * from {} where '.format(self.tablename)
|
||||
querystr += ' or '.join('{} like ?'.format(f) for f in self._fields)
|
||||
querystr += " escape '\\'"
|
||||
|
||||
if sort_by:
|
||||
sortstr = 'asc' if sort_order == Qt.AscendingOrder else 'desc'
|
||||
querystr += ' order by {} {}'.format(sort_by, sortstr)
|
||||
|
||||
if limit:
|
||||
querystr += ' limit {}'.format(limit)
|
||||
|
||||
self._querystr = querystr
|
||||
self.set_pattern('%')
|
||||
|
||||
def set_pattern(self, pattern):
|
||||
# TODO: kill star-args for run_query
|
||||
query = sql.run_query(self._querystr, *[pattern for _ in self._fields])
|
||||
self.setQuery(query)
|
||||
|
||||
|
||||
class SqlCompletionModel(QAbstractItemModel):
|
||||
@ -58,7 +88,7 @@ class SqlCompletionModel(QAbstractItemModel):
|
||||
self.srcmodel = self # TODO: dummy for compat with old API
|
||||
self.pattern = ''
|
||||
|
||||
def new_category(self, name, sort_by=None, sort_order=Qt.AscendingOrder):
|
||||
def new_category(self, name, sort_by=None, sort_order=None, limit=None):
|
||||
"""Create a new completion category and add it to this model.
|
||||
|
||||
Args:
|
||||
@ -68,14 +98,10 @@ class SqlCompletionModel(QAbstractItemModel):
|
||||
|
||||
Return: A new CompletionCategory.
|
||||
"""
|
||||
database = QSqlDatabase.database()
|
||||
cat = QSqlTableModel(parent=self, db=database)
|
||||
cat.setTable(name)
|
||||
if sort_by:
|
||||
cat.setSort(cat.fieldIndex(sort_by), sort_order)
|
||||
cat.select()
|
||||
cat = SqlCompletionCategory(name, parent=self, sort_by=sort_by,
|
||||
sort_order=sort_order, limit=limit,
|
||||
columns_to_filter=self.columns_to_filter)
|
||||
self._categories.append(cat)
|
||||
return cat
|
||||
|
||||
def delete_cur_item(self, completion):
|
||||
"""Delete the selected item."""
|
||||
@ -95,7 +121,7 @@ class SqlCompletionModel(QAbstractItemModel):
|
||||
return
|
||||
if not index.parent().isValid():
|
||||
if index.column() == 0:
|
||||
return self._categories[index.row()].tableName()
|
||||
return self._categories[index.row()].tablename
|
||||
else:
|
||||
table = self._categories[index.parent().row()]
|
||||
idx = table.index(index.row(), index.column())
|
||||
@ -177,6 +203,7 @@ class SqlCompletionModel(QAbstractItemModel):
|
||||
Args:
|
||||
pattern: The filter pattern to set.
|
||||
"""
|
||||
log.completion.debug("Setting completion pattern '{}'".format(pattern))
|
||||
# TODO: should pattern be saved in the view layer instead?
|
||||
self.pattern = pattern
|
||||
# escape to treat a user input % or _ as a literal, not a wildcard
|
||||
@ -184,14 +211,9 @@ class SqlCompletionModel(QAbstractItemModel):
|
||||
pattern = pattern.replace('_', '\\_')
|
||||
# treat spaces as wildcards to match any of the typed words
|
||||
pattern = re.sub(r' +', '%', pattern)
|
||||
for t in self._categories:
|
||||
fields = (t.record().fieldName(i) for i in self.columns_to_filter)
|
||||
query = ' or '.join("{} like '%{}%' escape '\\'"
|
||||
.format(field, pattern)
|
||||
for field in fields)
|
||||
log.completion.debug("Setting filter = '{}' for table '{}'"
|
||||
.format(query, t.tableName()))
|
||||
t.setFilter(query)
|
||||
pattern = '%{}%'.format(pattern)
|
||||
for cat in self._categories:
|
||||
cat.set_pattern(pattern)
|
||||
|
||||
def first_item(self):
|
||||
"""Return the index of the first child (non-category) in the model."""
|
||||
|
@ -39,7 +39,7 @@ def close():
|
||||
QSqlDatabase.removeDatabase(QSqlDatabase.database().connectionName())
|
||||
|
||||
|
||||
def _run_query(querystr, *values):
|
||||
def run_query(querystr, *values):
|
||||
"""Run the given SQL query string on the database.
|
||||
|
||||
Args:
|
||||
@ -85,12 +85,12 @@ class SqlTable(QObject):
|
||||
super().__init__(parent)
|
||||
self._name = name
|
||||
self._primary_key = primary_key
|
||||
_run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))"
|
||||
run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))"
|
||||
.format(name, ','.join(fields), primary_key))
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate rows in the table."""
|
||||
result = _run_query("SELECT * FROM {}".format(self._name))
|
||||
result = run_query("SELECT * FROM {}".format(self._name))
|
||||
while result.next():
|
||||
rec = result.record()
|
||||
yield tuple(rec.value(i) for i in range(rec.count()))
|
||||
@ -101,13 +101,13 @@ class SqlTable(QObject):
|
||||
Args:
|
||||
key: Primary key value to search for.
|
||||
"""
|
||||
query = _run_query("SELECT * FROM {} where {} = ?"
|
||||
query = run_query("SELECT * FROM {} where {} = ?"
|
||||
.format(self._name, self._primary_key), key)
|
||||
return query.next()
|
||||
|
||||
def __len__(self):
|
||||
"""Return the count of rows in the table."""
|
||||
result = _run_query("SELECT count(*) FROM {}".format(self._name))
|
||||
result = run_query("SELECT count(*) FROM {}".format(self._name))
|
||||
result.next()
|
||||
return result.value(0)
|
||||
|
||||
@ -117,7 +117,7 @@ class SqlTable(QObject):
|
||||
Args:
|
||||
key: Primary key value to fetch.
|
||||
"""
|
||||
result = _run_query("SELECT * FROM {} where {} = ?"
|
||||
result = run_query("SELECT * FROM {} where {} = ?"
|
||||
.format(self._name, self._primary_key), key)
|
||||
result.next()
|
||||
rec = result.record()
|
||||
@ -134,7 +134,7 @@ class SqlTable(QObject):
|
||||
The number of rows deleted.
|
||||
"""
|
||||
field = field or self._primary_key
|
||||
query = _run_query("DELETE FROM {} where {} = ?"
|
||||
query = run_query("DELETE FROM {} where {} = ?"
|
||||
.format(self._name, field), value)
|
||||
if not query.numRowsAffected():
|
||||
raise KeyError('No row with {} = "{}"'.format(field, value))
|
||||
@ -149,13 +149,13 @@ class SqlTable(QObject):
|
||||
"""
|
||||
cmd = "REPLACE" if replace else "INSERT"
|
||||
paramstr = ','.join(['?'] * len(values))
|
||||
_run_query("{} INTO {} values({})"
|
||||
run_query("{} INTO {} values({})"
|
||||
.format(cmd, self._name, paramstr), *values)
|
||||
self.changed.emit()
|
||||
|
||||
def delete_all(self):
|
||||
"""Remove all row from the table."""
|
||||
_run_query("DELETE FROM {}".format(self._name))
|
||||
run_query("DELETE FROM {}".format(self._name))
|
||||
self.changed.emit()
|
||||
|
||||
|
||||
|
@ -155,16 +155,19 @@ def test_sorting(sort_by, sort_order, data, expected):
|
||||
('_', [0],
|
||||
[('A', [('a_b', '', ''), ('__a', '', ''), ('abc', '', '')])],
|
||||
[('A', [('a_b', '', ''), ('__a', '', '')])]),
|
||||
|
||||
("can't", [0],
|
||||
[('A', [("can't touch this", '', ''), ('a', '', '')])],
|
||||
[('A', [("can't touch this", '', '')])]),
|
||||
])
|
||||
def test_set_pattern(pattern, filter_cols, before, after):
|
||||
"""Validate the filtering and sorting results of set_pattern."""
|
||||
model = sqlmodel.SqlCompletionModel()
|
||||
model = sqlmodel.SqlCompletionModel(columns_to_filter=filter_cols)
|
||||
for name, rows in before:
|
||||
table = sql.SqlTable(name, ['a', 'b', 'c'], primary_key='a')
|
||||
for row in rows:
|
||||
table.insert(*row)
|
||||
model.new_category(name)
|
||||
model.columns_to_filter = filter_cols
|
||||
model.set_pattern(pattern)
|
||||
_check_model(model, after)
|
||||
|
||||
@ -198,3 +201,11 @@ def test_first_last_item(data, first, last):
|
||||
model.new_category(name)
|
||||
assert model.data(model.first_item()) == first
|
||||
assert model.data(model.last_item()) == last
|
||||
|
||||
def test_limit():
|
||||
table = sql.SqlTable('test_limit', ['a'], primary_key='a')
|
||||
for i in range(5):
|
||||
table.insert([i])
|
||||
model = sqlmodel.SqlCompletionModel()
|
||||
model.new_category('test_limit', limit=3)
|
||||
assert model.count() == 3
|
||||
|
Loading…
Reference in New Issue
Block a user