diff --git a/pirate/data.py b/pirate/data.py index 457809f..1efdbfd 100644 --- a/pirate/data.py +++ b/pirate/data.py @@ -11,3 +11,6 @@ blacklist = set(json.loads(get_resource('blacklist.json').decode())) default_headers = {'User-Agent': 'pirate get'} default_timeout = 10 + +MIRROR_DEFAULT = 'https://thepiratebay.mn' +MIRROR_SOURCE = 'https://proxybay.co/list.txt' diff --git a/pirate/pirate.py b/pirate/pirate.py index 7254e11..e508432 100755 --- a/pirate/pirate.py +++ b/pirate/pirate.py @@ -24,8 +24,6 @@ import pirate.local from os.path import expanduser, expandvars from pirate.print import Printer -MIRROR_DEFAULT = 'https://thepiratebay.mn' -MIRROR_SOURCE = 'https://proxybay.co/list.txt' def parse_config_file(text): config = configparser.RawConfigParser() @@ -273,14 +271,14 @@ def search_mirrors(printer, pages, category, sort, action, search, mirror): # Search on our mirror, or the default one. if not mirror: - mirror = MIRROR_DEFAULT + mirror = pirate.data.MIRROR_DEFAULT results, mirror = search_on_mirror(printer, pages, category, sort, action, search, mirror) if results: return results, mirror # If the default mirror failed, get some mirrors. - mirror_sources = [MIRROR_SOURCE] + mirror_sources = [pirate.data.MIRROR_SOURCE] for mirror_source in mirror_sources: mirrors = OrderedDict() try: @@ -343,8 +341,8 @@ def pirate_main(args): if args.source == 'local_tpb': results = pirate.local.search(args.database, args.search) elif args.source == 'tpb': - results, site = search_mirrors(printer, args.pages, args.category,\ - args.sort, args.action, args.search,\ + results, site = search_mirrors(printer, args.pages, args.category, + args.sort, args.action, args.search, args.mirror) if len(results) == 0: diff --git a/tests/test_pirate.py b/tests/test_pirate.py index 552db2f..8457d40 100755 --- a/tests/test_pirate.py +++ b/tests/test_pirate.py @@ -85,6 +85,9 @@ class TestPirate(unittest.TestCase): 'LocalDB': { 'enabled': bool, 'path': str, + }, + 'Mirror': { + 'url': str, } } config1 = """ @@ -99,6 +102,10 @@ class TestPirate(unittest.TestCase): [Save] Magnets=True """ + config3= """ + [Mirror] + url = http:abc + """ tests = [ (config1, {'Save': {'magnets': False}}), (config1, {'Save': {'torrents': False}}), @@ -106,6 +113,7 @@ class TestPirate(unittest.TestCase): (config1, {'LocalDB': {'enabled': True}}), (config1, {'LocalDB': {'path': 'abc'}}), (config2, {'Save': {'magnets': True}}), + (config3, {'Mirror': {'url': 'http:abc'}}), ] for test in tests: config = pirate.pirate.parse_config_file(test[0]) @@ -125,6 +133,7 @@ class TestPirate(unittest.TestCase): ('', ['-R'], {'action': 'recent'}), ('', ['-l'], {'action': 'list_categories'}), ('', ['--list_sorts'], {'action': 'list_sorts'}), + ('', ['--mirror', 'url'], {'mirror': 'url'}), ('', ['term'], {'action': 'search', 'source': 'tpb'}), ('', ['-L', 'filename', 'term'], {'action': 'search', 'source': 'local_tpb', 'database': 'filename'}), ('', ['term', '-S', 'dir'], {'action': 'search', 'save_directory': 'dir'}), @@ -146,7 +155,7 @@ class TestPirate(unittest.TestCase): self.assertEqual(test[2][option], value) def test_search_mirrors(self): - pages, category, sort, action, search = (1, 100, 10, 'browse', []) + pages, category, sort, action, search, mirror = (1, 100, 10, 'browse', [], None) class MockResponse(): readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']]) info = mock.MagicMock() @@ -155,7 +164,7 @@ class TestPirate(unittest.TestCase): printer = MagicMock(Printer) with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: with patch('pirate.torrent.remote', return_value=[]) as remote: - results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) + results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search, mirror) self.assertEqual(results, []) self.assertEqual(mirror, 'https://thepiratebay.mn') remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn')