diff --git a/README.md b/README.md index c1ea1a5..6c11ade 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,10 @@ transmission = false ; use colored output colors = true + +; the pirate bay mirror(s) to use: +; one or more space separated URLs +mirror = http://thepiratebay.org ``` Note: diff --git a/pirate/pirate.py b/pirate/pirate.py index 8bb2e56..2e5f971 100755 --- a/pirate/pirate.py +++ b/pirate/pirate.py @@ -39,6 +39,7 @@ def parse_config_file(text): config.set('Misc', 'openCommand', '') config.set('Misc', 'transmission', 'false') config.set('Misc', 'colors', 'true') + config.set('Misc', 'mirror', pirate.data.default_mirror) config.read_string(text) @@ -54,16 +55,12 @@ def parse_config_file(text): def load_config(): # user-defined config files - main = expandvars('$XDG_CONFIG_HOME/pirate-get') - alt = expanduser('~/.config/pirate-get') + config_home = os.getenv('XDG_CONFIG_HOME', '~/.config') + config = expanduser(os.path.join(config_home, 'pirate-get')) # read config file - if os.path.isfile(main): - with open(main) as f: - return parse_config_file(f.read()) - - if os.path.isfile(alt): - with open(alt) as f: + if os.path.isfile(config): + with open(config) as f: return parse_config_file(f.read()) return parse_config_file("") @@ -173,6 +170,9 @@ def parse_args(args_in): parser.add_argument('--disable-colors', dest='color', action='store_false', help='disable colored output') + parser.add_argument('-m', '--mirror', + type=str, nargs='+', + help='the pirate bay mirror(s) to use') args = parser.parse_args(args_in) return args @@ -207,6 +207,9 @@ def combine_configs(config, args): if not args.save_directory: args.save_directory = config.get('Save', 'directory') + if not args.mirror: + args.mirror = config.get('Misc', 'mirror').split() + args.transmission_command = ['transmission-remote'] if args.port: args.transmission_command.append(args.port) @@ -228,16 +231,16 @@ def combine_configs(config, args): return args -def connect_mirror(mirror, printer, pages, category, sort, action, search): +def connect_mirror(mirror, printer, args): try: printer.print('Trying', mirror, end='... ') results = pirate.torrent.remote( printer=printer, - pages=pages, - category=pirate.torrent.parse_category(printer, category), - sort=pirate.torrent.parse_sort(printer, sort), - mode=action, - terms=search, + pages=args.pages, + category=pirate.torrent.parse_category(printer, args.category), + sort=pirate.torrent.parse_sort(printer, args.sort), + mode=args.action, + terms=args.search, mirror=mirror) except (urllib.error.URLError, socket.timeout, IOError, ValueError): printer.print('Failed', color='WARN') @@ -247,11 +250,12 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search): return results, mirror -def search_mirrors(printer, *args): - # try official site - result = connect_mirror(pirate.data.default_mirror, printer, *args) - if result: - return result +def search_mirrors(printer, args): + # try default or user mirrors + for mirror in args.mirror: + result = connect_mirror(mirror, printer, args) + if result: + return result # download mirror list try: @@ -271,7 +275,7 @@ def search_mirrors(printer, *args): for mirror in mirrors: if mirror in pirate.data.blacklist: continue - result = connect_mirror(mirror, printer, *args) + result = connect_mirror(mirror, printer, args) if result: return result else: @@ -312,8 +316,7 @@ def pirate_main(args): results = pirate.local.search(args.database, args.search) elif args.source == 'tpb': try: - results, site = search_mirrors(printer, args.pages, args.category, - args.sort, args.action, args.search) + results, site = search_mirrors(printer, args) except IOError as e: printer.print(e.args[0] + ' :( ', color='ERROR') if len(e.args) > 1: diff --git a/tests/test_pirate.py b/tests/test_pirate.py index d5371eb..7ccbb03 100755 --- a/tests/test_pirate.py +++ b/tests/test_pirate.py @@ -2,6 +2,7 @@ import socket import unittest import subprocess +from argparse import Namespace from unittest import mock from unittest.mock import patch, call, MagicMock @@ -147,7 +148,9 @@ class TestPirate(unittest.TestCase): self.assertEqual(test[2][option], value) def test_search_mirrors(self): - pages, category, sort, action, search = (1, 100, 10, 'browse', []) + args = Namespace(pages=1, category=100, sort=10, + action='browse', search=[], + mirror=[pirate.data.default_mirror]) class MockResponse(): readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']]) info = mock.MagicMock() @@ -156,12 +159,12 @@ class TestPirate(unittest.TestCase): printer = MagicMock(Printer) with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: with patch('pirate.torrent.remote', return_value=[]) as remote: - results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) + results, mirror = pirate.pirate.search_mirrors(printer, args) self.assertEqual(results, []) self.assertEqual(mirror, pirate.data.default_mirror) remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror) with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote: - results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) + results, mirror = pirate.pirate.search_mirrors(printer, args) self.assertEqual(results, []) self.assertEqual(mirror, 'https://example.com') remote.assert_has_calls([