Merge remote-tracking branch 'origin/pr/4191'

This commit is contained in:
Florian Bruhin 2018-09-30 22:01:57 +02:00
commit 4b495303f9
2 changed files with 41 additions and 9 deletions

View File

@ -173,15 +173,12 @@ class HostBlocker:
for url in config.val.content.host_blocking.lists:
if url.scheme() == 'file':
filename = url.toLocalFile()
try:
fileobj = open(filename, 'rb')
except OSError as e:
message.error("adblock: Error while reading {}: {}".format(
filename, e.strerror))
continue
download = _FakeDownload(fileobj)
self._in_progress.append(download)
self._on_download_finished(download)
if os.path.isdir(filename):
for filenames in os.scandir(filename):
if filenames.is_file():
self._import_local(filenames.path)
else:
self._import_local(filename)
else:
fobj = io.BytesIO()
fobj.name = 'adblock: ' + url.host()
@ -192,6 +189,22 @@ class HostBlocker:
download.finished.connect(
functools.partial(self._on_download_finished, download))
def _import_local(self, filename):
"""Adds the contents of a file to the blocklist.
Args:
filename: path to a local file to import.
"""
try:
fileobj = open(filename, 'rb')
except OSError as e:
message.error("adblock: Error while reading {}: {}".format(
filename, e.strerror))
return
download = _FakeDownload(fileobj)
self._in_progress.append(download)
self._on_download_finished(download)
def _parse_line(self, line):
"""Parse a line from a host file.

View File

@ -434,3 +434,22 @@ def test_config_change(config_stub, basedir, download_stub,
host_blocker.read_hosts()
for str_url in URLS_TO_CHECK:
assert not host_blocker.is_blocked(QUrl(str_url))
def test_add_directory(config_stub, basedir, download_stub,
data_tmpdir, tmpdir):
"""Ensure adblocker can import all files in a directory."""
blocklist_hosts2 = []
for i in BLOCKLIST_HOSTS[1:]:
blocklist_hosts2.append('1' + i)
create_blocklist(tmpdir, blocked_hosts=BLOCKLIST_HOSTS,
name='blocked-hosts', line_format='one_per_line')
create_blocklist(tmpdir, blocked_hosts=blocklist_hosts2,
name='blocked-hosts2', line_format='one_per_line')
config_stub.val.content.host_blocking.lists = [tmpdir.strpath]
config_stub.val.content.host_blocking.enabled = True
host_blocker = adblock.HostBlocker()
host_blocker.adblock_update()
assert len(host_blocker._blocked_hosts) == len(blocklist_hosts2) * 2