1
0
mirror of https://github.com/vikstrous/pirate-get synced 2025-01-10 10:04:21 +01:00

Allow users to set custom mirrors

This commit is contained in:
rnhmjoj 2016-09-03 14:53:00 +02:00
parent a1cba67656
commit bba0f41224
No known key found for this signature in database
GPG Key ID: 362BB82B7E496B7C
3 changed files with 35 additions and 25 deletions

View File

@ -55,6 +55,10 @@ transmission = false
; use colored output
colors = true
; the pirate bay mirror(s) to use:
; one or more space separated URLs
mirror = http://thepiratebay.org
```
Note:

View File

@ -39,6 +39,7 @@ def parse_config_file(text):
config.set('Misc', 'openCommand', '')
config.set('Misc', 'transmission', 'false')
config.set('Misc', 'colors', 'true')
config.set('Misc', 'mirror', pirate.data.default_mirror)
config.read_string(text)
@ -54,16 +55,12 @@ def parse_config_file(text):
def load_config():
# user-defined config files
main = expandvars('$XDG_CONFIG_HOME/pirate-get')
alt = expanduser('~/.config/pirate-get')
config_home = os.getenv('XDG_CONFIG_HOME', '~/.config')
config = expanduser(os.path.join(config_home, 'pirate-get'))
# read config file
if os.path.isfile(main):
with open(main) as f:
return parse_config_file(f.read())
if os.path.isfile(alt):
with open(alt) as f:
if os.path.isfile(config):
with open(config) as f:
return parse_config_file(f.read())
return parse_config_file("")
@ -173,6 +170,9 @@ def parse_args(args_in):
parser.add_argument('--disable-colors', dest='color',
action='store_false',
help='disable colored output')
parser.add_argument('-m', '--mirror',
type=str, nargs='+',
help='the pirate bay mirror(s) to use')
args = parser.parse_args(args_in)
return args
@ -207,6 +207,9 @@ def combine_configs(config, args):
if not args.save_directory:
args.save_directory = config.get('Save', 'directory')
if not args.mirror:
args.mirror = config.get('Misc', 'mirror').split()
args.transmission_command = ['transmission-remote']
if args.port:
args.transmission_command.append(args.port)
@ -228,16 +231,16 @@ def combine_configs(config, args):
return args
def connect_mirror(mirror, printer, pages, category, sort, action, search):
def connect_mirror(mirror, printer, args):
try:
printer.print('Trying', mirror, end='... ')
results = pirate.torrent.remote(
printer=printer,
pages=pages,
category=pirate.torrent.parse_category(printer, category),
sort=pirate.torrent.parse_sort(printer, sort),
mode=action,
terms=search,
pages=args.pages,
category=pirate.torrent.parse_category(printer, args.category),
sort=pirate.torrent.parse_sort(printer, args.sort),
mode=args.action,
terms=args.search,
mirror=mirror)
except (urllib.error.URLError, socket.timeout, IOError, ValueError):
printer.print('Failed', color='WARN')
@ -247,9 +250,10 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search):
return results, mirror
def search_mirrors(printer, *args):
# try official site
result = connect_mirror(pirate.data.default_mirror, printer, *args)
def search_mirrors(printer, args):
# try default or user mirrors
for mirror in args.mirror:
result = connect_mirror(mirror, printer, args)
if result:
return result
@ -271,7 +275,7 @@ def search_mirrors(printer, *args):
for mirror in mirrors:
if mirror in pirate.data.blacklist:
continue
result = connect_mirror(mirror, printer, *args)
result = connect_mirror(mirror, printer, args)
if result:
return result
else:
@ -312,8 +316,7 @@ def pirate_main(args):
results = pirate.local.search(args.database, args.search)
elif args.source == 'tpb':
try:
results, site = search_mirrors(printer, args.pages, args.category,
args.sort, args.action, args.search)
results, site = search_mirrors(printer, args)
except IOError as e:
printer.print(e.args[0] + ' :( ', color='ERROR')
if len(e.args) > 1:

View File

@ -2,6 +2,7 @@
import socket
import unittest
import subprocess
from argparse import Namespace
from unittest import mock
from unittest.mock import patch, call, MagicMock
@ -147,7 +148,9 @@ class TestPirate(unittest.TestCase):
self.assertEqual(test[2][option], value)
def test_search_mirrors(self):
pages, category, sort, action, search = (1, 100, 10, 'browse', [])
args = Namespace(pages=1, category=100, sort=10,
action='browse', search=[],
mirror=[pirate.data.default_mirror])
class MockResponse():
readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']])
info = mock.MagicMock()
@ -156,12 +159,12 @@ class TestPirate(unittest.TestCase):
printer = MagicMock(Printer)
with patch('urllib.request.urlopen', return_value=response_obj) as urlopen:
with patch('pirate.torrent.remote', return_value=[]) as remote:
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
results, mirror = pirate.pirate.search_mirrors(printer, args)
self.assertEqual(results, [])
self.assertEqual(mirror, pirate.data.default_mirror)
remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror)
with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote:
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
results, mirror = pirate.pirate.search_mirrors(printer, args)
self.assertEqual(results, [])
self.assertEqual(mirror, 'https://example.com')
remote.assert_has_calls([