general cleanup, performance improvements, bug fixes

* Removed unnecessary functions such as those that just wraps a standard
  library function (e.g. '__random_bytes()' in SybilNode), and those
  that are wrongly abstracted (e.g. `cleanup` in __main__.py)

* Created `__build_GET_PEERS_query()` and `__build_ANNOUNCE_PEER()` in
  SybilNode to eliminate the cost of calling `bencode.dumps()` in these
  critical functions.

* Added some more comments to explain the rationale behind some
  decisions in-place.

* Improved our still-primitive congestion control support for BSD-based
  OSes, including OS X.
This commit is contained in:
Bora M. Alper 2017-06-11 15:27:31 +03:00
parent a083bf40f9
commit f1f0b9531d
3 changed files with 116 additions and 97 deletions

View File

@ -31,42 +31,6 @@ from . import dht
from . import persistence from . import persistence
def create_tasks():
arguments = parse_cmdline_arguments(sys.argv[1:])
logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s")
logging.info("magneticod v%d.%d.%d started", *__version__)
# use uvloop if it's installed
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logging.info("uvloop is in use")
except ImportError:
if sys.platform not in ["linux", "darwin"]:
logging.warning("uvloop could not be imported, using the default asyncio implementation")
# noinspection PyBroadException
try:
path = arguments.database_file
database = persistence.Database(path)
except:
logging.exception("could NOT connect to the database!")
return 1
loop = asyncio.get_event_loop()
node = dht.SybilNode(arguments.node_addr, database.is_infohash_new, arguments.max_metadata_size)
loop.create_task(node.launch(loop))
metadata_queue_watcher_task = loop.create_task(metadata_queue_watcher(database, node.metadata_q()))
metadata_queue_watcher_task.add_done_callback(lambda x: clean_up(loop, database, node))
return metadata_queue_watcher_task
def clean_up(loop, database, node):
database.close()
loop.run_until_complete(node.shutdown())
async def metadata_queue_watcher(database: persistence.Database, metadata_queue: asyncio.Queue) -> None: async def metadata_queue_watcher(database: persistence.Database, metadata_queue: asyncio.Queue) -> None:
""" """
Watches for the metadata queue to commit any complete info hashes to the database. Watches for the metadata queue to commit any complete info hashes to the database.
@ -152,14 +116,42 @@ def parse_cmdline_arguments(args) -> typing.Optional[argparse.Namespace]:
return parser.parse_args(args) return parser.parse_args(args)
def main(): def main() -> int:
main_task = create_tasks() # main_task = create_tasks()
arguments = parse_cmdline_arguments(sys.argv[1:])
logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s")
logging.info("magneticod v%d.%d.%d started", *__version__)
# use uvloop if it's installed
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logging.info("uvloop is in use")
except ImportError:
if sys.platform not in ["linux", "darwin"]:
logging.warning("uvloop could not be imported, using the default asyncio implementation")
# noinspection PyBroadException
try:
database = persistence.Database(arguments.database_file)
except:
logging.exception("could NOT connect to the database!")
return 1
loop = asyncio.get_event_loop()
node = dht.SybilNode(database.is_infohash_new, arguments.max_metadata_size)
loop.create_task(node.launch(arguments.node_addr))
metadata_queue_watcher_task = loop.create_task(metadata_queue_watcher(database, node.metadata_q()))
try: try:
asyncio.get_event_loop().run_forever() asyncio.get_event_loop().run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
logging.critical("Keyboard interrupt received! Exiting gracefully...") logging.critical("Keyboard interrupt received! Exiting gracefully...")
finally: finally:
main_task.cancel() metadata_queue_watcher_task.cancel()
database.close()
loop.run_until_complete(node.shutdown())
return 0 return 0

View File

@ -63,7 +63,7 @@ class DisposablePeer:
self._writer.write(b"\x13BitTorrent protocol%s%s%s" % ( self._writer.write(b"\x13BitTorrent protocol%s%s%s" % (
b"\x00\x00\x00\x00\x00\x10\x00\x01", b"\x00\x00\x00\x00\x00\x10\x00\x01",
self.__info_hash, self.__info_hash,
self.__random_bytes(20) os.urandom(20)
)) ))
# Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) # Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but)
# badly designed protocols I have ever seen (I am 19 years old so what I could have seen?). # badly designed protocols I have ever seen (I am 19 years old so what I could have seen?).
@ -217,7 +217,3 @@ class DisposablePeer:
self.__ut_metadata.to_bytes(1, "big"), self.__ut_metadata.to_bytes(1, "big"),
msg_dict_dump msg_dict_dump
)) ))
@staticmethod
def __random_bytes(n: int) -> bytes:
return os.urandom(n)

View File

@ -13,6 +13,7 @@
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see # You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>. # <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import errno
import zlib import zlib
import logging import logging
import socket import socket
@ -30,15 +31,13 @@ InfoHash = bytes
Metadata = bytes Metadata = bytes
class SybilNode: class SybilNode(asyncio.DatagramProtocol):
def __init__(self, address: typing.Tuple[str, int], is_infohash_new, max_metadata_size): def __init__(self, is_infohash_new, max_metadata_size):
self.__true_id = self.__random_bytes(20) self.__true_id = os.urandom(20)
self.__address = address
self._routing_table = {} # type: typing.Dict[NodeID, NodeAddress] self._routing_table = {} # type: typing.Dict[NodeID, NodeAddress]
self.__token_secret = self.__random_bytes(4) self.__token_secret = os.urandom(4)
# Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will
# stop; but until then, the total number of neighbours might exceed the threshold). # stop; but until then, the total number of neighbours might exceed the threshold).
self.__n_max_neighbours = 2000 self.__n_max_neighbours = 2000
@ -50,18 +49,17 @@ class SybilNode:
self._is_writing_paused = False self._is_writing_paused = False
self._tick_task = None self._tick_task = None
logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address) logging.info("SybilNode %s initialized!", self.__true_id.hex().upper())
def metadata_q(self): def metadata_q(self):
return self.__metadata_queue return self.__metadata_queue
async def launch(self, loop): async def launch(self, address):
self._loop = loop await asyncio.get_event_loop().create_datagram_endpoint(lambda: self, local_addr=address)
await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address) logging.info("SybliNode is launched on %s!", address)
def connection_made(self, transport: asyncio.DatagramTransport) -> None: def connection_made(self, transport: asyncio.DatagramTransport) -> None:
event_loop = asyncio.get_event_loop() self._tick_task = asyncio.get_event_loop().create_task(self.tick_periodically())
self._tick_task = event_loop.create_task(self.tick_periodically())
self._transport = transport self._transport = transport
def connection_lost(self, exc) -> None: def connection_lost(self, exc) -> None:
@ -82,17 +80,24 @@ class SybilNode:
self._transport.sendto(data, addr) self._transport.sendto(data, addr)
def error_received(self, exc: Exception) -> None: def error_received(self, exc: Exception) -> None:
if isinstance(exc, PermissionError): if isinstance(exc, PermissionError) or (isinstance(exc, OSError) and errno.ENOBUFS):
# This exception (EPERM errno: 1) is kernel's way of saying that "you are far too fast, chill". # This exception (EPERM errno: 1) is kernel's way of saying that "you are far too fast, chill".
# It is also likely that we have received a ICMP source quench packet (meaning, that we really need to # It is also likely that we have received a ICMP source quench packet (meaning, that we really need to
# slow down. # slow down.
# #
# Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto # Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto
# -amp-amp-EPERM.html # -amp-amp-EPERM.html
#
# > Note On BSD systems (OS X, FreeBSD, etc.) flow control is not supported for DatagramProtocol, because
# > send failures caused by writing too many packets cannot be detected easily. The socket always appears
# > ready and excess packets are dropped; an OSError with errno set to errno.ENOBUFS may or may not be
# > raised; if it is raised, it will be reported to DatagramProtocol.error_received() but otherwise ignored.
# Source: https://docs.python.org/3/library/asyncio-protocol.html#flow-control-callbacks
# In case of congestion, decrease the maximum number of nodes to the 90% of the current value. # In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
if self.__n_max_neighbours < 200: if self.__n_max_neighbours < 200:
logging.warning("Maximum number of neighbours are now less than 200 due to congestion!") logging.warning("Max. number of neighbours are < 200 and there is still congestion! (check your network "
"connection if this message recurs)")
else: else:
self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10 self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10
logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours) logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours)
@ -103,6 +108,8 @@ class SybilNode:
async def tick_periodically(self) -> None: async def tick_periodically(self) -> None:
while True: while True:
await asyncio.sleep(TICK_INTERVAL) await asyncio.sleep(TICK_INTERVAL)
# Bootstrap (by querying the bootstrapping servers) ONLY IF the routing table is empty (i.e. we don't have
# any neighbours). Otherwise we'll increase the load on those central servers by querying them every second.
if not self._routing_table: if not self._routing_table:
await self.__bootstrap() await self.__bootstrap()
self.__make_neighbours() self.__make_neighbours()
@ -114,7 +121,8 @@ class SybilNode:
logging.debug("asyncio task count: %d", len(asyncio.Task.all_tasks())) logging.debug("asyncio task count: %d", len(asyncio.Task.all_tasks()))
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
# Ignore nodes that uses port 0 (assholes). # Ignore nodes that "uses" port 0, as we cannot communicate with them reliably across the different systems.
# See https://tools.cisco.com/security/center/viewAlert.x?alertId=19935 for slightly more details
if addr[1] == 0: if addr[1] == 0:
return return
@ -134,9 +142,9 @@ class SybilNode:
self.__on_ANNOUNCE_PEER_query(message, addr) self.__on_ANNOUNCE_PEER_query(message, addr)
async def shutdown(self) -> None: async def shutdown(self) -> None:
tasks = list(self.__parent_futures.values()) parent_futures = list(self.__parent_futures.values())
for t in tasks: for pf in parent_futures:
t.set_result(None) pf.set_result(None)
self._tick_task.cancel() self._tick_task.cancel()
await asyncio.wait([self._tick_task]) await asyncio.wait([self._tick_task])
self._transport.close() self._transport.close()
@ -158,7 +166,7 @@ class SybilNode:
# Add new found nodes to the routing table, assuring that we have no more than n_max_neighbours in total. # Add new found nodes to the routing table, assuring that we have no more than n_max_neighbours in total.
if len(self._routing_table) < self.__n_max_neighbours: if len(self._routing_table) < self.__n_max_neighbours:
self._routing_table.update(nodes) self._routing_table.update(nodes[:self.__n_max_neighbours - len(self._routing_table)])
def __on_GET_PEERS_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None: def __on_GET_PEERS_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None:
try: try:
@ -169,17 +177,15 @@ class SybilNode:
except (TypeError, KeyError, AssertionError): except (TypeError, KeyError, AssertionError):
return return
data = bencode.dumps({ data = self.__build_GET_PEERS_query(
b"y": b"r", info_hash[:15] + self.__true_id[:5], transaction_id, self.__calculate_token(addr, info_hash)
b"t": transaction_id, )
b"r": {
b"id": info_hash[:15] + self.__true_id[:5], # TODO:
b"nodes": b"", # We would like to prioritise GET_PEERS responses as they are the most fruitful ones, i.e., that leads to the
b"token": self.__calculate_token(addr, info_hash) # discovery of an info hash & metadata! But there is no easy way to do this with asyncio...
} # Maybe use priority queues to prioritise certain messages and let them accumulate, and dispatch them to the
}) # transport at every tick?
# we want to prioritise GET_PEERS responses as they are the most fruitful ones!
# but there is no easy way to do this with asyncio
self.sendto(data, addr) self.sendto(data, addr)
def __on_ANNOUNCE_PEER_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None: def __on_ANNOUNCE_PEER_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None:
@ -203,13 +209,7 @@ class SybilNode:
except (TypeError, KeyError, AssertionError): except (TypeError, KeyError, AssertionError):
return return
data = bencode.dumps({ data = self.__build_ANNOUNCE_PEER_query(node_id[:15] + self.__true_id[:5], transaction_id)
b"y": b"r",
b"t": transaction_id,
b"r": {
b"id": node_id[:15] + self.__true_id[:5]
}
})
self.sendto(data, addr) self.sendto(data, addr)
if implied_port: if implied_port:
@ -300,7 +300,7 @@ class SybilNode:
data = self.__build_FIND_NODE_query(self.__true_id) data = self.__build_FIND_NODE_query(self.__true_id)
self.sendto(data, sockaddr) self.sendto(data, sockaddr)
except Exception: except Exception:
logging.exception("bootstrap problem") logging.exception("An exception occurred during bootstrapping!")
def __make_neighbours(self) -> None: def __make_neighbours(self) -> None:
for node_id, addr in self._routing_table.items(): for node_id, addr in self._routing_table.items():
@ -308,7 +308,7 @@ class SybilNode:
@staticmethod @staticmethod
def __decode_nodes(infos: bytes) -> typing.List[typing.Tuple[NodeID, NodeAddress]]: def __decode_nodes(infos: bytes) -> typing.List[typing.Tuple[NodeID, NodeAddress]]:
""" REFERENCE IMPLEMENTATION """ Reference Implementation:
nodes = [] nodes = []
for i in range(0, len(infos), 26): for i in range(0, len(infos), 26):
info = infos[i: i + 26] info = infos[i: i + 26]
@ -318,8 +318,8 @@ class SybilNode:
nodes.append((node_id, (node_host, node_port))) nodes.append((node_id, (node_host, node_port)))
return nodes return nodes
""" """
""" Optimized Version: """
""" Optimized Version """ # Because dot-access also has a cost
inet_ntoa = socket.inet_ntoa inet_ntoa = socket.inet_ntoa
int_from_bytes = int.from_bytes int_from_bytes = int.from_bytes
return [ return [
@ -327,29 +327,60 @@ class SybilNode:
for i in range(0, len(infos), 26) for i in range(0, len(infos), 26)
] ]
def __calculate_token(self, addr: NodeAddress, info_hash: InfoHash): def __calculate_token(self, addr: NodeAddress, info_hash: InfoHash) -> bytes:
# Believe it or not, faster than using built-in hash (including conversion from int -> bytes of course) # Believe it or not, faster than using built-in hash (including conversion from int -> bytes of course)
return zlib.adler32(b"%s%s%d%s" % (self.__token_secret, socket.inet_aton(addr[0]), addr[1], info_hash)) checksum = zlib.adler32(b"%s%s%d%s" % (self.__token_secret, socket.inet_aton(addr[0]), addr[1], info_hash))
return checksum.to_bytes(4, "big")
@staticmethod @staticmethod
def __random_bytes(n: int) -> bytes: def __build_FIND_NODE_query(id_: bytes) -> bytes:
return os.urandom(n) """ Reference Implementation:
def __build_FIND_NODE_query(self, id_: bytes) -> bytes:
""" BENCODE IMPLEMENTATION
bencode.dumps({ bencode.dumps({
b"y": b"q", b"y": b"q",
b"q": b"find_node", b"q": b"find_node",
b"t": self.__random_bytes(2), b"t": b"aa",
b"a": { b"a": {
b"id": id_, b"id": id_,
b"target": self.__random_bytes(20) b"target": self.__random_bytes(20)
} }
}) })
""" """
""" Optimized Version: """
""" Optimized Version """
return b"d1:ad2:id20:%s6:target20:%se1:q9:find_node1:t2:aa1:y1:qe" % ( return b"d1:ad2:id20:%s6:target20:%se1:q9:find_node1:t2:aa1:y1:qe" % (
id_, id_,
self.__random_bytes(20) os.urandom(20)
) )
@staticmethod
def __build_GET_PEERS_query(id_: bytes, transaction_id: bytes, token: bytes) -> bytes:
""" Reference Implementation:
bencode.dumps({
b"y": b"r",
b"t": transaction_id,
b"r": {
b"id": info_hash[:15] + self.__true_id[:5],
b"nodes": b"",
b"token": self.__calculate_token(addr, info_hash)
}
})
"""
""" Optimized Version: """
return b"d1:rd2:id20:%s5:nodes0:5:token%d:%se1:t%d:%s1:y1:re" % (
id_, len(token), token, len(transaction_id), transaction_id
)
@staticmethod
def __build_ANNOUNCE_PEER_query(id_: bytes, transaction_id: bytes) -> bytes:
""" Reference Implementation:
bencode.dumps({
b"y": b"r",
b"t": transaction_id,
b"r": {
b"id": node_id[:15] + self.__true_id[:5]
}
})
"""
""" Optimized Version: """
return b"d1:rd2:id20:%se1:t%d:%s1:y1:re" % (id_, len(transaction_id), transaction_id)