mhtml: Sanitize default filename suggestion

This commit is contained in:
Daniel 2015-10-30 21:08:37 +01:00
parent 4f01382c64
commit a1bc020fec
4 changed files with 37 additions and 1 deletions

View File

@ -1187,6 +1187,7 @@ class CommandDispatcher:
tab_id = self._current_index() tab_id = self._current_index()
if dest is None: if dest is None:
suggested_fn = self._current_title() + ".mht" suggested_fn = self._current_title() + ".mht"
suggested_fn = utils.sanitize_filename(suggested_fn)
q = usertypes.Question() q = usertypes.Question()
q.text = "Save page to: " q.text = "Save page to: "
q.mode = usertypes.PromptMode.text q.mode = usertypes.PromptMode.text

View File

@ -458,7 +458,7 @@ def start_download_checked(dest, win_id, tab_id):
# The default name is 'page title.mht' # The default name is 'page title.mht'
title = (objreg.get('webview', scope='tab', window=win_id, tab=tab_id) title = (objreg.get('webview', scope='tab', window=win_id, tab=tab_id)
.title()) .title())
default_name = title + '.mht' default_name = utils.sanitize_filename(title + '.mht')
# Remove characters which cannot be expressed in the file system encoding # Remove characters which cannot be expressed in the file system encoding
encoding = sys.getfilesystemencoding() encoding = sys.getfilesystemencoding()

View File

@ -611,6 +611,27 @@ def force_encoding(text, encoding):
return text.encode(encoding, errors='replace').decode(encoding) return text.encode(encoding, errors='replace').decode(encoding)
def sanitize_filename(name, replacement='_'):
"""Replace invalid filename characters.
Note: This should be used for the basename, as it also removes the path
separator.
Args:
name: The filename.
replacement: The replacement character (or None).
"""
if replacement is None:
replacement = ''
# Bad characters taken from Windows, there are even fewer on Linux
# See also
# https://en.wikipedia.org/wiki/Filename#Reserved_characters_and_words
bad_chars = '\\/:*?"<>|'
for bad_char in bad_chars:
name = name.replace(bad_char, replacement)
return name
def newest_slice(iterable, count): def newest_slice(iterable, count):
"""Get an iterable for the n newest items of the given iterable. """Get an iterable for the n newest items of the given iterable.

View File

@ -839,6 +839,20 @@ def test_force_encoding(inp, enc, expected):
assert utils.force_encoding(inp, enc) == expected assert utils.force_encoding(inp, enc) == expected
@pytest.mark.parametrize('inp, expected', [
('normal.txt', 'normal.txt'),
('user/repo issues.mht', 'user_repo issues.mht'),
('<Test\\File> - "*?:|', '_Test_File_ - _____'),
])
def test_sanitize_filename(inp, expected):
assert utils.sanitize_filename(inp) == expected
def test_sanitize_filename_empty_replacement():
name = '/<Bad File>/'
assert utils.sanitize_filename(name, replacement=None) == 'Bad File'
class TestNewestSlice: class TestNewestSlice:
"""Test newest_slice.""" """Test newest_slice."""