1
0
mirror of https://github.com/vikstrous/pirate-get synced 2025-01-10 10:04:21 +01:00

make the printer a class so it can be configured

This commit is contained in:
Viktor Stanchev 2015-09-20 14:14:00 -07:00
parent 1e7da45710
commit a2b4f29643
7 changed files with 188 additions and 178 deletions

View File

@ -11,4 +11,3 @@ blacklist = set(json.loads(get_resource('blacklist.json').decode()))
default_headers = {'User-Agent': 'pirate get'} default_headers = {'User-Agent': 'pirate get'}
default_timeout = 10 default_timeout = 10
colored_output = True

View File

@ -14,10 +14,9 @@ import webbrowser
import pirate.data import pirate.data
import pirate.torrent import pirate.torrent
import pirate.local import pirate.local
import pirate.print
from os.path import expanduser, expandvars from os.path import expanduser, expandvars
from pirate.print import print from pirate.print import Printer
def parse_config_file(text): def parse_config_file(text):
@ -227,7 +226,7 @@ def combine_configs(config, args):
return args return args
def search_mirrors(pages, category, sort, action, search): def search_mirrors(printer, pages, category, sort, action, search):
mirror_sources = [None, 'https://proxybay.co/list.txt'] mirror_sources = [None, 'https://proxybay.co/list.txt']
for mirror_source in mirror_sources: for mirror_source in mirror_sources:
mirrors = OrderedDict() mirrors = OrderedDict()
@ -239,7 +238,7 @@ def search_mirrors(pages, category, sort, action, search):
headers=pirate.data.default_headers) headers=pirate.data.default_headers)
f = request.urlopen(req, timeout=pirate.data.default_timeout) f = request.urlopen(req, timeout=pirate.data.default_timeout)
except IOError: except IOError:
print('Could not fetch additional mirrors', color='WARN') printer.print('Could not fetch additional mirrors', color='WARN')
else: else:
if f.getcode() != 200: if f.getcode() != 200:
raise IOError('The proxy bay responded with an error.') raise IOError('The proxy bay responded with an error.')
@ -251,36 +250,39 @@ def search_mirrors(pages, category, sort, action, search):
for mirror in mirrors.keys(): for mirror in mirrors.keys():
try: try:
print('Trying', mirror, end='... \n') printer.print('Trying', mirror, end='... \n')
results = pirate.torrent.remote( results = pirate.torrent.remote(
printer=printer,
pages=pages, pages=pages,
category=pirate.torrent.parse_category(category), category=pirate.torrent.parse_category(printer, category),
sort=pirate.torrent.parse_sort(sort), sort=pirate.torrent.parse_sort(printer, sort),
mode=action, mode=action,
terms=search, terms=search,
mirror=mirror mirror=mirror
) )
except (urllib.error.URLError, socket.timeout, except (urllib.error.URLError, socket.timeout,
IOError, ValueError): IOError, ValueError):
print('Failed', color='WARN') printer.print('Failed', color='WARN')
else: else:
print('Ok', color='alt') printer.print('Ok', color='alt')
return results, mirror return results, mirror
else: else:
print('No available mirrors :(', color='WARN') printer.print('No available mirrors :(', color='WARN')
return [], None return [], None
def main(): def main():
args = combine_configs(load_config(), parse_args(sys.argv[1:])) args = combine_configs(load_config(), parse_args(sys.argv[1:]))
printer = Printer(args.color)
# check it transmission is running # check it transmission is running
if args.transmission: if args.transmission:
ret = subprocess.call(args.transmission_command + ['-l'], ret = subprocess.call(args.transmission_command + ['-l'],
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL) stderr=subprocess.DEVNULL)
if ret != 0: if ret != 0:
print('Transmission is not running.') printer.print('Transmission is not running.')
sys.exit(1) sys.exit(1)
# non-torrent fetching actions # non-torrent fetching actions
@ -289,14 +291,14 @@ def main():
cur_color = 'zebra_0' cur_color = 'zebra_0'
for key, value in sorted(pirate.data.categories.items()): for key, value in sorted(pirate.data.categories.items()):
cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1' cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1'
print(str(value), '\t', key, sep='', color=cur_color) printer.print(str(value), '\t', key, sep='', color=cur_color)
return return
if args.action == 'list_sorts': if args.action == 'list_sorts':
cur_color = 'zebra_0' cur_color = 'zebra_0'
for key, value in sorted(pirate.data.sorts.items()): for key, value in sorted(pirate.data.sorts.items()):
cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1' cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1'
print(str(value), '\t', key, sep='', color=cur_color) printer.print(str(value), '\t', key, sep='', color=cur_color)
return return
# fetch torrents # fetch torrents
@ -304,38 +306,38 @@ def main():
if args.source == 'local_tpb': if args.source == 'local_tpb':
results = pirate.local.search(args.database, args.search) results = pirate.local.search(args.database, args.search)
elif args.source == 'tpb': elif args.source == 'tpb':
results, site = search_mirrors(args.pages, args.category, args.sort, args.action, args.search) results, site = search_mirrors(printer, args.pages, args.category, args.sort, args.action, args.search)
if len(results) == 0: if len(results) == 0:
print('No results') printer.print('No results')
return return
pirate.print.search_results(results, local=args.source == 'local_tpb') printer.search_results(results, local=args.source == 'local_tpb')
# number of results to pick # number of results to pick
if args.first: if args.first:
print('Choosing first result') printer.print('Choosing first result')
choices = [0] choices = [0]
elif args.download_all: elif args.download_all:
print('Downloading all results') printer.print('Downloading all results')
choices = range(len(results)) choices = range(len(results))
else: else:
# interactive loop for per-torrent actions # interactive loop for per-torrent actions
while True: while True:
print("\nSelect links (Type 'h' for more options" printer.print("\nSelect links (Type 'h' for more options"
", 'q' to quit)", end='\b', color='alt') ", 'q' to quit)", end='\b', color='alt')
try: try:
l = input(': ') l = input(': ')
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
print('\nCancelled.') printer.print('\nCancelled.')
return return
try: try:
code, choices = parse_torrent_command(l) code, choices = parse_torrent_command(l)
# Act on option, if supplied # Act on option, if supplied
print('') printer.print('')
if code == 'h': if code == 'h':
print('Options:', printer.print('Options:',
'<links>: Download selected torrents', '<links>: Download selected torrents',
'[m<links>]: Save magnets as files', '[m<links>]: Save magnets as files',
'[t<links>]: Save .torrent files', '[t<links>]: Save .torrent files',
@ -344,35 +346,35 @@ def main():
'[p] Print search results', '[p] Print search results',
'[q] Quit', sep='\n') '[q] Quit', sep='\n')
elif code == 'q': elif code == 'q':
print('Bye.', color='alt') printer.print('Bye.', color='alt')
return return
elif code == 'd': elif code == 'd':
pirate.print.descriptions(choices, results, site) printer.descriptions(choices, results, site)
elif code == 'f': elif code == 'f':
pirate.print.file_lists(choices, results, site) printer.file_lists(choices, results, site)
elif code == 'p': elif code == 'p':
pirate.print.search_results(results) printer.search_results(results)
elif code == 'm': elif code == 'm':
pirate.torrent.save_magnets(choices, results, args.save_directory) pirate.torrent.save_magnets(printer, choices, results, args.save_directory)
elif code == 't': elif code == 't':
pirate.torrent.save_torrents(choices, results, args.save_directory) pirate.torrent.save_torrents(printer, choices, results, args.save_directory)
elif not l: elif not l:
print('No links entered!', color='WARN') printer.print('No links entered!', color='WARN')
else: else:
break break
except Exception as e: except Exception as e:
print('Exception:', e, color='ERROR') printer.print('Exception:', e, color='ERROR')
return return
# output # output
if args.output == 'save_magnet_files': if args.output == 'save_magnet_files':
print('Saving selected magnets...') printer.print('Saving selected magnets...')
pirate.torrent.save_magnets(choices, results, args.save_directory) pirate.torrent.save_magnets(choices, results, args.save_directory)
return return
if args.output == 'save_torrent_files': if args.output == 'save_torrent_files':
print('Saving selected torrents...') printer.print('Saving selected torrents...')
pirate.torrent.save_torrents(choices, results, args.save_directory) pirate.torrent.save_torrents(choices, results, args.save_directory)
return return

