1
0
mirror of https://github.com/vikstrous/pirate-get synced 2025-01-10 10:04:21 +01:00

Allow users to set custom mirrors

This commit is contained in:
rnhmjoj 2016-09-03 14:53:00 +02:00
parent a1cba67656
commit bba0f41224
No known key found for this signature in database
GPG Key ID: 362BB82B7E496B7C
3 changed files with 35 additions and 25 deletions

View File

@ -55,6 +55,10 @@ transmission = false
; use colored output ; use colored output
colors = true colors = true
; the pirate bay mirror(s) to use:
; one or more space separated URLs
mirror = http://thepiratebay.org
``` ```
Note: Note:

View File

@ -39,6 +39,7 @@ def parse_config_file(text):
config.set('Misc', 'openCommand', '') config.set('Misc', 'openCommand', '')
config.set('Misc', 'transmission', 'false') config.set('Misc', 'transmission', 'false')
config.set('Misc', 'colors', 'true') config.set('Misc', 'colors', 'true')
config.set('Misc', 'mirror', pirate.data.default_mirror)
config.read_string(text) config.read_string(text)
@ -54,16 +55,12 @@ def parse_config_file(text):
def load_config(): def load_config():
# user-defined config files # user-defined config files
main = expandvars('$XDG_CONFIG_HOME/pirate-get') config_home = os.getenv('XDG_CONFIG_HOME', '~/.config')
alt = expanduser('~/.config/pirate-get') config = expanduser(os.path.join(config_home, 'pirate-get'))
# read config file # read config file
if os.path.isfile(main): if os.path.isfile(config):
with open(main) as f: with open(config) as f:
return parse_config_file(f.read())
if os.path.isfile(alt):
with open(alt) as f:
return parse_config_file(f.read()) return parse_config_file(f.read())
return parse_config_file("") return parse_config_file("")
@ -173,6 +170,9 @@ def parse_args(args_in):
parser.add_argument('--disable-colors', dest='color', parser.add_argument('--disable-colors', dest='color',
action='store_false', action='store_false',
help='disable colored output') help='disable colored output')
parser.add_argument('-m', '--mirror',
type=str, nargs='+',
help='the pirate bay mirror(s) to use')
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
return args return args
@ -207,6 +207,9 @@ def combine_configs(config, args):
if not args.save_directory: if not args.save_directory:
args.save_directory = config.get('Save', 'directory') args.save_directory = config.get('Save', 'directory')
if not args.mirror:
args.mirror = config.get('Misc', 'mirror').split()
args.transmission_command = ['transmission-remote'] args.transmission_command = ['transmission-remote']
if args.port: if args.port:
args.transmission_command.append(args.port) args.transmission_command.append(args.port)
@ -228,16 +231,16 @@ def combine_configs(config, args):
return args return args
def connect_mirror(mirror, printer, pages, category, sort, action, search): def connect_mirror(mirror, printer, args):
try: try:
printer.print('Trying', mirror, end='... ') printer.print('Trying', mirror, end='... ')
results = pirate.torrent.remote( results = pirate.torrent.remote(
printer=printer, printer=printer,
pages=pages, pages=args.pages,
category=pirate.torrent.parse_category(printer, category), category=pirate.torrent.parse_category(printer, args.category),
sort=pirate.torrent.parse_sort(printer, sort), sort=pirate.torrent.parse_sort(printer, args.sort),
mode=action, mode=args.action,
terms=search, terms=args.search,
mirror=mirror) mirror=mirror)
except (urllib.error.URLError, socket.timeout, IOError, ValueError): except (urllib.error.URLError, socket.timeout, IOError, ValueError):
printer.print('Failed', color='WARN') printer.print('Failed', color='WARN')
@ -247,11 +250,12 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search):
return results, mirror return results, mirror
def search_mirrors(printer, *args): def search_mirrors(printer, args):
# try official site # try default or user mirrors
result = connect_mirror(pirate.data.default_mirror, printer, *args) for mirror in args.mirror:
if result: result = connect_mirror(mirror, printer, args)
return result if result:
return result
# download mirror list # download mirror list
try: try:
@ -271,7 +275,7 @@ def search_mirrors(printer, *args):
for mirror in mirrors: for mirror in mirrors:
if mirror in pirate.data.blacklist: if mirror in pirate.data.blacklist:
continue continue
result = connect_mirror(mirror, printer, *args) result = connect_mirror(mirror, printer, args)
if result: if result:
return result return result
else: else:
@ -312,8 +316,7 @@ def pirate_main(args):
results = pirate.local.search(args.database, args.search) results = pirate.local.search(args.database, args.search)
elif args.source == 'tpb': elif args.source == 'tpb':
try: try:
results, site = search_mirrors(printer, args.pages, args.category, results, site = search_mirrors(printer, args)
args.sort, args.action, args.search)
except IOError as e: except IOError as e:
printer.print(e.args[0] + ' :( ', color='ERROR') printer.print(e.args[0] + ' :( ', color='ERROR')
if len(e.args) > 1: if len(e.args) > 1:

View File

@ -2,6 +2,7 @@
import socket import socket
import unittest import unittest
import subprocess import subprocess
from argparse import Namespace
from unittest import mock from unittest import mock
from unittest.mock import patch, call, MagicMock from unittest.mock import patch, call, MagicMock
@ -147,7 +148,9 @@ class TestPirate(unittest.TestCase):
self.assertEqual(test[2][option], value) self.assertEqual(test[2][option], value)
def test_search_mirrors(self): def test_search_mirrors(self):
pages, category, sort, action, search = (1, 100, 10, 'browse', []) args = Namespace(pages=1, category=100, sort=10,
action='browse', search=[],
mirror=[pirate.data.default_mirror])
class MockResponse(): class MockResponse():
readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']]) readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']])
info = mock.MagicMock() info = mock.MagicMock()
@ -156,12 +159,12 @@ class TestPirate(unittest.TestCase):
printer = MagicMock(Printer) printer = MagicMock(Printer)
with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: with patch('urllib.request.urlopen', return_value=response_obj) as urlopen:
with patch('pirate.torrent.remote', return_value=[]) as remote: with patch('pirate.torrent.remote', return_value=[]) as remote:
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) results, mirror = pirate.pirate.search_mirrors(printer, args)
self.assertEqual(results, []) self.assertEqual(results, [])
self.assertEqual(mirror, pirate.data.default_mirror) self.assertEqual(mirror, pirate.data.default_mirror)
remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror) remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror)
with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote: with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote:
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) results, mirror = pirate.pirate.search_mirrors(printer, args)
self.assertEqual(results, []) self.assertEqual(results, [])
self.assertEqual(mirror, 'https://example.com') self.assertEqual(mirror, 'https://example.com')
remote.assert_has_calls([ remote.assert_has_calls([