Merge branch 'asyncio' of https://github.com/richardkiss/magnetico into richardkiss-asyncio

This commit is contained in:
Bora M. Alper 2017-05-30 12:45:17 +03:00
commit 04e6d583f3
4 changed files with 218 additions and 348 deletions

View File

@ -13,50 +13,38 @@
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see # You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>. # <http://www.gnu.org/licenses/>.
import argparse import argparse
import collections import asyncio
import functools
import logging import logging
import ipaddress import ipaddress
import selectors
import textwrap import textwrap
import urllib.parse import urllib.parse
import itertools
import os import os
import sys import sys
import time
import typing import typing
import appdirs import appdirs
import humanfriendly import humanfriendly
from .constants import TICK_INTERVAL, MAX_ACTIVE_PEERS_PER_INFO_HASH, DEFAULT_MAX_METADATA_SIZE from .constants import DEFAULT_MAX_METADATA_SIZE
from . import __version__ from . import __version__
from . import bittorrent
from . import dht from . import dht
from . import persistence from . import persistence
# Global variables are bad bla bla bla, BUT these variables are used so many times that I think it is justified; else
# the signatures of many functions are literally cluttered.
#
# If you are using a global variable, please always indicate that at the VERY BEGINNING of the function instead of right
# before using the variable for the first time.
selector = selectors.DefaultSelector()
database = None # type: persistence.Database
node = None
peers = collections.defaultdict(list) # type: typing.DefaultDict[dht.InfoHash, typing.List[bittorrent.DisposablePeer]]
# info hashes whose metadata is valid & complete (OR complete but deemed to be corrupt) so do NOT download them again:
complete_info_hashes = set()
def main(): def main():
global complete_info_hashes, database, node, peers, selector
arguments = parse_cmdline_arguments() arguments = parse_cmdline_arguments()
logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s") logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s")
logging.info("magneticod v%d.%d.%d started", *__version__) logging.info("magneticod v%d.%d.%d started", *__version__)
# use uvloop if it's installed
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logging.info("using uvloop")
except ImportError:
pass
# noinspection PyBroadException # noinspection PyBroadException
try: try:
path = arguments.database_file path = arguments.database_file
@ -67,106 +55,31 @@ def main():
complete_info_hashes = database.get_complete_info_hashes() complete_info_hashes = database.get_complete_info_hashes()
node = dht.SybilNode(arguments.node_addr) loop = asyncio.get_event_loop()
node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size)
loop.run_until_complete(node.launch(loop))
node.when_peer_found = lambda info_hash, peer_address: on_peer_found(info_hash=info_hash, watch_q_task = loop.create_task(watch_q(database, node._metadata_q))
peer_address=peer_address,
max_metadata_size=arguments.max_metadata_size)
selector.register(node, selectors.EVENT_READ)
try: try:
loop() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
logging.critical("Keyboard interrupt received! Exiting gracefully...") logging.critical("Keyboard interrupt received! Exiting gracefully...")
pass
finally: finally:
database.close() database.close()
selector.close() watch_q_task.cancel()
node.shutdown() loop.run_until_complete(node.shutdown())
for peer in itertools.chain.from_iterable(peers.values()): loop.run_until_complete(asyncio.wait([watch_q_task]))
peer.shutdown()
return 0 return 0
def on_peer_found(info_hash: dht.InfoHash, peer_address, max_metadata_size: int=DEFAULT_MAX_METADATA_SIZE) -> None: async def watch_q(database, q):
global selector, peers, complete_info_hashes
if len(peers[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH or info_hash in complete_info_hashes:
return
try:
peer = bittorrent.DisposablePeer(info_hash, peer_address, max_metadata_size)
except ConnectionError:
return
selector.register(peer, selectors.EVENT_READ | selectors.EVENT_WRITE)
peer.when_metadata_found = on_metadata_found
peer.when_error = functools.partial(on_peer_error, peer, info_hash)
peers[info_hash].append(peer)
def on_metadata_found(info_hash: dht.InfoHash, metadata: bytes) -> None:
global complete_info_hashes, database, peers, selector
succeeded = database.add_metadata(info_hash, metadata)
if not succeeded:
logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex())
# When we fetch the metadata of an info hash completely, shut down all other peers who are trying to do the same.
for peer in peers[info_hash]:
selector.unregister(peer)
peer.shutdown()
del peers[info_hash]
complete_info_hashes.add(info_hash)
def on_peer_error(peer: bittorrent.DisposablePeer, info_hash: dht.InfoHash) -> None:
global peers, selector
peer.shutdown()
peers[info_hash].remove(peer)
selector.unregister(peer)
# TODO:
# Consider whether time.monotonic() is a good choice. Maybe we should use CLOCK_MONOTONIC_RAW as its not affected by NTP
# adjustments, and all we need is how many seconds passed since a certain point in time.
def loop() -> None:
global selector, node, peers
t0 = time.monotonic()
while True: while True:
keys_and_events = selector.select(timeout=TICK_INTERVAL) info_hash, metadata = await q.get()
succeeded = database.add_metadata(info_hash, metadata)
# Check if it is time to tick if not succeeded:
delta = time.monotonic() - t0 logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex())
if delta >= TICK_INTERVAL:
if not (delta < 2 * TICK_INTERVAL):
logging.warning("Belated TICK! (Δ = %d)", delta)
node.on_tick()
for peer_list in peers.values():
for peer in peer_list:
peer.on_tick()
t0 = time.monotonic()
for key, events in keys_and_events:
if events & selectors.EVENT_READ:
key.fileobj.on_receivable()
if events & selectors.EVENT_WRITE:
key.fileobj.on_sendable()
# Check for entities that would like to write to their socket
keymap = selector.get_map()
for fd in keymap:
fileobj = keymap[fd].fileobj
if fileobj.would_send():
selector.modify(fileobj, selectors.EVENT_READ | selectors.EVENT_WRITE)
else:
selector.modify(fileobj, selectors.EVENT_READ)
def parse_ip_port(netloc) -> typing.Optional[typing.Tuple[str, int]]: def parse_ip_port(netloc) -> typing.Optional[typing.Tuple[str, int]]:

