First crack at porting to asyncio.

This commit is contained in:
Richard Kiss 2017-05-13 22:25:34 -07:00
parent a3adf88b45
commit 8da8d20b53
3 changed files with 140 additions and 320 deletions

View File

@ -13,45 +13,25 @@
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>.
import argparse
import collections
import functools
import asyncio
import logging
import ipaddress
import selectors
import textwrap
import urllib.parse
import itertools
import os
import sys
import time
import typing
import appdirs
import humanfriendly
from .constants import TICK_INTERVAL, MAX_ACTIVE_PEERS_PER_INFO_HASH, DEFAULT_MAX_METADATA_SIZE
from .constants import DEFAULT_MAX_METADATA_SIZE
from . import __version__
from . import bittorrent
from . import dht
from . import persistence
# Global variables are bad bla bla bla, BUT these variables are used so many times that I think it is justified; else
# the signatures of many functions are literally cluttered.
#
# If you are using a global variable, please always indicate that at the VERY BEGINNING of the function instead of right
# before using the variable for the first time.
selector = selectors.DefaultSelector()
database = None # type: persistence.Database
node = None
peers = collections.defaultdict(list) # type: typing.DefaultDict[dht.InfoHash, typing.List[bittorrent.DisposablePeer]]
# info hashes whose metadata is valid & complete (OR complete but deemed to be corrupt) so do NOT download them again:
complete_info_hashes = set()
def main():
global complete_info_hashes, database, node, peers, selector
arguments = parse_cmdline_arguments()
logging.basicConfig(level=arguments.loglevel, format="%(asctime)s %(levelname)-8s %(message)s")
@ -67,106 +47,30 @@ def main():
complete_info_hashes = database.get_complete_info_hashes()
node = dht.SybilNode(arguments.node_addr)
loop = asyncio.get_event_loop()
node = dht.SybilNode(arguments.node_addr, complete_info_hashes, arguments.max_metadata_size)
loop.run_until_complete(node.launch(loop))
node.when_peer_found = lambda info_hash, peer_address: on_peer_found(info_hash=info_hash,
peer_address=peer_address,
max_metadata_size=arguments.max_metadata_size)
selector.register(node, selectors.EVENT_READ)
loop.create_task(watch_q(database, node._metadata_q))
try:
loop()
loop.run_forever()
except KeyboardInterrupt:
logging.critical("Keyboard interrupt received! Exiting gracefully...")
pass
finally:
database.close()
selector.close()
node.shutdown()
for peer in itertools.chain.from_iterable(peers.values()):
peer.shutdown()
return 0
def on_peer_found(info_hash: dht.InfoHash, peer_address, max_metadata_size: int=DEFAULT_MAX_METADATA_SIZE) -> None:
global selector, peers, complete_info_hashes
if len(peers[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH or info_hash in complete_info_hashes:
return
try:
peer = bittorrent.DisposablePeer(info_hash, peer_address, max_metadata_size)
except ConnectionError:
return
selector.register(peer, selectors.EVENT_READ | selectors.EVENT_WRITE)
peer.when_metadata_found = on_metadata_found
peer.when_error = functools.partial(on_peer_error, peer, info_hash)
peers[info_hash].append(peer)
def on_metadata_found(info_hash: dht.InfoHash, metadata: bytes) -> None:
global complete_info_hashes, database, peers, selector
succeeded = database.add_metadata(info_hash, metadata)
if not succeeded:
logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex())
# When we fetch the metadata of an info hash completely, shut down all other peers who are trying to do the same.
for peer in peers[info_hash]:
selector.unregister(peer)
peer.shutdown()
del peers[info_hash]
complete_info_hashes.add(info_hash)
def on_peer_error(peer: bittorrent.DisposablePeer, info_hash: dht.InfoHash) -> None:
global peers, selector
peer.shutdown()
peers[info_hash].remove(peer)
selector.unregister(peer)
# TODO:
# Consider whether time.monotonic() is a good choice. Maybe we should use CLOCK_MONOTONIC_RAW as its not affected by NTP
# adjustments, and all we need is how many seconds passed since a certain point in time.
def loop() -> None:
global selector, node, peers
t0 = time.monotonic()
async def watch_q(database, q):
while True:
keys_and_events = selector.select(timeout=TICK_INTERVAL)
# Check if it is time to tick
delta = time.monotonic() - t0
if delta >= TICK_INTERVAL:
if not (delta < 2 * TICK_INTERVAL):
logging.warning("Belated TICK! (Δ = %d)", delta)
node.on_tick()
for peer_list in peers.values():
for peer in peer_list:
peer.on_tick()
t0 = time.monotonic()
for key, events in keys_and_events:
if events & selectors.EVENT_READ:
key.fileobj.on_receivable()
if events & selectors.EVENT_WRITE:
key.fileobj.on_sendable()
# Check for entities that would like to write to their socket
keymap = selector.get_map()
for fd in keymap:
fileobj = keymap[fd].fileobj
if fileobj.would_send():
selector.modify(fileobj, selectors.EVENT_READ | selectors.EVENT_WRITE)
else:
selector.modify(fileobj, selectors.EVENT_READ)
info_hash, metadata = await q.get()
succeeded = database.add_metadata(info_hash, metadata)
if not succeeded:
logging.info("Corrupt metadata for %s! Ignoring.", info_hash.hex())
def parse_ip_port(netloc) -> typing.Optional[typing.Tuple[str, int]]:

View File

@ -12,11 +12,10 @@
#
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>.
import errno
import asyncio
import logging
import hashlib
import math
import socket
import typing
import os
@ -27,27 +26,20 @@ InfoHash = bytes
PeerAddress = typing.Tuple[str, int]
async def get_torrent_data(info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size):
loop = asyncio.get_event_loop()
peer = DisposablePeer(info_hash, peer_addr, max_metadata_size)
r = await peer.launch(loop)
return r
class DisposablePeer:
def __init__(self, info_hash: InfoHash, peer_addr: PeerAddress, max_metadata_size: int= DEFAULT_MAX_METADATA_SIZE):
self.__socket = socket.socket()
self.__socket.setblocking(False)
# To reduce the latency:
self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
if hasattr(socket, "TCP_QUICKACK"):
self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK, True)
res = self.__socket.connect_ex(peer_addr)
if res != errno.EINPROGRESS:
raise ConnectionError()
self.__peer_addr = peer_addr
self.__info_hash = info_hash
self.__max_metadata_size = max_metadata_size
self.__incoming_buffer = bytearray()
self.__outgoing_buffer = bytearray()
self.__bt_handshake_complete = False # BitTorrent Handshake
self.__ext_handshake_complete = False # Extension Handshake
self.__ut_metadata = None # Since we don't know ut_metadata code that remote peer uses...
@ -56,102 +48,55 @@ class DisposablePeer:
self.__metadata_received = 0 # Amount of metadata bytes received...
self.__metadata = None
# To prevent double shutdown
self.__shutdown = False
# After 120 ticks passed, a peer should report an error and shut itself down due to being stall.
self.__ticks_passed = 0
async def launch(self, loop):
self._loop = loop
self._metadata_future = self._loop.create_future()
self._reader, self._writer = await asyncio.open_connection(
self.__peer_addr[0], self.__peer_addr[1], loop=loop)
# Send the BitTorrent handshake message (0x13 = 19 in decimal, the length of the handshake message)
self.__outgoing_buffer += b"\x13BitTorrent protocol%s%s%s" % (
self._writer.write(b"\x13BitTorrent protocol%s%s%s" % (
b"\x00\x00\x00\x00\x00\x10\x00\x01",
self.__info_hash,
self.__random_bytes(20)
)
@staticmethod
def when_error() -> None:
raise NotImplementedError()
@staticmethod
def when_metadata_found(info_hash: InfoHash, metadata: bytes) -> None:
raise NotImplementedError()
def on_tick(self):
self.__ticks_passed += 1
if self.__ticks_passed == 120:
logging.debug("Peer failed to fetch metadata in time for info hash %s!", self.__info_hash.hex())
self.when_error()
def on_receivable(self) -> None:
while True:
try:
received = self.__socket.recv(8192)
except BlockingIOError:
break
except ConnectionResetError:
self.when_error()
return
except ConnectionRefusedError:
self.when_error()
return
except OSError: # TODO: check for "no route to host 113" error
self.when_error()
return
if not received:
self.when_error()
return
self.__incoming_buffer += received
))
# Honestly speaking, BitTorrent protocol might be one of the most poorly documented and (not the most but) badly
# designed protocols I have ever seen (I am 19 years old so what I could have seen?).
#
# Anyway, all the messages EXCEPT the handshake are length-prefixed by 4 bytes in network order, BUT the
# size of the handshake message is the 1-byte length prefix + 49 bytes, but luckily, there is only one canonical
# way of handshaking in the wild.
if not self.__bt_handshake_complete:
if len(self.__incoming_buffer) < 68:
# We are still receiving the handshake...
return
message = await self._reader.readexactly(68)
if message[1:20] != b"BitTorrent protocol":
# Erroneous handshake, possibly unknown version...
logging.debug("Erroneous BitTorrent handshake! %s", message)
self.close()
return
if self.__incoming_buffer[1:20] != b"BitTorrent protocol":
# Erroneous handshake, possibly unknown version...
logging.debug("Erroneous BitTorrent handshake! %s", self.__incoming_buffer[:68])
self.when_error()
return
self.__on_bt_handshake(message)
self.__on_bt_handshake(self.__incoming_buffer[:68])
try:
while not self._metadata_future.done():
buffer = await self._reader.readexactly(4)
length = int.from_bytes(buffer, "big")
message = await self._reader.readexactly(length)
self.__on_message(message)
except Exception as ex:
self.close()
return await self._metadata_future
self.__bt_handshake_complete = True
self.__incoming_buffer = self.__incoming_buffer[68:]
def when_metadata_found(self, info_hash: InfoHash, metadata: bytes) -> None:
self._metadata_future.set_result((info_hash, metadata))
self.close()
while len(self.__incoming_buffer) >= 4:
# Beware that while there are still messages in the incoming queue/buffer, one of previous messages might
# have caused an error that necessitates us to quit.
if self.__shutdown:
break
length = int.from_bytes(self.__incoming_buffer[:4], "big")
if len(self.__incoming_buffer) - 4 < length:
# Message is still incoming...
return
self.__on_message(self.__incoming_buffer[4:4+length])
self.__incoming_buffer = self.__incoming_buffer[4+length:]
def on_sendable(self) -> None:
while self.__outgoing_buffer:
try:
n_sent = self.__socket.send(self.__outgoing_buffer)
assert n_sent
self.__outgoing_buffer = self.__outgoing_buffer[n_sent:]
except BlockingIOError:
break
except OSError:
# In case -while looping- on_sendable is called after socket is closed (mostly because of an error)
return
def close(self):
self._writer.close()
if not self._metadata_future.done():
self._metadata_future.set_result(None)
def __on_message(self, message: bytes) -> None:
length = len(message)
@ -191,10 +136,10 @@ class DisposablePeer:
# In case you cannot read_file hex:
# 0x14 = 20 (BitTorrent ID indicating that it's an extended message)
# 0x00 = 0 (Extension ID indicating that it's the handshake message)
self.__outgoing_buffer += b"%s\x14\x00%s" % (
self._writer.write(b"%b\x14%s" % (
(2 + len(msg_dict_dump)).to_bytes(4, "big"),
msg_dict_dump
)
b'\0' + msg_dict_dump
))
def __on_ext_handshake_message(self, message: bytes) -> None:
if self.__ext_handshake_complete:
@ -284,29 +229,11 @@ class DisposablePeer:
# In case you cannot read_file hex:
# 0x14 = 20 (BitTorrent ID indicating that it's an extended message)
# 0x03 = 3 (Extension ID indicating that it's an ut_metadata message)
self.__outgoing_buffer += b"%b\x14%s%s" % (
self._writer.write(b"%b\x14%s%s" % (
(2 + len(msg_dict_dump)).to_bytes(4, "big"),
self.__ut_metadata.to_bytes(1, "big"),
msg_dict_dump
)
def shutdown(self) -> None:
if self.__shutdown:
return
try:
self.__socket.shutdown(socket.SHUT_RDWR)
except OSError:
# OSError might be raised in case the connection to the remote peer fails: nevertheless, when_error should
# be called, and the supervisor will try to shutdown the peer, and ta da: OSError!
pass
self.__socket.close()
self.__shutdown = True
def would_send(self) -> bool:
return bool(len(self.__outgoing_buffer))
def fileno(self) -> int:
return self.__socket.fileno()
))
@staticmethod
def __random_bytes(n: int) -> bytes:

View File

@ -13,15 +13,18 @@
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>.
import array
import asyncio
import collections
import itertools
import zlib
import logging
import socket
import typing
import os
from .constants import BOOTSTRAPPING_NODES, DEFAULT_MAX_METADATA_SIZE
from .constants import BOOTSTRAPPING_NODES, DEFAULT_MAX_METADATA_SIZE, MAX_ACTIVE_PEERS_PER_INFO_HASH
from . import bencode
from . import bittorrent
NodeID = bytes
NodeAddress = typing.Tuple[str, int]
@ -30,107 +33,77 @@ InfoHash = bytes
class SybilNode:
def __init__(self, address: typing.Tuple[str, int]):
def __init__(self, address: typing.Tuple[str, int], complete_info_hashes, max_metadata_size):
self.__true_id = self.__random_bytes(20)
self.__socket = socket.socket(type=socket.SOCK_DGRAM)
self.__socket.bind(address)
self.__socket.setblocking(False)
self.__address = address
self.__incoming_buffer = array.array("B", (0 for _ in range(65536)))
self.__outgoing_queue = collections.deque()
self.__routing_table = {} # type: typing.Dict[NodeID, NodeAddress]
self._routing_table = {} # type: typing.Dict[NodeID, NodeAddress]
self.__token_secret = self.__random_bytes(4)
# Maximum number of neighbours (this is a THRESHOLD where, once reached, the search for new neighbours will
# stop; but until then, the total number of neighbours might exceed the threshold).
self.__n_max_neighbours = 2000
self.__peers = collections.defaultdict(
list) # type: typing.DefaultDict[dht.InfoHash, typing.List[bittorrent.DisposablePeer]]
self._complete_info_hashes = complete_info_hashes
self.__max_metadata_size = max_metadata_size
self._metadata_q = asyncio.Queue()
logging.info("SybilNode %s on %s initialized!", self.__true_id.hex().upper(), address)
@staticmethod
def when_peer_found(info_hash: InfoHash, peer_addr: PeerAddress) -> None:
raise NotImplementedError()
async def launch(self, loop):
self._loop = loop
await loop.create_datagram_endpoint(lambda: self, local_addr=self.__address)
def on_tick(self) -> None:
self.__bootstrap()
self.__make_neighbours()
self.__routing_table.clear()
def connection_made(self, transport):
self._loop.create_task(self.on_tick())
self._loop.create_task(self.increase_neighbour_task())
self._transport = transport
def on_receivable(self) -> None:
buffer = self.__incoming_buffer
while True:
try:
_, addr = self.__socket.recvfrom_into(buffer, 65536)
data = buffer.tobytes()
except BlockingIOError:
break
except ConnectionResetError:
continue
except ConnectionRefusedError:
continue
# Ignore nodes that uses port 0 (assholes).
if addr[1] == 0:
continue
try:
message = bencode.loads(data)
except bencode.BencodeDecodingError:
continue
if isinstance(message.get(b"r"), dict) and type(message[b"r"].get(b"nodes")) is bytes:
self.__on_FIND_NODE_response(message)
elif message.get(b"q") == b"get_peers":
self.__on_GET_PEERS_query(message, addr)
elif message.get(b"q") == b"announce_peer":
self.__on_ANNOUNCE_PEER_query(message, addr)
def on_sendable(self) -> None:
congestion = None
while True:
try:
addr, data = self.__outgoing_queue.pop()
except IndexError:
break
try:
self.__socket.sendto(data, addr)
except BlockingIOError:
self.__outgoing_queue.appendleft((addr, data))
break
except PermissionError:
# This exception (EPERM errno: 1) is kernel's way of saying that "you are far too fast, chill".
# It is also likely that we have received a ICMP source quench packet (meaning, that we really need to
# slow down.
#
# Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto
# -amp-amp-EPERM.html
congestion = True
break
except OSError:
# Pass in case of trying to send to port 0 (it is much faster to catch exceptions than using an
# if-statement).
pass
if congestion:
self.__outgoing_queue.clear()
def error_received(self, exc):
logging.error("got error %s", exc)
if isinstance(exc, PermissionError):
# In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
if self.__n_max_neighbours < 200:
logging.warning("Maximum number of neighbours are now less than 200 due to congestion!")
else:
self.__n_max_neighbours = self.__n_max_neighbours * 9 // 10
else:
# In case of the lack of congestion, increase the maximum number of nodes by 1%.
logging.debug("Maximum number of neighbours now %d", self.__n_max_neighbours)
async def on_tick(self) -> None:
while True:
await asyncio.sleep(1)
self.__bootstrap()
self.__make_neighbours()
self._routing_table.clear()
def datagram_received(self, data, addr) -> None:
# Ignore nodes that uses port 0 (assholes).
if addr[1] == 0:
return
try:
message = bencode.loads(data)
except bencode.BencodeDecodingError:
return
if isinstance(message.get(b"r"), dict) and type(message[b"r"].get(b"nodes")) is bytes:
self.__on_FIND_NODE_response(message)
elif message.get(b"q") == b"get_peers":
self.__on_GET_PEERS_query(message, addr)
elif message.get(b"q") == b"announce_peer":
self.__on_ANNOUNCE_PEER_query(message, addr)
async def increase_neighbour_task(self):
while True:
await asyncio.sleep(10)
self.__n_max_neighbours = self.__n_max_neighbours * 101 // 100
def would_send(self) -> bool:
""" Whether node is waiting to write on its socket or not. """
return bool(self.__outgoing_queue)
def shutdown(self) -> None:
self.__socket.close()
for peer in itertools.chain.from_iterable(self.__peers.values()):
peer.close()
self._transport.close()
def __on_FIND_NODE_response(self, message: bencode.KRPCDict) -> None:
try:
@ -145,8 +118,8 @@ class SybilNode:
return
# Add new found nodes to the routing table, assuring that we have no more than n_max_neighbours in total.
if len(self.__routing_table) < self.__n_max_neighbours:
self.__routing_table.update(nodes)
if len(self._routing_table) < self.__n_max_neighbours:
self._routing_table.update(nodes)
def __on_GET_PEERS_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None:
try:
@ -157,8 +130,7 @@ class SybilNode:
except (TypeError, KeyError, AssertionError):
return
# appendleft to prioritise GET_PEERS responses as they are the most fruitful ones!
self.__outgoing_queue.appendleft((addr, bencode.dumps({
data = bencode.dumps({
b"y": b"r",
b"t": transaction_id,
b"r": {
@ -166,7 +138,10 @@ class SybilNode:
b"nodes": b"",
b"token": self.__calculate_token(addr, info_hash)
}
})))
})
# we want to prioritise GET_PEERS responses as they are the most fruitful ones!
# but there is no easy way to do this with asyncio
self._transport.sendto(data, addr)
def __on_ANNOUNCE_PEER_query(self, message: bencode.KRPCDict, addr: NodeAddress) -> None:
try:
@ -189,31 +164,45 @@ class SybilNode:
except (TypeError, KeyError, AssertionError):
return
self.__outgoing_queue.append((addr, bencode.dumps({
data = bencode.dumps({
b"y": b"r",
b"t": transaction_id,
b"r": {
b"id": node_id[:15] + self.__true_id[:5]
}
})))
})
self._transport.sendto(data, addr)
if implied_port:
peer_addr = (addr[0], addr[1])
else:
peer_addr = (addr[0], port)
self.when_peer_found(info_hash, peer_addr)
if len(self.__peers[info_hash]) > MAX_ACTIVE_PEERS_PER_INFO_HASH or \
info_hash in self._complete_info_hashes:
return
def fileno(self) -> int:
return self.__socket.fileno()
peer = bittorrent.get_torrent_data(info_hash, peer_addr, self.__max_metadata_size)
self.__peers[info_hash].append(peer)
self._loop.create_task(peer).add_done_callback(self.metadata_found)
def metadata_found(self, future):
r = future.result()
if r:
info_hash, metadata = r
for peer in self.__peers[info_hash]:
peer.close()
self._metadata_q.put_nowait(r)
self._complete_info_hashes.add(info_hash)
def __bootstrap(self) -> None:
for addr in BOOTSTRAPPING_NODES:
self.__outgoing_queue.append((addr, self.__build_FIND_NODE_query(self.__true_id)))
data = self.__build_FIND_NODE_query(self.__true_id)
self._transport.sendto(data, addr)
def __make_neighbours(self) -> None:
for node_id, addr in self.__routing_table.items():
self.__outgoing_queue.append((addr, self.__build_FIND_NODE_query(node_id[:15] + self.__true_id[:5])))
for node_id, addr in self._routing_table.items():
self._transport.sendto(self.__build_FIND_NODE_query(node_id[:15] + self.__true_id[:5]), addr)
@staticmethod
def __decode_nodes(infos: bytes) -> typing.List[typing.Tuple[NodeID, NodeAddress]]: