magneticod: metadata leech refactored heavily, now much more readable <3
+ persistence: we now make sure that rows are always closed using `defer`
This commit is contained in:
parent
0614e9e0f9
commit
c07daa3eca
@ -56,3 +56,7 @@
|
|||||||
[[constraint]]
|
[[constraint]]
|
||||||
name = "go.uber.org/zap"
|
name = "go.uber.org/zap"
|
||||||
version = "1.7.1"
|
version = "1.7.1"
|
||||||
|
|
||||||
|
[[constraint]]
|
||||||
|
name = "github.com/libp2p/go-sockaddr"
|
||||||
|
version = "1.0.3"
|
||||||
|
439
cmd/magneticod/bittorrent/metadata/leech.go
Normal file
439
cmd/magneticod/bittorrent/metadata/leech.go
Normal file
@ -0,0 +1,439 @@
|
|||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/anacrolix/torrent/bencode"
|
||||||
|
"github.com/anacrolix/torrent/metainfo"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/boramalper/magnetico/pkg/persistence"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MAX_METADATA_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
type rootDict struct {
|
||||||
|
M mDict `bencode:"m"`
|
||||||
|
MetadataSize int `bencode:"metadata_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type mDict struct {
|
||||||
|
UTMetadata int `bencode:"ut_metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type extDict struct {
|
||||||
|
MsgType int `bencode:"msg_type"`
|
||||||
|
Piece int `bencode:"piece"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Leech struct {
|
||||||
|
infoHash [20]byte
|
||||||
|
peerAddr *net.TCPAddr
|
||||||
|
ev LeechEventHandlers
|
||||||
|
|
||||||
|
conn *net.TCPConn
|
||||||
|
clientID [20]byte
|
||||||
|
|
||||||
|
ut_metadata uint8
|
||||||
|
metadataReceived, metadataSize uint
|
||||||
|
metadata []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type LeechEventHandlers struct {
|
||||||
|
OnSuccess func(Metadata) // must be supplied. args: metadata
|
||||||
|
OnError func([20]byte, error) // must be supplied. args: infohash, error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLeech(infoHash [20]byte, peerAddr *net.TCPAddr, ev LeechEventHandlers) *Leech {
|
||||||
|
l := new(Leech)
|
||||||
|
l.infoHash = infoHash
|
||||||
|
l.peerAddr = peerAddr
|
||||||
|
l.ev = ev
|
||||||
|
|
||||||
|
if _, err := rand.Read(l.clientID[:]); err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) writeAll(b []byte) error {
|
||||||
|
for len(b) != 0 {
|
||||||
|
n, err := l.conn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) doBtHandshake() error {
|
||||||
|
lHandshake := []byte(fmt.Sprintf(
|
||||||
|
"\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s",
|
||||||
|
l.infoHash[:],
|
||||||
|
l.clientID,
|
||||||
|
))
|
||||||
|
|
||||||
|
// ASSERTION
|
||||||
|
if len(lHandshake) != 68 { panic(fmt.Sprintf("len(lHandshake) == %d", len(lHandshake))) }
|
||||||
|
|
||||||
|
|
||||||
|
err := l.writeAll(lHandshake)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "writeAll lHandshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
zap.L().Debug("BitTorrent handshake sent, waiting for the remote's...")
|
||||||
|
|
||||||
|
rHandshake, err := l.readExactly(68)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "readExactly rHandshake")
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) {
|
||||||
|
return fmt.Errorf("corrupt BitTorrent handshake received")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: maybe check for the infohash sent by the remote peer to double check?
|
||||||
|
|
||||||
|
if (rHandshake[25] & 0x10) == 0 {
|
||||||
|
return fmt.Errorf("peer does not support the extension protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) doExHandshake() error {
|
||||||
|
err := l.writeAll([]byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee"))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "writeAll lHandshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
rExMessage, err := l.readExMessage()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "readExMessage")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extension Handshake has the Extension Message ID = 0x00
|
||||||
|
if rExMessage[1] != 0 {
|
||||||
|
return errors.Wrap(err, "first extension message is not an extension handshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
rRootDict := new(rootDict)
|
||||||
|
err = bencode.Unmarshal(rExMessage[2:], rRootDict)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "unmarshal rExMessage")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(0 < rRootDict.MetadataSize && rRootDict.MetadataSize < MAX_METADATA_SIZE) {
|
||||||
|
return fmt.Errorf("metadata too big or its size is less than or equal zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(0 < rRootDict.M.UTMetadata && rRootDict.M.UTMetadata < 255) {
|
||||||
|
return fmt.Errorf("ut_metadata is not an uint8")
|
||||||
|
}
|
||||||
|
|
||||||
|
l.ut_metadata = uint8(rRootDict.M.UTMetadata) // Save the ut_metadata code the remote peer uses
|
||||||
|
l.metadataSize = uint(rRootDict.MetadataSize)
|
||||||
|
l.metadata = make([]byte, l.metadataSize)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) requestAllPieces() error {
|
||||||
|
// Request all the pieces of metadata
|
||||||
|
nPieces := int(math.Ceil(float64(l.metadataSize) / math.Pow(2, 14)))
|
||||||
|
for piece := 0; piece < nPieces; piece++ {
|
||||||
|
// __request_metadata_piece(piece)
|
||||||
|
// ...............................
|
||||||
|
extDictDump, err := bencode.Marshal(extDict{
|
||||||
|
MsgType: 0,
|
||||||
|
Piece: piece,
|
||||||
|
})
|
||||||
|
if err != nil { // ASSERT
|
||||||
|
panic(errors.Wrap(err, "marshal extDict"))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = l.writeAll([]byte(fmt.Sprintf(
|
||||||
|
"%s\x14%s%s",
|
||||||
|
toBigEndian(uint(2 + len(extDictDump)), 4),
|
||||||
|
toBigEndian(uint(l.ut_metadata), 1),
|
||||||
|
extDictDump,
|
||||||
|
)))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "writeAll piece request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMessage returns a BitTorrent message, sans the first 4 bytes indicating its length.
|
||||||
|
func (l *Leech) readMessage() ([]byte, error) {
|
||||||
|
rLengthB, err := l.readExactly(4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "readExactly rLengthB")
|
||||||
|
}
|
||||||
|
|
||||||
|
rLength := uint(binary.BigEndian.Uint32(rLengthB))
|
||||||
|
|
||||||
|
rMessage, err := l.readExactly(rLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "readExactly rMessage")
|
||||||
|
}
|
||||||
|
|
||||||
|
return rMessage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readExMessage returns an *extension* message, sans the first 4 bytes indicating its length.
|
||||||
|
//
|
||||||
|
// It will IGNORE all non-extension messages!
|
||||||
|
func (l *Leech) readExMessage() ([]byte, error) {
|
||||||
|
for {
|
||||||
|
rMessage, err := l.readMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "readMessage")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are interested only in extension messages, whose first byte is always 20
|
||||||
|
if rMessage[0] == 20 {
|
||||||
|
return rMessage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readUmMessage returns an ut_metadata extension message, sans the first 4 bytes indicating its
|
||||||
|
// length.
|
||||||
|
//
|
||||||
|
// It will IGNORE all non-"ut_metadata extension" messages!
|
||||||
|
func (l *Leech) readUmMessage() ([]byte, error) {
|
||||||
|
for {
|
||||||
|
rExMessage, err := l.readExMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "readExMessage")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rExMessage[1] == 0x01 {
|
||||||
|
return rExMessage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) connect(deadline time.Time) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
l.conn, err = net.DialTCP("tcp", nil, l.peerAddr)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "dial")
|
||||||
|
}
|
||||||
|
defer l.conn.Close()
|
||||||
|
|
||||||
|
err = l.conn.SetNoDelay(true)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "NODELAY")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = l.conn.SetDeadline(deadline)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "SetDeadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) Do(deadline time.Time) {
|
||||||
|
err := l.connect(deadline)
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "connect"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = l.doBtHandshake()
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "doBtHandshake"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = l.doExHandshake()
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "doExHandshake"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = l.requestAllPieces()
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "requestAllPieces"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for l.metadataReceived < l.metadataSize {
|
||||||
|
rUmMessage, err := l.readUmMessage()
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "readUmMessage"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run TestDecoder() function in leech_test.go in case you have any doubts.
|
||||||
|
rMessageBuf := bytes.NewBuffer(rUmMessage[2:])
|
||||||
|
rExtDict := new(extDict)
|
||||||
|
err = bencode.NewDecoder(rMessageBuf).Decode(rExtDict)
|
||||||
|
if err != nil {
|
||||||
|
zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rExtDict.MsgType == 2 { // reject
|
||||||
|
l.OnError(fmt.Errorf("remote peer rejected sending metadata"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rExtDict.MsgType == 1 { // data
|
||||||
|
// Get the unread bytes!
|
||||||
|
metadataPiece := rMessageBuf.Bytes()
|
||||||
|
|
||||||
|
// BEP 9 explicitly states:
|
||||||
|
// > If the piece is the last piece of the metadata, it may be less than 16kiB. If
|
||||||
|
// > it is not the last piece of the metadata, it MUST be 16kiB.
|
||||||
|
//
|
||||||
|
// Hence...
|
||||||
|
// ... if the length of @metadataPiece is more than 16kiB, we err.
|
||||||
|
if len(metadataPiece) > 16*1024 {
|
||||||
|
l.OnError(fmt.Errorf("metadataPiece > 16kiB"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
piece := rExtDict.Piece
|
||||||
|
// metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python
|
||||||
|
copy(l.metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece)
|
||||||
|
l.metadataReceived += uint(len(metadataPiece))
|
||||||
|
|
||||||
|
// ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT
|
||||||
|
// complete then we err.
|
||||||
|
if len(metadataPiece) < 16*1024 && l.metadataReceived != l.metadataSize {
|
||||||
|
l.OnError(fmt.Errorf("metadataPiece < 16 kiB but incomplete"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.metadataReceived > l.metadataSize {
|
||||||
|
l.OnError(fmt.Errorf("metadataReceived > metadataSize"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the checksum
|
||||||
|
sha1Sum := sha1.Sum(l.metadata)
|
||||||
|
if !bytes.Equal(sha1Sum[:], l.infoHash[:]) {
|
||||||
|
l.OnError(fmt.Errorf("infohash mismatch"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the info dictionary
|
||||||
|
info := new(metainfo.Info)
|
||||||
|
err = bencode.Unmarshal(l.metadata, info)
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "unmarshal info"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = validateInfo(info)
|
||||||
|
if err != nil {
|
||||||
|
l.OnError(errors.Wrap(err, "validateInfo"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []persistence.File
|
||||||
|
// If there is only one file, there won't be a Files slice. That's why we need to add it here
|
||||||
|
if len(info.Files) == 0 {
|
||||||
|
files = append(files, persistence.File{
|
||||||
|
Size: info.Length,
|
||||||
|
Path: info.Name,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
for _, file := range info.Files {
|
||||||
|
files = append(files, persistence.File{
|
||||||
|
Size: file.Length,
|
||||||
|
Path: file.DisplayPath(info),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalSize uint64
|
||||||
|
for _, file := range files {
|
||||||
|
if file.Size < 0 {
|
||||||
|
l.OnError(fmt.Errorf("file size less than zero"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
totalSize += uint64(file.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.ev.OnSuccess(Metadata{
|
||||||
|
InfoHash: l.infoHash[:],
|
||||||
|
Name: info.Name,
|
||||||
|
TotalSize: totalSize,
|
||||||
|
DiscoveredOn: time.Now().Unix(),
|
||||||
|
Files: files,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// COPIED FROM anacrolix/torrent
|
||||||
|
func validateInfo(info *metainfo.Info) error {
|
||||||
|
if len(info.Pieces)%20 != 0 {
|
||||||
|
return errors.New("pieces has invalid length")
|
||||||
|
}
|
||||||
|
if info.PieceLength == 0 {
|
||||||
|
if info.TotalLength() != 0 {
|
||||||
|
return errors.New("zero piece length")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() {
|
||||||
|
return errors.New("piece count and file lengths are at odds")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) readExactly(n uint) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := io.ReadFull(l.conn, b)
|
||||||
|
return b, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Leech) OnError(err error) {
|
||||||
|
l.ev.OnError(l.infoHash, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add bounds checking!
|
||||||
|
func toBigEndian(i uint, n int) []byte {
|
||||||
|
b := make([]byte, n)
|
||||||
|
switch n {
|
||||||
|
case 1:
|
||||||
|
b = []byte{byte(i)}
|
||||||
|
|
||||||
|
case 2:
|
||||||
|
binary.BigEndian.PutUint16(b, uint16(i))
|
||||||
|
|
||||||
|
case 4:
|
||||||
|
binary.BigEndian.PutUint32(b, uint32(i))
|
||||||
|
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("n must be 1, 2, or 4!"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b) != n {
|
||||||
|
panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b))
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package bittorrent
|
package metadata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
@ -1,4 +1,4 @@
|
|||||||
package bittorrent
|
package metadata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
@ -24,21 +24,17 @@ type Metadata struct {
|
|||||||
Files []persistence.File
|
Files []persistence.File
|
||||||
}
|
}
|
||||||
|
|
||||||
type Peer struct {
|
type Sink struct {
|
||||||
Addr *net.TCPAddr
|
clientID []byte
|
||||||
}
|
deadline time.Duration
|
||||||
|
drain chan Metadata
|
||||||
type MetadataSink struct {
|
|
||||||
clientID []byte
|
|
||||||
deadline time.Duration
|
|
||||||
drain chan Metadata
|
|
||||||
incomingInfoHashes map[[20]byte]struct{}
|
incomingInfoHashes map[[20]byte]struct{}
|
||||||
terminated bool
|
terminated bool
|
||||||
termination chan interface{}
|
termination chan interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetadataSink(deadline time.Duration) *MetadataSink {
|
func NewSink(deadline time.Duration) *Sink {
|
||||||
ms := new(MetadataSink)
|
ms := new(Sink)
|
||||||
|
|
||||||
ms.clientID = make([]byte, 20)
|
ms.clientID = make([]byte, 20)
|
||||||
_, err := rand.Read(ms.clientID)
|
_, err := rand.Read(ms.clientID)
|
||||||
@ -52,27 +48,29 @@ func NewMetadataSink(deadline time.Duration) *MetadataSink {
|
|||||||
return ms
|
return ms
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *MetadataSink) Sink(res mainline.TrawlingResult) {
|
func (ms *Sink) Sink(res mainline.TrawlingResult) {
|
||||||
if ms.terminated {
|
if ms.terminated {
|
||||||
zap.L().Panic("Trying to Sink() an already closed MetadataSink!")
|
zap.L().Panic("Trying to Sink() an already closed Sink!")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, exists := ms.incomingInfoHashes[res.InfoHash]; exists {
|
if _, exists := ms.incomingInfoHashes[res.InfoHash]; exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// BEWARE!
|
// BEWARE!
|
||||||
// Although not crucial, the assumption is that MetadataSink.Sink() will be called by only one
|
// Although not crucial, the assumption is that Sink.Sink() will be called by only one
|
||||||
// goroutine (i.e. it's not thread-safe), lest there might be a race condition between where we
|
// goroutine (i.e. it's not thread-safe), lest there might be a race condition between where we
|
||||||
// check whether res.infoHash exists in the ms.incomingInfoHashes, and where we add the infoHash
|
// check whether res.infoHash exists in the ms.incomingInfoHashes, and where we add the infoHash
|
||||||
// to the incomingInfoHashes at the end of this function.
|
// to the incomingInfoHashes at the end of this function.
|
||||||
|
|
||||||
|
zap.L().Info("Sunk!", zap.String("infoHash", res.InfoHash.String()))
|
||||||
|
|
||||||
IPs := res.PeerIP.String()
|
IPs := res.PeerIP.String()
|
||||||
var rhostport string
|
var rhostport string
|
||||||
if IPs == "<nil>" {
|
if IPs == "<nil>" {
|
||||||
zap.L().Debug("MetadataSink.Sink: Peer IP is nil!")
|
zap.L().Debug("Sink.Sink: Peer IP is nil!")
|
||||||
return
|
return
|
||||||
} else if IPs[0] == '?' {
|
} else if IPs[0] == '?' {
|
||||||
zap.L().Debug("MetadataSink.Sink: Peer IP is invalid!")
|
zap.L().Debug("Sink.Sink: Peer IP is invalid!")
|
||||||
return
|
return
|
||||||
} else if strings.ContainsRune(IPs, ':') { // IPv6
|
} else if strings.ContainsRune(IPs, ':') { // IPv6
|
||||||
rhostport = fmt.Sprintf("[%s]:%d", IPs, res.PeerPort)
|
rhostport = fmt.Sprintf("[%s]:%d", IPs, res.PeerPort)
|
||||||
@ -82,29 +80,33 @@ func (ms *MetadataSink) Sink(res mainline.TrawlingResult) {
|
|||||||
|
|
||||||
raddr, err := net.ResolveTCPAddr("tcp", rhostport)
|
raddr, err := net.ResolveTCPAddr("tcp", rhostport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zap.L().Debug("MetadataSink.Sink: Couldn't resolve peer address!", zap.Error(err))
|
zap.L().Debug("Sink.Sink: Couldn't resolve peer address!", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go ms.awaitMetadata(res.InfoHash, Peer{Addr: raddr})
|
leech := NewLeech(res.InfoHash, raddr, LeechEventHandlers{
|
||||||
|
OnSuccess: ms.flush,
|
||||||
|
OnError: ms.onLeechError,
|
||||||
|
})
|
||||||
|
go leech.Do(time.Now().Add(ms.deadline))
|
||||||
|
|
||||||
ms.incomingInfoHashes[res.InfoHash] = struct{}{}
|
ms.incomingInfoHashes[res.InfoHash] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *MetadataSink) Drain() <-chan Metadata {
|
func (ms *Sink) Drain() <-chan Metadata {
|
||||||
if ms.terminated {
|
if ms.terminated {
|
||||||
zap.L().Panic("Trying to Drain() an already closed MetadataSink!")
|
zap.L().Panic("Trying to Drain() an already closed Sink!")
|
||||||
}
|
}
|
||||||
return ms.drain
|
return ms.drain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *MetadataSink) Terminate() {
|
func (ms *Sink) Terminate() {
|
||||||
ms.terminated = true
|
ms.terminated = true
|
||||||
close(ms.termination)
|
close(ms.termination)
|
||||||
close(ms.drain)
|
close(ms.drain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *MetadataSink) flush(result Metadata) {
|
func (ms *Sink) flush(result Metadata) {
|
||||||
if !ms.terminated {
|
if !ms.terminated {
|
||||||
ms.drain <- result
|
ms.drain <- result
|
||||||
// Delete the infoHash from ms.incomingInfoHashes ONLY AFTER once we've flushed the
|
// Delete the infoHash from ms.incomingInfoHashes ONLY AFTER once we've flushed the
|
||||||
@ -114,3 +116,8 @@ func (ms *MetadataSink) flush(result Metadata) {
|
|||||||
delete(ms.incomingInfoHashes, infoHash)
|
delete(ms.incomingInfoHashes, infoHash)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ms *Sink) onLeechError(infoHash [20]byte, err error) {
|
||||||
|
zap.L().Debug("leech error", zap.ByteString("infoHash", infoHash[:]), zap.Error(err))
|
||||||
|
delete(ms.incomingInfoHashes, infoHash)
|
||||||
|
}
|
@ -1,480 +0,0 @@
|
|||||||
package bittorrent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/anacrolix/torrent/bencode"
|
|
||||||
"github.com/anacrolix/torrent/metainfo"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
|
||||||
|
|
||||||
"github.com/boramalper/magnetico/pkg/persistence"
|
|
||||||
)
|
|
||||||
|
|
||||||
const MAX_METADATA_SIZE = 10 * 1024 * 1024
|
|
||||||
|
|
||||||
type rootDict struct {
|
|
||||||
M mDict `bencode:"m"`
|
|
||||||
MetadataSize int `bencode:"metadata_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type mDict struct {
|
|
||||||
UTMetadata int `bencode:"ut_metadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type extDict struct {
|
|
||||||
MsgType int `bencode:"msg_type"`
|
|
||||||
Piece int `bencode:"piece"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ms *MetadataSink) awaitMetadata(infoHash metainfo.Hash, peer Peer) {
|
|
||||||
conn, err := net.DialTCP("tcp", nil, peer.Addr)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"awaitMetadata couldn't connect to the peer!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
err = conn.SetNoDelay(true)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Panic(
|
|
||||||
"Couldn't set NODELAY!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = conn.SetDeadline(time.Now().Add(ms.deadline))
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Panic(
|
|
||||||
"Couldn't set the deadline!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// State Variables
|
|
||||||
var isExtHandshakeDone, done bool
|
|
||||||
var ut_metadata, metadataReceived, metadataSize int
|
|
||||||
var metadata []byte
|
|
||||||
|
|
||||||
lHandshake := []byte(fmt.Sprintf(
|
|
||||||
"\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s",
|
|
||||||
infoHash[:],
|
|
||||||
ms.clientID,
|
|
||||||
))
|
|
||||||
if len(lHandshake) != 68 {
|
|
||||||
zap.L().Panic(
|
|
||||||
"Generated BitTorrent handshake is not of length 68!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.Int("len_lHandshake", len(lHandshake)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
err = writeAll(conn, lHandshake)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Couldn't write BitTorrent handshake!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
zap.L().Debug("BitTorrent handshake sent, waiting for the remote's...")
|
|
||||||
|
|
||||||
rHandshake, err := readExactly(conn, 68)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Couldn't read remote BitTorrent handshake!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Remote BitTorrent handshake is not what it is supposed to be!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.ByteString("rHandshake[:20]", rHandshake[:20]),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// __on_bt_handshake
|
|
||||||
// ================
|
|
||||||
if rHandshake[25] != 16 { // TODO (later): do *not* compare the whole byte, check the bit instead! (0x10)
|
|
||||||
zap.L().Debug(
|
|
||||||
"Peer does not support the extension protocol!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writeAll(conn, []byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee"))
|
|
||||||
zap.L().Debug(
|
|
||||||
"Extension handshake sent, waiting for the remote's...",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// the loop!
|
|
||||||
// =========
|
|
||||||
for !done {
|
|
||||||
rLengthB, err := readExactly(conn, 4)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Couldn't read the first 4 bytes from the remote peer in the loop!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The messages we are interested in have the length of AT LEAST two bytes
|
|
||||||
// (TODO: actually a bit more than that but SURELY when it's less than two bytes, the
|
|
||||||
// program panics)
|
|
||||||
rLength := bigEndianToInt(rLengthB)
|
|
||||||
if rLength < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rMessage, err := readExactly(conn, rLength)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Couldn't read the rest of the message from the remote peer in the loop!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", peer.Addr.String()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// __on_message
|
|
||||||
// ------------
|
|
||||||
if rMessage[0] != 0x14 { // We are interested only in extension messages, whose first byte is always 0x14
|
|
||||||
zap.L().Debug(
|
|
||||||
"Ignoring the non-extension message.",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if rMessage[1] == 0x00 { // Extension Handshake has the Extension Message ID = 0x00
|
|
||||||
// __on_ext_handshake_message(message[2:])
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
// TODO: continue editing log messages from here
|
|
||||||
|
|
||||||
if isExtHandshakeDone {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rRootDict := new(rootDict)
|
|
||||||
err := bencode.Unmarshal(rMessage[2:], rRootDict)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug("Couldn't unmarshal extension handshake!", zap.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if rRootDict.MetadataSize <= 0 || rRootDict.MetadataSize > MAX_METADATA_SIZE {
|
|
||||||
zap.L().Debug("Unacceptable metadata size!", zap.Int("metadata_size", rRootDict.MetadataSize))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ut_metadata = rRootDict.M.UTMetadata // Save the ut_metadata code the remote peer uses
|
|
||||||
metadataSize = rRootDict.MetadataSize
|
|
||||||
metadata = make([]byte, metadataSize)
|
|
||||||
isExtHandshakeDone = true
|
|
||||||
|
|
||||||
zap.L().Debug("GOT EXTENSION HANDSHAKE!", zap.Int("ut_metadata", ut_metadata), zap.Int("metadata_size", metadataSize))
|
|
||||||
|
|
||||||
// Request all the pieces of metadata
|
|
||||||
n_pieces := int(math.Ceil(float64(metadataSize) / math.Pow(2, 14)))
|
|
||||||
for piece := 0; piece < n_pieces; piece++ {
|
|
||||||
// __request_metadata_piece(piece)
|
|
||||||
// ...............................
|
|
||||||
extDictDump, err := bencode.Marshal(extDict{
|
|
||||||
MsgType: 0,
|
|
||||||
Piece: piece,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Warn("Couldn't marshal extDictDump!", zap.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeAll(conn, []byte(fmt.Sprintf(
|
|
||||||
"%s\x14%s%s",
|
|
||||||
intToBigEndian(2+len(extDictDump), 4),
|
|
||||||
intToBigEndian(ut_metadata, 1),
|
|
||||||
extDictDump,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
|
|
||||||
zap.L().Warn("requested all metadata pieces!")
|
|
||||||
|
|
||||||
} else if rMessage[1] == 0x01 {
|
|
||||||
// __on_ext_message(message[2:])
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
// Run TestDecoder() function in operations_test.go in case you have any doubts.
|
|
||||||
rMessageBuf := bytes.NewBuffer(rMessage[2:])
|
|
||||||
rExtDict := new(extDict)
|
|
||||||
err := bencode.NewDecoder(rMessageBuf).Decode(rExtDict)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if rExtDict.MsgType == 1 { // data
|
|
||||||
// Get the unread bytes!
|
|
||||||
metadataPiece := rMessageBuf.Bytes()
|
|
||||||
piece := rExtDict.Piece
|
|
||||||
// metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python
|
|
||||||
copy(metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece)
|
|
||||||
metadataReceived += len(metadataPiece)
|
|
||||||
done = metadataReceived == metadataSize
|
|
||||||
|
|
||||||
// BEP 9 explicitly states:
|
|
||||||
// > If the piece is the last piece of the metadata, it may be less than 16kiB. If
|
|
||||||
// > it is not the last piece of the metadata, it MUST be 16kiB.
|
|
||||||
//
|
|
||||||
// Hence...
|
|
||||||
// ... if the length of @metadataPiece is more than 16kiB, we err.
|
|
||||||
if len(metadataPiece) > 16*1024 {
|
|
||||||
zap.L().Debug(
|
|
||||||
"metadataPiece is bigger than 16kiB!",
|
|
||||||
zap.Int("len_metadataPiece", len(metadataPiece)),
|
|
||||||
zap.Int("metadataReceived", metadataReceived),
|
|
||||||
zap.Int("metadataSize", metadataSize),
|
|
||||||
zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT
|
|
||||||
// complete (!done) then we err.
|
|
||||||
if len(metadataPiece) < 16*1024 && !done {
|
|
||||||
zap.L().Debug(
|
|
||||||
"metadataPiece is less than 16kiB and metadata is incomplete!",
|
|
||||||
zap.Int("len_metadataPiece", len(metadataPiece)),
|
|
||||||
zap.Int("metadataReceived", metadataReceived),
|
|
||||||
zap.Int("metadataSize", metadataSize),
|
|
||||||
zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadataReceived > metadataSize {
|
|
||||||
zap.L().Debug(
|
|
||||||
"metadataReceived is greater than metadataSize!",
|
|
||||||
zap.Int("len_metadataPiece", len(metadataPiece)),
|
|
||||||
zap.Int("metadataReceived", metadataReceived),
|
|
||||||
zap.Int("metadataSize", metadataSize),
|
|
||||||
zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
zap.L().Debug(
|
|
||||||
"Fetching...",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", conn.RemoteAddr().String()),
|
|
||||||
zap.Int("metadataReceived", metadataReceived),
|
|
||||||
zap.Int("metadataSize", metadataSize),
|
|
||||||
)
|
|
||||||
} else if rExtDict.MsgType == 2 { // reject
|
|
||||||
zap.L().Debug(
|
|
||||||
"Remote peer rejected sending metadata!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("remotePeerAddr", conn.RemoteAddr().String()),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Message is not an ut_metadata message! (ignoring)",
|
|
||||||
zap.ByteString("msg", rMessage[:100]),
|
|
||||||
)
|
|
||||||
// no return!
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
zap.L().Debug(
|
|
||||||
"Metadata is complete, verifying the checksum...",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
)
|
|
||||||
|
|
||||||
sha1Sum := sha1.Sum(metadata)
|
|
||||||
if !bytes.Equal(sha1Sum[:], infoHash[:]) {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Info-hash mismatch!",
|
|
||||||
zap.ByteString("expectedInfoHash", infoHash[:]),
|
|
||||||
zap.ByteString("actualInfoHash", sha1Sum[:]),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
zap.L().Debug(
|
|
||||||
"Checksum verified, checking the info dictionary...",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
)
|
|
||||||
|
|
||||||
info := new(metainfo.Info)
|
|
||||||
err = bencode.Unmarshal(metadata, info)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Couldn't unmarshal info bytes!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = validateInfo(info)
|
|
||||||
if err != nil {
|
|
||||||
zap.L().Debug(
|
|
||||||
"Bad info dictionary!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var files []persistence.File
|
|
||||||
// If there is only one file, there won't be a Files slice. That's why we need to add it here
|
|
||||||
if len(info.Files) == 0 {
|
|
||||||
files = append(files, persistence.File{
|
|
||||||
Size: info.Length,
|
|
||||||
Path: info.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, file := range info.Files {
|
|
||||||
if file.Length < 0 {
|
|
||||||
zap.L().Debug(
|
|
||||||
"File size is less than zero!",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
zap.String("filePath", file.DisplayPath(info)),
|
|
||||||
zap.Int64("fileSize", file.Length),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
files = append(files, persistence.File{
|
|
||||||
Size: file.Length,
|
|
||||||
Path: file.DisplayPath(info),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var totalSize uint64
|
|
||||||
for _, file := range files {
|
|
||||||
totalSize += uint64(file.Size)
|
|
||||||
}
|
|
||||||
|
|
||||||
zap.L().Debug(
|
|
||||||
"Flushing metadata...",
|
|
||||||
zap.ByteString("infoHash", infoHash[:]),
|
|
||||||
)
|
|
||||||
|
|
||||||
ms.flush(Metadata{
|
|
||||||
InfoHash: infoHash[:],
|
|
||||||
Name: info.Name,
|
|
||||||
TotalSize: totalSize,
|
|
||||||
DiscoveredOn: time.Now().Unix(),
|
|
||||||
Files: files,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// COPIED FROM anacrolix/torrent
|
|
||||||
func validateInfo(info *metainfo.Info) error {
|
|
||||||
if len(info.Pieces)%20 != 0 {
|
|
||||||
return errors.New("pieces has invalid length")
|
|
||||||
}
|
|
||||||
if info.PieceLength == 0 {
|
|
||||||
if info.TotalLength() != 0 {
|
|
||||||
return errors.New("zero piece length")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() {
|
|
||||||
return errors.New("piece count and file lengths are at odds")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeAll(c *net.TCPConn, b []byte) error {
|
|
||||||
for len(b) != 0 {
|
|
||||||
n, err := c.Write(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
b = b[n:]
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readExactly(c *net.TCPConn, n int) ([]byte, error) {
|
|
||||||
b := make([]byte, n)
|
|
||||||
_, err := io.ReadFull(c, b)
|
|
||||||
return b, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: add bounds checking!
|
|
||||||
func intToBigEndian(i int, n int) []byte {
|
|
||||||
b := make([]byte, n)
|
|
||||||
switch n {
|
|
||||||
case 1:
|
|
||||||
b = []byte{byte(i)}
|
|
||||||
|
|
||||||
case 2:
|
|
||||||
binary.BigEndian.PutUint16(b, uint16(i))
|
|
||||||
|
|
||||||
case 4:
|
|
||||||
binary.BigEndian.PutUint32(b, uint32(i))
|
|
||||||
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("n must be 1, 2, or 4!"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(b) != n {
|
|
||||||
panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b))
|
|
||||||
}
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func bigEndianToInt(b []byte) int {
|
|
||||||
switch len(b) {
|
|
||||||
case 1:
|
|
||||||
return int(b[0])
|
|
||||||
|
|
||||||
case 2:
|
|
||||||
return int(binary.BigEndian.Uint16(b))
|
|
||||||
|
|
||||||
case 4:
|
|
||||||
return int(binary.BigEndian.Uint32(b))
|
|
||||||
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("bigEndianToInt: b is too long! (%d bytes)", len(b)))
|
|
||||||
}
|
|
||||||
}
|
|
@ -26,11 +26,12 @@ type ProtocolEventHandlers struct {
|
|||||||
OnGetPeersResponse func(*Message, net.Addr)
|
OnGetPeersResponse func(*Message, net.Addr)
|
||||||
OnFindNodeResponse func(*Message, net.Addr)
|
OnFindNodeResponse func(*Message, net.Addr)
|
||||||
OnPingORAnnouncePeerResponse func(*Message, net.Addr)
|
OnPingORAnnouncePeerResponse func(*Message, net.Addr)
|
||||||
|
OnCongestion func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProtocol(laddr string, eventHandlers ProtocolEventHandlers) (p *Protocol) {
|
func NewProtocol(laddr string, eventHandlers ProtocolEventHandlers) (p *Protocol) {
|
||||||
p = new(Protocol)
|
p = new(Protocol)
|
||||||
p.transport = NewTransport(laddr, p.onMessage)
|
p.transport = NewTransport(laddr, p.onMessage, p.eventHandlers.OnCongestion)
|
||||||
p.eventHandlers = eventHandlers
|
p.eventHandlers = eventHandlers
|
||||||
|
|
||||||
p.currentTokenSecret, p.previousTokenSecret = make([]byte, 20), make([]byte, 20)
|
p.currentTokenSecret, p.previousTokenSecret = make([]byte, 20), make([]byte, 20)
|
||||||
|
@ -31,13 +31,14 @@ type TrawlingService struct {
|
|||||||
// ^~~~~~
|
// ^~~~~~
|
||||||
routingTable map[string]net.Addr
|
routingTable map[string]net.Addr
|
||||||
routingTableMutex *sync.Mutex
|
routingTableMutex *sync.Mutex
|
||||||
|
maxNeighbors uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type TrawlingServiceEventHandlers struct {
|
type TrawlingServiceEventHandlers struct {
|
||||||
OnResult func(TrawlingResult)
|
OnResult func(TrawlingResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTrawlingService(laddr string, eventHandlers TrawlingServiceEventHandlers) *TrawlingService {
|
func NewTrawlingService(laddr string, initialMaxNeighbors uint, eventHandlers TrawlingServiceEventHandlers) *TrawlingService {
|
||||||
service := new(TrawlingService)
|
service := new(TrawlingService)
|
||||||
service.protocol = NewProtocol(
|
service.protocol = NewProtocol(
|
||||||
laddr,
|
laddr,
|
||||||
@ -45,12 +46,14 @@ func NewTrawlingService(laddr string, eventHandlers TrawlingServiceEventHandlers
|
|||||||
OnGetPeersQuery: service.onGetPeersQuery,
|
OnGetPeersQuery: service.onGetPeersQuery,
|
||||||
OnAnnouncePeerQuery: service.onAnnouncePeerQuery,
|
OnAnnouncePeerQuery: service.onAnnouncePeerQuery,
|
||||||
OnFindNodeResponse: service.onFindNodeResponse,
|
OnFindNodeResponse: service.onFindNodeResponse,
|
||||||
|
OnCongestion: service.onCongestion,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
service.trueNodeID = make([]byte, 20)
|
service.trueNodeID = make([]byte, 20)
|
||||||
service.routingTable = make(map[string]net.Addr)
|
service.routingTable = make(map[string]net.Addr)
|
||||||
service.routingTableMutex = new(sync.Mutex)
|
service.routingTableMutex = new(sync.Mutex)
|
||||||
service.eventHandlers = eventHandlers
|
service.eventHandlers = eventHandlers
|
||||||
|
service.maxNeighbors = initialMaxNeighbors
|
||||||
|
|
||||||
_, err := rand.Read(service.trueNodeID)
|
_, err := rand.Read(service.trueNodeID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -78,11 +81,14 @@ func (s *TrawlingService) Terminate() {
|
|||||||
|
|
||||||
func (s *TrawlingService) trawl() {
|
func (s *TrawlingService) trawl() {
|
||||||
for range time.Tick(3 * time.Second) {
|
for range time.Tick(3 * time.Second) {
|
||||||
|
s.maxNeighbors = uint(float32(s.maxNeighbors) * 1.01)
|
||||||
|
|
||||||
s.routingTableMutex.Lock()
|
s.routingTableMutex.Lock()
|
||||||
if len(s.routingTable) == 0 {
|
if len(s.routingTable) == 0 {
|
||||||
s.bootstrap()
|
s.bootstrap()
|
||||||
} else {
|
} else {
|
||||||
zap.L().Debug("Latest status:", zap.Int("n", len(s.routingTable)))
|
zap.L().Warn("Latest status:", zap.Int("n", len(s.routingTable)),
|
||||||
|
zap.Uint("maxNeighbors", s.maxNeighbors))
|
||||||
s.findNeighbors()
|
s.findNeighbors()
|
||||||
s.routingTable = make(map[string]net.Addr)
|
s.routingTable = make(map[string]net.Addr)
|
||||||
}
|
}
|
||||||
@ -185,11 +191,34 @@ func (s *TrawlingService) onFindNodeResponse(response *Message, addr net.Addr) {
|
|||||||
s.routingTableMutex.Lock()
|
s.routingTableMutex.Lock()
|
||||||
defer s.routingTableMutex.Unlock()
|
defer s.routingTableMutex.Unlock()
|
||||||
|
|
||||||
|
zap.L().Debug("find node response!!", zap.Uint("maxNeighbors", s.maxNeighbors),
|
||||||
|
zap.Int("response.R.Nodes length", len(response.R.Nodes)))
|
||||||
|
|
||||||
for _, node := range response.R.Nodes {
|
for _, node := range response.R.Nodes {
|
||||||
if node.Addr.Port != 0 { // Ignore nodes who "use" port 0.
|
if uint(len(s.routingTable)) >= s.maxNeighbors {
|
||||||
if len(s.routingTable) < 8000 {
|
break
|
||||||
s.routingTable[string(node.ID)] = &node.Addr
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if node.Addr.Port == 0 { // Ignore nodes who "use" port 0.
|
||||||
|
zap.L().Debug("ignoring 0 port!!!")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s.routingTable[string(node.ID)] = &node.Addr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *TrawlingService) onCongestion() {
|
||||||
|
/* The Congestion Prevention Strategy:
|
||||||
|
*
|
||||||
|
* In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
|
||||||
|
*/
|
||||||
|
if s.maxNeighbors < 200 {
|
||||||
|
zap.L().Warn("Max. number of neighbours are < 200 and there is still congestion!" +
|
||||||
|
"(check your network connection if this message recurs)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.maxNeighbors = uint(float32(s.maxNeighbors) * 0.9)
|
||||||
|
zap.L().Debug("Max. number of neighbours updated!",
|
||||||
|
zap.Uint("s.maxNeighbors", s.maxNeighbors))
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/anacrolix/torrent/bencode"
|
"github.com/anacrolix/torrent/bencode"
|
||||||
|
sockaddr "github.com/libp2p/go-sockaddr/net"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
@ -12,23 +13,43 @@ type Transport struct {
|
|||||||
fd int
|
fd int
|
||||||
laddr *net.UDPAddr
|
laddr *net.UDPAddr
|
||||||
started bool
|
started bool
|
||||||
|
buffer []byte
|
||||||
|
|
||||||
// OnMessage is the function that will be called when Transport receives a packet that is
|
// OnMessage is the function that will be called when Transport receives a packet that is
|
||||||
// successfully unmarshalled as a syntactically correct Message (but -of course- the checking
|
// successfully unmarshalled as a syntactically correct Message (but -of course- the checking
|
||||||
// the semantic correctness of the Message is left to Protocol).
|
// the semantic correctness of the Message is left to Protocol).
|
||||||
onMessage func(*Message, net.Addr)
|
onMessage func(*Message, net.Addr)
|
||||||
|
// OnCongestion
|
||||||
|
onCongestion func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransport(laddr string, onMessage func(*Message, net.Addr)) *Transport {
|
func NewTransport(laddr string, onMessage func(*Message, net.Addr), onCongestion func()) *Transport {
|
||||||
transport := new(Transport)
|
t := new(Transport)
|
||||||
transport.onMessage = onMessage
|
/* The field size sets a theoretical limit of 65,535 bytes (8 byte header + 65,527 bytes of
|
||||||
|
* data) for a UDP datagram. However the actual limit for the data length, which is imposed by
|
||||||
|
* the underlying IPv4 protocol, is 65,507 bytes (65,535 − 8 byte UDP header − 20 byte IP
|
||||||
|
* header).
|
||||||
|
*
|
||||||
|
* In IPv6 jumbograms it is possible to have UDP packets of size greater than 65,535 bytes.
|
||||||
|
* RFC 2675 specifies that the length field is set to zero if the length of the UDP header plus
|
||||||
|
* UDP data is greater than 65,535.
|
||||||
|
*
|
||||||
|
* https://en.wikipedia.org/wiki/User_Datagram_Protocol
|
||||||
|
*/
|
||||||
|
t.buffer = make([]byte, 65507)
|
||||||
|
t.onMessage = onMessage
|
||||||
|
t.onCongestion = onCongestion
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
transport.laddr, err = net.ResolveUDPAddr("udp", laddr)
|
t.laddr, err = net.ResolveUDPAddr("udp", laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zap.L().Panic("Could not resolve the UDP address for the trawler!", zap.Error(err))
|
zap.L().Panic("Could not resolve the UDP address for the trawler!", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
if t.laddr.IP.To4() == nil {
|
||||||
|
zap.L().Panic("IP address is not IPv4!")
|
||||||
|
}
|
||||||
|
|
||||||
return transport
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Start() {
|
func (t *Transport) Start() {
|
||||||
@ -51,7 +72,12 @@ func (t *Transport) Start() {
|
|||||||
zap.L().Fatal("Could NOT create a UDP socket!", zap.Error(err))
|
zap.L().Fatal("Could NOT create a UDP socket!", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
unix.Bind(t.fd, unix.SockaddrInet4{Addr: t.laddr.IP, Port: t.laddr.Port})
|
var ip [4]byte
|
||||||
|
copy(ip[:], t.laddr.IP.To4())
|
||||||
|
err = unix.Bind(t.fd, &unix.SockaddrInet4{Addr: ip, Port: t.laddr.Port})
|
||||||
|
if err != nil {
|
||||||
|
zap.L().Fatal("Could NOT bind the socket!", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
go t.readMessages()
|
go t.readMessages()
|
||||||
}
|
}
|
||||||
@ -62,21 +88,33 @@ func (t *Transport) Terminate() {
|
|||||||
|
|
||||||
// readMessages is a goroutine!
|
// readMessages is a goroutine!
|
||||||
func (t *Transport) readMessages() {
|
func (t *Transport) readMessages() {
|
||||||
buffer := make([]byte, 65536)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, from, err := unix.Recvfrom(t.fd, buffer, 0)
|
n, fromSA, err := unix.Recvfrom(t.fd, t.buffer, 0)
|
||||||
if err != nil {
|
if err == unix.EPERM || err == unix.ENOBUFS { // todo: are these errors possible for recvfrom?
|
||||||
|
zap.L().Warn("READ CONGESTION!", zap.Error(err))
|
||||||
|
t.onCongestion()
|
||||||
|
} else if err != nil {
|
||||||
// TODO: isn't there a more reliable way to detect if UDPConn is closed?
|
// TODO: isn't there a more reliable way to detect if UDPConn is closed?
|
||||||
zap.L().Debug("Could NOT read an UDP packet!", zap.Error(err))
|
zap.L().Debug("Could NOT read an UDP packet!", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
/* Datagram sockets in various domains (e.g., the UNIX and Internet domains) permit
|
||||||
|
* zero-length datagrams. When such a datagram is received, the return value (n) is 0.
|
||||||
|
*/
|
||||||
|
zap.L().Debug("zero-length received!!")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
from := sockaddr.SockaddrToUDPAddr(fromSA)
|
||||||
|
|
||||||
var msg Message
|
var msg Message
|
||||||
err = bencode.Unmarshal(buffer[:n], &msg)
|
err = bencode.Unmarshal(t.buffer[:n], &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zap.L().Debug("Could NOT unmarshal packet data!", zap.Error(err))
|
zap.L().Debug("Could NOT unmarshal packet data!", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
zap.L().Debug("message read! (first 20...)", zap.ByteString("msg", t.buffer[:20]))
|
||||||
t.onMessage(&msg, from)
|
t.onMessage(&msg, from)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,9 +125,29 @@ func (t *Transport) WriteMessages(msg *Message, addr net.Addr) {
|
|||||||
zap.L().Panic("Could NOT marshal an outgoing message! (Programmer error.)")
|
zap.L().Panic("Could NOT marshal an outgoing message! (Programmer error.)")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.Sendto(t.fd, data, 0, addr)
|
addrSA := sockaddr.NetAddrToSockaddr(addr)
|
||||||
// TODO: isn't there a more reliable way to detect if UDPConn is closed?
|
|
||||||
if err != nil {
|
zap.L().Debug("sent message!!!")
|
||||||
|
|
||||||
|
err = unix.Sendto(t.fd, data, 0, addrSA)
|
||||||
|
if err == unix.EPERM || err == unix.ENOBUFS {
|
||||||
|
/* EPERM (errno: 1) is kernel's way of saying that "you are far too fast, chill". It is
|
||||||
|
* also likely that we have received a ICMP source quench packet (meaning, that we *really*
|
||||||
|
* need to slow down.
|
||||||
|
*
|
||||||
|
* Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto-amp-amp-EPERM.html
|
||||||
|
*
|
||||||
|
* > Note On BSD systems (OS X, FreeBSD, etc.) flow control is not supported for
|
||||||
|
* > DatagramProtocol, because send failures caused by writing too many packets cannot be
|
||||||
|
* > detected easily. The socket always appears ‘ready’ and excess packets are dropped; an
|
||||||
|
* > OSError with errno set to errno.ENOBUFS may or may not be raised; if it is raised, it
|
||||||
|
* > will be reported to DatagramProtocol.error_received() but otherwise ignored.
|
||||||
|
*
|
||||||
|
* Source: https://docs.python.org/3/library/asyncio-protocol.html#flow-control-callbacks
|
||||||
|
*/
|
||||||
|
zap.L().Warn("WRITE CONGESTION!", zap.Error(err))
|
||||||
|
t.onCongestion()
|
||||||
|
} else if err != nil {
|
||||||
zap.L().Debug("Could NOT write an UDP packet!", zap.Error(err))
|
zap.L().Debug("Could NOT write an UDP packet!", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ func NewTrawlingManager(mlAddrs []string) *TrawlingManager {
|
|||||||
for _, addr := range mlAddrs {
|
for _, addr := range mlAddrs {
|
||||||
manager.services = append(manager.services, mainline.NewTrawlingService(
|
manager.services = append(manager.services, mainline.NewTrawlingService(
|
||||||
addr,
|
addr,
|
||||||
|
2000,
|
||||||
mainline.TrawlingServiceEventHandlers{
|
mainline.TrawlingServiceEventHandlers{
|
||||||
OnResult: manager.onResult,
|
OnResult: manager.onResult,
|
||||||
},
|
},
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path"
|
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
|
|
||||||
"github.com/boramalper/magnetico/cmd/magneticod/bittorrent"
|
"github.com/boramalper/magnetico/cmd/magneticod/bittorrent/metadata"
|
||||||
"github.com/boramalper/magnetico/cmd/magneticod/dht"
|
"github.com/boramalper/magnetico/cmd/magneticod/dht"
|
||||||
|
|
||||||
"github.com/Wessie/appdirs"
|
"github.com/Wessie/appdirs"
|
||||||
@ -75,6 +74,8 @@ func main() {
|
|||||||
|
|
||||||
zap.ReplaceGlobals(logger)
|
zap.ReplaceGlobals(logger)
|
||||||
|
|
||||||
|
zap.L().Debug("debug message!")
|
||||||
|
|
||||||
switch opFlags.Profile {
|
switch opFlags.Profile {
|
||||||
case "cpu":
|
case "cpu":
|
||||||
file, err := os.OpenFile("magneticod_cpu.prof", os.O_CREATE | os.O_WRONLY, 0755)
|
file, err := os.OpenFile("magneticod_cpu.prof", os.O_CREATE | os.O_WRONLY, 0755)
|
||||||
@ -96,19 +97,19 @@ func main() {
|
|||||||
interruptChan := make(chan os.Signal)
|
interruptChan := make(chan os.Signal)
|
||||||
signal.Notify(interruptChan, os.Interrupt)
|
signal.Notify(interruptChan, os.Interrupt)
|
||||||
|
|
||||||
database, err := persistence.MakeDatabase(opFlags.DatabaseURL, false, logger)
|
database, err := persistence.MakeDatabase(opFlags.DatabaseURL, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Sugar().Fatalf("Could not open the database at `%s`: %s", opFlags.DatabaseURL, err.Error())
|
logger.Sugar().Fatalf("Could not open the database at `%s`: %s", opFlags.DatabaseURL, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
trawlingManager := dht.NewTrawlingManager(opFlags.TrawlerMlAddrs)
|
trawlingManager := dht.NewTrawlingManager(opFlags.TrawlerMlAddrs)
|
||||||
metadataSink := bittorrent.NewMetadataSink(2 * time.Minute)
|
metadataSink := metadata.NewSink(2 * time.Minute)
|
||||||
|
|
||||||
// The Event Loop
|
// The Event Loop
|
||||||
for stopped := false; !stopped; {
|
for stopped := false; !stopped; {
|
||||||
select {
|
select {
|
||||||
case result := <-trawlingManager.Output():
|
case result := <-trawlingManager.Output():
|
||||||
zap.L().Info("Trawled!", zap.String("infoHash", result.InfoHash.String()))
|
zap.L().Debug("Trawled!", zap.String("infoHash", result.InfoHash.String()))
|
||||||
exists, err := database.DoesTorrentExist(result.InfoHash[:])
|
exists, err := database.DoesTorrentExist(result.InfoHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zap.L().Fatal("Could not check whether torrent exists!", zap.Error(err))
|
zap.L().Fatal("Could not check whether torrent exists!", zap.Error(err))
|
||||||
@ -144,10 +145,11 @@ func parseFlags() (*opFlags, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cmdF.DatabaseURL == "" {
|
if cmdF.DatabaseURL == "" {
|
||||||
opF.DatabaseURL = "sqlite3://" + path.Join(
|
opF.DatabaseURL =
|
||||||
appdirs.UserDataDir("magneticod", "", "", false),
|
"sqlite3://" +
|
||||||
"database.sqlite3",
|
appdirs.UserDataDir("magneticod", "", "", false) +
|
||||||
)
|
"/database.sqlite3" +
|
||||||
|
"?_journal_mode=WAL" // https://github.com/mattn/go-sqlite3#connection-string
|
||||||
} else {
|
} else {
|
||||||
opF.DatabaseURL = cmdF.DatabaseURL
|
opF.DatabaseURL = cmdF.DatabaseURL
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,9 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Close your rows lest you get "database table is locked" error(s)!
|
||||||
|
// See https://github.com/mattn/go-sqlite3/issues/2741
|
||||||
|
|
||||||
type sqlite3Database struct {
|
type sqlite3Database struct {
|
||||||
conn *sql.DB
|
conn *sql.DB
|
||||||
}
|
}
|
||||||
@ -55,15 +58,12 @@ func (db *sqlite3Database) DoesTorrentExist(infoHash []byte) (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
// If rows.Next() returns true, meaning that the torrent is in the database, return true; else
|
// If rows.Next() returns true, meaning that the torrent is in the database, return true; else
|
||||||
// return false.
|
// return false.
|
||||||
exists := rows.Next()
|
exists := rows.Next()
|
||||||
if !exists && rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = rows.Close(); err != nil {
|
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,6 +143,7 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
if rows.Next() != true {
|
if rows.Next() != true {
|
||||||
fmt.Errorf("No rows returned from `SELECT MAX(ROWID)`")
|
fmt.Errorf("No rows returned from `SELECT MAX(ROWID)`")
|
||||||
@ -153,10 +154,6 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = rows.Close(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,6 +244,7 @@ func (db *sqlite3Database) QueryTorrents(
|
|||||||
queryArgs = append(queryArgs, limit)
|
queryArgs = append(queryArgs, limit)
|
||||||
|
|
||||||
rows, err := db.conn.Query(sqlQuery, queryArgs...)
|
rows, err := db.conn.Query(sqlQuery, queryArgs...)
|
||||||
|
defer rows.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error while querying torrents: %s", err.Error())
|
return nil, fmt.Errorf("error while querying torrents: %s", err.Error())
|
||||||
}
|
}
|
||||||
@ -269,10 +267,6 @@ func (db *sqlite3Database) QueryTorrents(
|
|||||||
torrents = append(torrents, torrent)
|
torrents = append(torrents, torrent)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return torrents, nil
|
return torrents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,6 +301,7 @@ func (db *sqlite3Database) GetTorrent(infoHash []byte) (*TorrentMetadata, error)
|
|||||||
WHERE info_hash = ?`,
|
WHERE info_hash = ?`,
|
||||||
infoHash,
|
infoHash,
|
||||||
)
|
)
|
||||||
|
defer rows.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -320,10 +315,6 @@ func (db *sqlite3Database) GetTorrent(infoHash []byte) (*TorrentMetadata, error)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = rows.Close(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tm, nil
|
return &tm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,6 +322,7 @@ func (db *sqlite3Database) GetFiles(infoHash []byte) ([]File, error) {
|
|||||||
rows, err := db.conn.Query(
|
rows, err := db.conn.Query(
|
||||||
"SELECT size, path FROM files, torrents WHERE files.torrent_id = torrents.id AND torrents.info_hash = ?;",
|
"SELECT size, path FROM files, torrents WHERE files.torrent_id = torrents.id AND torrents.info_hash = ?;",
|
||||||
infoHash)
|
infoHash)
|
||||||
|
defer rows.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -344,10 +336,6 @@ func (db *sqlite3Database) GetFiles(infoHash []byte) ([]File, error) {
|
|||||||
files = append(files, file)
|
files = append(files, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return files, nil
|
return files, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -391,6 +379,7 @@ func (db *sqlite3Database) GetStatistics(from string, n uint) (*Statistics, erro
|
|||||||
GROUP BY dt;`,
|
GROUP BY dt;`,
|
||||||
timef),
|
timef),
|
||||||
fromTime.Unix(), toTime.Unix())
|
fromTime.Unix(), toTime.Unix())
|
||||||
|
defer rows.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -478,6 +467,7 @@ func (db *sqlite3Database) setupDatabase() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sql.Tx.Query (user_version): %s", err.Error())
|
return fmt.Errorf("sql.Tx.Query (user_version): %s", err.Error())
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
var userVersion int
|
var userVersion int
|
||||||
if rows.Next() != true {
|
if rows.Next() != true {
|
||||||
return fmt.Errorf("sql.Rows.Next (user_version): PRAGMA user_version did not return any rows!")
|
return fmt.Errorf("sql.Rows.Next (user_version): PRAGMA user_version did not return any rows!")
|
||||||
@ -485,11 +475,6 @@ func (db *sqlite3Database) setupDatabase() error {
|
|||||||
if err = rows.Scan(&userVersion); err != nil {
|
if err = rows.Scan(&userVersion); err != nil {
|
||||||
return fmt.Errorf("sql.Rows.Scan (user_version): %s", err.Error())
|
return fmt.Errorf("sql.Rows.Scan (user_version): %s", err.Error())
|
||||||
}
|
}
|
||||||
// Close your rows lest you get "database table is locked" error(s)!
|
|
||||||
// See https://github.com/mattn/go-sqlite3/issues/2741
|
|
||||||
if err = rows.Close(); err != nil {
|
|
||||||
return fmt.Errorf("sql.Rows.Close (user_version): %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
switch userVersion {
|
switch userVersion {
|
||||||
case 0: // FROZEN.
|
case 0: // FROZEN.
|
||||||
|
Loading…
Reference in New Issue
Block a user