diff --git a/magneticod/magneticod/bittorrent.py b/magneticod/magneticod/bittorrent.py index 7a9a60a..1f9e0c7 100644 --- a/magneticod/magneticod/bittorrent.py +++ b/magneticod/magneticod/bittorrent.py @@ -25,9 +25,17 @@ InfoHash = bytes PeerAddress = typing.Tuple[str, int] -async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size): - return await DisposablePeer().run( - asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size) +async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size, timeout=None): + loop = asyncio.get_event_loop() + task = asyncio.ensure_future(DisposablePeer().run( + asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size)) + h = None + if timeout is not None: + h = loop.call_later(timeout, lambda: task.cancel()) + try: + return await task + except asyncio.CancelledError: + return None class ProtocolError(Exception): @@ -80,12 +88,13 @@ class DisposablePeer: length = int.from_bytes(buffer, "big") message = await self._reader.readexactly(length) self.__on_message(message) - except Exception as ex: + except Exception: logging.debug("closing %s to %s", self.__info_hash.hex(), self.__peer_addr) + finally: if not self._metadata_future.done(): self._metadata_future.set_result(None) - if self._writer: - self._writer.close() + if self._writer: + self._writer.close() return self._metadata_future.result() def __on_message(self, message: bytes) -> None: @@ -199,7 +208,7 @@ class DisposablePeer: if self.__metadata_received == self.__metadata_size: if hashlib.sha1(self.__metadata).digest() == self.__info_hash: if not self._metadata_future.done(): - self._metadata_future.set_result((self.__info_hash, bytes(self.__metadata))) + self._metadata_future.set_result(bytes(self.__metadata)) else: logging.debug("Invalid Metadata! Ignoring.") diff --git a/magneticod/magneticod/dht.py b/magneticod/magneticod/dht.py index 56e2417..ae6092b 100644 --- a/magneticod/magneticod/dht.py +++ b/magneticod/magneticod/dht.py @@ -43,8 +43,8 @@ class SybilNode: # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # stop; but until then, the total number of neighbours might exceed the threshold). self.__n_max_neighbours = 2000 - self.__peers = collections.defaultdict( - list) # type: typing.DefaultDict[dht.InfoHash, typing.List[bittorrent.DisposablePeer]] + self.__tasks = collections.defaultdict( + set) # type: typing.DefaultDict[dht.InfoHash, typing.Set[asyncio.Task]] self._complete_info_hashes = complete_info_hashes self.__max_metadata_size = max_metadata_size self._metadata_q = asyncio.Queue() @@ -116,7 +116,7 @@ class SybilNode: self.__on_ANNOUNCE_PEER_query(message, addr) async def shutdown(self) -> None: - futures = [peer for peer in itertools.chain.from_iterable(self.__peers.values())] + futures = [task for task in itertools.chain.from_iterable(self.__tasks.values())] if self._tick_task: futures.append(self._tick_task) for future in futures: @@ -197,25 +197,24 @@ class SybilNode: else: peer_addr = (addr[0], port) - if len(self.__peers[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH or \ - info_hash in self._complete_info_hashes: + if info_hash in self._complete_info_hashes or \ + len(self.__tasks[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH: return + task = self._loop.create_task(bittorrent.fetch_metadata( + info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT)) + self.__tasks[info_hash].add(task) + task.add_done_callback(lambda f: self._got_result(task, info_hash)) - peer = self._loop.create_task(self._launch_fetch(info_hash, peer_addr)) - self.__peers[info_hash].append(peer) - - async def _launch_fetch(self, info_hash, peer_addr): - try: - f = bittorrent.fetch_metadata(info_hash, peer_addr, self.__max_metadata_size) - r = await asyncio.wait_for(f, timeout=PEER_TIMEOUT) - if r: - info_hash, metadata = r - for peer in self.__peers[info_hash]: - peer.cancel() - self._complete_info_hashes.add(info_hash) - await self._metadata_q.put(r) - except asyncio.TimeoutError: - pass + def _got_result(self, task, info_hash): + task_set = self.__tasks[info_hash] + metadata = task.result() + if metadata: + self._complete_info_hashes.add(info_hash) + self._metadata_q.put_nowait((info_hash, metadata)) + for task in task_set: + task.cancel() + if len(task_set) == 0: + del self.__tasks[info_hash] async def __bootstrap(self) -> None: for node in BOOTSTRAPPING_NODES: