mainline/service done, also changed the signatures of transport signals
This commit is contained in:
parent
57d466a666
commit
e0241fe48c
@ -14,7 +14,6 @@
|
||||
# <http://www.gnu.org/licenses/>.
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import cerberus
|
||||
@ -23,47 +22,54 @@ from . import transport
|
||||
|
||||
|
||||
class Protocol:
|
||||
def __init__(self, *, client_version: bytes=b"mc00"):
|
||||
self.client_version = client_version
|
||||
self.transport = transport.Transport()
|
||||
def __init__(self, client_version: bytes):
|
||||
self._client_version = client_version
|
||||
self._transport = transport.Transport()
|
||||
|
||||
self.transport.on_message = functools.partial(self.__when_message, self)
|
||||
self._transport.on_message = self.__when_message
|
||||
|
||||
async def launch(self, address: transport.Address):
|
||||
await asyncio.get_event_loop().create_datagram_endpoint(lambda: self.transport, local_addr=address)
|
||||
await asyncio.get_event_loop().create_datagram_endpoint(lambda: self._transport, local_addr=address)
|
||||
|
||||
# Offered Functionality
|
||||
# =====================
|
||||
def make_query(self, query: BaseQuery, address: transport.Address) -> None:
|
||||
return self._transport.send_message(query.to_message(b"\0\0", self._client_version), address)
|
||||
|
||||
@staticmethod
|
||||
def on_ping_query(query: PingQuery) -> typing.Optional[typing.Union[PingResponse, Error]]:
|
||||
def on_ping_query(query: PingQuery, address: transport.Address) \
|
||||
-> typing.Optional[typing.Union[PingResponse, Error]]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_find_node_query(query: FindNodeQuery) -> typing.Optional[typing.Union[FindNodeResponse, Error]]:
|
||||
def on_find_node_query(query: FindNodeQuery, address: transport.Address) \
|
||||
-> typing.Optional[typing.Union[FindNodeResponse, Error]]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_get_peers_query(query: GetPeersQuery) -> typing.Optional[typing.Union[GetPeersQuery, Error]]:
|
||||
def on_get_peers_query(query: GetPeersQuery, address: transport.Address) \
|
||||
-> typing.Optional[typing.Union[GetPeersResponse, Error]]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_announce_peer_query(query: AnnouncePeerQuery) -> typing.Optional[typing.Union[AnnouncePeerResponse, Error]]:
|
||||
def on_announce_peer_query(query: AnnouncePeerQuery, address: transport.Address) \
|
||||
-> typing.Optional[typing.Union[AnnouncePeerResponse, Error]]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_ping_OR_announce_peer_response(response: PingResponse) -> None:
|
||||
def on_ping_OR_announce_peer_response(response: PingResponse, address: transport.Address) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_find_node_response(response: FindNodeResponse) -> None:
|
||||
def on_find_node_response(response: FindNodeResponse, address: transport.Address) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_get_peers_response(response: GetPeersResponse) -> None:
|
||||
def on_get_peers_response(response: GetPeersResponse, address: transport.Address) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def on_error(error: Error) -> None:
|
||||
def on_error(error: Error, address: transport.Address) -> None:
|
||||
pass
|
||||
|
||||
# Private Functionality
|
||||
@ -79,18 +85,18 @@ class Protocol:
|
||||
if AnnouncePeerQuery.validate_message(message):
|
||||
response = self.on_announce_peer_query(AnnouncePeerQuery(
|
||||
args[b"id"], args[b"info_hash"], args[b"port"], args[b"token"], args[b"implied_port"]
|
||||
))
|
||||
), address)
|
||||
elif GetPeersQuery.validate_message(message):
|
||||
response = self.on_get_peers_query(GetPeersQuery(args[b"id"], args[b"info_hash"]))
|
||||
response = self.on_get_peers_query(GetPeersQuery(args[b"id"], args[b"info_hash"]), address)
|
||||
elif FindNodeQuery.validate_message(message):
|
||||
response = self.on_find_node_query(FindNodeQuery(args[b"id"], args[b"target"]))
|
||||
response = self.on_find_node_query(FindNodeQuery(args[b"id"], args[b"target"]), address)
|
||||
elif PingQuery.validate_message(message):
|
||||
response = self.on_ping_query(PingQuery(args[b"id"]))
|
||||
response = self.on_ping_query(PingQuery(args[b"id"]), address)
|
||||
else:
|
||||
# Unknown Query received!
|
||||
response = None
|
||||
if response:
|
||||
self.transport.send_message(response.to_message(message[b"t"], self.client_version), address)
|
||||
self._transport.send_message(response.to_message(message[b"t"], self._client_version), address)
|
||||
|
||||
elif BaseResponse.validate_message(message):
|
||||
return_values = message[b"r"]
|
||||
@ -98,22 +104,22 @@ class Protocol:
|
||||
if b"nodes" in return_values:
|
||||
self.on_get_peers_response(GetPeersResponse(
|
||||
return_values[b"id"], return_values[b"token"], nodes=return_values[b"nodes"]
|
||||
))
|
||||
), address)
|
||||
else:
|
||||
self.on_get_peers_response(GetPeersResponse(
|
||||
return_values[b"id"], return_values[b"token"], values=return_values[b"values"]
|
||||
))
|
||||
), address)
|
||||
elif FindNodeResponse.validate_message(message):
|
||||
self.on_find_node_response(FindNodeResponse(return_values[b"id"], return_values[b"nodes"]))
|
||||
self.on_find_node_response(FindNodeResponse(return_values[b"id"], return_values[b"nodes"]), address)
|
||||
elif PingResponse.validate_message(message):
|
||||
self.on_ping_OR_announce_peer_response(PingResponse(return_values[b"id"]))
|
||||
self.on_ping_OR_announce_peer_response(PingResponse(return_values[b"id"]), address)
|
||||
else:
|
||||
# Unknown Response received!
|
||||
pass
|
||||
|
||||
elif Error.validate_message(message):
|
||||
if Error.validate_message(message):
|
||||
self.on_error(Error(message[b"e"][0], message[b"e"][1]))
|
||||
self.on_error(Error(message[b"e"][0], message[b"e"][1]), address)
|
||||
else:
|
||||
# Erroneous Error received!
|
||||
pass
|
||||
@ -308,11 +314,11 @@ class GetPeersResponse(BaseResponse):
|
||||
}
|
||||
__validator = cerberus.Validator()
|
||||
|
||||
def __init__(self, id_: NodeID, token: bytes, *, values: typing.Optional[typing.List[bytes]]=None,
|
||||
def __init__(self, id_: NodeID, token: bytes, *, values: typing.Optional[typing.List[transport.Address]]=None,
|
||||
nodes: typing.Optional[typing.List[NodeInfo]]=None
|
||||
):
|
||||
if not bool(values) ^ bool(nodes):
|
||||
raise ValueError("Supply either `values` or `nodes` but not both or neither.")
|
||||
if not (values and nodes):
|
||||
raise ValueError("Supply either `values` or `nodes` or neither but not both.")
|
||||
|
||||
super().__init__(id_)
|
||||
self.token = token
|
||||
|
@ -12,3 +12,78 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
|
||||
# <http://www.gnu.org/licenses/>.
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from magneticod import constants
|
||||
from . import protocol
|
||||
|
||||
|
||||
class TrawlingService:
|
||||
def __init__(self):
|
||||
self._protocol = protocol.Protocol(b"mc00")
|
||||
|
||||
self._protocol.on_get_peers_query = self._when_get_peers_query
|
||||
self._protocol.on_announce_peer_query = self._when_announce_peer_query
|
||||
self._protocol.on_find_node_response = self._when_find_node_response
|
||||
|
||||
self._true_node_id = os.urandom(20)
|
||||
self._token_secret = os.urandom(4)
|
||||
self._routing_table = {} # typing.Dict[protocol.NodeID, protocol.transport.Address]
|
||||
|
||||
async def launch(self, address: protocol.transport.Address):
|
||||
await self._protocol.launch(address)
|
||||
|
||||
# Offered Functionality
|
||||
# =====================
|
||||
@staticmethod
|
||||
def on_info_hash_and_peer(info_hash: protocol.InfoHash, address: protocol.transport.Address) -> None:
|
||||
pass
|
||||
|
||||
# Private Functionality
|
||||
# =====================
|
||||
async def tick_periodically(self) -> None:
|
||||
while True:
|
||||
if not self._routing_table:
|
||||
await self._bootstrap()
|
||||
else:
|
||||
self._make_neighbors()
|
||||
self._routing_table.clear()
|
||||
await asyncio.sleep(constants.TICK_INTERVAL)
|
||||
|
||||
async def _bootstrap(self) -> None:
|
||||
event_loop = asyncio.get_event_loop()
|
||||
for node in constants.BOOTSTRAPPING_NODES:
|
||||
for *_, address in await event_loop.getaddrinfo(*node, family=socket.AF_INET):
|
||||
self._protocol.make_query(protocol.FindNodeQuery(self._true_node_id, os.urandom(20)), address)
|
||||
|
||||
def _make_neighbors(self) -> None:
|
||||
for id_, address in self._routing_table.items():
|
||||
self._protocol.make_query(
|
||||
protocol.FindNodeQuery(id_[:15] + self._true_node_id[:5], os.urandom(20)),
|
||||
address
|
||||
)
|
||||
|
||||
def _when_get_peers_query(self, query: protocol.GetPeersQuery, address: protocol.transport.Address) \
|
||||
-> typing.Optional[typing.Union[protocol.GetPeersResponse, protocol.Error]]:
|
||||
return protocol.GetPeersResponse(query.info_hash[:15] + self._true_node_id[:5], self._calculate_token(address))
|
||||
|
||||
def _when_announce_peer_query(self, query: protocol.AnnouncePeerQuery, address: protocol.transport.Address) \
|
||||
-> typing.Optional[typing.Union[protocol.AnnouncePeerResponse, protocol.Error]]:
|
||||
if query.implied_port:
|
||||
peer_address = (address[0], address[1])
|
||||
else:
|
||||
peer_address = (address[0], query.port)
|
||||
self.on_info_hash_and_peer(query.info_hash, peer_address)
|
||||
|
||||
return protocol.AnnouncePeerResponse(query.info_hash[:15] + self._true_node_id[:5])
|
||||
|
||||
def _when_find_node_response(self, response: protocol.FindNodeResponse, address: protocol.transport.Address) \
|
||||
-> None:
|
||||
self._routing_table.update({node.id: node.address for node in response.nodes if node.address != 0})
|
||||
|
||||
def _calculate_token(self, address: protocol.transport.Address) -> bytes:
|
||||
return hashlib.sha1(b"%s%d" % (socket.inet_aton(address[0]), socket.htons(address[1]))).digest()
|
||||
|
@ -42,18 +42,18 @@ class Transport(asyncio.DatagramProtocol):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.__datagram_transport = asyncio.DatagramTransport()
|
||||
self.__write_allowed = asyncio.Event()
|
||||
self.__queue_nonempty = asyncio.Event()
|
||||
self.__message_queue = collections.deque() # type: typing.Deque[MessageQueueEntry]
|
||||
self.__messenger_task = asyncio.Task(self.__send_messages())
|
||||
self._datagram_transport = asyncio.DatagramTransport()
|
||||
self._write_allowed = asyncio.Event()
|
||||
self._queue_nonempty = asyncio.Event()
|
||||
self._message_queue = collections.deque() # type: typing.Deque[MessageQueueEntry]
|
||||
self._messenger_task = asyncio.Task(self._send_messages())
|
||||
|
||||
# Offered Functionality
|
||||
# =====================
|
||||
def send_message(self, message, address: Address) -> None:
|
||||
self.__message_queue.append(MessageQueueEntry(time.monotonic(), message, address))
|
||||
if not self.__queue_nonempty.is_set():
|
||||
self.__queue_nonempty.set()
|
||||
self._message_queue.append(MessageQueueEntry(time.monotonic(), message, address))
|
||||
if not self._queue_nonempty.is_set():
|
||||
self._queue_nonempty.set()
|
||||
|
||||
@staticmethod
|
||||
def on_message(message: dict, address: Address):
|
||||
@ -62,10 +62,15 @@ class Transport(asyncio.DatagramProtocol):
|
||||
# Private Functionality
|
||||
# =====================
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
||||
self.__datagram_transport = transport
|
||||
self.__write_allowed.set()
|
||||
self._datagram_transport = transport
|
||||
self._write_allowed.set()
|
||||
|
||||
def datagram_received(self, data: bytes, address: Address) -> None:
|
||||
# Ignore nodes that "uses" port 0, as we cannot communicate with them reliably across the different systems.
|
||||
# See https://tools.cisco.com/security/center/viewAlert.x?alertId=19935 for slightly more details
|
||||
if address[1] == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
message = codec.decode(data)
|
||||
except codec.EncodeError:
|
||||
@ -80,10 +85,10 @@ class Transport(asyncio.DatagramProtocol):
|
||||
logging.debug("Mainline DHT received error!", exc_info=exc)
|
||||
|
||||
def pause_writing(self):
|
||||
self.__write_allowed.clear()
|
||||
self._write_allowed.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self.__write_allowed.set()
|
||||
self._write_allowed.set()
|
||||
|
||||
def connection_lost(self, exc: Exception):
|
||||
if exc:
|
||||
@ -94,16 +99,16 @@ class Transport(asyncio.DatagramProtocol):
|
||||
logging.fatal("Mainline DHT lost connection!")
|
||||
sys.exit(1)
|
||||
|
||||
async def __send_messages(self) -> None:
|
||||
async def _send_messages(self) -> None:
|
||||
while True:
|
||||
await asyncio.wait([self.__write_allowed.wait(), self.__queue_nonempty.wait()])
|
||||
await asyncio.wait([self._write_allowed.wait(), self._queue_nonempty.wait()])
|
||||
try:
|
||||
queued_on, message, address = self.__message_queue.pop()
|
||||
queued_on, message, address = self._message_queue.pop()
|
||||
except IndexError:
|
||||
self.__queue_nonempty.clear()
|
||||
self._queue_nonempty.clear()
|
||||
continue
|
||||
|
||||
if time.monotonic() - queued_on > 60:
|
||||
return
|
||||
|
||||
self.__datagram_transport.sendto(message, address)
|
||||
self._datagram_transport.sendto(message, address)
|
||||
|
@ -0,0 +1,19 @@
|
||||
# magneticod - Autonomous BitTorrent DHT crawler and metadata fetcher.
|
||||
# Copyright (C) 2017 Mert Bora ALPER <bora@boramalper.org>
|
||||
# Dedicated to Cemile Binay, in whose hands I thrived.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General
|
||||
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any
|
||||
# later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
|
||||
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
|
||||
# details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
|
||||
# <http://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
class InfoHashSink:
|
||||
def __init__(self):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user