diff --git a/pirate/data.py b/pirate/data.py index 8c8390f..457809f 100644 --- a/pirate/data.py +++ b/pirate/data.py @@ -11,4 +11,3 @@ blacklist = set(json.loads(get_resource('blacklist.json').decode())) default_headers = {'User-Agent': 'pirate get'} default_timeout = 10 -colored_output = True diff --git a/pirate/pirate.py b/pirate/pirate.py index f534593..c3d4113 100755 --- a/pirate/pirate.py +++ b/pirate/pirate.py @@ -14,10 +14,9 @@ import webbrowser import pirate.data import pirate.torrent import pirate.local -import pirate.print from os.path import expanduser, expandvars -from pirate.print import print +from pirate.print import Printer def parse_config_file(text): @@ -227,7 +226,7 @@ def combine_configs(config, args): return args -def search_mirrors(pages, category, sort, action, search): +def search_mirrors(printer, pages, category, sort, action, search): mirror_sources = [None, 'https://proxybay.co/list.txt'] for mirror_source in mirror_sources: mirrors = OrderedDict() @@ -239,7 +238,7 @@ def search_mirrors(pages, category, sort, action, search): headers=pirate.data.default_headers) f = request.urlopen(req, timeout=pirate.data.default_timeout) except IOError: - print('Could not fetch additional mirrors', color='WARN') + printer.print('Could not fetch additional mirrors', color='WARN') else: if f.getcode() != 200: raise IOError('The proxy bay responded with an error.') @@ -251,36 +250,39 @@ def search_mirrors(pages, category, sort, action, search): for mirror in mirrors.keys(): try: - print('Trying', mirror, end='... \n') + printer.print('Trying', mirror, end='... \n') results = pirate.torrent.remote( + printer=printer, pages=pages, - category=pirate.torrent.parse_category(category), - sort=pirate.torrent.parse_sort(sort), + category=pirate.torrent.parse_category(printer, category), + sort=pirate.torrent.parse_sort(printer, sort), mode=action, terms=search, mirror=mirror ) except (urllib.error.URLError, socket.timeout, IOError, ValueError): - print('Failed', color='WARN') + printer.print('Failed', color='WARN') else: - print('Ok', color='alt') + printer.print('Ok', color='alt') return results, mirror else: - print('No available mirrors :(', color='WARN') + printer.print('No available mirrors :(', color='WARN') return [], None def main(): args = combine_configs(load_config(), parse_args(sys.argv[1:])) + printer = Printer(args.color) + # check it transmission is running if args.transmission: ret = subprocess.call(args.transmission_command + ['-l'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) if ret != 0: - print('Transmission is not running.') + printer.print('Transmission is not running.') sys.exit(1) # non-torrent fetching actions @@ -289,14 +291,14 @@ def main(): cur_color = 'zebra_0' for key, value in sorted(pirate.data.categories.items()): cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1' - print(str(value), '\t', key, sep='', color=cur_color) + printer.print(str(value), '\t', key, sep='', color=cur_color) return if args.action == 'list_sorts': cur_color = 'zebra_0' for key, value in sorted(pirate.data.sorts.items()): cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1' - print(str(value), '\t', key, sep='', color=cur_color) + printer.print(str(value), '\t', key, sep='', color=cur_color) return # fetch torrents @@ -304,38 +306,38 @@ def main(): if args.source == 'local_tpb': results = pirate.local.search(args.database, args.search) elif args.source == 'tpb': - results, site = search_mirrors(args.pages, args.category, args.sort, args.action, args.search) + results, site = search_mirrors(printer, args.pages, args.category, args.sort, args.action, args.search) if len(results) == 0: - print('No results') + printer.print('No results') return - pirate.print.search_results(results, local=args.source == 'local_tpb') + printer.search_results(results, local=args.source == 'local_tpb') # number of results to pick if args.first: - print('Choosing first result') + printer.print('Choosing first result') choices = [0] elif args.download_all: - print('Downloading all results') + printer.print('Downloading all results') choices = range(len(results)) else: # interactive loop for per-torrent actions while True: - print("\nSelect links (Type 'h' for more options" + printer.print("\nSelect links (Type 'h' for more options" ", 'q' to quit)", end='\b', color='alt') try: l = input(': ') except (KeyboardInterrupt, EOFError): - print('\nCancelled.') + printer.print('\nCancelled.') return try: code, choices = parse_torrent_command(l) # Act on option, if supplied - print('') + printer.print('') if code == 'h': - print('Options:', + printer.print('Options:', ': Download selected torrents', '[m]: Save magnets as files', '[t]: Save .torrent files', @@ -344,35 +346,35 @@ def main(): '[p] Print search results', '[q] Quit', sep='\n') elif code == 'q': - print('Bye.', color='alt') + printer.print('Bye.', color='alt') return elif code == 'd': - pirate.print.descriptions(choices, results, site) + printer.descriptions(choices, results, site) elif code == 'f': - pirate.print.file_lists(choices, results, site) + printer.file_lists(choices, results, site) elif code == 'p': - pirate.print.search_results(results) + printer.search_results(results) elif code == 'm': - pirate.torrent.save_magnets(choices, results, args.save_directory) + pirate.torrent.save_magnets(printer, choices, results, args.save_directory) elif code == 't': - pirate.torrent.save_torrents(choices, results, args.save_directory) + pirate.torrent.save_torrents(printer, choices, results, args.save_directory) elif not l: - print('No links entered!', color='WARN') + printer.print('No links entered!', color='WARN') else: break except Exception as e: - print('Exception:', e, color='ERROR') + printer.print('Exception:', e, color='ERROR') return # output if args.output == 'save_magnet_files': - print('Saving selected magnets...') + printer.print('Saving selected magnets...') pirate.torrent.save_magnets(choices, results, args.save_directory) return if args.output == 'save_torrent_files': - print('Saving selected torrents...') + printer.print('Saving selected torrents...') pirate.torrent.save_torrents(choices, results, args.save_directory) return diff --git a/pirate/print.py b/pirate/print.py index a83b89b..6c11f0e 100644 --- a/pirate/print.py +++ b/pirate/print.py @@ -2,136 +2,148 @@ import builtins import re import os import gzip -import colorama import urllib.parse as parse import urllib.request as request import shutil from io import BytesIO +import colorama +import veryprettytable + import pirate.data -def print(*args, **kwargs): - if kwargs.get('color', False) and pirate.data.colored_output: - colorama.init() - color_dict = { - 'default': '', - 'header': colorama.Back.BLACK + colorama.Fore.WHITE, - 'alt': colorama.Fore.YELLOW, - 'zebra_0': '', - 'zebra_1': colorama.Fore.BLUE, - 'WARN': colorama.Fore.MAGENTA, - 'ERROR': colorama.Fore.RED} +class Printer: + def __init__(self, enable_color): + self.enable_color = enable_color - c = color_dict[kwargs.pop('color')] - args = (c + args[0],) + args[1:] + (colorama.Style.RESET_ALL,) - kwargs.pop('color', None) - return builtins.print(*args, **kwargs) - else: - kwargs.pop('color', None) - return builtins.print(*args, **kwargs) + def print(self, *args, **kwargs): + if kwargs.get('color', False) and self.enable_color: + colorama.init() + color_dict = { + 'default': '', + 'header': colorama.Back.BLACK + colorama.Fore.WHITE, + 'alt': colorama.Fore.YELLOW, + 'zebra_0': '', + 'zebra_1': colorama.Fore.BLUE, + 'WARN': colorama.Fore.MAGENTA, + 'ERROR': colorama.Fore.RED} + + c = color_dict[kwargs.pop('color')] + args = (c + args[0],) + args[1:] + (colorama.Style.RESET_ALL,) + kwargs.pop('color', None) + return builtins.print(*args, **kwargs) + else: + kwargs.pop('color', None) + return builtins.print(*args, **kwargs) -# TODO: extract the name from the search results instead of the magnet link when possible -def search_results(results, local=None): - columns = shutil.get_terminal_size((80, 20)).columns - cur_color = 'zebra_0' - - if local: - print('{:>4} {:{length}}'.format( - 'LINK', 'NAME', length=columns - 8), - color='header') - else: - print('{:>4} {:>5} {:>5} {:>5} {:9} {:11} {:{length}}'.format( - 'LINK', 'SEED', 'LEECH', 'RATIO', - 'SIZE', 'UPLOAD', 'NAME', length=columns - 52), - color='header') - - for n, result in enumerate(results): - # Alternate between colors - cur_color = 'zebra_0' if cur_color == 'zebra_1' else 'zebra_1' - - name = re.search(r'dn=([^\&]*)', result['magnet']) - torrent_name = parse.unquote_plus(name.group(1)) + # TODO: extract the name from the search results instead of from the magnet link when possible + def search_results(self, results, local=None): + columns = shutil.get_terminal_size((80, 20)).columns + even = True if local: - line = '{:5} {:{length}}' - content = [n, torrent_name[:columns]] + table = veryprettytable.VeryPrettyTable(['LINK', 'NAME']) else: - no_seeders = int(result['seeds']) - no_leechers = int(result['leechers']) - if result['size'] != []: - size = float(result['size'][0]) - unit = result['size'][1] + table = veryprettytable.VeryPrettyTable(['LINK', 'SEED', 'LEECH', 'RATIO', 'SIZE', '', 'UPLOAD', 'NAME']) + table.align['NAME'] = 'l' + table.align['SEED'] = 'r' + table.align['LEECH'] = 'r' + table.align['RATIO'] = 'r' + table.align['SIZE'] = 'r' + table.align['UPLOAD'] = 'l' + + table.max_width = columns + table.border = False + table.padding_width = 1 + + for n, result in enumerate(results): + + name = re.search(r'dn=([^\&]*)', result['magnet']) + torrent_name = parse.unquote_plus(name.group(1)) + + if local: + content = [n, torrent_name[:columns - 7]] else: - size = 0 - unit = '???' - date = result['uploaded'] + no_seeders = int(result['seeds']) + no_leechers = int(result['leechers']) + if result['size'] != []: + size = float(result['size'][0]) + unit = result['size'][1] + else: + size = 0 + unit = '???' + date = result['uploaded'] - # compute the S/L ratio (Higher is better) - try: - ratio = no_seeders / no_leechers - except ZeroDivisionError: - ratio = float('inf') + # compute the S/L ratio (Higher is better) + try: + ratio = no_seeders / no_leechers + except ZeroDivisionError: + ratio = float('inf') - line = ('{:4} {:5} {:5} {:5.1f} {:5.1f}' - ' {:3} {:<11} {:{length}}') - content = [n, no_seeders, no_leechers, ratio, - size, unit, date, torrent_name[:columns - 52]] + content = [n, no_seeders, no_leechers, '{:.1f}'.format(ratio), + '{:.1f}'.format(size), unit, date, torrent_name[:columns - 53]] - # enhanced print output with justified columns - print(line.format(*content, length=columns - 52), color=cur_color) + if even or not self.enable_color: + table.add_row(content) + else: + table.add_row(content, fore_color='blue') + + # Alternate between colors + even = not even + self.print(table) -def descriptions(chosen_links, results, site): - for link in chosen_links: - path = '/torrent/%s/' % results[link]['id'] - req = request.Request(site + path, headers=pirate.data.default_headers) - req.add_header('Accept-encoding', 'gzip') - f = request.urlopen(req, timeout=pirate.data.default_timeout) + def descriptions(self, chosen_links, results, site): + for link in chosen_links: + path = '/torrent/%s/' % results[link]['id'] + req = request.Request(site + path, headers=pirate.data.default_headers) + req.add_header('Accept-encoding', 'gzip') + f = request.urlopen(req, timeout=pirate.data.default_timeout) - if f.info().get('Content-Encoding') == 'gzip': - f = gzip.GzipFile(fileobj=BytesIO(f.read())) + if f.info().get('Content-Encoding') == 'gzip': + f = gzip.GzipFile(fileobj=BytesIO(f.read())) - res = f.read().decode('utf-8') - name = re.search(r'dn=([^\&]*)', results[link]['magnet']) - torrent_name = parse.unquote(name.group(1)).replace('+', ' ') - desc = re.search(r'
\s*
(.+?)(?=
)', - res, re.DOTALL).group(1) + res = f.read().decode('utf-8') + name = re.search(r'dn=([^\&]*)', results[link]['magnet']) + torrent_name = parse.unquote(name.group(1)).replace('+', ' ') + desc = re.search(r'
\s*
(.+?)(?=
)', + res, re.DOTALL).group(1) - # Replace HTML links with markdown style versions - desc = re.sub(r']*>(\s*)([^<]+?)(\s*' - r')', r'\2[\3](\1)\4', desc) + # Replace HTML links with markdown style versions + desc = re.sub(r']*>(\s*)([^<]+?)(\s*' + r')', r'\2[\3](\1)\4', desc) - print('Description for "%s":' % torrent_name, color='zebra_1') - print(desc, color='zebra_0') + self.print('Description for "%s":' % torrent_name, color='zebra_1') + self.print(desc, color='zebra_0') -def file_lists(chosen_links, results, site): - for link in chosen_links: - path = '/ajax_details_filelist.php' - query = '?id=' + results[link]['id'] - req = request.Request(site + path + query, - headers=pirate.data.default_headers) - req.add_header('Accept-encoding', 'gzip') - f = request.urlopen(req, timeout=pirate.data.default_timeout) + def file_lists(self, chosen_links, results, site): + for link in chosen_links: + path = '/ajax_details_filelist.php' + query = '?id=' + results[link]['id'] + req = request.Request(site + path + query, + headers=pirate.data.default_headers) + req.add_header('Accept-encoding', 'gzip') + f = request.urlopen(req, timeout=pirate.data.default_timeout) - if f.info().get('Content-Encoding') == 'gzip': - f = gzip.GzipFile(fileobj=BytesIO(f.read())) + if f.info().get('Content-Encoding') == 'gzip': + f = gzip.GzipFile(fileobj=BytesIO(f.read())) - # TODO: proper html decoding/parsing - res = f.read().decode('utf-8').replace(' ', ' ') - if 'File list not available.' in res: - print('File list not available.') - return - files = re.findall(r'\s*([^<]+?)\s*\s*([^<]+?)\s*', res) - name = re.search(r'dn=([^\&]*)', results[link]['magnet']) - torrent_name = parse.unquote(name.group(1)).replace('+', ' ') + # TODO: proper html decoding/parsing + res = f.read().decode('utf-8').replace(' ', ' ') + if 'File list not available.' in res: + self.print('File list not available.') + return + files = re.findall(r'\s*([^<]+?)\s*\s*([^<]+?)\s*', res) + name = re.search(r'dn=([^\&]*)', results[link]['magnet']) + torrent_name = parse.unquote(name.group(1)).replace('+', ' ') - print('Files in "%s":' % torrent_name, color='zebra_1') - cur_color = 'zebra_0' + self.print('Files in "%s":' % torrent_name, color='zebra_1') + cur_color = 'zebra_0' - for f in files: - print('{0[0]:>11} {0[1]}'.format(f), color=cur_color) - cur_color = 'zebra_0' if (cur_color == 'zebra_1') else 'zebra_1' + for f in files: + self.print('{0[0]:>11} {0[1]}'.format(f), color=cur_color) + cur_color = 'zebra_0' if (cur_color == 'zebra_1') else 'zebra_1' diff --git a/pirate/torrent.py b/pirate/torrent.py index ece8a20..039157b 100644 --- a/pirate/torrent.py +++ b/pirate/torrent.py @@ -9,7 +9,6 @@ import os.path from pyquery import PyQuery as pq import pirate.data -from pirate.print import print from io import BytesIO @@ -17,7 +16,7 @@ from io import BytesIO parser_regex = r'"(magnet\:\?xt=[^"]*)|([^<]+)' -def parse_category(category): +def parse_category(printer, category): try: category = int(category) except ValueError: @@ -27,11 +26,11 @@ def parse_category(category): elif category in pirate.data.categories.keys(): return pirate.data.categories[category] else: - print('Invalid category ignored', color='WARN') + printer.print('Invalid category ignored', color='WARN') return 0 -def parse_sort(sort): +def parse_sort(printer, sort): try: sort = int(sort) except ValueError: @@ -41,7 +40,7 @@ def parse_sort(sort): elif sort in pirate.data.sorts.keys(): return pirate.data.sorts[sort] else: - print('Invalid sort ignored', color='WARN') + printer.print('Invalid sort ignored', color='WARN') return 99 @@ -119,7 +118,7 @@ def parse_page(html): return results -def remote(pages, category, sort, mode, terms, mirror): +def remote(printer, pages, category, sort, mode, terms, mirror): res_l = [] if pages < 1: @@ -142,7 +141,7 @@ def remote(pages, category, sort, mode, terms, mirror): res_l += parse_page(res) except KeyboardInterrupt: - print('\nCancelled.') + printer.print('\nCancelled.') sys.exit(0) return res_l @@ -162,7 +161,7 @@ def get_torrent(info_hash): # TODO: handle slashes in torrent names -def save_torrents(chosen_links, results, folder): +def save_torrents(printer, chosen_links, results, folder): for link in chosen_links: magnet = results[link]['magnet'] name = re.search(r'dn=([^\&]*)', magnet) @@ -173,14 +172,14 @@ def save_torrents(chosen_links, results, folder): try: torrent = get_torrent(info_hash) except urllib.error.HTTPError: - print('There is no cached file for this torrent :(', color='ERROR') + printer.print('There is no cached file for this torrent :(', color='ERROR') else: open(file, 'wb').write(torrent) - print('Saved {:X} in {}'.format(info_hash, file)) + printer.print('Saved {:X} in {}'.format(info_hash, file)) # TODO: handle slashes in torrent names -def save_magnets(chosen_links, results, folder): +def save_magnets(printer, chosen_links, results, folder): for link in chosen_links: magnet = results[link]['magnet'] name = re.search(r'dn=([^\&]*)', magnet) @@ -188,6 +187,6 @@ def save_magnets(chosen_links, results, folder): info_hash = int(re.search(r'btih:([a-f0-9]{40})', magnet).group(1), 16) file = os.path.join(folder, torrent_name + '.magnet') - print('Saved {:X} in {}'.format(info_hash, file)) + printer.print('Saved {:X} in {}'.format(info_hash, file)) with open(file, 'w') as f: f.write(magnet + '\n') diff --git a/setup.py b/setup.py index 8d30a87..67f4029 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setup(name='pirate-get', entry_points={ 'console_scripts': ['pirate-get = pirate.pirate:main'] }, - install_requires=['colorama>=0.3.3', 'pyquery>=1.2.9'], + install_requires=['colorama>=0.3.3', 'pyquery>=1.2.9', 'veryprettytable>=0.8.1'], keywords=['torrent', 'magnet', 'download', 'tpb', 'client'], classifiers=[ 'Topic :: Utilities', diff --git a/tests/test_pirate.py b/tests/test_pirate.py index b5b047a..6451adb 100755 --- a/tests/test_pirate.py +++ b/tests/test_pirate.py @@ -2,9 +2,10 @@ import socket import unittest from unittest import mock -from unittest.mock import patch, call +from unittest.mock import patch, call, MagicMock import pirate.pirate +from pirate.print import Printer class TestPirate(unittest.TestCase): @@ -119,19 +120,20 @@ class TestPirate(unittest.TestCase): info = mock.MagicMock() getcode = mock.MagicMock(return_value=200) response_obj = MockResponse() + printer = MagicMock(Printer) with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: with patch('pirate.torrent.remote', return_value=[]) as remote: - results, mirror = pirate.pirate.search_mirrors(pages, category, sort, action, search) + results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) self.assertEqual(results, []) self.assertEqual(mirror, 'https://thepiratebay.mn') - remote.assert_called_once_with(pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn') + remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn') with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote: - results, mirror = pirate.pirate.search_mirrors(pages, category, sort, action, search) + results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search) self.assertEqual(results, []) self.assertEqual(mirror, 'https://example.com') remote.assert_has_calls([ - call(pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn'), - call(pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://example.com') + call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://thepiratebay.mn'), + call(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror='https://example.com') ]) if __name__ == '__main__': diff --git a/tests/test_print.py b/tests/test_print.py index 438170d..32749ff 100755 --- a/tests/test_print.py +++ b/tests/test_print.py @@ -1,15 +1,20 @@ #!/usr/bin/env python3 import unittest -from unittest.mock import patch -from unittest.mock import call +from unittest.mock import patch, call, MagicMock -import pirate.print +from pirate.print import Printer class TestPrint(unittest.TestCase): - def test_print_results(self): - with patch('pirate.print.print') as mock: + def test_print_results_remote(self): + class MockTable: + add_row = MagicMock() + align = {} + mock = MockTable() + printer = Printer(False) + printer.print = MagicMock() + with patch('veryprettytable.VeryPrettyTable', return_value=mock) as prettytable: results = [{ 'magnet': 'dn=name', 'seeds': 1, @@ -17,13 +22,87 @@ class TestPrint(unittest.TestCase): 'size': ['3','MiB'], 'uploaded': 'never' }] - pirate.print.search_results(results) - actual = mock.call_args_list - expected = [ - call('LINK SEED LEECH RATIO SIZE UPLOAD NAME ', color='header'), - call(' 0 1 2 0.5 3.0 MiB never name ', color='zebra_1'), - ] - self.assertEqual(expected, actual) + printer.search_results(results) + prettytable.assert_called_once_with(['LINK', 'SEED', 'LEECH', 'RATIO', 'SIZE', '', 'UPLOAD', 'NAME']) + mock.add_row.assert_has_calls([call([0, 1, 2, '0.5', '3.0', 'MiB', 'never', 'name'])]) + + def test_print_results_local(self): + class MockTable: + add_row = MagicMock() + align = {} + mock = MockTable() + printer = Printer(False) + printer.print = MagicMock() + with patch('veryprettytable.VeryPrettyTable', return_value=mock) as prettytable: + results = [{ + 'magnet': 'dn=name', + 'Name': 'name', + },{ + 'magnet': 'dn=name2', + 'Name': 'name2', + }] + printer.search_results(results, local=True) + prettytable.assert_called_once_with(['LINK', 'NAME']) + mock.add_row.assert_has_calls([call([0, 'name']), call([1, 'name2'])]) + + def test_print_color(self): + printer = Printer(False) + with patch('pirate.print.builtins.print') as mock_print: + printer.print('abc', color='zebra_1') + mock_print.assert_called_once_with('abc') + printer = Printer(True) + with patch('pirate.print.builtins.print') as mock_print: + printer.print('abc', color='zebra_1') + mock_print.assert_called_once_with('\x1b[34mabc', '\x1b[0m') + + def test_print_results_local(self): + class MockTable: + add_row = MagicMock() + align = {} + mock = MockTable() + printer = Printer(True) + printer.print = MagicMock() + with patch('veryprettytable.VeryPrettyTable', return_value=mock) as prettytable: + results = [{ + 'magnet': 'dn=name', + 'Name': 'name', + },{ + 'magnet': 'dn=name2', + 'Name': 'name2', + }] + printer.search_results(results, local=True) + prettytable.assert_called_once_with(['LINK', 'NAME']) + mock.add_row.assert_has_calls([call([0, 'name']), call([1, 'name2'], fore_color='blue')]) + + def test_print_descriptions(self): + printer = Printer(False) + printer.print = MagicMock() + class MockRequest(): + add_header = MagicMock() + request_obj = MockRequest() + class MockResponse(): + read = MagicMock(return_value='
stuff link
'.encode('utf8')) + info = MagicMock() + response_obj = MockResponse() + with patch('urllib.request.Request', return_value=request_obj) as request: + with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: + printer.descriptions([0], [{'id': '1', 'magnet': 'dn=name'}], 'example.com') + printer.print.assert_has_calls([call('Description for "name":', color='zebra_1'),call('stuff [link](href)', color='zebra_0')]) + + def test_print_file_lists(self): + printer = Printer(False) + printer.print = MagicMock() + class MockRequest(): + add_header = MagicMock() + request_obj = MockRequest() + class MockResponse(): + read = MagicMock(return_value='1.filename'.encode('utf8')) + info = MagicMock() + response_obj = MockResponse() + with patch('urllib.request.Request', return_value=request_obj) as request: + with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: + printer.file_lists([0], [{'id': '1', 'magnet': 'dn=name'}], 'example.com') + printer.print.assert_has_calls([call('Files in "name":', color='zebra_1'),call(' 1. filename', color='zebra_0')]) if __name__ == '__main__': unittest.main() diff --git a/tests/test_torrent.py b/tests/test_torrent.py index d8fb4ad..5ee3ff5 100755 --- a/tests/test_torrent.py +++ b/tests/test_torrent.py @@ -1,16 +1,16 @@ #!/usr/bin/env python3 import unittest from unittest import mock -from unittest.mock import patch -import pirate.torrent -import pirate.data +from unittest.mock import patch, MagicMock import os import io import urllib +import pirate.torrent +import pirate.data +from pirate.print import Printer from tests import util - class TestTorrent(unittest.TestCase): def test_no_hits(self): @@ -47,25 +47,25 @@ class TestTorrent(unittest.TestCase): self.assertEqual(actual, expected) def test_parse_category(self): - category = pirate.torrent.parse_category('Audio') + category = pirate.torrent.parse_category(MagicMock(Printer), 'Audio') self.assertEqual(100, category) - category = pirate.torrent.parse_category('Video') + category = pirate.torrent.parse_category(MagicMock(Printer), 'Video') self.assertEqual(200, category) - category = pirate.torrent.parse_category('100') + category = pirate.torrent.parse_category(MagicMock(Printer), '100') self.assertEqual(100, category) - category = pirate.torrent.parse_category('asdf') + category = pirate.torrent.parse_category(MagicMock(Printer), 'asdf') self.assertEqual(0, category) - category = pirate.torrent.parse_category('9001') + category = pirate.torrent.parse_category(MagicMock(Printer), '9001') self.assertEqual(0, category) def test_parse_sort(self): - sort = pirate.torrent.parse_sort('SeedersDsc') + sort = pirate.torrent.parse_sort(MagicMock(Printer), 'SeedersDsc') self.assertEqual(7, sort) - sort = pirate.torrent.parse_sort('7') + sort = pirate.torrent.parse_sort(MagicMock(Printer), '7') self.assertEqual(7, sort) - sort = pirate.torrent.parse_sort('asdf') + sort = pirate.torrent.parse_sort(MagicMock(Printer), 'asdf') self.assertEqual(99, sort) - sort = pirate.torrent.parse_sort('7000') + sort = pirate.torrent.parse_sort(MagicMock(Printer), '7000') self.assertEqual(99, sort) def test_request_path(self): @@ -94,19 +94,19 @@ class TestTorrent(unittest.TestCase): def test_save_torrents(self, get_torrent): with patch('pirate.torrent.open', mock.mock_open(), create=True) as open_: magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969' - pirate.torrent.save_torrents([0], [{'magnet':magnet}], 'path') + pirate.torrent.save_torrents(MagicMock(Printer), [0], [{'magnet':magnet}], 'path') get_torrent.assert_called_once_with(293294978876299923284263767676068334936407502135) open_.assert_called_once_with('path/Test Drive Unlimited [PC Version].torrent', 'wb') @patch('pirate.torrent.get_torrent', side_effect=urllib.error.HTTPError('', '', '', '', io.StringIO())) def test_save_torrents_fail(self, get_torrent): magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969' - pirate.torrent.save_torrents([0], [{'magnet':magnet}], 'path') + pirate.torrent.save_torrents(MagicMock(Printer), [0], [{'magnet':magnet}], 'path') def test_save_magnets(self): with patch('pirate.torrent.open', mock.mock_open(), create=True) as open_: magnet = 'magnet:?xt=urn:btih:335fcd3cfbecc85554616d73de888033c6c16d37&dn=Test+Drive+Unlimited+%5BPC+Version%5D&tr=udp%3A%2F%2Ftracker.openbittorrent.com%3A80&tr=udp%3A%2F%2Fopen.demonii.com%3A1337&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Fexodus.desync.com%3A6969' - pirate.torrent.save_magnets([0], [{'magnet':magnet}], 'path') + pirate.torrent.save_magnets(MagicMock(Printer), [0], [{'magnet':magnet}], 'path') open_.assert_called_once_with('path/Test Drive Unlimited [PC Version].magnet', 'w') @patch('urllib.request.urlopen') @@ -129,7 +129,7 @@ class TestTorrent(unittest.TestCase): response_obj = MockResponse() with patch('urllib.request.Request', return_value=request_obj) as request: with patch('urllib.request.urlopen', return_value=response_obj) as urlopen: - res = pirate.torrent.remote(1, 100, 10, 'browse', [], 'http://example.com') + res = pirate.torrent.remote(MagicMock(Printer), 1, 100, 10, 'browse', [], 'http://example.com') request.assert_called_once_with('http://example.com/browse/100/0/10', headers=pirate.data.default_headers) urlopen.assert_called_once_with(request_obj, timeout=pirate.data.default_timeout) self.assertEqual(res, [])