From 2f40523db246ed9c7ceae64f03d7b33a52634960 Mon Sep 17 00:00:00 2001 From: rnhmjoj Date: Wed, 31 Aug 2016 21:48:25 +0200 Subject: [PATCH] tidy up search_mirrors function --- pirate/data.py | 3 +++ pirate/pirate.py | 26 ++++++++++++++++---------- tests/test_pirate.py | 7 ++++--- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pirate/data.py b/pirate/data.py index 457809f..e33cb5a 100644 --- a/pirate/data.py +++ b/pirate/data.py @@ -11,3 +11,6 @@ blacklist = set(json.loads(get_resource('blacklist.json').decode())) default_headers = {'User-Agent': 'pirate get'} default_timeout = 10 + +default_mirror = 'https://thepiratebay.org/' +mirror_list = 'https://proxybay.co/list.txt' diff --git a/pirate/pirate.py b/pirate/pirate.py index 07ce4cf..8bb2e56 100755 --- a/pirate/pirate.py +++ b/pirate/pirate.py @@ -241,6 +241,7 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search): mirror=mirror) except (urllib.error.URLError, socket.timeout, IOError, ValueError): printer.print('Failed', color='WARN') + return None else: printer.print('Ok', color='alt') return results, mirror @@ -248,21 +249,21 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search): def search_mirrors(printer, *args): # try official site - result = connect_mirror('https://thepiratebay.mn', printer, *args) + result = connect_mirror(pirate.data.default_mirror, printer, *args) if result: return result # download mirror list try: - req = request.Request('https://proxybay.co/list.txt', + req = request.Request(pirate.data.mirror_list, headers=pirate.data.default_headers) f = request.urlopen(req, timeout=pirate.data.default_timeout) - except IOError: - printer.print('Could not fetch mirrors :(', color='ERROR') - sys.exit(1) + except urllib.error.URLError as e: + raise IOError('Could not fetch mirrors', e.reason) if f.getcode() != 200: - raise IOError('The proxy bay responded with an error') + raise IOError('The proxy bay responded with an error', + f.read().decode('utf-8')) mirrors = [i.decode('utf-8').strip() for i in f.readlines()][3:] @@ -274,8 +275,7 @@ def search_mirrors(printer, *args): if result: return result else: - printer.print('No more available mirrors :(', color='ERROR') - sys.exit(1) + raise IOError('No more available mirrors') def pirate_main(args): @@ -311,8 +311,14 @@ def pirate_main(args): if args.source == 'local_tpb': results = pirate.local.search(args.database, args.search) elif args.source == 'tpb': - results, site = search_mirrors(printer, args.pages, args.category, - args.sort, args.action, args.search) + try: + results, site = search_mirrors(printer, args.pages, args.category, + args.sort, args.action, args.search) + except IOError as e: + printer.print(e.args[0] + ' :( ', color='ERROR') + if len(e.args) > 1: + printer.print(e.args[1]) + sys.exit(1) if len(results) == 0: printer.print('No results') diff --git a/tests/test_pirate.py b/tests/test_pirate.py index 552db2f..d5371eb 100755 --- a/tests/test_pirate.py +++ b/tests/test_pirate.py @@ -6,6 +6,7 @@ from unittest import mock from unittest.mock import patch, call, MagicMock import pirate.pirate +import pirate.data from pirate.print import Printer @@ -157,14 +158,14 @@ class TestPirate(unittest.TestCase): with patch('pirate.torrent.remote', return_value=[]) as remote: results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) self.assertEqual(results, []) - self.assertEqual(mirror, 'https://thepiratebay.mn') - remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn') + self.assertEqual(mirror, pirate.data.default_mirror) + remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror) with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote: results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) self.assertEqual(results, []) self.assertEqual(mirror, 'https://example.com') remote.assert_has_calls([ - call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn'), + call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror), call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://example.com') ])