View File

@ -13,133 +13,137 @@ import veryprettytable
import pirate.data import pirate.data
def print(*args, **kwargs): class Printer:
if kwargs.get('color', False) and pirate.data.colored_output: def __init__(self, enable_color):
colorama.init() self.enable_color = enable_color
color_dict = {
'default': '',
'header': colorama.Back.BLACK + colorama.Fore.WHITE,
'alt': colorama.Fore.YELLOW,
'zebra_0': '',
'zebra_1': colorama.Fore.BLUE,
'WARN': colorama.Fore.MAGENTA,
'ERROR': colorama.Fore.RED}
c = color_dict[kwargs.pop('color')] def print(self, *args, **kwargs):
args = (c + args[0],) + args[1:] + (colorama.Style.RESET_ALL,) if kwargs.get('color', False) and self.enable_color:
kwargs.pop('color', None) colorama.init()
return builtins.print(*args, **kwargs) color_dict = {
else: 'default': '',
kwargs.pop('color', None) 'header': colorama.Back.BLACK + colorama.Fore.WHITE,
return builtins.print(*args, **kwargs) 'alt': colorama.Fore.YELLOW,
'zebra_0': '',
'zebra_1': colorama.Fore.BLUE,
'WARN': colorama.Fore.MAGENTA,
'ERROR': colorama.Fore.RED}
c = color_dict[kwargs.pop('color')]
args = (c + args[0],) + args[1:] + (colorama.Style.RESET_ALL,)
kwargs.pop('color', None)
return builtins.print(*args, **kwargs)
else:
kwargs.pop('color', None)
return builtins.print(*args, **kwargs)
# TODO: extract the name from the search results instead of from the magnet link when possible # TODO: extract the name from the search results instead of from the magnet link when possible
def search_results(results, local=None): def search_results(self, results, local=None):
columns = shutil.get_terminal_size((80, 20)).columns columns = shutil.get_terminal_size((80, 20)).columns
even = True even = True
if local:
table = veryprettytable.VeryPrettyTable(['LINK', 'NAME'])
else:
table = veryprettytable.VeryPrettyTable(['LINK', 'SEED', 'LEECH', 'RATIO', 'SIZE', '', 'UPLOAD', 'NAME'])
table.align['NAME'] = 'l'
table.align['SEED'] = 'r'
table.align['LEECH'] = 'r'
table.align['RATIO'] = 'r'
table.align['SIZE'] = 'r'
table.align['UPLOAD'] = 'l'
table.max_width = columns
table.border = False
table.padding_width = 1
for n, result in enumerate(results):
name = re.search(r'dn=([^\&]*)', result['magnet'])
torrent_name = parse.unquote_plus(name.group(1))
if local: if local:
content = [n, torrent_name[:columns - 7]] table = veryprettytable.VeryPrettyTable(['LINK', 'NAME'])
else: else:
no_seeders = int(result['seeds']) table = veryprettytable.VeryPrettyTable(['LINK', 'SEED', 'LEECH', 'RATIO', 'SIZE', '', 'UPLOAD', 'NAME'])
no_leechers = int(result['leechers']) table.align['NAME'] = 'l'
if result['size'] != []: table.align['SEED'] = 'r'
size = float(result['size'][0]) table.align['LEECH'] = 'r'
unit = result['size'][1] table.align['RATIO'] = 'r'
table.align['SIZE'] = 'r'
table.align['UPLOAD'] = 'l'
table.max_width = columns
table.border = False
table.padding_width = 1
for n, result in enumerate(results):
name = re.search(r'dn=([^\&]*)', result['magnet'])
torrent_name = parse.unquote_plus(name.group(1))
if local:
content = [n, torrent_name[:columns - 7]]
else: else:
size = 0 no_seeders = int(result['seeds'])
unit = '???' no_leechers = int(result['leechers'])
date = result['uploaded'] if result['size'] != []:
size = float(result['size'][0])
unit = result['size'][1]
else:
size = 0
unit = '???'
date = result['uploaded']
# compute the S/L ratio (Higher is better) # compute the S/L ratio (Higher is better)
try: try:
ratio = no_seeders / no_leechers ratio = no_seeders / no_leechers
except ZeroDivisionError: except ZeroDivisionError:
ratio = float('inf') ratio = float('inf')
content = [n, no_seeders, no_leechers, '{:.1f}'.format(ratio), content = [n, no_seeders, no_leechers, '{:.1f}'.format(ratio),
'{:.1f}'.format(size), unit, date, torrent_name[:columns - 53]] '{:.1f}'.format(size), unit, date, torrent_name[:columns - 53]]
if even: if even:
table.add_row(content) table.add_row(content)
else: else:
table.add_row(content, fore_color='blue') table.add_row(content, fore_color='blue')
# Alternate between colors # Alternate between colors
even = not even even = not even
print(table) self.print(table)
def descriptions(chosen_links, results, site): def descriptions(self, chosen_links, results, site):
for link in chosen_links: for link in chosen_links:
path = '/torrent/%s/' % results[link]['id'] path = '/torrent/%s/' % results[link]['id']
req = request.Request(site + path, headers=pirate.data.default_headers) req = request.Request(site + path, headers=pirate.data.default_headers)
req.add_header('Accept-encoding', 'gzip') req.add_header('Accept-encoding', 'gzip')
f = request.urlopen(req, timeout=pirate.data.default_timeout) f = request.urlopen(req, timeout=pirate.data.default_timeout)
if f.info().get('Content-Encoding') == 'gzip': if f.info().get('Content-Encoding') == 'gzip':
f = gzip.GzipFile(fileobj=BytesIO(f.read())) f = gzip.GzipFile(fileobj=BytesIO(f.read()))
res = f.read().decode('utf-8') res = f.read().decode('utf-8')
name = re.search(r'dn=([^\&]*)', results[link]['magnet']) name = re.search(r'dn=([^\&]*)', results[link]['magnet'])
torrent_name = parse.unquote(name.group(1)).replace('+', ' ') torrent_name = parse.unquote(name.group(1)).replace('+', ' ')
desc = re.search(r'<div class="nfo">\s*<pre>(.+?)(?=</pre>)', desc = re.search(r'<div class="nfo">\s*<pre>(.+?)(?=</pre>)',
res, re.DOTALL).group(1) res, re.DOTALL).group(1)
# Replace HTML links with markdown style versions # Replace HTML links with markdown style versions
desc = re.sub(r'<a href="\s*([^"]+?)\s*"[^>]*>(\s*)([^<]+?)(\s*' desc = re.sub(r'<a href="\s*([^"]+?)\s*"[^>]*>(\s*)([^<]+?)(\s*'
r')</a>', r'\2[\3](\1)\4', desc) r')</a>', r'\2[\3](\1)\4', desc)
print('Description for "%s":' % torrent_name, color='zebra_1') self.print('Description for "%s":' % torrent_name, color='zebra_1')
print(desc, color='zebra_0') self.print(desc, color='zebra_0')
def file_lists(chosen_links, results, site): def file_lists(self, chosen_links, results, site):
for link in chosen_links: for link in chosen_links:
path = '/ajax_details_filelist.php' path = '/ajax_details_filelist.php'
query = '?id=' + results[link]['id'] query = '?id=' + results[link]['id']
req = request.Request(site + path + query, req = request.Request(site + path + query,
headers=pirate.data.default_headers) headers=pirate.data.default_headers)
req.add_header('Accept-encoding', 'gzip') req.add_header('Accept-encoding', 'gzip')
f = request.urlopen(req, timeout=pirate.data.default_timeout) f = request.urlopen(req, timeout=pirate.data.default_timeout)
if f.info().get('Content-Encoding') == 'gzip': if f.info().get('Content-Encoding') == 'gzip':
f = gzip.GzipFile(fileobj=BytesIO(f.read())) f = gzip.GzipFile(fileobj=BytesIO(f.read()))
# TODO: proper html decoding/parsing # TODO: proper html decoding/parsing
res = f.read().decode('utf-8').replace('&nbsp;', ' ') res = f.read().decode('utf-8').replace('&nbsp;', ' ')
if 'File list not available.' in res: if 'File list not available.' in res:
print('File list not available.') self.print('File list not available.')
return return
files = re.findall(r'<td align="left">\s*([^<]+?)\s*</td><td ali' files = re.findall(r'<td align="left">\s*([^<]+?)\s*</td><td ali'
r'gn="right">\s*([^<]+?)\s*</tr>', res) r'gn="right">\s*([^<]+?)\s*</tr>', res)
name = re.search(r'dn=([^\&]*)', results[link]['magnet']) name = re.search(r'dn=([^\&]*)', results[link]['magnet'])
torrent_name = parse.unquote(name.group(1)).replace('+', ' ') torrent_name = parse.unquote(name.group(1)).replace('+', ' ')
print('Files in "%s":' % torrent_name, color='zebra_1') self.print('Files in "%s":' % torrent_name, color='zebra_1')
cur_color = 'zebra_0' cur_color = 'zebra_0'
for f in files: for f in files:
print('{0[0]:>11} {0[1]}'.format(f), color=cur_color) self.print('{0[0]:>11} {0[1]}'.format(f), color=cur_color)
cur_color = 'zebra_0' if (cur_color == 'zebra_1') else 'zebra_1' cur_color = 'zebra_0' if (cur_color == 'zebra_1') else 'zebra_1'

