mainline/service done, also changed the signatures of transport signals

This commit is contained in:
Bora M. Alper 2017-07-15 23:27:45 +03:00
parent 57d466a666
commit e0241fe48c
4 changed files with 149 additions and 44 deletions

View File

@ -14,7 +14,6 @@
# <http://www.gnu.org/licenses/>. # <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import enum import enum
import functools
import typing import typing
import cerberus import cerberus
@ -23,47 +22,54 @@ from . import transport
class Protocol: class Protocol:
def __init__(self, *, client_version: bytes=b"mc00"): def __init__(self, client_version: bytes):
self.client_version = client_version self._client_version = client_version
self.transport = transport.Transport() self._transport = transport.Transport()
self.transport.on_message = functools.partial(self.__when_message, self) self._transport.on_message = self.__when_message
async def launch(self, address: transport.Address): async def launch(self, address: transport.Address):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: self.transport, local_addr=address) await asyncio.get_event_loop().create_datagram_endpoint(lambda: self._transport, local_addr=address)
# Offered Functionality # Offered Functionality
# ===================== # =====================
def make_query(self, query: BaseQuery, address: transport.Address) -> None:
return self._transport.send_message(query.to_message(b"\0\0", self._client_version), address)
@staticmethod @staticmethod
def on_ping_query(query: PingQuery) -> typing.Optional[typing.Union[PingResponse, Error]]: def on_ping_query(query: PingQuery, address: transport.Address) \
-> typing.Optional[typing.Union[PingResponse, Error]]:
pass pass
@staticmethod @staticmethod
def on_find_node_query(query: FindNodeQuery) -> typing.Optional[typing.Union[FindNodeResponse, Error]]: def on_find_node_query(query: FindNodeQuery, address: transport.Address) \
-> typing.Optional[typing.Union[FindNodeResponse, Error]]:
pass pass
@staticmethod @staticmethod
def on_get_peers_query(query: GetPeersQuery) -> typing.Optional[typing.Union[GetPeersQuery, Error]]: def on_get_peers_query(query: GetPeersQuery, address: transport.Address) \
-> typing.Optional[typing.Union[GetPeersResponse, Error]]:
pass pass
@staticmethod @staticmethod
def on_announce_peer_query(query: AnnouncePeerQuery) -> typing.Optional[typing.Union[AnnouncePeerResponse, Error]]: def on_announce_peer_query(query: AnnouncePeerQuery, address: transport.Address) \
-> typing.Optional[typing.Union[AnnouncePeerResponse, Error]]:
pass pass
@staticmethod @staticmethod
def on_ping_OR_announce_peer_response(response: PingResponse) -> None: def on_ping_OR_announce_peer_response(response: PingResponse, address: transport.Address) -> None:
pass pass
@staticmethod @staticmethod
def on_find_node_response(response: FindNodeResponse) -> None: def on_find_node_response(response: FindNodeResponse, address: transport.Address) -> None:
pass pass
@staticmethod @staticmethod
def on_get_peers_response(response: GetPeersResponse) -> None: def on_get_peers_response(response: GetPeersResponse, address: transport.Address) -> None:
pass pass
@staticmethod @staticmethod
def on_error(error: Error) -> None: def on_error(error: Error, address: transport.Address) -> None:
pass pass
# Private Functionality # Private Functionality
@ -79,18 +85,18 @@ class Protocol:
if AnnouncePeerQuery.validate_message(message): if AnnouncePeerQuery.validate_message(message):
response = self.on_announce_peer_query(AnnouncePeerQuery( response = self.on_announce_peer_query(AnnouncePeerQuery(
args[b"id"], args[b"info_hash"], args[b"port"], args[b"token"], args[b"implied_port"] args[b"id"], args[b"info_hash"], args[b"port"], args[b"token"], args[b"implied_port"]
)) ), address)
elif GetPeersQuery.validate_message(message): elif GetPeersQuery.validate_message(message):
response = self.on_get_peers_query(GetPeersQuery(args[b"id"], args[b"info_hash"])) response = self.on_get_peers_query(GetPeersQuery(args[b"id"], args[b"info_hash"]), address)
elif FindNodeQuery.validate_message(message): elif FindNodeQuery.validate_message(message):
response = self.on_find_node_query(FindNodeQuery(args[b"id"], args[b"target"])) response = self.on_find_node_query(FindNodeQuery(args[b"id"], args[b"target"]), address)
elif PingQuery.validate_message(message): elif PingQuery.validate_message(message):
response = self.on_ping_query(PingQuery(args[b"id"])) response = self.on_ping_query(PingQuery(args[b"id"]), address)
else: else:
# Unknown Query received! # Unknown Query received!
response = None response = None
if response: if response:
self.transport.send_message(response.to_message(message[b"t"], self.client_version), address) self._transport.send_message(response.to_message(message[b"t"], self._client_version), address)
elif BaseResponse.validate_message(message): elif BaseResponse.validate_message(message):
return_values = message[b"r"] return_values = message[b"r"]
@ -98,22 +104,22 @@ class Protocol:
if b"nodes" in return_values: if b"nodes" in return_values:
self.on_get_peers_response(GetPeersResponse( self.on_get_peers_response(GetPeersResponse(
return_values[b"id"], return_values[b"token"], nodes=return_values[b"nodes"] return_values[b"id"], return_values[b"token"], nodes=return_values[b"nodes"]
)) ), address)
else: else:
self.on_get_peers_response(GetPeersResponse( self.on_get_peers_response(GetPeersResponse(
return_values[b"id"], return_values[b"token"], values=return_values[b"values"] return_values[b"id"], return_values[b"token"], values=return_values[b"values"]
)) ), address)
elif FindNodeResponse.validate_message(message): elif FindNodeResponse.validate_message(message):
self.on_find_node_response(FindNodeResponse(return_values[b"id"], return_values[b"nodes"])) self.on_find_node_response(FindNodeResponse(return_values[b"id"], return_values[b"nodes"]), address)
elif PingResponse.validate_message(message): elif PingResponse.validate_message(message):
self.on_ping_OR_announce_peer_response(PingResponse(return_values[b"id"])) self.on_ping_OR_announce_peer_response(PingResponse(return_values[b"id"]), address)
else: else:
# Unknown Response received! # Unknown Response received!
pass pass
elif Error.validate_message(message): elif Error.validate_message(message):
if Error.validate_message(message): if Error.validate_message(message):
self.on_error(Error(message[b"e"][0], message[b"e"][1])) self.on_error(Error(message[b"e"][0], message[b"e"][1]), address)
else: else:
# Erroneous Error received! # Erroneous Error received!
pass pass
@ -308,11 +314,11 @@ class GetPeersResponse(BaseResponse):
} }
__validator = cerberus.Validator() __validator = cerberus.Validator()
def __init__(self, id_: NodeID, token: bytes, *, values: typing.Optional[typing.List[bytes]]=None, def __init__(self, id_: NodeID, token: bytes, *, values: typing.Optional[typing.List[transport.Address]]=None,
nodes: typing.Optional[typing.List[NodeInfo]]=None nodes: typing.Optional[typing.List[NodeInfo]]=None
): ):
if not bool(values) ^ bool(nodes): if not (values and nodes):
raise ValueError("Supply either `values` or `nodes` but not both or neither.") raise ValueError("Supply either `values` or `nodes` or neither but not both.")
super().__init__(id_) super().__init__(id_)
self.token = token self.token = token

