diff --git a/qutebrowser/completion/completer.py b/qutebrowser/completion/completer.py index bd4f43009..8742156a3 100644 --- a/qutebrowser/completion/completer.py +++ b/qutebrowser/completion/completer.py @@ -73,8 +73,8 @@ class Completer(QObject): A completion model or None. """ model = completion(*pos_args) - if model is None: - return None + if model is None or hasattr(model, 'set_pattern'): + return model else: return sortfilter.CompletionFilterModel(source=model, parent=self) diff --git a/qutebrowser/completion/models/sqlmodel.py b/qutebrowser/completion/models/sqlmodel.py new file mode 100644 index 000000000..a5a9bc99a --- /dev/null +++ b/qutebrowser/completion/models/sqlmodel.py @@ -0,0 +1,218 @@ +# vim: ft=python fileencoding=utf-8 sts=4 sw=4 et: + +# Copyright 2016 Ryan Roden-Corrent (rcorre) +# +# This file is part of qutebrowser. +# +# qutebrowser is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# qutebrowser is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with qutebrowser. If not, see . + +"""A completion model backed by SQL tables.""" + +import re + +from PyQt5.QtCore import Qt, QModelIndex, QAbstractItemModel +from PyQt5.QtSql import QSqlTableModel, QSqlDatabase, QSqlQuery + +from qutebrowser.utils import usertypes, log + + +class SqlCompletionModel(QAbstractItemModel): + + """A sqlite-based model that provides data for the CompletionView. + + This model is a wrapper around one or more sql tables. The tables are all + stored in a single database in qutebrowser's cache directory. + + Top level indices represent categories, each of which is backed by a single + table. Child indices represent rows of those tables. + + Class Attributes: + COLUMN_WIDTHS: The width percentages of the columns used in the + completion view. + + Attributes: + column_widths: The width percentages of the columns used in the + completion view. + columns_to_filter: A list of indices of columns to apply the filter to. + pattern: Current filter pattern, used for highlighting. + _categories: The category tables. + """ + + def __init__(self, column_widths=(30, 70, 0), columns_to_filter=None, + parent=None): + super().__init__(parent) + self.columns_to_filter = columns_to_filter or [0] + self.column_widths = column_widths + self._categories = [] + self.srcmodel = self # TODO: dummy for compat with old API + self.pattern = '' + + def new_category(self, name, sort_by=None, sort_order=Qt.AscendingOrder): + """Create a new completion category and add it to this model. + + Args: + name: Name of category, and the table in the database. + sort_by: The name of the field to sort by, or None for no sorting. + sort_order: Sorting order, if sort_by is non-None. + + Return: A new CompletionCategory. + """ + database = QSqlDatabase.database() + cat = QSqlTableModel(parent=self, db=database) + cat.setTable(name) + if sort_by: + cat.setSort(cat.fieldIndex(sort_by), sort_order) + cat.select() + self._categories.append(cat) + return cat + + def delete_cur_item(self, completion): + """Delete the selected item.""" + raise NotImplementedError + + def data(self, index, role=Qt.DisplayRole): + """Return the item data for index. + + Override QAbstractItemModel::data. + + Args: + index: The QModelIndex to get item flags for. + + Return: The item data, or None on an invalid index. + """ + if not index.isValid() or role != Qt.DisplayRole: + return + if not index.parent().isValid(): + if index.column() == 0: + return self._categories[index.row()].tableName() + else: + table = self._categories[index.parent().row()] + idx = table.index(index.row(), index.column()) + return table.data(idx) + + def flags(self, index): + """Return the item flags for index. + + Override QAbstractItemModel::flags. + + Return: The item flags, or Qt.NoItemFlags on error. + """ + if not index.isValid(): + return + if index.parent().isValid(): + # item + return (Qt.ItemIsEnabled | Qt.ItemIsSelectable | + Qt.ItemNeverHasChildren) + else: + # category + return Qt.NoItemFlags + + def index(self, row, col, parent=QModelIndex()): + """Get an index into the model. + + Override QAbstractItemModel::index. + + Return: A QModelIndex. + """ + if (row < 0 or row >= self.rowCount(parent) or + col < 0 or col >= self.columnCount(parent)): + return QModelIndex() + if parent.isValid(): + if parent.column() != 0: + return QModelIndex() + # store a pointer to the parent table in internalPointer + return self.createIndex(row, col, self._categories[parent.row()]) + return self.createIndex(row, col, None) + + def parent(self, index): + """Get an index to the parent of the given index. + + Override QAbstractItemModel::parent. + + Args: + index: The QModelIndex to get the parent index for. + """ + parent_table = index.internalPointer() + if not parent_table: + # categories have no parent + return QModelIndex() + row = self._categories.index(parent_table) + return self.createIndex(row, 0, None) + + def rowCount(self, parent=QModelIndex()): + if not parent.isValid(): + # top-level + return len(self._categories) + elif parent.internalPointer() or parent.column() != 0: + # item or nonzero category column (only first col has children) + return 0 + else: + # category + return self._categories[parent.row()].rowCount() + + def columnCount(self, parent=QModelIndex()): + # pylint: disable=unused-argument + return 3 + + def count(self): + """Return the count of non-category items.""" + return sum(t.rowCount() for t in self._categories) + + def set_pattern(self, pattern): + """Set the filter pattern for all category tables. + + This will apply to the fields indicated in columns_to_filter. + + Args: + pattern: The filter pattern to set. + """ + # TODO: should pattern be saved in the view layer instead? + self.pattern = pattern + # escape to treat a user input % or _ as a literal, not a wildcard + pattern = pattern.replace('%', '\\%') + pattern = pattern.replace('_', '\\_') + # treat spaces as wildcards to match any of the typed words + pattern = re.sub(r' +', '%', pattern) + for t in self._categories: + fields = (t.record().fieldName(i) for i in self.columns_to_filter) + query = ' or '.join("{} like '%{}%' escape '\\'" + .format(field, pattern) + for field in fields) + log.completion.debug("Setting filter = '{}' for table '{}'" + .format(query, t.tableName())) + t.setFilter(query) + + def first_item(self): + """Return the index of the first child (non-category) in the model.""" + for row, table in enumerate(self._categories): + if table.rowCount() > 0: + parent = self.index(row, 0) + return self.index(0, 0, parent) + return QModelIndex() + + def last_item(self): + """Return the index of the last child (non-category) in the model.""" + for row, table in reversed(list(enumerate(self._categories))): + childcount = table.rowCount() + if childcount > 0: + parent = self.index(row, 0) + return self.index(childcount - 1, 0, parent) + return QModelIndex() + + +class SqlException(Exception): + + """Raised on an error interacting with the SQL database.""" + + pass diff --git a/tests/unit/completion/test_sqlmodel.py b/tests/unit/completion/test_sqlmodel.py new file mode 100644 index 000000000..6b8f01f4d --- /dev/null +++ b/tests/unit/completion/test_sqlmodel.py @@ -0,0 +1,204 @@ +# vim: ft=python fileencoding=utf-8 sts=4 sw=4 et: + +# Copyright 2016 Ryan Roden-Corrent (rcorre) +# +# This file is part of qutebrowser. +# +# qutebrowser is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# qutebrowser is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with qutebrowser. If not, see . + +"""Tests for the base sql completion model.""" + +import pytest +from PyQt5.QtCore import Qt + +from qutebrowser.misc import sql +from qutebrowser.completion.models import sqlmodel + + +@pytest.fixture(autouse=True) +def init(): + sql.init() + yield + sql.close() + + +def _check_model(model, expected): + """Check that a model contains the expected items in the given order. + + Args: + expected: A list of form + [ + (cat, [(name, desc, misc), (name, desc, misc), ...]), + (cat, [(name, desc, misc), (name, desc, misc), ...]), + ... + ] + """ + assert model.rowCount() == len(expected) + for i, (expected_title, expected_items) in enumerate(expected): + catidx = model.index(i, 0) + assert model.data(catidx) == expected_title + assert model.rowCount(catidx) == len(expected_items) + for j, (name, desc, misc) in enumerate(expected_items): + assert model.data(model.index(j, 0, catidx)) == name + assert model.data(model.index(j, 1, catidx)) == desc + assert model.data(model.index(j, 2, catidx)) == misc + + +@pytest.mark.parametrize('rowcounts, expected', [ + ([0], 0), + ([1], 1), + ([2], 2), + ([0, 0], 0), + ([0, 0, 0], 0), + ([1, 1], 2), + ([3, 2, 1], 6), + ([0, 2, 0], 2), +]) +def test_count(rowcounts, expected): + model = sqlmodel.SqlCompletionModel() + for i, rowcount in enumerate(rowcounts): + name = 'Foo' + str(i) + table = sql.SqlTable(name, ['a'], primary_key='a') + for rownum in range(rowcount): + table.insert(rownum) + model.new_category(name) + assert model.count() == expected + + +@pytest.mark.parametrize('sort_by, sort_order, data, expected', [ + (None, Qt.AscendingOrder, + [('B', 'C', 'D'), ('A', 'F', 'C'), ('C', 'A', 'G')], + [('B', 'C', 'D'), ('A', 'F', 'C'), ('C', 'A', 'G')]), + + ('a', Qt.AscendingOrder, + [('B', 'C', 'D'), ('A', 'F', 'C'), ('C', 'A', 'G')], + [('A', 'F', 'C'), ('B', 'C', 'D'), ('C', 'A', 'G')]), + + ('a', Qt.DescendingOrder, + [('B', 'C', 'D'), ('A', 'F', 'C'), ('C', 'A', 'G')], + [('C', 'A', 'G'), ('B', 'C', 'D'), ('A', 'F', 'C')]), + + ('b', Qt.AscendingOrder, + [('B', 'C', 'D'), ('A', 'F', 'C'), ('C', 'A', 'G')], + [('C', 'A', 'G'), ('B', 'C', 'D'), ('A', 'F', 'C')]), + + ('b', Qt.DescendingOrder, + [('B', 'C', 'D'), ('A', 'F', 'C'), ('C', 'A', 'G')], + [('A', 'F', 'C'), ('B', 'C', 'D'), ('C', 'A', 'G')]), + + ('c', Qt.AscendingOrder, + [('B', 'C', 2), ('A', 'F', 0), ('C', 'A', 1)], + [('A', 'F', 0), ('C', 'A', 1), ('B', 'C', 2)]), + + ('c', Qt.DescendingOrder, + [('B', 'C', 2), ('A', 'F', 0), ('C', 'A', 1)], + [('B', 'C', 2), ('C', 'A', 1), ('A', 'F', 0)]), +]) +def test_sorting(sort_by, sort_order, data, expected): + table = sql.SqlTable('Foo', ['a', 'b', 'c'], primary_key='a') + for row in data: + table.insert(*row) + model = sqlmodel.SqlCompletionModel() + model.new_category('Foo', sort_by=sort_by, sort_order=sort_order) + _check_model(model, [('Foo', expected)]) + + +@pytest.mark.parametrize('pattern, filter_cols, before, after', [ + ('foo', [0], + [('A', [('foo', '', ''), ('bar', '', ''), ('aafobbb', '', '')])], + [('A', [('foo', '', '')])]), + + ('foo', [0], + [('A', [('baz', 'bar', 'foo'), ('foo', '', ''), ('bar', 'foo', '')])], + [('A', [('foo', '', '')])]), + + ('foo', [0], + [('A', [('foo', '', ''), ('bar', '', '')]), + ('B', [('foo', '', ''), ('bar', '', '')])], + [('A', [('foo', '', '')]), ('B', [('foo', '', '')])]), + + ('foo', [0], + [('A', [('fooa', '', ''), ('foob', '', ''), ('fooc', '', '')])], + [('A', [('fooa', '', ''), ('foob', '', ''), ('fooc', '', '')])]), + + ('foo', [0], + [('A', [('foo', '', '')]), ('B', [('bar', '', '')])], + [('A', [('foo', '', '')]), ('B', [])]), + + ('foo', [1], + [('A', [('foo', 'bar', ''), ('bar', 'foo', '')])], + [('A', [('bar', 'foo', '')])]), + + ('foo', [0, 1], + [('A', [('foo', 'bar', ''), ('bar', 'foo', '')])], + [('A', [('foo', 'bar', ''), ('bar', 'foo', '')])]), + + ('foo', [0, 1, 2], + [('A', [('foo', '', ''), ('bar', '', '')])], + [('A', [('foo', '', '')])]), + + ('foo bar', [0], + [('A', [('foo', '', ''), ('bar foo', '', ''), ('xfooyybarz', '', '')])], + [('A', [('xfooyybarz', '', '')])]), + + ('foo%bar', [0], + [('A', [('foo%bar', '', ''), ('foo bar', '', ''), ('foobar', '', '')])], + [('A', [('foo%bar', '', '')])]), + + ('_', [0], + [('A', [('a_b', '', ''), ('__a', '', ''), ('abc', '', '')])], + [('A', [('a_b', '', ''), ('__a', '', '')])]), +]) +def test_set_pattern(pattern, filter_cols, before, after): + """Validate the filtering and sorting results of set_pattern.""" + model = sqlmodel.SqlCompletionModel() + for name, rows in before: + table = sql.SqlTable(name, ['a', 'b', 'c'], primary_key='a') + for row in rows: + table.insert(*row) + model.new_category(name) + model.columns_to_filter = filter_cols + model.set_pattern(pattern) + _check_model(model, after) + + +@pytest.mark.parametrize('data, first, last', [ + ([('A', ['Aa'])], 'Aa', 'Aa'), + ([('A', ['Aa', 'Ba'])], 'Aa', 'Ba'), + ([('A', ['Aa', 'Ab', 'Ac']), ('B', ['Ba', 'Bb']), + ('C', ['Ca'])], 'Aa', 'Ca'), + ([('A', []), ('B', ['Ba'])], 'Ba', 'Ba'), + ([('A', []), ('B', []), ('C', ['Ca'])], 'Ca', 'Ca'), + ([('A', []), ('B', []), ('C', ['Ca', 'Cb'])], 'Ca', 'Cb'), + ([('A', ['Aa']), ('B', [])], 'Aa', 'Aa'), + ([('A', ['Aa']), ('B', []), ('C', [])], 'Aa', 'Aa'), + ([('A', ['Aa']), ('B', []), ('C', ['Ca'])], 'Aa', 'Ca'), + ([('A', []), ('B', [])], None, None), +]) +def test_first_last_item(data, first, last): + """Test that first() and last() return indexes to the first and last items. + + Args: + data: Input to _make_model + first: text of the first item + last: text of the last item + """ + model = sqlmodel.SqlCompletionModel() + for name, rows in data: + table = sql.SqlTable(name, ['a'], primary_key='a') + for row in rows: + table.insert(row) + model.new_category(name) + assert model.data(model.first_item()) == first + assert model.data(model.last_item()) == last