View File

@ -9,7 +9,6 @@ import os.path
from pyquery import PyQuery as pq from pyquery import PyQuery as pq
import pirate.data import pirate.data
from pirate.print import print
from io import BytesIO from io import BytesIO
@ -17,7 +16,7 @@ from io import BytesIO
parser_regex = r'"(magnet\:\?xt=[^"]*)|<td align="right">([^<]+)</td>' parser_regex = r'"(magnet\:\?xt=[^"]*)|<td align="right">([^<]+)</td>'
def parse_category(category): def parse_category(printer, category):
try: try:
category = int(category) category = int(category)
except ValueError: except ValueError:
@ -27,11 +26,11 @@ def parse_category(category):
elif category in pirate.data.categories.keys(): elif category in pirate.data.categories.keys():
return pirate.data.categories[category] return pirate.data.categories[category]
else: else:
print('Invalid category ignored', color='WARN') printer.print('Invalid category ignored', color='WARN')
return 0 return 0
def parse_sort(sort): def parse_sort(printer, sort):
try: try:
sort = int(sort) sort = int(sort)
except ValueError: except ValueError:
@ -41,7 +40,7 @@ def parse_sort(sort):
elif sort in pirate.data.sorts.keys(): elif sort in pirate.data.sorts.keys():
return pirate.data.sorts[sort] return pirate.data.sorts[sort]
else: else:
print('Invalid sort ignored', color='WARN') printer.print('Invalid sort ignored', color='WARN')
return 99 return 99
@ -119,7 +118,7 @@ def parse_page(html):
return results return results
def remote(pages, category, sort, mode, terms, mirror): def remote(printer, pages, category, sort, mode, terms, mirror):
res_l = [] res_l = []
if pages < 1: if pages < 1:
@ -142,7 +141,7 @@ def remote(pages, category, sort, mode, terms, mirror):
res_l += parse_page(res) res_l += parse_page(res)
except KeyboardInterrupt: except KeyboardInterrupt:
print('\nCancelled.') printer.print('\nCancelled.')
sys.exit(0) sys.exit(0)
return res_l return res_l
@ -162,7 +161,7 @@ def get_torrent(info_hash):
# TODO: handle slashes in torrent names # TODO: handle slashes in torrent names
def save_torrents(chosen_links, results, folder): def save_torrents(printer, chosen_links, results, folder):
for link in chosen_links: for link in chosen_links:
magnet = results[link]['magnet'] magnet = results[link]['magnet']
name = re.search(r'dn=([^\&]*)', magnet) name = re.search(r'dn=([^\&]*)', magnet)
@ -173,14 +172,14 @@ def save_torrents(chosen_links, results, folder):
try: try:
torrent = get_torrent(info_hash) torrent = get_torrent(info_hash)
except urllib.error.HTTPError: except urllib.error.HTTPError:
print('There is no cached file for this torrent :(', color='ERROR') printer.print('There is no cached file for this torrent :(', color='ERROR')
else: else:
open(file, 'wb').write(torrent) open(file, 'wb').write(torrent)
print('Saved {:X} in {}'.format(info_hash, file)) printer.print('Saved {:X} in {}'.format(info_hash, file))
# TODO: handle slashes in torrent names # TODO: handle slashes in torrent names
def save_magnets(chosen_links, results, folder): def save_magnets(printer, chosen_links, results, folder):
for link in chosen_links: for link in chosen_links:
magnet = results[link]['magnet'] magnet = results[link]['magnet']
name = re.search(r'dn=([^\&]*)', magnet) name = re.search(r'dn=([^\&]*)', magnet)
@ -188,6 +187,6 @@ def save_magnets(chosen_links, results, folder):
info_hash = int(re.search(r'btih:([a-f0-9]{40})', magnet).group(1), 16) info_hash = int(re.search(r'btih:([a-f0-9]{40})', magnet).group(1), 16)
file = os.path.join(folder, torrent_name + '.magnet') file = os.path.join(folder, torrent_name + '.magnet')
print('Saved {:X} in {}'.format(info_hash, file)) printer.print('Saved {:X} in {}'.format(info_hash, file))
with open(file, 'w') as f: with open(file, 'w') as f:
f.write(magnet + '\n') f.write(magnet + '\n')