View File

@ -12,42 +12,38 @@
# #
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see # You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>. # <http://www.gnu.org/licenses/>.
import errno import asyncio
import logging import logging
import hashlib import hashlib
import math import math
import socket
import typing import typing
import os import os
from . import bencode from . import bencode
from .constants import DEFAULT_MAX_METADATA_SIZE
InfoHash = bytes InfoHash = bytes
PeerAddress = typing.Tuple[str, int] PeerAddress = typing.Tuple[str, int]
async def fetch_metadata(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size, timeout=None):
try:
return await asyncio.wait_for(DisposablePeer().run(
asyncio.get_event_loop(), info_hash, peer_addr, max_metadata_size), timeout=timeout)
except asyncio.TimeoutError:
return None
class ProtocolError(Exception):
pass
class DisposablePeer: class DisposablePeer:
def __init__(self, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int= DEFAULT_MAX_METADATA_SIZE): async def run(self, loop, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int):
self.__socket = socket.socket()
self.__socket.setblocking(False)
# To reduce the latency:
self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
if hasattr(socket, "TCP_QUICKACK"):
self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK, True)
res = self.__socket.connect_ex(peer_addr)
if res != errno.EINPROGRESS:
raise ConnectionError()
self.__peer_addr = peer_addr self.__peer_addr = peer_addr
self.__info_hash = info_hash self.__info_hash = info_hash
self.__max_metadata_size = max_metadata_size self.__max_metadata_size = max_metadata_size
self.__incoming_buffer = bytearray()
self.__outgoing_buffer = bytearray()
self.__bt_handshake_complete = False # BitTorrent Handshake self.__bt_handshake_complete = False # BitTorrent Handshake
self.__ext_handshake_complete = False # Extension Handshake self.__ext_handshake_complete = False # Extension Handshake
self.__ut_metadata = None # Since we don't know ut_metadata code that remote peer uses... self.__ut_metadata = None # Since we don't know ut_metadata code that remote peer uses...
@ -55,103 +51,46 @@ class DisposablePeer:
self.__metadata_size = None self.__metadata_size = None
self.__metadata_received = 0 # Amount of metadata bytes received... self.__metadata_received = 0 # Amount of metadata bytes received...
self.__metadata = None self.__metadata = None
self._run_task = None
# To prevent double shutdown self._metadata_future = loop.create_future()
self.__shutdown = False self._writer = None
# After 120 ticks passed, a peer should report an error and shut itself down due to being stall. try:
self.__ticks_passed = 0 self._reader, self._writer = await asyncio.open_connection(
self.__peer_addr[0], self.__peer_addr[1], loop=loop)
# Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message) # Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message)
self.__outgoing_buffer += b"\x13BitTorrent protocol%s%s%s" % ( self._writer.write(b"\x13BitTorrent protocol%s%s%s" % (
b"\x00\x00\x00\x00\x00\x10\x00\x01", b"\x00\x00\x00\x00\x00\x10\x00\x01",
self.__info_hash, self.__info_hash,
self.__random_bytes(20) self.__random_bytes(20)
) ))
# Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) badly
@staticmethod # designed protocols I have ever seen (I am 19 years old so what I could have seen?).
def when_error() -> None: #
raise NotImplementedError() # Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the
# size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one canonical
@staticmethod # way of handshaking in the wild.
def when_metadata_found(info_hash: InfoHash, metadata: bytes) -> None: message = await self._reader.readexactly(68)
raise NotImplementedError() if message[1:20] != b"BitTorrent protocol":
def on_tick(self):
self.__ticks_passed += 1
if self.__ticks_passed == 120:
logging.debug("Peer failed to fetch metadata in time for info hash %s!", self.__info_hash.hex())
self.when_error()
def on_receivable(self) -> None:
while True:
try:
received = self.__socket.recv(8192)
except BlockingIOError:
break
except ConnectionResetError:
self.when_error()
return
except ConnectionRefusedError:
self.when_error()
return
except OSError: # TODO: check for "no route to host 113" error
self.when_error()
return
if not received:
self.when_error()
return
self.__incoming_buffer += received
# Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) badly
# designed protocols I have ever seen (I am 19 years old so what I could have seen?).
#
# Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the
# size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one canonical
# way of handshaking in the wild.
if not self.__bt_handshake_complete:
if len(self.__incoming_buffer) < 68:
# We are still receiving the handshake...
return
if self.__incoming_buffer[1:20] != b"BitTorrent protocol":
# Erroneous handshake, possibly unknown version... # Erroneous handshake, possibly unknown version...
logging.debug("Erroneous BitTorrent handshake! %s", self.__incoming_buffer[:68]) raise ProtocolError("Erroneous BitTorrent handshake! %s" % message)
self.when_error()
return
self.__on_bt_handshake(self.__incoming_buffer[:68]) self.__on_bt_handshake(message)
self.__bt_handshake_complete = True while not self._metadata_future.done():
self.__incoming_buffer = self.__incoming_buffer[68:] buffer = await self._reader.readexactly(4)
length = int.from_bytes(buffer, "big")
while len(self.__incoming_buffer) >= 4: message = await self._reader.readexactly(length)
# Beware that while there are still messages in the incoming queue/buffer, one of previous messages might self.__on_message(message)
# have caused an error that necessitates us to quit. except Exception:
if self.__shutdown: logging.debug("closing %s to %s", self.__info_hash.hex(), self.__peer_addr)
break finally:
if not self._metadata_future.done():
length = int.from_bytes(self.__incoming_buffer[:4], "big") self._metadata_future.set_result(None)
if len(self.__incoming_buffer) - 4 < length: if self._writer:
# Message is still incoming... self._writer.close()
return return self._metadata_future.result()
self.__on_message(self.__incoming_buffer[4:4+length])
self.__incoming_buffer = self.__incoming_buffer[4+length:]
def on_sendable(self) -> None:
while self.__outgoing_buffer:
try:
n_sent = self.__socket.send(self.__outgoing_buffer)
assert n_sent
self.__outgoing_buffer = self.__outgoing_buffer[n_sent:]
except BlockingIOError:
break
except OSError:
# In case -while looping- on_sendable is called after socket is closed (mostly because of an error)
return
def __on_message(self, message: bytes) -> None: def __on_message(self, message: bytes) -> None:
length = len(message) length = len(message)
@ -191,10 +130,10 @@ class DisposablePeer:
# In case you cannot read_file hex: # In case you cannot read_file hex:
# 0x14 = 20 (BitTorrent ID indicating that it's an extended message) # 0x14 = 20 (BitTorrent ID indicating that it's an extended message)
# 0x00 = 0 (Extension ID indicating that it's the handshake message) # 0x00 = 0 (Extension ID indicating that it's the handshake message)
self.__outgoing_buffer += b"%s\x14\x00%s" % ( self._writer.write(b"%b\x14%s" % (
(2 + len(msg_dict_dump)).to_bytes(4, "big"), (2 + len(msg_dict_dump)).to_bytes(4, "big"),
msg_dict_dump b'\0' + msg_dict_dump
) ))
def __on_ext_handshake_message(self, message: bytes) -> None: def __on_ext_handshake_message(self, message: bytes) -> None:
if self.__ext_handshake_complete: if self.__ext_handshake_complete:
@ -217,21 +156,16 @@ class DisposablePeer:
" {} max metadata size".format(self.__peer_addr[0], " {} max metadata size".format(self.__peer_addr[0],
self.__peer_addr[1], self.__peer_addr[1],
self.__max_metadata_size) self.__max_metadata_size)
except KeyError:
self.when_error()
return
except AssertionError as e: except AssertionError as e:
logging.debug(str(e)) logging.debug(str(e))
self.when_error() raise
return
self.__ut_metadata = ut_metadata self.__ut_metadata = ut_metadata
try: try:
self.__metadata = bytearray(metadata_size) self.__metadata = bytearray(metadata_size)
except MemoryError: except MemoryError:
logging.exception("Could not allocate %.1f KiB for the metadata!", metadata_size / 1024) logging.exception("Could not allocate %.1f KiB for the metadata!", metadata_size / 1024)
self.when_error() raise
return
self.__metadata_size = metadata_size self.__metadata_size = metadata_size
self.__ext_handshake_complete = True self.__ext_handshake_complete = True
@ -268,13 +202,13 @@ class DisposablePeer:
if self.__metadata_received == self.__metadata_size: if self.__metadata_received == self.__metadata_size:
if hashlib.sha1(self.__metadata).digest() == self.__info_hash: if hashlib.sha1(self.__metadata).digest() == self.__info_hash:
self.when_metadata_found(self.__info_hash, bytes(self.__metadata)) if not self._metadata_future.done():
self._metadata_future.set_result(bytes(self.__metadata))
else: else:
logging.debug("Invalid Metadata! Ignoring.") logging.debug("Invalid Metadata! Ignoring.")
elif msg_type == 2: # reject elif msg_type == 2: # reject
logging.info("Peer rejected us.") logging.info("Peer rejected us.")
self.when_error()
def __request_metadata_piece(self, piece: int) -> None: def __request_metadata_piece(self, piece: int) -> None:
msg_dict_dump = bencode.dumps({ msg_dict_dump = bencode.dumps({
@ -284,29 +218,11 @@ class DisposablePeer:
# In case you cannot read_file hex: # In case you cannot read_file hex:
# 0x14 = 20 (BitTorrent ID indicating that it's an extended message) # 0x14 = 20 (BitTorrent ID indicating that it's an extended message)
# 0x03 = 3 (Extension ID indicating that it's an ut_metadata message) # 0x03 = 3 (Extension ID indicating that it's an ut_metadata message)
self.__outgoing_buffer += b"%b\x14%s%s" % ( self._writer.write(b"%b\x14%s%s" % (
(2 + len(msg_dict_dump)).to_bytes(4, "big"), (2 + len(msg_dict_dump)).to_bytes(4, "big"),
self.__ut_metadata.to_bytes(1, "big"), self.__ut_metadata.to_bytes(1, "big"),
msg_dict_dump msg_dict_dump
) ))
def shutdown(self) -> None:
if self.__shutdown:
return
try:
self.__socket.shutdown(socket.SHUT_RDWR)
except OSError:
# OSError might be raised in case the connection to the remote peer fails: nevertheless, when_error should
# be called, and the supervisor will try to shutdown the peer, and ta da: OSError!
pass
self.__socket.close()
self.__shutdown = True
def would_send(self) -> bool:
return bool(len(self.__outgoing_buffer))
def fileno(self) -> int:
return self.__socket.fileno()
@staticmethod @staticmethod
def __random_bytes(n: int) -> bytes: def __random_bytes(n: int) -> bytes:

View File

@ -6,6 +6,7 @@ BOOTSTRAPPING_NODES = [
] ]
PENDING_INFO_HASHES = 10 # threshold for pending info hashes before being committed to database: PENDING_INFO_HASHES = 10 # threshold for pending info hashes before being committed to database:
TICK_INTERVAL = 1 # in seconds (soft constraint)
# maximum (inclusive) number of active (disposable) peers to fetch the metadata per info hash at the same time: # maximum (inclusive) number of active (disposable) peers to fetch the metadata per info hash at the same time:
MAX_ACTIVE_PEERS_PER_INFO_HASH = 5 MAX_ACTIVE_PEERS_PER_INFO_HASH = 5
PEER_TIMEOUT=120 # seconds

View File

@ -12,16 +12,17 @@
# #
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see # You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>. # <http://www.gnu.org/licenses/>.
import array import asyncio
import collections import itertools
import zlib import zlib
import logging import logging
import socket import socket
import typing import typing
import os import os
from .constants import BOOTSTRAPPING_NODES, DEFAULT_MAX_METADATA_SIZE from .constants import BOOTSTRAPPING_NODES, MAX_ACTIVE_PEERS_PER_INFO_HASH, PEER_TIMEOUT
from . import bencode from . import bencode
from . import bittorrent
NodeID = bytes NodeID = bytes
NodeAddress = typing.Tuple[str, int] NodeAddress = typing.Tuple[str, int]
@ -30,107 +31,95 @@ InfoHash = bytes
class SybilNode: class SybilNode:
def __init__(self, address: typing.Tuple[str, int]): def __init__(self, address: typing.Tuple[str, int], complete_info_hashes, max_metadata_size):
self.__true_id = self.__random_bytes(20) self.__true_id = self.__random_bytes(20)
self.__socket = socket.socket(type=socket.SOCK_DGRAM) self.__address = address
self.__socket.bind(address)
self.__socket.setblocking(False)
self.__incoming_buffer = array.array("B", (0 for _ in range(65536))) self._routing_table = {} # type: typing.Dict[NodeID, NodeAddress]
self.__outgoing_queue = collections.deque()
self.__routing_table = {} # type: typing.Dict[NodeID, NodeAddress]
self.__token_secret = self.__random_bytes(4) self.__token_secret = self.__random_bytes(4)
# Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will # Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will
# stop; but until then, the total number of neighbours might exceed the threshold). # stop; but until then, the total number of neighbours might exceed the threshold).
self.__n_max_neighbours = 2000 self.__n_max_neighbours = 2000
self.__tasks = {} # type: typing.Dict[dht.InfoHash, asyncio.Future]
self._complete_info_hashes = complete_info_hashes
self.__max_metadata_size = max_metadata_size
self._metadata_q = asyncio.Queue()
self._is_paused = False
self._tick_task = None
logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address) logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address)
@staticmethod async def launch(self, loop):
def when_peer_found(info_hash: InfoHash, peer_addr: PeerAddress) -> None: self._loop = loop
raise NotImplementedError() await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address)
def on_tick(self) -> None: def connection_made(self, transport):
self.__bootstrap() self._tick_task = self._loop.create_task(self.on_tick())
self.__make_neighbours() self._transport = transport
self.__routing_table.clear()
def on_receivable(self) -> None: def connection_lost(self, exc):
buffer = self.__incoming_buffer self._is_paused = True
while True:
try:
_, addr = self.__socket.recvfrom_into(buffer, 65536)
data = buffer.tobytes()
except BlockingIOError:
break
except ConnectionResetError:
continue
except ConnectionRefusedError:
continue
# Ignore nodes that uses port 0 (assholes). def pause_writing(self):
if addr[1] == 0: self._is_paused = True
continue
try: def resume_writing(self):
message = bencode.loads(data) self._is_paused = False
except bencode.BencodeDecodingError:
continue
if isinstance(message.get(b"r"), dict) and type(message[b"r"].get(b"nodes")) is bytes: def sendto(self, data, addr):
self.__on_FIND_NODE_response(message) if self._is_paused:
elif message.get(b"q") == b"get_peers": return
self.__on_GET_PEERS_query(message, addr) self._transport.sendto(data, addr)
elif message.get(b"q") == b"announce_peer":
self.__on_ANNOUNCE_PEER_query(message, addr)
def on_sendable(self) -> None: def error_received(self, exc):
congestion = None logging.error("got error %s", exc)
while True: if isinstance(exc, PermissionError):
try:
addr, data = self.__outgoing_queue.pop()
except IndexError:
break
try:
self.__socket.sendto(data, addr)
except BlockingIOError:
self.__outgoing_queue.appendleft((addr, data))
break
except PermissionError:
# This exception (EPERM errno: 1) is kernel's way of saying that "you are far too fast, chill".
# It is also likely that we have received a ICMP source quench packet (meaning, that we really need to
# slow down.
#
# Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto
# -amp-amp-EPERM.html
congestion = True
break
except OSError:
# Pass in case of trying to send to port 0 (it is much faster to catch exceptions than using an
# if-statement).
pass
if congestion:
self.__outgoing_queue.clear()
# In case of congestion, decrease the maximum number of nodes to the 90% of the current value. # In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
if self.__n_max_neighbours < 200: if self.__n_max_neighbours < 200:
logging.warning("Maximum number of neighbours are now less than 200 due to congestion!") logging.warning("Maximum number of neighbours are now less than 200 due to congestion!")
else: else:
self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10 self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10
else: logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours)
# In case of the lack of congestion, increase the maximum number of nodes by 1%.
self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100
def would_send(self) -> bool: async def on_tick(self) -> None:
""" Whether node is waiting to write on its socket or not. """ while True:
return bool(self.__outgoing_queue) await asyncio.sleep(1)
if len(self._routing_table) == 0:
await self.__bootstrap()
self.__make_neighbours()
self._routing_table.clear()
if not self._is_paused:
self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100
def shutdown(self) -> None: def datagram_received(self, data, addr) -> None:
self.__socket.close() # Ignore nodes that uses port 0 (assholes).
if addr[1] == 0:
return
if self._transport.is_closing():
return
try:
message = bencode.loads(data)
except bencode.BencodeDecodingError:
return
if isinstance(message.get(b"r"), dict) and type(message[b"r"].get(b"nodes")) is bytes:
self.__on_FIND_NODE_response(message)
elif message.get(b"q") == b"get_peers":
self.__on_GET_PEERS_query(message, addr)
elif message.get(b"q") == b"announce_peer":
self.__on_ANNOUNCE_PEER_query(message, addr)
async def shutdown(self) -> None:
tasks = list(self.__tasks.values())
for t in tasks:
t.set_result(None)
self._tick_task.cancel()
await asyncio.wait([self._tick_task])
self._transport.close()
def __on_FIND_NODE_response(self, message: bencode.KRPCDict) -> None: def __on_FIND_NODE_response(self, message: bencode.KRPCDict) -> None:
try: try:
@ -145,8 +134,8 @@ class SybilNode:
return return
# Add new found nodes to the routing table, assuring that we have no more than n_max_neighbours in total. # Add new found nodes to the routing table, assuring that we have no more than n_max_neighbours in total.
if len(self.__routing_table) < self.__n_max_neighbours: if len(self._routing_table) < self.__n_max_neighbours:
self.__routing_table.update(nodes) self._routing_table.update(nodes)
def __on_GET_PEERS_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None: def __on_GET_PEERS_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None:
try: try:
@ -157,8 +146,7 @@ class SybilNode:
except (TypeError, KeyError, AssertionError): except (TypeError, KeyError, AssertionError):
return return
# appendleft to prioritise GET_PEERS responses as they are the most fruitful ones! data = bencode.dumps({
self.__outgoing_queue.appendleft((addr, bencode.dumps({
b"y": b"r", b"y": b"r",
b"t": transaction_id, b"t": transaction_id,
b"r": { b"r": {
@ -166,7 +154,10 @@ class SybilNode:
b"nodes": b"", b"nodes": b"",
b"token": self.__calculate_token(addr, info_hash) b"token": self.__calculate_token(addr, info_hash)
} }
}))) })
# we want to prioritise GET_PEERS responses as they are the most fruitful ones!
# but there is no easy way to do this with asyncio
self.sendto(data, addr)
def __on_ANNOUNCE_PEER_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None: def __on_ANNOUNCE_PEER_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None:
try: try:
@ -189,31 +180,80 @@ class SybilNode:
except (TypeError, KeyError, AssertionError): except (TypeError, KeyError, AssertionError):
return return
self.__outgoing_queue.append((addr, bencode.dumps({ data = bencode.dumps({
b"y": b"r", b"y": b"r",
b"t": transaction_id, b"t": transaction_id,
b"r": { b"r": {
b"id": node_id[:15] + self.__true_id[:5] b"id": node_id[:15] + self.__true_id[:5]
} }
}))) })
self.sendto(data, addr)
if implied_port: if implied_port:
peer_addr = (addr[0], addr[1]) peer_addr = (addr[0], addr[1])
else: else:
peer_addr = (addr[0], port) peer_addr = (addr[0], port)
self.when_peer_found(info_hash, peer_addr) if info_hash in self._complete_info_hashes:
return
def fileno(self) -> int: # create the parent future
return self.__socket.fileno() if info_hash not in self.__tasks:
parent_f = self._loop.create_future()
parent_f.child_count = 0
parent_f.add_done_callback(lambda f: self._parent_task_done(f, info_hash))
self.__tasks[info_hash] = parent_f
def __bootstrap(self) -> None: parent_f = self.__tasks[info_hash]
for addr in BOOTSTRAPPING_NODES:
self.__outgoing_queue.append((addr, self.__build_FIND_NODE_query(self.__true_id))) if parent_f.done():
return
if parent_f.child_count > MAX_ACTIVE_PEERS_PER_INFO_HASH:
return
task = asyncio.ensure_future(bittorrent.fetch_metadata(
info_hash, peer_addr, self.__max_metadata_size, timeout=PEER_TIMEOUT))
task.add_done_callback(lambda task: self._got_child_result(parent_f, task))
parent_f.child_count += 1
parent_f.add_done_callback(lambda f: task.cancel())
def _got_child_result(self, parent_task, child_task):
parent_task.child_count -= 1
try:
metadata = child_task.result()
if metadata and not parent_task.done():
parent_task.set_result(metadata)
except asyncio.CancelledError:
pass
except Exception:
logging.exception("child result is exception")
if parent_task.child_count <= 0 and not parent_task.done():
parent_task.set_result(None)
def _parent_task_done(self, parent_task, info_hash):
try:
metadata = parent_task.result()
if metadata:
self._complete_info_hashes.add(info_hash)
self._metadata_q.put_nowait((info_hash, metadata))
except asyncio.CancelledError:
pass
del self.__tasks[info_hash]
async def __bootstrap(self) -> None:
for node in BOOTSTRAPPING_NODES:
try:
# AF_INET means ip4 only
responses = await self._loop.getaddrinfo(*node, family=socket.AF_INET)
for (family, type, proto, canonname, sockaddr) in responses:
data = self.__build_FIND_NODE_query(self.__true_id)
self.sendto(data, sockaddr)
except Exception:
logging.exception("bootstrap problem")
def __make_neighbours(self) -> None: def __make_neighbours(self) -> None:
for node_id, addr in self.__routing_table.items(): for node_id, addr in self._routing_table.items():
self.__outgoing_queue.append((addr, self.__build_FIND_NODE_query(node_id[:15] + self.__true_id[:5]))) self.sendto(self.__build_FIND_NODE_query(node_id[:15] + self.__true_id[:5]), addr)
@staticmethod @staticmethod
def __decode_nodes(infos: bytes) -> typing.List[typing.Tuple[NodeID, NodeAddress]]: def __decode_nodes(infos: bytes) -> typing.List[typing.Tuple[NodeID, NodeAddress]]: