huge commit, code review done on asyncio port

This commit is contained in:
Bora M. Alper 2017-06-02 15:34:22 +03:00
parent 04e6d583f3
commit 6a459d5e58
4 changed files with 106 additions and 63 deletions

View File

@ -41,9 +41,9 @@ def main():
try: try:
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logging.info("using uvloop") logging.info("uvloop is being used")
except ImportError: except ImportError:
pass logging.exception("uvloop could not be imported, using the default asyncio implementation")
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -57,9 +57,9 @@ def main():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size) node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size)
loop.run_until_complete(node.launch(loop)) loop.run_until_complete(node.launch())
watch_q_task = loop.create_task(watch_q(database, node._metadata_q)) watch_q_task = loop.create_task(metadata_queue_watcher(database, node.__metadata_queue))
try: try:
loop.run_forever() loop.run_forever()
@ -74,9 +74,12 @@ def main():
return 0 return 0
async def watch_q(database, q): async def metadata_queue_watcher(database: persistence.Database, metadata_queue: asyncio.Queue) -> None:
"""
Watches for the metadata queue to commit any complete info hashes to the database.
"""
while True: while True:
info_hash, metadata = await q.get() info_hash, metadata = await metadata_queue.get()
succeeded = database.add_metadata(info_hash, metadata) succeeded = database.add_metadata(info_hash, metadata)
if not succeeded: if not succeeded:
logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex()) logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex())

View File

@ -25,10 +25,9 @@ InfoHash = bytes
PeerAddress = typing.Tuple[str, int] PeerAddress = typing.Tuple[str, int]
async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size, timeout=None): async def fetch_metadata_from_peer(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int, timeout=None):
try: try:
return await asyncio.wait_for(DisposablePeer().run( return await asyncio.wait_for(DisposablePeer(info_hash, peer_addr, max_metadata_size).run(), timeout=timeout)
asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size), timeout=timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
return None return None
@ -38,39 +37,40 @@ class ProtocolError(Exception):
class DisposablePeer: class DisposablePeer:
async def run(self, loop, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int): def __init__(self, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int) -> None:
self.__peer_addr = peer_addr self.__peer_addr = peer_addr
self.__info_hash = info_hash self.__info_hash = info_hash
self.__max_metadata_size = max_metadata_size
self.__bt_handshake_complete = False # BitTorrent Handshake
self.__ext_handshake_complete = False # Extension Handshake self.__ext_handshake_complete = False # Extension Handshake
self.__ut_metadata = None # Since we don't know ut_metadata code that remote peer uses... self.__ut_metadata = None # Since we don't know ut_metadata code that remote peer uses...
self.__max_metadata_size = max_metadata_size
self.__metadata_size = None self.__metadata_size = None
self.__metadata_received = 0 # Amount of metadata bytes received... self.__metadata_received = 0 # Amount of metadata bytes received...
self.__metadata = None self.__metadata = None
self._run_task = None
self._metadata_future = loop.create_future() self._run_task = None
self._writer = None self._writer = None
async def run(self):
event_loop = asyncio.get_event_loop()
self._metadata_future = event_loop.create_future()
try: try:
self._reader, self._writer = await asyncio.open_connection( self._reader, self._writer = await asyncio.open_connection(*self.__peer_addr, loop=event_loop)
self.__peer_addr[0], self.__peer_addr[1], loop=loop)
# Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message) # Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message)
self._writer.write(b"\x13BitTorrent protocol%s%s%s" % ( self._writer.write(b"\x13BitTorrent protocol%s%s%s" % (
b"\x00\x00\x00\x00\x00\x10\x00\x01", b"\x00\x00\x00\x00\x00\x10\x00\x01",
self.__info_hash, self.__info_hash,
self.__random_bytes(20) self.__random_bytes(20)
)) ))
# Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) badly # Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but)
# designed protocols I have ever seen (I am 19 years old so what I could have seen?). # badly designed protocols I have ever seen (I am 19 years old so what I could have seen?).
# #
# Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the # Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the
# size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one canonical # size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one
# way of handshaking in the wild. # canonical way of handshaking in the wild.
message = await self._reader.readexactly(68) message = await self._reader.readexactly(68)
if message[1:20] != b"BitTorrent protocol": if message[1:20] != b"BitTorrent protocol":
# Erroneous handshake, possibly unknown version... # Erroneous handshake, possibly unknown version...
@ -93,12 +93,6 @@ class DisposablePeer:
return self._metadata_future.result() return self._metadata_future.result()
def __on_message(self, message: bytes) -> None: def __on_message(self, message: bytes) -> None:
length = len(message)
if length < 2:
# An extension message has minimum length of 2.
return
# Every extension message has BitTorrent Message ID = 20 # Every extension message has BitTorrent Message ID = 20
if message[0] != 20: if message[0] != 20:
# logging.debug("Message is NOT an EXTension message! %s", message[:200]) # logging.debug("Message is NOT an EXTension message! %s", message[:200])
@ -127,7 +121,7 @@ class DisposablePeer:
b"ut_metadata": 1 b"ut_metadata": 1
} }
}) })
# In case you cannot read_file hex: # In case you cannot read hex:
# 0x14 = 20 (BitTorrent ID indicating that it's an extended message) # 0x14 = 20 (BitTorrent ID indicating that it's an extended message)
# 0x00 = 0 (Extension ID indicating that it's the handshake message) # 0x00 = 0 (Extension ID indicating that it's the handshake message)
self._writer.write(b"%b\x14%s" % ( self._writer.write(b"%b\x14%s" % (

View File

@ -6,6 +6,8 @@ BOOTSTRAPPING_NODES = [
] ]
PENDING_INFO_HASHES = 10 # threshold for pending info hashes before being committed to database: PENDING_INFO_HASHES = 10 # threshold for pending info hashes before being committed to database:
TICK_INTERVAL = 1 # in seconds
# maximum (inclusive) number of active (disposable) peers to fetch the metadata per info hash at the same time: # maximum (inclusive) number of active (disposable) peers to fetch the metadata per info hash at the same time:
MAX_ACTIVE_PEERS_PER_INFO_HASH = 5 MAX_ACTIVE_PEERS_PER_INFO_HASH = 5

View File

@ -20,7 +20,7 @@ import socket
import typing import typing
import os import os
from .constants import BOOTSTRAPPING_NODES, MAX_ACTIVE_PEERS_PER_INFO_HASH, PEER_TIMEOUT from .constants import BOOTSTRAPPING_NODES, MAX_ACTIVE_PEERS_PER_INFO_HASH, PEER_TIMEOUT, TICK_INTERVAL
from . import bencode from . import bencode
from . import bittorrent from . import bittorrent
@ -28,9 +28,10 @@ NodeID = bytes
NodeAddress = typing.Tuple[str, int] NodeAddress = typing.Tuple[str, int]
PeerAddress = typing.Tuple[str, int] PeerAddress = typing.Tuple[str, int]
InfoHash = bytes InfoHash = bytes
Metadata = bytes
class SybilNode: class SybilNode(asyncio.DatagramProtocol):
def __init__(self, address: typing.Tuple[str, int], complete_info_hashes, max_metadata_size): def __init__(self, address: typing.Tuple[str, int], complete_info_hashes, max_metadata_size):
self.__true_id = self.__random_bytes(20) self.__true_id = self.__random_bytes(20)
@ -42,55 +43,69 @@ class SybilNode:
# Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will
# stop; but until then, the total number of neighbours might exceed the threshold). # stop; but until then, the total number of neighbours might exceed the threshold).
self.__n_max_neighbours = 2000 self.__n_max_neighbours = 2000
self.__tasks = {} # type: typing.Dict[dht.InfoHash, asyncio.Future] self.__parent_futures = {} # type: typing.Dict[dht.InfoHash, asyncio.Future]
self._complete_info_hashes = complete_info_hashes self._complete_info_hashes = complete_info_hashes
self.__max_metadata_size = max_metadata_size self.__max_metadata_size = max_metadata_size
self._metadata_q = asyncio.Queue() # Complete metadatas will be added to the queue, to be retrieved and committed to the database.
self._is_paused = False self.__metadata_queue = asyncio.Queue() # typing.Collection[typing.Tuple[InfoHash, Metadata]]
self._is_writing_paused = False
self._tick_task = None self._tick_task = None
logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address) logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address)
async def launch(self, loop): async def launch(self) -> None:
self._loop = loop event_loop = asyncio.get_event_loop()
await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address) await event_loop.create_datagram_endpoint(lambda: self, local_addr=self.__address)
def connection_made(self, transport): def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self._tick_task = self._loop.create_task(self.on_tick()) event_loop = asyncio.get_event_loop()
self._tick_task = event_loop.create_task(self.tick_periodically())
self._transport = transport self._transport = transport
def connection_lost(self, exc): def connection_lost(self, exc) -> None:
self._is_paused = True logging.critical("SybilNode's connection is lost.")
self._is_writing_paused = True
def pause_writing(self): def pause_writing(self) -> None:
self._is_paused = True self._is_writing_paused = True
# In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10
logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours)
def resume_writing(self): def resume_writing(self) -> None:
self._is_paused = False self._is_writing_paused = False
def sendto(self, data, addr): def sendto(self, data, addr) -> None:
if self._is_paused: if not self._is_writing_paused:
return
self._transport.sendto(data, addr) self._transport.sendto(data, addr)
def error_received(self, exc): def error_received(self, exc: Exception) -> None:
logging.error("got error %s", exc)
if isinstance(exc, PermissionError): if isinstance(exc, PermissionError):
# This exception (EPERM errno: 1) is kernel's way of saying that "you are far too fast, chill".
# It is also likely that we have received a ICMP source quench packet (meaning, that we really need to
# slow down.
#
# Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto
# -amp-amp-EPERM.html
#
# In case of congestion, decrease the maximum number of nodes to the 90% of the current value. # In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
if self.__n_max_neighbours < 200: if self.__n_max_neighbours < 200:
logging.warning("Maximum number of neighbours are now less than 200 due to congestion!") logging.warning("Maximum number of neighbours are now less than 200 due to congestion!")
else: else:
self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10 self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10
logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours) logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours)
else:
# The previous "exception" was kind of "unexceptional", but we should log anything else.
logging.error("SybilNode operational error: `%s`", exc)
async def on_tick(self) -> None: async def tick_periodically(self) -> None:
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(TICK_INTERVAL)
if len(self._routing_table) == 0: if not self._routing_table:
await self.__bootstrap() await self.__bootstrap()
self.__make_neighbours() self.__make_neighbours()
self._routing_table.clear() self._routing_table.clear()
if not self._is_paused: if not self._is_writing_paused:
self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100 self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
@ -114,7 +129,7 @@ class SybilNode:
self.__on_ANNOUNCE_PEER_query(message, addr) self.__on_ANNOUNCE_PEER_query(message, addr)
async def shutdown(self) -> None: async def shutdown(self) -> None:
tasks = list(self.__tasks.values()) tasks = list(self.__parent_futures.values())
for t in tasks: for t in tasks:
t.set_result(None) t.set_result(None)
self._tick_task.cancel() self._tick_task.cancel()
@ -197,21 +212,36 @@ class SybilNode:
if info_hash in self._complete_info_hashes: if info_hash in self._complete_info_hashes:
return return
event_loop = asyncio.get_event_loop()
# A little clarification about parent and child futures might be really useful here:
# For every info hash we are interested in, we create ONE parent future and save it under self.__tasks
# (info_hash -> task) dictionary.
# For EVERY DisposablePeer working to fetch the metadata of that info hash, we create a child future. Hence, for
# every parent future, there should be *at least* one child future.
#
# Parent and child futures are "connected" to each other through `add_done_callback` functionality:
# When a child is successfully done, it sets the result of its parent (`set_result()`), and if it was
# unsuccessful to fetch the metadata, it just checks whether there are any other child futures left and if not
# it terminates the parent future (by setting its result to None) and quits.
# When a parent future is successfully done, (through the callback) it adds the info hash to the set of
# completed metadatas and puts the metadata in the queue to be committed to the database.
# create the parent future # create the parent future
if info_hash not in self.__tasks: if info_hash not in self.__parent_futures:
parent_f = self._loop.create_future() parent_f = event_loop.create_future()
parent_f.child_count = 0 parent_f.child_count = 0
parent_f.add_done_callback(lambda f: self._parent_task_done(f, info_hash)) parent_f.add_done_callback(lambda f: self._parent_task_done(f, info_hash))
self.__tasks[info_hash] = parent_f self.__parent_futures[info_hash] = parent_f
parent_f = self.__tasks[info_hash] parent_f = self.__parent_futures[info_hash]
if parent_f.done(): if parent_f.done():
return return
if parent_f.child_count > MAX_ACTIVE_PEERS_PER_INFO_HASH: if parent_f.child_count > MAX_ACTIVE_PEERS_PER_INFO_HASH:
return return
task = asyncio.ensure_future(bittorrent.fetch_metadata( task = asyncio.ensure_future(bittorrent.fetch_metadata_from_peer(
info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT)) info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT))
task.add_done_callback(lambda task: self._got_child_result(parent_f, task)) task.add_done_callback(lambda task: self._got_child_result(parent_f, task))
parent_f.child_count += 1 parent_f.child_count += 1
@ -221,6 +251,19 @@ class SybilNode:
parent_task.child_count -= 1 parent_task.child_count -= 1
try: try:
metadata = child_task.result() metadata = child_task.result()
# Bora asked:
# Why do we check for parent_task being done here when a child got result? I mean, if parent_task is
# done before, and successful, all of its childs will be terminated and this function cannot be called
# anyway.
#
# --- https://github.com/boramalper/magnetico/pull/76#discussion_r119555423
#
# Suppose two child tasks are fetching the same metadata for a parent and they finish at the same time
# (or very close). The first one wakes up, sets the parent_task result which will cause the done
# callback to be scheduled. The scheduler might still then chooses the second child task to run next
# (why not? It's been waiting longer) before the parent has a chance to cancel it.
#
# Thus spoke Richard.
if metadata and not parent_task.done(): if metadata and not parent_task.done():
parent_task.set_result(metadata) parent_task.set_result(metadata)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -235,16 +278,17 @@ class SybilNode:
metadata = parent_task.result() metadata = parent_task.result()
if metadata: if metadata:
self._complete_info_hashes.add(info_hash) self._complete_info_hashes.add(info_hash)
self._metadata_q.put_nowait((info_hash, metadata)) self.__metadata_queue.put_nowait((info_hash, metadata))
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
del self.__tasks[info_hash] del self.__parent_futures[info_hash]
async def __bootstrap(self) -> None: async def __bootstrap(self) -> None:
event_loop = asyncio.get_event_loop()
for node in BOOTSTRAPPING_NODES: for node in BOOTSTRAPPING_NODES:
try: try:
# AF_INET means ip4 only # AF_INET means ip4 only
responses = await self._loop.getaddrinfo(*node, family=socket.AF_INET) responses = await event_loop.getaddrinfo(*node, family=socket.AF_INET)
for (family, type, proto, canonname, sockaddr) in responses: for (family, type, proto, canonname, sockaddr) in responses:
data = self.__build_FIND_NODE_query(self.__true_id) data = self.__build_FIND_NODE_query(self.__true_id)
self.sendto(data, sockaddr) self.sendto(data, sockaddr)