mirror of
https://github.com/vikstrous/pirate-get
synced 2025-01-10 10:04:21 +01:00
Merge pull request #89 from vikstrous/mirrors
Allow users to set custom mirrors
This commit is contained in:
commit
68e61af5d2
@ -55,6 +55,10 @@ transmission = false
|
||||
|
||||
; use colored output
|
||||
colors = true
|
||||
|
||||
; the pirate bay mirror(s) to use:
|
||||
; one or more space separated URLs
|
||||
mirror = http://thepiratebay.org
|
||||
```
|
||||
|
||||
Note:
|
||||
|
@ -39,6 +39,7 @@ def parse_config_file(text):
|
||||
config.set('Misc', 'openCommand', '')
|
||||
config.set('Misc', 'transmission', 'false')
|
||||
config.set('Misc', 'colors', 'true')
|
||||
config.set('Misc', 'mirror', pirate.data.default_mirror)
|
||||
|
||||
config.read_string(text)
|
||||
|
||||
@ -54,16 +55,12 @@ def parse_config_file(text):
|
||||
|
||||
def load_config():
|
||||
# user-defined config files
|
||||
main = expandvars('$XDG_CONFIG_HOME/pirate-get')
|
||||
alt = expanduser('~/.config/pirate-get')
|
||||
config_home = os.getenv('XDG_CONFIG_HOME', '~/.config')
|
||||
config = expanduser(os.path.join(config_home, 'pirate-get'))
|
||||
|
||||
# read config file
|
||||
if os.path.isfile(main):
|
||||
with open(main) as f:
|
||||
return parse_config_file(f.read())
|
||||
|
||||
if os.path.isfile(alt):
|
||||
with open(alt) as f:
|
||||
if os.path.isfile(config):
|
||||
with open(config) as f:
|
||||
return parse_config_file(f.read())
|
||||
|
||||
return parse_config_file("")
|
||||
@ -173,6 +170,9 @@ def parse_args(args_in):
|
||||
parser.add_argument('--disable-colors', dest='color',
|
||||
action='store_false',
|
||||
help='disable colored output')
|
||||
parser.add_argument('-m', '--mirror',
|
||||
type=str, nargs='+',
|
||||
help='the pirate bay mirror(s) to use')
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
return args
|
||||
@ -207,6 +207,9 @@ def combine_configs(config, args):
|
||||
if not args.save_directory:
|
||||
args.save_directory = config.get('Save', 'directory')
|
||||
|
||||
if not args.mirror:
|
||||
args.mirror = config.get('Misc', 'mirror').split()
|
||||
|
||||
args.transmission_command = ['transmission-remote']
|
||||
if args.port:
|
||||
args.transmission_command.append(args.port)
|
||||
@ -228,16 +231,16 @@ def combine_configs(config, args):
|
||||
return args
|
||||
|
||||
|
||||
def connect_mirror(mirror, printer, pages, category, sort, action, search):
|
||||
def connect_mirror(mirror, printer, args):
|
||||
try:
|
||||
printer.print('Trying', mirror, end='... ')
|
||||
results = pirate.torrent.remote(
|
||||
printer=printer,
|
||||
pages=pages,
|
||||
category=pirate.torrent.parse_category(printer, category),
|
||||
sort=pirate.torrent.parse_sort(printer, sort),
|
||||
mode=action,
|
||||
terms=search,
|
||||
pages=args.pages,
|
||||
category=pirate.torrent.parse_category(printer, args.category),
|
||||
sort=pirate.torrent.parse_sort(printer, args.sort),
|
||||
mode=args.action,
|
||||
terms=args.search,
|
||||
mirror=mirror)
|
||||
except (urllib.error.URLError, socket.timeout, IOError, ValueError):
|
||||
printer.print('Failed', color='WARN')
|
||||
@ -247,11 +250,12 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search):
|
||||
return results, mirror
|
||||
|
||||
|
||||
def search_mirrors(printer, *args):
|
||||
# try official site
|
||||
result = connect_mirror(pirate.data.default_mirror, printer, *args)
|
||||
if result:
|
||||
return result
|
||||
def search_mirrors(printer, args):
|
||||
# try default or user mirrors
|
||||
for mirror in args.mirror:
|
||||
result = connect_mirror(mirror, printer, args)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# download mirror list
|
||||
try:
|
||||
@ -271,7 +275,7 @@ def search_mirrors(printer, *args):
|
||||
for mirror in mirrors:
|
||||
if mirror in pirate.data.blacklist:
|
||||
continue
|
||||
result = connect_mirror(mirror, printer, *args)
|
||||
result = connect_mirror(mirror, printer, args)
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
@ -312,8 +316,7 @@ def pirate_main(args):
|
||||
results = pirate.local.search(args.database, args.search)
|
||||
elif args.source == 'tpb':
|
||||
try:
|
||||
results, site = search_mirrors(printer, args.pages, args.category,
|
||||
args.sort, args.action, args.search)
|
||||
results, site = search_mirrors(printer, args)
|
||||
except IOError as e:
|
||||
printer.print(e.args[0] + ' :( ', color='ERROR')
|
||||
if len(e.args) > 1:
|
||||
|
@ -2,6 +2,7 @@
|
||||
import socket
|
||||
import unittest
|
||||
import subprocess
|
||||
from argparse import Namespace
|
||||
from unittest import mock
|
||||
from unittest.mock import patch, call, MagicMock
|
||||
|
||||
@ -147,7 +148,9 @@ class TestPirate(unittest.TestCase):
|
||||
self.assertEqual(test[2][option], value)
|
||||
|
||||
def test_search_mirrors(self):
|
||||
pages, category, sort, action, search = (1, 100, 10, 'browse', [])
|
||||
args = Namespace(pages=1, category=100, sort=10,
|
||||
action='browse', search=[],
|
||||
mirror=[pirate.data.default_mirror])
|
||||
class MockResponse():
|
||||
readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']])
|
||||
info = mock.MagicMock()
|
||||
@ -156,12 +159,12 @@ class TestPirate(unittest.TestCase):
|
||||
printer = MagicMock(Printer)
|
||||
with patch('urllib.request.urlopen', return_value=response_obj) as urlopen:
|
||||
with patch('pirate.torrent.remote', return_value=[]) as remote:
|
||||
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
|
||||
results, mirror = pirate.pirate.search_mirrors(printer, args)
|
||||
self.assertEqual(results, [])
|
||||
self.assertEqual(mirror, pirate.data.default_mirror)
|
||||
remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror)
|
||||
with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote:
|
||||
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
|
||||
results, mirror = pirate.pirate.search_mirrors(printer, args)
|
||||
self.assertEqual(results, [])
|
||||
self.assertEqual(mirror, 'https://example.com')
|
||||
remote.assert_has_calls([
|
||||
|
Loading…
Reference in New Issue
Block a user