diff --git a/magneticod/magneticod/__main__.py b/magneticod/magneticod/__main__.py index 9205ffb..d859aec 100644 --- a/magneticod/magneticod/__main__.py +++ b/magneticod/magneticod/__main__.py @@ -31,8 +31,8 @@ from . import dht from . import persistence -def main(): - arguments = parse_cmdline_arguments() +def create_tasks(): + arguments = parse_cmdline_arguments(sys.argv[1:]) logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s") logging.info("magneticod v%d.%d.%d started", *__version__) @@ -56,22 +56,16 @@ def main(): complete_info_hashes = database.get_complete_info_hashes() loop = asyncio.get_event_loop() - node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size) - loop.run_until_complete(node.launch()) + node = dht.SybilNode(arguments.node_addr, database.is_infohash_new, arguments.max_metadata_size) + loop.create_task(node.launch(loop)) + watch_q_task = loop.create_task(watch_q(database, node.metadata_q())) + watch_q_task.add_done_callback(lambda x: clean_up(loop, database, node)) + return watch_q_task - watch_q_task = loop.create_task(metadata_queue_watcher(database, node.__metadata_queue)) - try: - loop.run_forever() - except KeyboardInterrupt: - logging.critical("Keyboard interrupt received! Exiting gracefully...") - finally: - database.close() - watch_q_task.cancel() - loop.run_until_complete(node.shutdown()) - loop.run_until_complete(asyncio.wait([watch_q_task])) - - return 0 +def clean_up(loop, database, node): + database.close() + loop.run_until_complete(node.shutdown()) async def metadata_queue_watcher(database: persistence.Database, metadata_queue: asyncio.Queue) -> None: @@ -112,7 +106,7 @@ def parse_size(value: str) -> int: raise argparse.ArgumentTypeError("Invalid argument. {}".format(e)) -def parse_cmdline_arguments() -> typing.Optional[argparse.Namespace]: +def parse_cmdline_arguments(args) -> typing.Optional[argparse.Namespace]: parser = argparse.ArgumentParser( description="Autonomous BitTorrent DHT crawler and metadata fetcher.", epilog=textwrap.dedent("""\ @@ -156,7 +150,19 @@ def parse_cmdline_arguments() -> typing.Optional[argparse.Namespace]: action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.INFO, help="Print debugging information in addition to normal processing.", ) - return parser.parse_args(sys.argv[1:]) + return parser.parse_args(args) + + +def main(): + main_task = create_tasks() + try: + asyncio.get_event_loop().run_forever() + except KeyboardInterrupt: + logging.critical("Keyboard interrupt received! Exiting gracefully...") + finally: + main_task.cancel() + + return 0 if __name__ == "__main__": diff --git a/magneticod/magneticod/dht.py b/magneticod/magneticod/dht.py index f451710..b1332d9 100644 --- a/magneticod/magneticod/dht.py +++ b/magneticod/magneticod/dht.py @@ -31,8 +31,8 @@ InfoHash = bytes Metadata = bytes -class SybilNode(asyncio.DatagramProtocol): - def __init__(self, address: typing.Tuple[str, int], complete_info_hashes, max_metadata_size): +class SybilNode: + def __init__(self, address: typing.Tuple[str, int], is_infohash_new, max_metadata_size): self.__true_id = self.__random_bytes(20) self.__address = address @@ -43,8 +43,8 @@ class SybilNode(asyncio.DatagramProtocol): # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # stop; but until then, the total number of neighbours might exceed the threshold). self.__n_max_neighbours = 2000 - self.__parent_futures = {} # type: typing.Dict[dht.InfoHash, asyncio.Future] - self._complete_info_hashes = complete_info_hashes + self.__tasks = {} # type: typing.Dict[dht.InfoHash, asyncio.Future] + self._is_inforhash_new = is_infohash_new self.__max_metadata_size = max_metadata_size # Complete metadatas will be added to the queue, to be retrieved and committed to the database. self.__metadata_queue = asyncio.Queue() # typing.Collection[typing.Tuple[InfoHash, Metadata]] @@ -53,9 +53,12 @@ class SybilNode(asyncio.DatagramProtocol): logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address) - async def launch(self) -> None: - event_loop = asyncio.get_event_loop() - await event_loop.create_datagram_endpoint(lambda: self, local_addr=self.__address) + def metadata_q(self): + return self._metadata_q + + async def launch(self, loop): + self._loop = loop + await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address) def connection_made(self, transport: asyncio.DatagramTransport) -> None: event_loop = asyncio.get_event_loop() @@ -107,6 +110,9 @@ class SybilNode(asyncio.DatagramProtocol): self._routing_table.clear() if not self._is_writing_paused: self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100 + logging.debug("fetch metadata task count: %d", sum( + x.child_count for x in self.__tasks.values())) + logging.debug("asyncio task count: %d", len(asyncio.Task.all_tasks())) def datagram_received(self, data, addr) -> None: # Ignore nodes that uses port 0 (assholes). @@ -209,7 +215,7 @@ class SybilNode(asyncio.DatagramProtocol): else: peer_addr = (addr[0], port) - if info_hash in self._complete_info_hashes: + if not self._is_inforhash_new(info_hash): return event_loop = asyncio.get_event_loop() @@ -277,8 +283,7 @@ class SybilNode(asyncio.DatagramProtocol): try: metadata = parent_task.result() if metadata: - self._complete_info_hashes.add(info_hash) - self.__metadata_queue.put_nowait((info_hash, metadata)) + self._metadata_q.put_nowait((info_hash, metadata)) except asyncio.CancelledError: pass del self.__parent_futures[info_hash] diff --git a/magneticod/magneticod/persistence.py b/magneticod/magneticod/persistence.py index 47b3872..e1a57de 100644 --- a/magneticod/magneticod/persistence.py +++ b/magneticod/magneticod/persistence.py @@ -94,11 +94,14 @@ class Database: return True - def get_complete_info_hashes(self) -> typing.Set[bytes]: + def is_infohash_new(self, info_hash): + if info_hash in [x[0] for x in self.__pending_metadata]: + return False cur = self.__db_conn.cursor() try: - cur.execute("SELECT info_hash FROM torrents;") - return set(x[0] for x in cur.fetchall()) + cur.execute("SELECT count(info_hash) FROM torrents where info_hash = ?;", [info_hash]) + x, = cur.fetchone() + return x == 0 finally: cur.close()