Properly clean up fetch_metadata tasks.

This commit is contained in:
Richard Kiss 2017-05-17 13:16:30 -07:00
parent 29b99a338e
commit 9b1bbfcaa1
2 changed files with 35 additions and 27 deletions

View File

@ -25,9 +25,17 @@ InfoHash = bytes
PeerAddress = typing.Tuple[str, int] PeerAddress = typing.Tuple[str, int]
async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size): async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size, timeout=None):
return await DisposablePeer().run( loop = asyncio.get_event_loop()
asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size) task = asyncio.ensure_future(DisposablePeer().run(
asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size))
h = None
if timeout is not None:
h = loop.call_later(timeout, lambda: task.cancel())
try:
return await task
except asyncio.CancelledError:
return None
class ProtocolError(Exception): class ProtocolError(Exception):
@ -80,12 +88,13 @@ class DisposablePeer:
length = int.from_bytes(buffer, "big") length = int.from_bytes(buffer, "big")
message = await self._reader.readexactly(length) message = await self._reader.readexactly(length)
self.__on_message(message) self.__on_message(message)
except Exception as ex: except Exception:
logging.debug("closing %s to %s", self.__info_hash.hex(), self.__peer_addr) logging.debug("closing %s to %s", self.__info_hash.hex(), self.__peer_addr)
finally:
if not self._metadata_future.done(): if not self._metadata_future.done():
self._metadata_future.set_result(None) self._metadata_future.set_result(None)
if self._writer: if self._writer:
self._writer.close() self._writer.close()
return self._metadata_future.result() return self._metadata_future.result()
def __on_message(self, message: bytes) -> None: def __on_message(self, message: bytes) -> None:
@ -199,7 +208,7 @@ class DisposablePeer:
if self.__metadata_received == self.__metadata_size: if self.__metadata_received == self.__metadata_size:
if hashlib.sha1(self.__metadata).digest() == self.__info_hash: if hashlib.sha1(self.__metadata).digest() == self.__info_hash:
if not self._metadata_future.done(): if not self._metadata_future.done():
self._metadata_future.set_result((self.__info_hash, bytes(self.__metadata))) self._metadata_future.set_result(bytes(self.__metadata))
else: else:
logging.debug("Invalid Metadata! Ignoring.") logging.debug("Invalid Metadata! Ignoring.")

View File

@ -43,8 +43,8 @@ class SybilNode:
# Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will
# stop; but until then, the total number of neighbours might exceed the threshold). # stop; but until then, the total number of neighbours might exceed the threshold).
self.__n_max_neighbours = 2000 self.__n_max_neighbours = 2000
self.__peers = collections.defaultdict( self.__tasks = collections.defaultdict(
list) # type: typing.DefaultDict[dht.InfoHash, typing.List[bittorrent.DisposablePeer]] set) # type: typing.DefaultDict[dht.InfoHash, typing.Set[asyncio.Task]]
self._complete_info_hashes = complete_info_hashes self._complete_info_hashes = complete_info_hashes
self.__max_metadata_size = max_metadata_size self.__max_metadata_size = max_metadata_size
self._metadata_q = asyncio.Queue() self._metadata_q = asyncio.Queue()
@ -116,7 +116,7 @@ class SybilNode:
self.__on_ANNOUNCE_PEER_query(message, addr) self.__on_ANNOUNCE_PEER_query(message, addr)
async def shutdown(self) -> None: async def shutdown(self) -> None:
futures = [peer for peer in itertools.chain.from_iterable(self.__peers.values())] futures = [task for task in itertools.chain.from_iterable(self.__tasks.values())]
if self._tick_task: if self._tick_task:
futures.append(self._tick_task) futures.append(self._tick_task)
for future in futures: for future in futures:
@ -197,25 +197,24 @@ class SybilNode:
else: else:
peer_addr = (addr[0], port) peer_addr = (addr[0], port)
if len(self.__peers[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH or \ if info_hash in self._complete_info_hashes or \
info_hash in self._complete_info_hashes: len(self.__tasks[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH:
return return
task = self._loop.create_task(bittorrent.fetch_metadata(
info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT))
self.__tasks[info_hash].add(task)
task.add_done_callback(lambda f: self._got_result(task, info_hash))
peer = self._loop.create_task(self._launch_fetch(info_hash, peer_addr)) def _got_result(self, task, info_hash):
self.__peers[info_hash].append(peer) task_set = self.__tasks[info_hash]
metadata = task.result()
async def _launch_fetch(self, info_hash, peer_addr): if metadata:
try: self._complete_info_hashes.add(info_hash)
f = bittorrent.fetch_metadata(info_hash, peer_addr, self.__max_metadata_size) self._metadata_q.put_nowait((info_hash, metadata))
r = await asyncio.wait_for(f, timeout=PEER_TIMEOUT) for task in task_set:
if r: task.cancel()
info_hash, metadata = r if len(task_set) == 0:
for peer in self.__peers[info_hash]: del self.__tasks[info_hash]
peer.cancel()
self._complete_info_hashes.add(info_hash)
await self._metadata_q.put(r)
except asyncio.TimeoutError:
pass
async def __bootstrap(self) -> None: async def __bootstrap(self) -> None:
for node in BOOTSTRAPPING_NODES: for node in BOOTSTRAPPING_NODES: