diff --git a/magneticod/magneticod/__main__.py b/magneticod/magneticod/__main__.py index 9e3d80b..02a724b 100644 --- a/magneticod/magneticod/__main__.py +++ b/magneticod/magneticod/__main__.py @@ -13,50 +13,38 @@ # You should have received a copy of the GNU Affero General Public License along with this program. If not, see # . import argparse -import collections -import functools +import asyncio import logging import ipaddress -import selectors import textwrap import urllib.parse -import itertools import os import sys -import time import typing import appdirs import humanfriendly -from .constants import TICK_INTERVAL, MAX_ACTIVE_PEERS_PER_INFO_HASH, DEFAULT_MAX_METADATA_SIZE +from .constants import DEFAULT_MAX_METADATA_SIZE from . import __version__ -from . import bittorrent from . import dht from . import persistence -# Global variables are bad bla bla bla, BUT these variables are used so many times that I think it is justified; else -# the signatures of many functions are literally cluttered. -# -# If you are using a global variable, please always indicate that at the VERY BEGINNING of the function instead of right -# before using the variable for the first time. -selector = selectors.DefaultSelector() -database = None # type: persistence.Database -node = None -peers = collections.defaultdict(list) # type: typing.DefaultDict[dht.InfoHash, typing.List[bittorrent.DisposablePeer]] -# info hashes whose metadata is valid & complete (OR complete but deemed to be corrupt) so do NOT download them again: -complete_info_hashes = set() - - def main(): - global complete_info_hashes, database, node, peers, selector - arguments = parse_cmdline_arguments() logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s") logging.info("magneticod v%d.%d.%d started", *__version__) + # use uvloop if it's installed + try: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + logging.info("using uvloop") + except ImportError: + pass + # noinspection PyBroadException try: path = arguments.database_file @@ -67,106 +55,31 @@ def main(): complete_info_hashes = database.get_complete_info_hashes() - node = dht.SybilNode(arguments.node_addr) + loop = asyncio.get_event_loop() + node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size) + loop.run_until_complete(node.launch(loop)) - node.when_peer_found = lambda info_hash, peer_address: on_peer_found(info_hash=info_hash, - peer_address=peer_address, - max_metadata_size=arguments.max_metadata_size) - - selector.register(node, selectors.EVENT_READ) + watch_q_task = loop.create_task(watch_q(database, node._metadata_q)) try: - loop() + loop.run_forever() except KeyboardInterrupt: logging.critical("Keyboard interrupt received! Exiting gracefully...") - pass finally: database.close() - selector.close() - node.shutdown() - for peer in itertools.chain.from_iterable(peers.values()): - peer.shutdown() + watch_q_task.cancel() + loop.run_until_complete(node.shutdown()) + loop.run_until_complete(asyncio.wait([watch_q_task])) return 0 -def on_peer_found(info_hash: dht.InfoHash, peer_address, max_metadata_size: int=DEFAULT_MAX_METADATA_SIZE) -> None: - global selector, peers, complete_info_hashes - - if len(peers[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH or info_hash in complete_info_hashes: - return - - try: - peer = bittorrent.DisposablePeer(info_hash, peer_address, max_metadata_size) - except ConnectionError: - return - - selector.register(peer, selectors.EVENT_READ | selectors.EVENT_WRITE) - peer.when_metadata_found = on_metadata_found - peer.when_error = functools.partial(on_peer_error, peer, info_hash) - peers[info_hash].append(peer) - - -def on_metadata_found(info_hash: dht.InfoHash, metadata: bytes) -> None: - global complete_info_hashes, database, peers, selector - - succeeded = database.add_metadata(info_hash, metadata) - if not succeeded: - logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex()) - - # When we fetch the metadata of an info hash completely, shut down all other peers who are trying to do the same. - for peer in peers[info_hash]: - selector.unregister(peer) - peer.shutdown() - del peers[info_hash] - - complete_info_hashes.add(info_hash) - - -def on_peer_error(peer: bittorrent.DisposablePeer, info_hash: dht.InfoHash) -> None: - global peers, selector - peer.shutdown() - peers[info_hash].remove(peer) - selector.unregister(peer) - - -# TODO: -# Consider whether time.monotonic() is a good choice. Maybe we should use CLOCK_MONOTONIC_RAW as its not affected by NTP -# adjustments, and all we need is how many seconds passed since a certain point in time. -def loop() -> None: - global selector, node, peers - - t0 = time.monotonic() +async def watch_q(database, q): while True: - keys_and_events = selector.select(timeout=TICK_INTERVAL) - - # Check if it is time to tick - delta = time.monotonic() - t0 - if delta >= TICK_INTERVAL: - if not (delta < 2 * TICK_INTERVAL): - logging.warning("Belated TICK! (Δ = %d)", delta) - - node.on_tick() - for peer_list in peers.values(): - for peer in peer_list: - peer.on_tick() - - t0 = time.monotonic() - - for key, events in keys_and_events: - if events & selectors.EVENT_READ: - key.fileobj.on_receivable() - if events & selectors.EVENT_WRITE: - key.fileobj.on_sendable() - - # Check for entities that would like to write to their socket - keymap = selector.get_map() - for fd in keymap: - fileobj = keymap[fd].fileobj - if fileobj.would_send(): - selector.modify(fileobj, selectors.EVENT_READ | selectors.EVENT_WRITE) - else: - selector.modify(fileobj, selectors.EVENT_READ) + info_hash, metadata = await q.get() + succeeded = database.add_metadata(info_hash, metadata) + if not succeeded: + logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex()) def parse_ip_port(netloc) -> typing.Optional[typing.Tuple[str, int]]: diff --git a/magneticod/magneticod/bittorrent.py b/magneticod/magneticod/bittorrent.py index 5a5b4be..993ca0b 100644 --- a/magneticod/magneticod/bittorrent.py +++ b/magneticod/magneticod/bittorrent.py @@ -12,42 +12,38 @@ # # You should have received a copy of the GNU Affero General Public License along with this program. If not, see # . -import errno +import asyncio import logging import hashlib import math -import socket import typing import os from . import bencode -from .constants import DEFAULT_MAX_METADATA_SIZE InfoHash = bytes PeerAddress = typing.Tuple[str, int] +async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size, timeout=None): + try: + return await asyncio.wait_for(DisposablePeer().run( + asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size), timeout=timeout) + except asyncio.TimeoutError: + return None + + +class ProtocolError(Exception): + pass + + class DisposablePeer: - def __init__(self, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int= DEFAULT_MAX_METADATA_SIZE): - self.__socket = socket.socket() - self.__socket.setblocking(False) - # To reduce the latency: - self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) - if hasattr(socket, "TCP_QUICKACK"): - self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK, True) - - res = self.__socket.connect_ex(peer_addr) - if res != errno.EINPROGRESS: - raise ConnectionError() - + async def run(self, loop, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int): self.__peer_addr = peer_addr self.__info_hash = info_hash self.__max_metadata_size = max_metadata_size - self.__incoming_buffer = bytearray() - self.__outgoing_buffer = bytearray() - self.__bt_handshake_complete = False # BitTorrent Handshake self.__ext_handshake_complete = False # Extension Handshake self.__ut_metadata = None # Since we don't know ut_metadata code that remote peer uses... @@ -55,103 +51,46 @@ class DisposablePeer: self.__metadata_size = None self.__metadata_received = 0 # Amount of metadata bytes received... self.__metadata = None + self._run_task = None - # To prevent double shutdown - self.__shutdown = False + self._metadata_future = loop.create_future() + self._writer = None - # After 120 ticks passed, a peer should report an error and shut itself down due to being stall. - self.__ticks_passed = 0 - - # Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message) - self.__outgoing_buffer += b"\x13BitTorrent protocol%s%s%s" % ( - b"\x00\x00\x00\x00\x00\x10\x00\x01", - self.__info_hash, - self.__random_bytes(20) - ) - - @staticmethod - def when_error() -> None: - raise NotImplementedError() - - @staticmethod - def when_metadata_found(info_hash: InfoHash, metadata: bytes) -> None: - raise NotImplementedError() - - def on_tick(self): - self.__ticks_passed += 1 - - if self.__ticks_passed == 120: - logging.debug("Peer failed to fetch metadata in time for info hash %s!", self.__info_hash.hex()) - self.when_error() - - def on_receivable(self) -> None: - while True: - try: - received = self.__socket.recv(8192) - except BlockingIOError: - break - except ConnectionResetError: - self.when_error() - return - except ConnectionRefusedError: - self.when_error() - return - except OSError: # TODO: check for "no route to host 113" error - self.when_error() - return - - if not received: - self.when_error() - return - - self.__incoming_buffer += received - # Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) badly - # designed protocols I have ever seen (I am 19 years old so what I could have seen?). - # - # Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the - # size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one canonical - # way of handshaking in the wild. - if not self.__bt_handshake_complete: - if len(self.__incoming_buffer) < 68: - # We are still receiving the handshake... - return - - if self.__incoming_buffer[1:20] != b"BitTorrent protocol": + try: + self._reader, self._writer = await asyncio.open_connection( + self.__peer_addr[0], self.__peer_addr[1], loop=loop) + # Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message) + self._writer.write(b"\x13BitTorrent protocol%s%s%s" % ( + b"\x00\x00\x00\x00\x00\x10\x00\x01", + self.__info_hash, + self.__random_bytes(20) + )) + # Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) badly + # designed protocols I have ever seen (I am 19 years old so what I could have seen?). + # + # Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the + # size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one canonical + # way of handshaking in the wild. + message = await self._reader.readexactly(68) + if message[1:20] != b"BitTorrent protocol": # Erroneous handshake, possibly unknown version... - logging.debug("Erroneous BitTorrent handshake! %s", self.__incoming_buffer[:68]) - self.when_error() - return + raise ProtocolError("Erroneous BitTorrent handshake! %s" % message) - self.__on_bt_handshake(self.__incoming_buffer[:68]) + self.__on_bt_handshake(message) - self.__bt_handshake_complete = True - self.__incoming_buffer = self.__incoming_buffer[68:] - - while len(self.__incoming_buffer) >= 4: - # Beware that while there are still messages in the incoming queue/buffer, one of previous messages might - # have caused an error that necessitates us to quit. - if self.__shutdown: - break - - length = int.from_bytes(self.__incoming_buffer[:4], "big") - if len(self.__incoming_buffer) - 4 < length: - # Message is still incoming... - return - - self.__on_message(self.__incoming_buffer[4:4+length]) - self.__incoming_buffer = self.__incoming_buffer[4+length:] - - def on_sendable(self) -> None: - while self.__outgoing_buffer: - try: - n_sent = self.__socket.send(self.__outgoing_buffer) - assert n_sent - self.__outgoing_buffer = self.__outgoing_buffer[n_sent:] - except BlockingIOError: - break - except OSError: - # In case -while looping- on_sendable is called after socket is closed (mostly because of an error) - return + while not self._metadata_future.done(): + buffer = await self._reader.readexactly(4) + length = int.from_bytes(buffer, "big") + message = await self._reader.readexactly(length) + self.__on_message(message) + except Exception: + logging.debug("closing %s to %s", self.__info_hash.hex(), self.__peer_addr) + finally: + if not self._metadata_future.done(): + self._metadata_future.set_result(None) + if self._writer: + self._writer.close() + return self._metadata_future.result() def __on_message(self, message: bytes) -> None: length = len(message) @@ -191,10 +130,10 @@ class DisposablePeer: # In case you cannot read_file hex: # 0x14 = 20 (BitTorrent ID indicating that it's an extended message) # 0x00 = 0 (Extension ID indicating that it's the handshake message) - self.__outgoing_buffer += b"%s\x14\x00%s" % ( + self._writer.write(b"%b\x14%s" % ( (2 + len(msg_dict_dump)).to_bytes(4, "big"), - msg_dict_dump - ) + b'\0' + msg_dict_dump + )) def __on_ext_handshake_message(self, message: bytes) -> None: if self.__ext_handshake_complete: @@ -217,21 +156,16 @@ class DisposablePeer: " {} max metadata size".format(self.__peer_addr[0], self.__peer_addr[1], self.__max_metadata_size) - except KeyError: - self.when_error() - return except AssertionError as e: logging.debug(str(e)) - self.when_error() - return + raise self.__ut_metadata = ut_metadata try: self.__metadata = bytearray(metadata_size) except MemoryError: logging.exception("Could not allocate %.1f KiB for the metadata!", metadata_size / 1024) - self.when_error() - return + raise self.__metadata_size = metadata_size self.__ext_handshake_complete = True @@ -268,13 +202,13 @@ class DisposablePeer: if self.__metadata_received == self.__metadata_size: if hashlib.sha1(self.__metadata).digest() == self.__info_hash: - self.when_metadata_found(self.__info_hash, bytes(self.__metadata)) + if not self._metadata_future.done(): + self._metadata_future.set_result(bytes(self.__metadata)) else: logging.debug("Invalid Metadata! Ignoring.") elif msg_type == 2: # reject logging.info("Peer rejected us.") - self.when_error() def __request_metadata_piece(self, piece: int) -> None: msg_dict_dump = bencode.dumps({ @@ -284,29 +218,11 @@ class DisposablePeer: # In case you cannot read_file hex: # 0x14 = 20 (BitTorrent ID indicating that it's an extended message) # 0x03 = 3 (Extension ID indicating that it's an ut_metadata message) - self.__outgoing_buffer += b"%b\x14%s%s" % ( + self._writer.write(b"%b\x14%s%s" % ( (2 + len(msg_dict_dump)).to_bytes(4, "big"), self.__ut_metadata.to_bytes(1, "big"), msg_dict_dump - ) - - def shutdown(self) -> None: - if self.__shutdown: - return - try: - self.__socket.shutdown(socket.SHUT_RDWR) - except OSError: - # OSError might be raised in case the connection to the remote peer fails: nevertheless, when_error should - # be called, and the supervisor will try to shutdown the peer, and ta da: OSError! - pass - self.__socket.close() - self.__shutdown = True - - def would_send(self) -> bool: - return bool(len(self.__outgoing_buffer)) - - def fileno(self) -> int: - return self.__socket.fileno() + )) @staticmethod def __random_bytes(n: int) -> bytes: diff --git a/magneticod/magneticod/constants.py b/magneticod/magneticod/constants.py index bc2c954..5f2b1b2 100644 --- a/magneticod/magneticod/constants.py +++ b/magneticod/magneticod/constants.py @@ -6,6 +6,7 @@ BOOTSTRAPPING_NODES = [ ] PENDING_INFO_HASHES = 10 # threshold for pending info hashes before being committed to database: -TICK_INTERVAL = 1 # in seconds (soft constraint) # maximum (inclusive) number of active (disposable) peers to fetch the metadata per info hash at the same time: MAX_ACTIVE_PEERS_PER_INFO_HASH = 5 + +PEER_TIMEOUT=120 # seconds diff --git a/magneticod/magneticod/dht.py b/magneticod/magneticod/dht.py index 5e2f057..b2e83e1 100644 --- a/magneticod/magneticod/dht.py +++ b/magneticod/magneticod/dht.py @@ -12,16 +12,17 @@ # # You should have received a copy of the GNU Affero General Public License along with this program. If not, see # . -import array -import collections +import asyncio +import itertools import zlib import logging import socket import typing import os -from .constants import BOOTSTRAPPING_NODES, DEFAULT_MAX_METADATA_SIZE +from .constants import BOOTSTRAPPING_NODES, MAX_ACTIVE_PEERS_PER_INFO_HASH, PEER_TIMEOUT from . import bencode +from . import bittorrent NodeID = bytes NodeAddress = typing.Tuple[str, int] @@ -30,107 +31,95 @@ InfoHash = bytes class SybilNode: - def __init__(self, address: typing.Tuple[str, int]): + def __init__(self, address: typing.Tuple[str, int], complete_info_hashes, max_metadata_size): self.__true_id = self.__random_bytes(20) - self.__socket = socket.socket(type=socket.SOCK_DGRAM) - self.__socket.bind(address) - self.__socket.setblocking(False) + self.__address = address - self.__incoming_buffer = array.array("B", (0 for _ in range(65536))) - self.__outgoing_queue = collections.deque() - - self.__routing_table = {} # type: typing.Dict[NodeID, NodeAddress] + self._routing_table = {} # type: typing.Dict[NodeID, NodeAddress] self.__token_secret = self.__random_bytes(4) # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # stop; but until then, the total number of neighbours might exceed the threshold). self.__n_max_neighbours = 2000 + self.__tasks = {} # type: typing.Dict[dht.InfoHash, asyncio.Future] + self._complete_info_hashes = complete_info_hashes + self.__max_metadata_size = max_metadata_size + self._metadata_q = asyncio.Queue() + self._is_paused = False + self._tick_task = None logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address) - @staticmethod - def when_peer_found(info_hash: InfoHash, peer_addr: PeerAddress) -> None: - raise NotImplementedError() + async def launch(self, loop): + self._loop = loop + await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address) - def on_tick(self) -> None: - self.__bootstrap() - self.__make_neighbours() - self.__routing_table.clear() + def connection_made(self, transport): + self._tick_task = self._loop.create_task(self.on_tick()) + self._transport = transport - def on_receivable(self) -> None: - buffer = self.__incoming_buffer - while True: - try: - _, addr = self.__socket.recvfrom_into(buffer, 65536) - data = buffer.tobytes() - except BlockingIOError: - break - except ConnectionResetError: - continue - except ConnectionRefusedError: - continue + def connection_lost(self, exc): + self._is_paused = True - # Ignore nodes that uses port 0 (assholes). - if addr[1] == 0: - continue + def pause_writing(self): + self._is_paused = True - try: - message = bencode.loads(data) - except bencode.BencodeDecodingError: - continue + def resume_writing(self): + self._is_paused = False - if isinstance(message.get(b"r"), dict) and type(message[b"r"].get(b"nodes")) is bytes: - self.__on_FIND_NODE_response(message) - elif message.get(b"q") == b"get_peers": - self.__on_GET_PEERS_query(message, addr) - elif message.get(b"q") == b"announce_peer": - self.__on_ANNOUNCE_PEER_query(message, addr) + def sendto(self, data, addr): + if self._is_paused: + return + self._transport.sendto(data, addr) - def on_sendable(self) -> None: - congestion = None - while True: - try: - addr, data = self.__outgoing_queue.pop() - except IndexError: - break - - try: - self.__socket.sendto(data, addr) - except BlockingIOError: - self.__outgoing_queue.appendleft((addr, data)) - break - except PermissionError: - # This exception (EPERM errno: 1) is kernel's way of saying that "you are far too fast, chill". - # It is also likely that we have received a ICMP source quench packet (meaning, that we really need to - # slow down. - # - # Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto - # -amp-amp-EPERM.html - congestion = True - break - except OSError: - # Pass in case of trying to send to port 0 (it is much faster to catch exceptions than using an - # if-statement). - pass - - if congestion: - self.__outgoing_queue.clear() + def error_received(self, exc): + logging.error("got error %s", exc) + if isinstance(exc, PermissionError): # In case of congestion, decrease the maximum number of nodes to the 90% of the current value. if self.__n_max_neighbours < 200: logging.warning("Maximum number of neighbours are now less than 200 due to congestion!") else: self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10 - else: - # In case of the lack of congestion, increase the maximum number of nodes by 1%. - self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100 + logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours) - def would_send(self) -> bool: - """ Whether node is waiting to write on its socket or not. """ - return bool(self.__outgoing_queue) + async def on_tick(self) -> None: + while True: + await asyncio.sleep(1) + if len(self._routing_table) == 0: + await self.__bootstrap() + self.__make_neighbours() + self._routing_table.clear() + if not self._is_paused: + self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100 - def shutdown(self) -> None: - self.__socket.close() + def datagram_received(self, data, addr) -> None: + # Ignore nodes that uses port 0 (assholes). + if addr[1] == 0: + return + + if self._transport.is_closing(): + return + + try: + message = bencode.loads(data) + except bencode.BencodeDecodingError: + return + + if isinstance(message.get(b"r"), dict) and type(message[b"r"].get(b"nodes")) is bytes: + self.__on_FIND_NODE_response(message) + elif message.get(b"q") == b"get_peers": + self.__on_GET_PEERS_query(message, addr) + elif message.get(b"q") == b"announce_peer": + self.__on_ANNOUNCE_PEER_query(message, addr) + + async def shutdown(self) -> None: + tasks = list(self.__tasks.values()) + for t in tasks: + t.set_result(None) + self._tick_task.cancel() + await asyncio.wait([self._tick_task]) + self._transport.close() def __on_FIND_NODE_response(self, message: bencode.KRPCDict) -> None: try: @@ -145,8 +134,8 @@ class SybilNode: return # Add new found nodes to the routing table, assuring that we have no more than n_max_neighbours in total. - if len(self.__routing_table) < self.__n_max_neighbours: - self.__routing_table.update(nodes) + if len(self._routing_table) < self.__n_max_neighbours: + self._routing_table.update(nodes) def __on_GET_PEERS_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None: try: @@ -157,8 +146,7 @@ class SybilNode: except (TypeError, KeyError, AssertionError): return - # appendleft to prioritise GET_PEERS responses as they are the most fruitful ones! - self.__outgoing_queue.appendleft((addr, bencode.dumps({ + data = bencode.dumps({ b"y": b"r", b"t": transaction_id, b"r": { @@ -166,7 +154,10 @@ class SybilNode: b"nodes": b"", b"token": self.__calculate_token(addr, info_hash) } - }))) + }) + # we want to prioritise GET_PEERS responses as they are the most fruitful ones! + # but there is no easy way to do this with asyncio + self.sendto(data, addr) def __on_ANNOUNCE_PEER_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None: try: @@ -189,31 +180,80 @@ class SybilNode: except (TypeError, KeyError, AssertionError): return - self.__outgoing_queue.append((addr, bencode.dumps({ + data = bencode.dumps({ b"y": b"r", b"t": transaction_id, b"r": { b"id": node_id[:15] + self.__true_id[:5] } - }))) + }) + self.sendto(data, addr) if implied_port: peer_addr = (addr[0], addr[1]) else: peer_addr = (addr[0], port) - self.when_peer_found(info_hash, peer_addr) + if info_hash in self._complete_info_hashes: + return - def fileno(self) -> int: - return self.__socket.fileno() + # create the parent future + if info_hash not in self.__tasks: + parent_f = self._loop.create_future() + parent_f.child_count = 0 + parent_f.add_done_callback(lambda f: self._parent_task_done(f, info_hash)) + self.__tasks[info_hash] = parent_f - def __bootstrap(self) -> None: - for addr in BOOTSTRAPPING_NODES: - self.__outgoing_queue.append((addr, self.__build_FIND_NODE_query(self.__true_id))) + parent_f = self.__tasks[info_hash] + + if parent_f.done(): + return + if parent_f.child_count > MAX_ACTIVE_PEERS_PER_INFO_HASH: + return + + task = asyncio.ensure_future(bittorrent.fetch_metadata( + info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT)) + task.add_done_callback(lambda task: self._got_child_result(parent_f, task)) + parent_f.child_count += 1 + parent_f.add_done_callback(lambda f: task.cancel()) + + def _got_child_result(self, parent_task, child_task): + parent_task.child_count -= 1 + try: + metadata = child_task.result() + if metadata and not parent_task.done(): + parent_task.set_result(metadata) + except asyncio.CancelledError: + pass + except Exception: + logging.exception("child result is exception") + if parent_task.child_count <= 0 and not parent_task.done(): + parent_task.set_result(None) + + def _parent_task_done(self, parent_task, info_hash): + try: + metadata = parent_task.result() + if metadata: + self._complete_info_hashes.add(info_hash) + self._metadata_q.put_nowait((info_hash, metadata)) + except asyncio.CancelledError: + pass + del self.__tasks[info_hash] + + async def __bootstrap(self) -> None: + for node in BOOTSTRAPPING_NODES: + try: + # AF_INET means ip4 only + responses = await self._loop.getaddrinfo(*node, family=socket.AF_INET) + for (family, type, proto, canonname, sockaddr) in responses: + data = self.__build_FIND_NODE_query(self.__true_id) + self.sendto(data, sockaddr) + except Exception: + logging.exception("bootstrap problem") def __make_neighbours(self) -> None: - for node_id, addr in self.__routing_table.items(): - self.__outgoing_queue.append((addr, self.__build_FIND_NODE_query(node_id[:15] + self.__true_id[:5]))) + for node_id, addr in self._routing_table.items(): + self.sendto(self.__build_FIND_NODE_query(node_id[:15] + self.__true_id[:5]), addr) @staticmethod def __decode_nodes(infos: bytes) -> typing.List[typing.Tuple[NodeID, NodeAddress]]: