diff --git a/magneticod/magneticod/__main__.py b/magneticod/magneticod/__main__.py index 02a724b..dc7e0cf 100644 --- a/magneticod/magneticod/__main__.py +++ b/magneticod/magneticod/__main__.py @@ -31,8 +31,8 @@ from . import dht from . import persistence -def main(): - arguments = parse_cmdline_arguments() +def create_tasks(): + arguments = parse_cmdline_arguments(sys.argv[1:]) logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s") logging.info("magneticod v%d.%d.%d started", *__version__) @@ -57,21 +57,15 @@ def main(): loop = asyncio.get_event_loop() node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size) - loop.run_until_complete(node.launch(loop)) + loop.create_task(node.launch(loop)) + watch_q_task = loop.create_task(watch_q(database, node.metadata_q())) + watch_q_task.add_done_callback(lambda x: clean_up(loop, database, node)) + return watch_q_task - watch_q_task = loop.create_task(watch_q(database, node._metadata_q)) - try: - loop.run_forever() - except KeyboardInterrupt: - logging.critical("Keyboard interrupt received! Exiting gracefully...") - finally: - database.close() - watch_q_task.cancel() - loop.run_until_complete(node.shutdown()) - loop.run_until_complete(asyncio.wait([watch_q_task])) - - return 0 +def clean_up(loop, database, node): + database.close() + loop.run_until_complete(node.shutdown()) async def watch_q(database, q): @@ -109,7 +103,7 @@ def parse_size(value: str) -> int: raise argparse.ArgumentTypeError("Invalid argument. {}".format(e)) -def parse_cmdline_arguments() -> typing.Optional[argparse.Namespace]: +def parse_cmdline_arguments(args) -> typing.Optional[argparse.Namespace]: parser = argparse.ArgumentParser( description="Autonomous BitTorrent DHT crawler and metadata fetcher.", epilog=textwrap.dedent("""\ @@ -153,7 +147,19 @@ def parse_cmdline_arguments() -> typing.Optional[argparse.Namespace]: action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.INFO, help="Print debugging information in addition to normal processing.", ) - return parser.parse_args(sys.argv[1:]) + return parser.parse_args(args) + + +def main(): + main_task = create_tasks() + try: + asyncio.get_event_loop().run_forever() + except KeyboardInterrupt: + logging.critical("Keyboard interrupt received! Exiting gracefully...") + finally: + main_task.cancel() + + return 0 if __name__ == "__main__": diff --git a/magneticod/magneticod/dht.py b/magneticod/magneticod/dht.py index c136401..9393a16 100644 --- a/magneticod/magneticod/dht.py +++ b/magneticod/magneticod/dht.py @@ -51,6 +51,9 @@ class SybilNode: logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address) + def metadata_q(self): + return self._metadata_q + async def launch(self, loop): self._loop = loop await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address)