View File

@ -12,3 +12,78 @@
# #
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see # You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>. # <http://www.gnu.org/licenses/>.
import asyncio
import hashlib
import os
import socket
import typing
from magneticod import constants
from . import protocol
class TrawlingService:
def __init__(self):
self._protocol = protocol.Protocol(b"mc00")
self._protocol.on_get_peers_query = self._when_get_peers_query
self._protocol.on_announce_peer_query = self._when_announce_peer_query
self._protocol.on_find_node_response = self._when_find_node_response
self._true_node_id = os.urandom(20)
self._token_secret = os.urandom(4)
self._routing_table = {} # typing.Dict[protocol.NodeID, protocol.transport.Address]
async def launch(self, address: protocol.transport.Address):
await self._protocol.launch(address)
# Offered Functionality
# =====================
@staticmethod
def on_info_hash_and_peer(info_hash: protocol.InfoHash, address: protocol.transport.Address) -> None:
pass
# Private Functionality
# =====================
async def tick_periodically(self) -> None:
while True:
if not self._routing_table:
await self._bootstrap()
else:
self._make_neighbors()
self._routing_table.clear()
await asyncio.sleep(constants.TICK_INTERVAL)
async def _bootstrap(self) -> None:
event_loop = asyncio.get_event_loop()
for node in constants.BOOTSTRAPPING_NODES:
for *_, address in await event_loop.getaddrinfo(*node, family=socket.AF_INET):
self._protocol.make_query(protocol.FindNodeQuery(self._true_node_id, os.urandom(20)), address)
def _make_neighbors(self) -> None:
for id_, address in self._routing_table.items():
self._protocol.make_query(
protocol.FindNodeQuery(id_[:15] + self._true_node_id[:5], os.urandom(20)),
address
)
def _when_get_peers_query(self, query: protocol.GetPeersQuery, address: protocol.transport.Address) \
-> typing.Optional[typing.Union[protocol.GetPeersResponse, protocol.Error]]:
return protocol.GetPeersResponse(query.info_hash[:15] + self._true_node_id[:5], self._calculate_token(address))
def _when_announce_peer_query(self, query: protocol.AnnouncePeerQuery, address: protocol.transport.Address) \
-> typing.Optional[typing.Union[protocol.AnnouncePeerResponse, protocol.Error]]:
if query.implied_port:
peer_address = (address[0], address[1])
else:
peer_address = (address[0], query.port)
self.on_info_hash_and_peer(query.info_hash, peer_address)
return protocol.AnnouncePeerResponse(query.info_hash[:15] + self._true_node_id[:5])
def _when_find_node_response(self, response: protocol.FindNodeResponse, address: protocol.transport.Address) \
-> None:
self._routing_table.update({node.id: node.address for node in response.nodes if node.address != 0})
def _calculate_token(self, address: protocol.transport.Address) -> bytes:
return hashlib.sha1(b"%s%d" % (socket.inet_aton(address[0]), socket.htons(address[1]))).digest()

