From 8df4015e06d099b58e7fb215f51ef73fcd2c5c2e Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 24 May 2017 12:36:47 -0700 Subject: [PATCH] Be a little smarter with task clean-up. --- magneticod/magneticod/dht.py | 54 ++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/magneticod/magneticod/dht.py b/magneticod/magneticod/dht.py index ae6092b..ea2dcd2 100644 --- a/magneticod/magneticod/dht.py +++ b/magneticod/magneticod/dht.py @@ -13,7 +13,6 @@ # You should have received a copy of the GNU Affero General Public License along with this program. If not, see # . import asyncio -import collections import itertools import zlib import logging @@ -43,8 +42,7 @@ class SybilNode: # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # stop; but until then, the total number of neighbours might exceed the threshold). self.__n_max_neighbours = 2000 - self.__tasks = collections.defaultdict( - set) # type: typing.DefaultDict[dht.InfoHash, typing.Set[asyncio.Task]] + self.__tasks = {} # type: typing.Dict[dht.InfoHash, asyncio.Future] self._complete_info_hashes = complete_info_hashes self.__max_metadata_size = max_metadata_size self._metadata_q = asyncio.Queue() @@ -116,7 +114,7 @@ class SybilNode: self.__on_ANNOUNCE_PEER_query(message, addr) async def shutdown(self) -> None: - futures = [task for task in itertools.chain.from_iterable(self.__tasks.values())] + futures = list(self.__tasks.values()) if self._tick_task: futures.append(self._tick_task) for future in futures: @@ -197,24 +195,46 @@ class SybilNode: else: peer_addr = (addr[0], port) - if info_hash in self._complete_info_hashes or \ - len(self.__tasks[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH: + if info_hash in self._complete_info_hashes: return - task = self._loop.create_task(bittorrent.fetch_metadata( - info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT)) - self.__tasks[info_hash].add(task) - task.add_done_callback(lambda f: self._got_result(task, info_hash)) - def _got_result(self, task, info_hash): - task_set = self.__tasks[info_hash] - metadata = task.result() + # create the parent future + if info_hash not in self.__tasks: + parent_f = self._loop.create_future() + parent_f.child_count = 0 + parent_f.add_done_callback(lambda f: self._parent_task_done(f, info_hash)) + self.__tasks[info_hash] = parent_f + + parent_f = self.__tasks[info_hash] + + if parent_f.done(): + return + if parent_f.child_count > MAX_ACTIVE_PEERS_PER_INFO_HASH: + return + + task = asyncio.ensure_future(bittorrent.fetch_metadata( + info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT)) + task.add_done_callback(lambda task: self._got_child_result(parent_f, task)) + parent_f.child_count += 1 + parent_f.add_done_callback(lambda f: task.cancel()) + + def _got_child_result(self, parent_task, child_task): + parent_task.child_count -= 1 + try: + metadata = child_task.result() + if metadata and not parent_task.done(): + parent_task.set_result(metadata) + except Exception: + logging.exception("child result is exception") + if parent_task.child_count <= 0 and not parent_task.done(): + parent_task.set_result(None) + + def _parent_task_done(self, parent_task, info_hash): + metadata = parent_task.result() if metadata: self._complete_info_hashes.add(info_hash) self._metadata_q.put_nowait((info_hash, metadata)) - for task in task_set: - task.cancel() - if len(task_set) == 0: - del self.__tasks[info_hash] + del self.__tasks[info_hash] async def __bootstrap(self) -> None: for node in BOOTSTRAPPING_NODES: