From b70d5ba901de09aee744993913a58e910d884936 Mon Sep 17 00:00:00 2001 From: Ryan Roden-Corrent Date: Tue, 14 Feb 2017 08:46:31 -0500 Subject: [PATCH] Use QSqlQueryModel instead of QSqlTableModel. This allows setting the query as a QSqlQuery instead of a string, which allows: - Escaping quotes - Using LIMIT (needed for history-max-items) - Using ORDER BY (needed for sorting history) - SELECTing columns (needed for quickmark completion) - Creating a custom select (needed for history timestamp formatting) --- qutebrowser/completion/models/sqlmodel.py | 58 ++++++++++++++++------- qutebrowser/misc/sql.py | 18 +++---- tests/unit/completion/test_sqlmodel.py | 15 +++++- 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/qutebrowser/completion/models/sqlmodel.py b/qutebrowser/completion/models/sqlmodel.py index 2f0a5059b..40b2e27ca 100644 --- a/qutebrowser/completion/models/sqlmodel.py +++ b/qutebrowser/completion/models/sqlmodel.py @@ -22,9 +22,39 @@ import re from PyQt5.QtCore import Qt, QModelIndex, QAbstractItemModel -from PyQt5.QtSql import QSqlTableModel, QSqlDatabase +from PyQt5.QtSql import QSqlQuery, QSqlQueryModel, QSqlDatabase from qutebrowser.utils import log +from qutebrowser.misc import sql + + +class SqlCompletionCategory(QSqlQueryModel): + def __init__(self, name, sort_by, sort_order, limit, columns_to_filter, + parent=None): + super().__init__(parent=parent) + self.tablename = name + + query = sql.run_query('select * from {} limit 1'.format(name)) + self._fields = [query.record().fieldName(i) for i in columns_to_filter] + + querystr = 'select * from {} where '.format(self.tablename) + querystr += ' or '.join('{} like ?'.format(f) for f in self._fields) + querystr += " escape '\\'" + + if sort_by: + sortstr = 'asc' if sort_order == Qt.AscendingOrder else 'desc' + querystr += ' order by {} {}'.format(sort_by, sortstr) + + if limit: + querystr += ' limit {}'.format(limit) + + self._querystr = querystr + self.set_pattern('%') + + def set_pattern(self, pattern): + # TODO: kill star-args for run_query + query = sql.run_query(self._querystr, *[pattern for _ in self._fields]) + self.setQuery(query) class SqlCompletionModel(QAbstractItemModel): @@ -58,7 +88,7 @@ class SqlCompletionModel(QAbstractItemModel): self.srcmodel = self # TODO: dummy for compat with old API self.pattern = '' - def new_category(self, name, sort_by=None, sort_order=Qt.AscendingOrder): + def new_category(self, name, sort_by=None, sort_order=None, limit=None): """Create a new completion category and add it to this model. Args: @@ -68,14 +98,10 @@ class SqlCompletionModel(QAbstractItemModel): Return: A new CompletionCategory. """ - database = QSqlDatabase.database() - cat = QSqlTableModel(parent=self, db=database) - cat.setTable(name) - if sort_by: - cat.setSort(cat.fieldIndex(sort_by), sort_order) - cat.select() + cat = SqlCompletionCategory(name, parent=self, sort_by=sort_by, + sort_order=sort_order, limit=limit, + columns_to_filter=self.columns_to_filter) self._categories.append(cat) - return cat def delete_cur_item(self, completion): """Delete the selected item.""" @@ -95,7 +121,7 @@ class SqlCompletionModel(QAbstractItemModel): return if not index.parent().isValid(): if index.column() == 0: - return self._categories[index.row()].tableName() + return self._categories[index.row()].tablename else: table = self._categories[index.parent().row()] idx = table.index(index.row(), index.column()) @@ -177,6 +203,7 @@ class SqlCompletionModel(QAbstractItemModel): Args: pattern: The filter pattern to set. """ + log.completion.debug("Setting completion pattern '{}'".format(pattern)) # TODO: should pattern be saved in the view layer instead? self.pattern = pattern # escape to treat a user input % or _ as a literal, not a wildcard @@ -184,14 +211,9 @@ class SqlCompletionModel(QAbstractItemModel): pattern = pattern.replace('_', '\\_') # treat spaces as wildcards to match any of the typed words pattern = re.sub(r' +', '%', pattern) - for t in self._categories: - fields = (t.record().fieldName(i) for i in self.columns_to_filter) - query = ' or '.join("{} like '%{}%' escape '\\'" - .format(field, pattern) - for field in fields) - log.completion.debug("Setting filter = '{}' for table '{}'" - .format(query, t.tableName())) - t.setFilter(query) + pattern = '%{}%'.format(pattern) + for cat in self._categories: + cat.set_pattern(pattern) def first_item(self): """Return the index of the first child (non-category) in the model.""" diff --git a/qutebrowser/misc/sql.py b/qutebrowser/misc/sql.py index fd7257bdb..9065caa48 100644 --- a/qutebrowser/misc/sql.py +++ b/qutebrowser/misc/sql.py @@ -39,7 +39,7 @@ def close(): QSqlDatabase.removeDatabase(QSqlDatabase.database().connectionName()) -def _run_query(querystr, *values): +def run_query(querystr, *values): """Run the given SQL query string on the database. Args: @@ -85,12 +85,12 @@ class SqlTable(QObject): super().__init__(parent) self._name = name self._primary_key = primary_key - _run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))" + run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))" .format(name, ','.join(fields), primary_key)) def __iter__(self): """Iterate rows in the table.""" - result = _run_query("SELECT * FROM {}".format(self._name)) + result = run_query("SELECT * FROM {}".format(self._name)) while result.next(): rec = result.record() yield tuple(rec.value(i) for i in range(rec.count())) @@ -101,13 +101,13 @@ class SqlTable(QObject): Args: key: Primary key value to search for. """ - query = _run_query("SELECT * FROM {} where {} = ?" + query = run_query("SELECT * FROM {} where {} = ?" .format(self._name, self._primary_key), key) return query.next() def __len__(self): """Return the count of rows in the table.""" - result = _run_query("SELECT count(*) FROM {}".format(self._name)) + result = run_query("SELECT count(*) FROM {}".format(self._name)) result.next() return result.value(0) @@ -117,7 +117,7 @@ class SqlTable(QObject): Args: key: Primary key value to fetch. """ - result = _run_query("SELECT * FROM {} where {} = ?" + result = run_query("SELECT * FROM {} where {} = ?" .format(self._name, self._primary_key), key) result.next() rec = result.record() @@ -134,7 +134,7 @@ class SqlTable(QObject): The number of rows deleted. """ field = field or self._primary_key - query = _run_query("DELETE FROM {} where {} = ?" + query = run_query("DELETE FROM {} where {} = ?" .format(self._name, field), value) if not query.numRowsAffected(): raise KeyError('No row with {} = "{}"'.format(field, value)) @@ -149,13 +149,13 @@ class SqlTable(QObject): """ cmd = "REPLACE" if replace else "INSERT" paramstr = ','.join(['?'] * len(values)) - _run_query("{} INTO {} values({})" + run_query("{} INTO {} values({})" .format(cmd, self._name, paramstr), *values) self.changed.emit() def delete_all(self): """Remove all row from the table.""" - _run_query("DELETE FROM {}".format(self._name)) + run_query("DELETE FROM {}".format(self._name)) self.changed.emit() diff --git a/tests/unit/completion/test_sqlmodel.py b/tests/unit/completion/test_sqlmodel.py index 49cb216dd..024b69ec2 100644 --- a/tests/unit/completion/test_sqlmodel.py +++ b/tests/unit/completion/test_sqlmodel.py @@ -155,16 +155,19 @@ def test_sorting(sort_by, sort_order, data, expected): ('_', [0], [('A', [('a_b', '', ''), ('__a', '', ''), ('abc', '', '')])], [('A', [('a_b', '', ''), ('__a', '', '')])]), + + ("can't", [0], + [('A', [("can't touch this", '', ''), ('a', '', '')])], + [('A', [("can't touch this", '', '')])]), ]) def test_set_pattern(pattern, filter_cols, before, after): """Validate the filtering and sorting results of set_pattern.""" - model = sqlmodel.SqlCompletionModel() + model = sqlmodel.SqlCompletionModel(columns_to_filter=filter_cols) for name, rows in before: table = sql.SqlTable(name, ['a', 'b', 'c'], primary_key='a') for row in rows: table.insert(*row) model.new_category(name) - model.columns_to_filter = filter_cols model.set_pattern(pattern) _check_model(model, after) @@ -198,3 +201,11 @@ def test_first_last_item(data, first, last): model.new_category(name) assert model.data(model.first_item()) == first assert model.data(model.last_item()) == last + +def test_limit(): + table = sql.SqlTable('test_limit', ['a'], primary_key='a') + for i in range(5): + table.insert([i]) + model = sqlmodel.SqlCompletionModel() + model.new_category('test_limit', limit=3) + assert model.count() == 3