View File

@ -2,9 +2,10 @@
import socket import socket
import unittest import unittest
from unittest import mock from unittest import mock
from unittest.mock import patch, call from unittest.mock import patch, call, MagicMock
import pirate.pirate import pirate.pirate
from pirate.print import Printer
class TestPirate(unittest.TestCase): class TestPirate(unittest.TestCase):
@ -119,19 +120,20 @@ class TestPirate(unittest.TestCase):
info = mock.MagicMock() info = mock.MagicMock()
getcode = mock.MagicMock(return_value=200) getcode = mock.MagicMock(return_value=200)
response_obj = MockResponse() response_obj = MockResponse()
printer = MagicMock(Printer)
with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: with patch('urllib.request.urlopen', return_value=response_obj) as urlopen:
with patch('pirate.torrent.remote', return_value=[]) as remote: with patch('pirate.torrent.remote', return_value=[]) as remote:
results, mirror = pirate.pirate.search_mirrors(pages, category, sort, action, search) results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
self.assertEqual(results, []) self.assertEqual(results, [])
self.assertEqual(mirror, 'https://thepiratebay.mn') self.assertEqual(mirror, 'https://thepiratebay.mn')
remote.assert_called_once_with(pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn') remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn')
with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote: with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote:
results, mirror = pirate.pirate.search_mirrors(pages, category, sort, action, search) results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
self.assertEqual(results, []) self.assertEqual(results, [])
self.assertEqual(mirror, 'https://example.com') self.assertEqual(mirror, 'https://example.com')
remote.assert_has_calls([ remote.assert_has_calls([
call(pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn'), call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn'),
call(pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://example.com') call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://example.com')
]) ])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -2,7 +2,7 @@
import unittest import unittest
from unittest.mock import patch, call, MagicMock from unittest.mock import patch, call, MagicMock
import pirate.print from pirate.print import Printer
class TestPrint(unittest.TestCase): class TestPrint(unittest.TestCase):
@ -12,6 +12,8 @@ class TestPrint(unittest.TestCase):
add_row = MagicMock() add_row = MagicMock()
align = {} align = {}
mock = MockTable() mock = MockTable()
printer = Printer(False)
printer.print = MagicMock()
with patch('prettytable.PrettyTable', return_value=mock) as prettytable: with patch('prettytable.PrettyTable', return_value=mock) as prettytable:
results = [{ results = [{
'magnet': 'dn=name', 'magnet': 'dn=name',
@ -20,7 +22,7 @@ class TestPrint(unittest.TestCase):
'size': ['3','MiB'], 'size': ['3','MiB'],
'uploaded': 'never' 'uploaded': 'never'
}] }]
pirate.print.search_results(results) printer.search_results(results)
prettytable.assert_called_once_with(['LINK', 'SEED', 'LEECH', 'RATIO', 'SIZE', '', 'UPLOAD', 'NAME']) prettytable.assert_called_once_with(['LINK', 'SEED', 'LEECH', 'RATIO', 'SIZE', '', 'UPLOAD', 'NAME'])
mock.add_row.assert_has_calls([call([0, 1, 2, '0.5', '3.0', 'MiB', 'never', 'name'])]) mock.add_row.assert_has_calls([call([0, 1, 2, '0.5', '3.0', 'MiB', 'never', 'name'])])
@ -29,12 +31,14 @@ class TestPrint(unittest.TestCase):
add_row = MagicMock() add_row = MagicMock()
align = {} align = {}
mock = MockTable() mock = MockTable()
printer = Printer(False)
printer.print = MagicMock()
with patch('veryprettytable.VeryPrettyTable', return_value=mock) as prettytable: with patch('veryprettytable.VeryPrettyTable', return_value=mock) as prettytable:
results = [{ results = [{
'magnet': 'dn=name', 'magnet': 'dn=name',
'Name': 'name', 'Name': 'name',
}] }]
pirate.print.search_results(results, local=True) printer.search_results(results, local=True)
prettytable.assert_called_once_with(['LINK', 'NAME']) prettytable.assert_called_once_with(['LINK', 'NAME'])
mock.add_row.assert_has_calls([call([0, 'name'])]) mock.add_row.assert_has_calls([call([0, 'name'])])

View File

@ -1,16 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import unittest import unittest
from unittest import mock from unittest import mock
from unittest.mock import patch from unittest.mock import patch, MagicMock
import pirate.torrent
import pirate.data
import os import os
import io import io
import urllib import urllib
import pirate.torrent
import pirate.data
from pirate.print import Printer
from tests import util from tests import util
class TestTorrent(unittest.TestCase): class TestTorrent(unittest.TestCase):
def test_no_hits(self): def test_no_hits(self):
@ -47,25 +47,25 @@ class TestTorrent(unittest.TestCase):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_parse_category(self): def test_parse_category(self):
category = pirate.torrent.parse_category('Audio') category = pirate.torrent.parse_category(MagicMock(Printer), 'Audio')
self.assertEqual(100, category) self.assertEqual(100, category)
category = pirate.torrent.parse_category('Video') category = pirate.torrent.parse_category(MagicMock(Printer), 'Video')
self.assertEqual(200, category) self.assertEqual(200, category)
category = pirate.torrent.parse_category('100') category = pirate.torrent.parse_category(MagicMock(Printer), '100')
self.assertEqual(100, category) self.assertEqual(100, category)
category = pirate.torrent.parse_category('asdf') category = pirate.torrent.parse_category(MagicMock(Printer), 'asdf')
self.assertEqual(0, category) self.assertEqual(0, category)
category = pirate.torrent.parse_category('9001') category = pirate.torrent.parse_category(MagicMock(Printer), '9001')
self.assertEqual(0, category) self.assertEqual(0, category)
def test_parse_sort(self): def test_parse_sort(self):
sort = pirate.torrent.parse_sort('SeedersDsc') sort = pirate.torrent.parse_sort(MagicMock(Printer), 'SeedersDsc')
self.assertEqual(7, sort) self.assertEqual(7, sort)
sort = pirate.torrent.parse_sort('7') sort = pirate.torrent.parse_sort(MagicMock(Printer), '7')
self.assertEqual(7, sort) self.assertEqual(7, sort)
sort = pirate.torrent.parse_sort('asdf') sort = pirate.torrent.parse_sort(MagicMock(Printer), 'asdf')
self.assertEqual(99, sort) self.assertEqual(99, sort)
sort = pirate.torrent.parse_sort('7000') sort = pirate.torrent.parse_sort(MagicMock(Printer), '7000')
self.assertEqual(99, sort) self.assertEqual(99, sort)
def test_request_path(self): def test_request_path(self):
@ -94,19 +94,19 @@ class TestTorrent(unittest.TestCase):
def test_save_torrents(self, get_torrent): def test_save_torrents(self, get_torrent):
with patch('pirate.torrent.open', mock.mock_open(), create=True) as open_: with patch('pirate.torrent.open', mock.mock_open(), create=True) as open_:
magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969' magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969'
pirate.torrent.save_torrents([0], [{'magnet':magnet}], 'path') pirate.torrent.save_torrents(MagicMock(Printer), [0], [{'magnet':magnet}], 'path')
get_torrent.assert_called_once_with(293294978876299923284263767676068334936407502135) get_torrent.assert_called_once_with(293294978876299923284263767676068334936407502135)
open_.assert_called_once_with('path/Test Drive Unlimited [PC Version].torrent', 'wb') open_.assert_called_once_with('path/Test Drive Unlimited [PC Version].torrent', 'wb')
@patch('pirate.torrent.get_torrent', side_effect=urllib.error.HTTPError('', '', '', '', io.StringIO())) @patch('pirate.torrent.get_torrent', side_effect=urllib.error.HTTPError('', '', '', '', io.StringIO()))
def test_save_torrents_fail(self, get_torrent): def test_save_torrents_fail(self, get_torrent):
magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969' magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969'
pirate.torrent.save_torrents([0], [{'magnet':magnet}], 'path') pirate.torrent.save_torrents(MagicMock(Printer), [0], [{'magnet':magnet}], 'path')
def test_save_magnets(self): def test_save_magnets(self):
with patch('pirate.torrent.open', mock.mock_open(), create=True) as open_: with patch('pirate.torrent.open', mock.mock_open(), create=True) as open_:
magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969' magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969'
pirate.torrent.save_magnets([0], [{'magnet':magnet}], 'path') pirate.torrent.save_magnets(MagicMock(Printer), [0], [{'magnet':magnet}], 'path')
open_.assert_called_once_with('path/Test Drive Unlimited [PC Version].magnet', 'w') open_.assert_called_once_with('path/Test Drive Unlimited [PC Version].magnet', 'w')
@patch('urllib.request.urlopen') @patch('urllib.request.urlopen')
@ -129,7 +129,7 @@ class TestTorrent(unittest.TestCase):
response_obj = MockResponse() response_obj = MockResponse()
with patch('urllib.request.Request', return_value=request_obj) as request: with patch('urllib.request.Request', return_value=request_obj) as request:
with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: with patch('urllib.request.urlopen', return_value=response_obj) as urlopen:
res = pirate.torrent.remote(1, 100, 10, 'browse', [], 'http://example.com') res = pirate.torrent.remote(MagicMock(Printer), 1, 100, 10, 'browse', [], 'http://example.com')
request.assert_called_once_with('http://example.com/browse/100/0/10', headers=pirate.data.default_headers) request.assert_called_once_with('http://example.com/browse/100/0/10', headers=pirate.data.default_headers)
urlopen.assert_called_once_with(request_obj, timeout=pirate.data.default_timeout) urlopen.assert_called_once_with(request_obj, timeout=pirate.data.default_timeout)
self.assertEqual(res, []) self.assertEqual(res, [])