Return namedtuples from SqlTable.

Instead of returning a regular tuple and trying to remember which index maps to
which field, return named tuples that allow accessing the fields by name.
This commit is contained in:
Ryan Roden-Corrent 2017-02-15 09:06:23 -05:00
parent d89898ef7d
commit 0e650ad719
3 changed files with 12 additions and 3 deletions

View File

@ -193,7 +193,7 @@ class QuickmarkManager(UrlMarkManager):
if name not in self: if name not in self:
raise DoesNotExistError("Quickmark '{}' does not exist!" raise DoesNotExistError("Quickmark '{}' does not exist!"
.format(name)) .format(name))
urlstr = self[name] urlstr = self[name].url
try: try:
url = urlutils.fuzzy_url(urlstr, do_search=False) url = urlutils.fuzzy_url(urlstr, do_search=False)
except urlutils.InvalidUrlError as e: except urlutils.InvalidUrlError as e:

View File

@ -24,6 +24,8 @@ from PyQt5.QtSql import QSqlDatabase, QSqlQuery
from qutebrowser.utils import log from qutebrowser.utils import log
import collections
def init(): def init():
"""Initialize the SQL database connection.""" """Initialize the SQL database connection."""
@ -87,13 +89,14 @@ class SqlTable(QObject):
self._primary_key = primary_key self._primary_key = primary_key
run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))" run_query("CREATE TABLE {} ({}, PRIMARY KEY ({}))"
.format(name, ','.join(fields), primary_key)) .format(name, ','.join(fields), primary_key))
self.Entry = collections.namedtuple(name + '_Entry', fields)
def __iter__(self): def __iter__(self):
"""Iterate rows in the table.""" """Iterate rows in the table."""
result = run_query("SELECT * FROM {}".format(self._name)) result = run_query("SELECT * FROM {}".format(self._name))
while result.next(): while result.next():
rec = result.record() rec = result.record()
yield tuple(rec.value(i) for i in range(rec.count())) yield self.Entry(*[rec.value(i) for i in range(rec.count())])
def __contains__(self, key): def __contains__(self, key):
"""Return whether the table contains the matching item. """Return whether the table contains the matching item.
@ -121,7 +124,7 @@ class SqlTable(QObject):
.format(self._name, self._primary_key), [key]) .format(self._name, self._primary_key), [key])
result.next() result.next()
rec = result.record() rec = result.record()
return tuple(rec.value(i) for i in range(rec.count())) return self.Entry(*[rec.value(i) for i in range(rec.count())])
def delete(self, value, field=None): def delete(self, value, field=None):
"""Remove all rows for which `field` equals `value`. """Remove all rows for which `field` equals `value`.

View File

@ -231,3 +231,9 @@ def test_where():
model = sqlmodel.SqlCompletionModel() model = sqlmodel.SqlCompletionModel()
model.new_category('test_where', where='not c') model.new_category('test_where', where='not c')
_check_model(model, [('test_where', [('foo', 'bar', False)])]) _check_model(model, [('test_where', [('foo', 'bar', False)])])
def test_entry():
table = sql.SqlTable('test_entry', ['a', 'b', 'c'], primary_key='a')
assert hasattr(table.Entry, 'a')
assert hasattr(table.Entry, 'b')
assert hasattr(table.Entry, 'c')