View File

@ -42,18 +42,18 @@ class Transport(asyncio.DatagramProtocol):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.__datagram_transport = asyncio.DatagramTransport() self._datagram_transport = asyncio.DatagramTransport()
self.__write_allowed = asyncio.Event() self._write_allowed = asyncio.Event()
self.__queue_nonempty = asyncio.Event() self._queue_nonempty = asyncio.Event()
self.__message_queue = collections.deque() # type: typing.Deque[MessageQueueEntry] self._message_queue = collections.deque() # type: typing.Deque[MessageQueueEntry]
self.__messenger_task = asyncio.Task(self.__send_messages()) self._messenger_task = asyncio.Task(self._send_messages())
# Offered Functionality # Offered Functionality
# ===================== # =====================
def send_message(self, message, address: Address) -> None: def send_message(self, message, address: Address) -> None:
self.__message_queue.append(MessageQueueEntry(time.monotonic(), message, address)) self._message_queue.append(MessageQueueEntry(time.monotonic(), message, address))
if not self.__queue_nonempty.is_set(): if not self._queue_nonempty.is_set():
self.__queue_nonempty.set() self._queue_nonempty.set()
@staticmethod @staticmethod
def on_message(message: dict, address: Address): def on_message(message: dict, address: Address):
@ -62,10 +62,15 @@ class Transport(asyncio.DatagramProtocol):
# Private Functionality # Private Functionality
# ===================== # =====================
def connection_made(self, transport: asyncio.DatagramTransport) -> None: def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self.__datagram_transport = transport self._datagram_transport = transport
self.__write_allowed.set() self._write_allowed.set()
def datagram_received(self, data: bytes, address: Address) -> None: def datagram_received(self, data: bytes, address: Address) -> None:
# Ignore nodes that "uses" port 0, as we cannot communicate with them reliably across the different systems.
# See https://tools.cisco.com/security/center/viewAlert.x?alertId=19935 for slightly more details
if address[1] == 0:
return
try: try:
message = codec.decode(data) message = codec.decode(data)
except codec.EncodeError: except codec.EncodeError:
@ -80,10 +85,10 @@ class Transport(asyncio.DatagramProtocol):
logging.debug("Mainline DHT received error!", exc_info=exc) logging.debug("Mainline DHT received error!", exc_info=exc)
def pause_writing(self): def pause_writing(self):
self.__write_allowed.clear() self._write_allowed.clear()
def resume_writing(self): def resume_writing(self):
self.__write_allowed.set() self._write_allowed.set()
def connection_lost(self, exc: Exception): def connection_lost(self, exc: Exception):
if exc: if exc:
@ -94,16 +99,16 @@ class Transport(asyncio.DatagramProtocol):
logging.fatal("Mainline DHT lost connection!") logging.fatal("Mainline DHT lost connection!")
sys.exit(1) sys.exit(1)
async def __send_messages(self) -> None: async def _send_messages(self) -> None:
while True: while True:
await asyncio.wait([self.__write_allowed.wait(), self.__queue_nonempty.wait()]) await asyncio.wait([self._write_allowed.wait(), self._queue_nonempty.wait()])
try: try:
queued_on, message, address = self.__message_queue.pop() queued_on, message, address = self._message_queue.pop()
except IndexError: except IndexError:
self.__queue_nonempty.clear() self._queue_nonempty.clear()
continue continue
if time.monotonic() - queued_on > 60: if time.monotonic() - queued_on > 60:
return return
self.__datagram_transport.sendto(message, address) self._datagram_transport.sendto(message, address)

View File

@ -0,0 +1,19 @@
# magneticod - Autonomous BitTorrent DHT crawler and metadata fetcher.
# Copyright (C) 2017 Mert Bora ALPER <bora@boramalper.org>
# Dedicated to Cemile Binay, in whose hands I thrived.
#
# This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <http://www.gnu.org/licenses/>.
class InfoHashSink:
def __init__(self):
pass