diff --git a/qutebrowser/browser/adblock.py b/qutebrowser/browser/adblock.py index e38b44d5f..862ec18fa 100644 --- a/qutebrowser/browser/adblock.py +++ b/qutebrowser/browser/adblock.py @@ -173,15 +173,12 @@ class HostBlocker: for url in config.val.content.host_blocking.lists: if url.scheme() == 'file': filename = url.toLocalFile() - try: - fileobj = open(filename, 'rb') - except OSError as e: - message.error("adblock: Error while reading {}: {}".format( - filename, e.strerror)) - continue - download = _FakeDownload(fileobj) - self._in_progress.append(download) - self._on_download_finished(download) + if os.path.isdir(filename): + for filenames in os.scandir(filename): + if filenames.is_file(): + self._import_local(filenames.path) + else: + self._import_local(filename) else: fobj = io.BytesIO() fobj.name = 'adblock: ' + url.host() @@ -192,6 +189,22 @@ class HostBlocker: download.finished.connect( functools.partial(self._on_download_finished, download)) + def _import_local(self, filename): + """Adds the contents of a file to the blocklist. + + Args: + filename: path to a local file to import. + """ + try: + fileobj = open(filename, 'rb') + except OSError as e: + message.error("adblock: Error while reading {}: {}".format( + filename, e.strerror)) + return + download = _FakeDownload(fileobj) + self._in_progress.append(download) + self._on_download_finished(download) + def _parse_line(self, line): """Parse a line from a host file. diff --git a/tests/unit/browser/test_adblock.py b/tests/unit/browser/test_adblock.py index 470bce5cd..8bcbf3eb6 100644 --- a/tests/unit/browser/test_adblock.py +++ b/tests/unit/browser/test_adblock.py @@ -434,3 +434,22 @@ def test_config_change(config_stub, basedir, download_stub, host_blocker.read_hosts() for str_url in URLS_TO_CHECK: assert not host_blocker.is_blocked(QUrl(str_url)) + + +def test_add_directory(config_stub, basedir, download_stub, + data_tmpdir, tmpdir): + """Ensure adblocker can import all files in a directory.""" + blocklist_hosts2 = [] + for i in BLOCKLIST_HOSTS[1:]: + blocklist_hosts2.append('1' + i) + + create_blocklist(tmpdir, blocked_hosts=BLOCKLIST_HOSTS, + name='blocked-hosts', line_format='one_per_line') + create_blocklist(tmpdir, blocked_hosts=blocklist_hosts2, + name='blocked-hosts2', line_format='one_per_line') + + config_stub.val.content.host_blocking.lists = [tmpdir.strpath] + config_stub.val.content.host_blocking.enabled = True + host_blocker = adblock.HostBlocker() + host_blocker.adblock_update() + assert len(host_blocker._blocked_hosts) == len(blocklist_hosts2) * 2