magneticod: metadata leech refactored heavily, now much more readable <3

+ persistence: we now make sure that rows are always closed using `defer`
This commit is contained in:
Bora Alper 2018-07-24 15:41:13 +03:00
parent 0614e9e0f9
commit c07daa3eca
11 changed files with 607 additions and 561 deletions

View File

@ -56,3 +56,7 @@
[[constraint]] [[constraint]]
name = "go.uber.org/zap" name = "go.uber.org/zap"
version = "1.7.1" version = "1.7.1"
[[constraint]]
name = "github.com/libp2p/go-sockaddr"
version = "1.0.3"

View File

@ -0,0 +1,439 @@
package metadata
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"fmt"
"io"
"math"
"net"
"time"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/torrent/metainfo"
"github.com/pkg/errors"
"go.uber.org/zap"
"github.com/boramalper/magnetico/pkg/persistence"
)
const MAX_METADATA_SIZE = 10 * 1024 * 1024
type rootDict struct {
M mDict `bencode:"m"`
MetadataSize int `bencode:"metadata_size"`
}
type mDict struct {
UTMetadata int `bencode:"ut_metadata"`
}
type extDict struct {
MsgType int `bencode:"msg_type"`
Piece int `bencode:"piece"`
}
type Leech struct {
infoHash [20]byte
peerAddr *net.TCPAddr
ev LeechEventHandlers
conn *net.TCPConn
clientID [20]byte
ut_metadata uint8
metadataReceived, metadataSize uint
metadata []byte
}
type LeechEventHandlers struct {
OnSuccess func(Metadata) // must be supplied. args: metadata
OnError func([20]byte, error) // must be supplied. args: infohash, error
}
func NewLeech(infoHash [20]byte, peerAddr *net.TCPAddr, ev LeechEventHandlers) *Leech {
l := new(Leech)
l.infoHash = infoHash
l.peerAddr = peerAddr
l.ev = ev
if _, err := rand.Read(l.clientID[:]); err != nil {
panic(err.Error())
}
return l
}
func (l *Leech) writeAll(b []byte) error {
for len(b) != 0 {
n, err := l.conn.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}
func (l *Leech) doBtHandshake() error {
lHandshake := []byte(fmt.Sprintf(
"\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s",
l.infoHash[:],
l.clientID,
))
// ASSERTION
if len(lHandshake) != 68 { panic(fmt.Sprintf("len(lHandshake) == %d", len(lHandshake))) }
err := l.writeAll(lHandshake)
if err != nil {
return errors.Wrap(err, "writeAll lHandshake")
}
zap.L().Debug("BitTorrent handshake sent, waiting for the remote's...")
rHandshake, err := l.readExactly(68)
if err != nil {
return errors.Wrap(err, "readExactly rHandshake")
}
if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) {
return fmt.Errorf("corrupt BitTorrent handshake received")
}
// TODO: maybe check for the infohash sent by the remote peer to double check?
if (rHandshake[25] & 0x10) == 0 {
return fmt.Errorf("peer does not support the extension protocol")
}
return nil
}
func (l *Leech) doExHandshake() error {
err := l.writeAll([]byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee"))
if err != nil {
return errors.Wrap(err, "writeAll lHandshake")
}
rExMessage, err := l.readExMessage()
if err != nil {
return errors.Wrap(err, "readExMessage")
}
// Extension Handshake has the Extension Message ID = 0x00
if rExMessage[1] != 0 {
return errors.Wrap(err, "first extension message is not an extension handshake")
}
rRootDict := new(rootDict)
err = bencode.Unmarshal(rExMessage[2:], rRootDict)
if err != nil {
return errors.Wrap(err, "unmarshal rExMessage")
}
if !(0 < rRootDict.MetadataSize && rRootDict.MetadataSize < MAX_METADATA_SIZE) {
return fmt.Errorf("metadata too big or its size is less than or equal zero")
}
if !(0 < rRootDict.M.UTMetadata && rRootDict.M.UTMetadata < 255) {
return fmt.Errorf("ut_metadata is not an uint8")
}
l.ut_metadata = uint8(rRootDict.M.UTMetadata) // Save the ut_metadata code the remote peer uses
l.metadataSize = uint(rRootDict.MetadataSize)
l.metadata = make([]byte, l.metadataSize)
return nil
}
func (l *Leech) requestAllPieces() error {
// Request all the pieces of metadata
nPieces := int(math.Ceil(float64(l.metadataSize) / math.Pow(2, 14)))
for piece := 0; piece < nPieces; piece++ {
// __request_metadata_piece(piece)
// ...............................
extDictDump, err := bencode.Marshal(extDict{
MsgType: 0,
Piece: piece,
})
if err != nil { // ASSERT
panic(errors.Wrap(err, "marshal extDict"))
}
err = l.writeAll([]byte(fmt.Sprintf(
"%s\x14%s%s",
toBigEndian(uint(2 + len(extDictDump)), 4),
toBigEndian(uint(l.ut_metadata), 1),
extDictDump,
)))
if err != nil {
return errors.Wrap(err, "writeAll piece request")
}
}
return nil
}
// readMessage returns a BitTorrent message, sans the first 4 bytes indicating its length.
func (l *Leech) readMessage() ([]byte, error) {
rLengthB, err := l.readExactly(4)
if err != nil {
return nil, errors.Wrap(err, "readExactly rLengthB")
}
rLength := uint(binary.BigEndian.Uint32(rLengthB))
rMessage, err := l.readExactly(rLength)
if err != nil {
return nil, errors.Wrap(err, "readExactly rMessage")
}
return rMessage, nil
}
// readExMessage returns an *extension* message, sans the first 4 bytes indicating its length.
//
// It will IGNORE all non-extension messages!
func (l *Leech) readExMessage() ([]byte, error) {
for {
rMessage, err := l.readMessage()
if err != nil {
return nil, errors.Wrap(err, "readMessage")
}
// We are interested only in extension messages, whose first byte is always 20
if rMessage[0] == 20 {
return rMessage, nil
}
}
}
// readUmMessage returns an ut_metadata extension message, sans the first 4 bytes indicating its
// length.
//
// It will IGNORE all non-"ut_metadata extension" messages!
func (l *Leech) readUmMessage() ([]byte, error) {
for {
rExMessage, err := l.readExMessage()
if err != nil {
return nil, errors.Wrap(err, "readExMessage")
}
if rExMessage[1] == 0x01 {
return rExMessage, nil
}
}
}
func (l *Leech) connect(deadline time.Time) error {
var err error
l.conn, err = net.DialTCP("tcp", nil, l.peerAddr)
if err != nil {
return errors.Wrap(err, "dial")
}
defer l.conn.Close()
err = l.conn.SetNoDelay(true)
if err != nil {
return errors.Wrap(err, "NODELAY")
}
err = l.conn.SetDeadline(deadline)
if err != nil {
return errors.Wrap(err, "SetDeadline")
}
return nil
}
func (l *Leech) Do(deadline time.Time) {
err := l.connect(deadline)
if err != nil {
l.OnError(errors.Wrap(err, "connect"))
return
}
err = l.doBtHandshake()
if err != nil {
l.OnError(errors.Wrap(err, "doBtHandshake"))
return
}
err = l.doExHandshake()
if err != nil {
l.OnError(errors.Wrap(err, "doExHandshake"))
return
}
err = l.requestAllPieces()
if err != nil {
l.OnError(errors.Wrap(err, "requestAllPieces"))
return
}
for l.metadataReceived < l.metadataSize {
rUmMessage, err := l.readUmMessage()
if err != nil {
l.OnError(errors.Wrap(err, "readUmMessage"))
return
}
// Run TestDecoder() function in leech_test.go in case you have any doubts.
rMessageBuf := bytes.NewBuffer(rUmMessage[2:])
rExtDict := new(extDict)
err = bencode.NewDecoder(rMessageBuf).Decode(rExtDict)
if err != nil {
zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err))
return
}
if rExtDict.MsgType == 2 { // reject
l.OnError(fmt.Errorf("remote peer rejected sending metadata"))
return
}
if rExtDict.MsgType == 1 { // data
// Get the unread bytes!
metadataPiece := rMessageBuf.Bytes()
// BEP 9 explicitly states:
// > If the piece is the last piece of the metadata, it may be less than 16kiB. If
// > it is not the last piece of the metadata, it MUST be 16kiB.
//
// Hence...
// ... if the length of @metadataPiece is more than 16kiB, we err.
if len(metadataPiece) > 16*1024 {
l.OnError(fmt.Errorf("metadataPiece > 16kiB"))
return
}
piece := rExtDict.Piece
// metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python
copy(l.metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece)
l.metadataReceived += uint(len(metadataPiece))
// ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT
// complete then we err.
if len(metadataPiece) < 16*1024 && l.metadataReceived != l.metadataSize {
l.OnError(fmt.Errorf("metadataPiece < 16 kiB but incomplete"))
return
}
if l.metadataReceived > l.metadataSize {
l.OnError(fmt.Errorf("metadataReceived > metadataSize"))
return
}
}
}
// Verify the checksum
sha1Sum := sha1.Sum(l.metadata)
if !bytes.Equal(sha1Sum[:], l.infoHash[:]) {
l.OnError(fmt.Errorf("infohash mismatch"))
return
}
// Check the info dictionary
info := new(metainfo.Info)
err = bencode.Unmarshal(l.metadata, info)
if err != nil {
l.OnError(errors.Wrap(err, "unmarshal info"))
return
}
err = validateInfo(info)
if err != nil {
l.OnError(errors.Wrap(err, "validateInfo"))
return
}
var files []persistence.File
// If there is only one file, there won't be a Files slice. That's why we need to add it here
if len(info.Files) == 0 {
files = append(files, persistence.File{
Size: info.Length,
Path: info.Name,
})
} else {
for _, file := range info.Files {
files = append(files, persistence.File{
Size: file.Length,
Path: file.DisplayPath(info),
})
}
}
var totalSize uint64
for _, file := range files {
if file.Size < 0 {
l.OnError(fmt.Errorf("file size less than zero"))
return
}
totalSize += uint64(file.Size)
}
l.ev.OnSuccess(Metadata{
InfoHash: l.infoHash[:],
Name: info.Name,
TotalSize: totalSize,
DiscoveredOn: time.Now().Unix(),
Files: files,
})
}
// COPIED FROM anacrolix/torrent
func validateInfo(info *metainfo.Info) error {
if len(info.Pieces)%20 != 0 {
return errors.New("pieces has invalid length")
}
if info.PieceLength == 0 {
if info.TotalLength() != 0 {
return errors.New("zero piece length")
}
} else {
if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() {
return errors.New("piece count and file lengths are at odds")
}
}
return nil
}
func (l *Leech) readExactly(n uint) ([]byte, error) {
b := make([]byte, n)
_, err := io.ReadFull(l.conn, b)
return b, err
}
func (l *Leech) OnError(err error) {
l.ev.OnError(l.infoHash, err)
}
// TODO: add bounds checking!
func toBigEndian(i uint, n int) []byte {
b := make([]byte, n)
switch n {
case 1:
b = []byte{byte(i)}
case 2:
binary.BigEndian.PutUint16(b, uint16(i))
case 4:
binary.BigEndian.PutUint32(b, uint32(i))
default:
panic(fmt.Sprintf("n must be 1, 2, or 4!"))
}
if len(b) != n {
panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b))
}
return b
}

View File

@ -1,4 +1,4 @@
package bittorrent package metadata
import ( import (
"bytes" "bytes"

View File

@ -1,4 +1,4 @@
package bittorrent package metadata
import ( import (
"crypto/rand" "crypto/rand"
@ -24,21 +24,17 @@ type Metadata struct {
Files []persistence.File Files []persistence.File
} }
type Peer struct { type Sink struct {
Addr *net.TCPAddr clientID []byte
} deadline time.Duration
drain chan Metadata
type MetadataSink struct {
clientID []byte
deadline time.Duration
drain chan Metadata
incomingInfoHashes map[[20]byte]struct{} incomingInfoHashes map[[20]byte]struct{}
terminated bool terminated bool
termination chan interface{} termination chan interface{}
} }
func NewMetadataSink(deadline time.Duration) *MetadataSink { func NewSink(deadline time.Duration) *Sink {
ms := new(MetadataSink) ms := new(Sink)
ms.clientID = make([]byte, 20) ms.clientID = make([]byte, 20)
_, err := rand.Read(ms.clientID) _, err := rand.Read(ms.clientID)
@ -52,27 +48,29 @@ func NewMetadataSink(deadline time.Duration) *MetadataSink {
return ms return ms
} }
func (ms *MetadataSink) Sink(res mainline.TrawlingResult) { func (ms *Sink) Sink(res mainline.TrawlingResult) {
if ms.terminated { if ms.terminated {
zap.L().Panic("Trying to Sink() an already closed MetadataSink!") zap.L().Panic("Trying to Sink() an already closed Sink!")
} }
if _, exists := ms.incomingInfoHashes[res.InfoHash]; exists { if _, exists := ms.incomingInfoHashes[res.InfoHash]; exists {
return return
} }
// BEWARE! // BEWARE!
// Although not crucial, the assumption is that MetadataSink.Sink() will be called by only one // Although not crucial, the assumption is that Sink.Sink() will be called by only one
// goroutine (i.e. it's not thread-safe), lest there might be a race condition between where we // goroutine (i.e. it's not thread-safe), lest there might be a race condition between where we
// check whether res.infoHash exists in the ms.incomingInfoHashes, and where we add the infoHash // check whether res.infoHash exists in the ms.incomingInfoHashes, and where we add the infoHash
// to the incomingInfoHashes at the end of this function. // to the incomingInfoHashes at the end of this function.
zap.L().Info("Sunk!", zap.String("infoHash", res.InfoHash.String()))
IPs := res.PeerIP.String() IPs := res.PeerIP.String()
var rhostport string var rhostport string
if IPs == "<nil>" { if IPs == "<nil>" {
zap.L().Debug("MetadataSink.Sink: Peer IP is nil!") zap.L().Debug("Sink.Sink: Peer IP is nil!")
return return
} else if IPs[0] == '?' { } else if IPs[0] == '?' {
zap.L().Debug("MetadataSink.Sink: Peer IP is invalid!") zap.L().Debug("Sink.Sink: Peer IP is invalid!")
return return
} else if strings.ContainsRune(IPs, ':') { // IPv6 } else if strings.ContainsRune(IPs, ':') { // IPv6
rhostport = fmt.Sprintf("[%s]:%d", IPs, res.PeerPort) rhostport = fmt.Sprintf("[%s]:%d", IPs, res.PeerPort)
@ -82,29 +80,33 @@ func (ms *MetadataSink) Sink(res mainline.TrawlingResult) {
raddr, err := net.ResolveTCPAddr("tcp", rhostport) raddr, err := net.ResolveTCPAddr("tcp", rhostport)
if err != nil { if err != nil {
zap.L().Debug("MetadataSink.Sink: Couldn't resolve peer address!", zap.Error(err)) zap.L().Debug("Sink.Sink: Couldn't resolve peer address!", zap.Error(err))
return return
} }
go ms.awaitMetadata(res.InfoHash, Peer{Addr: raddr}) leech := NewLeech(res.InfoHash, raddr, LeechEventHandlers{
OnSuccess: ms.flush,
OnError: ms.onLeechError,
})
go leech.Do(time.Now().Add(ms.deadline))
ms.incomingInfoHashes[res.InfoHash] = struct{}{} ms.incomingInfoHashes[res.InfoHash] = struct{}{}
} }
func (ms *MetadataSink) Drain() <-chan Metadata { func (ms *Sink) Drain() <-chan Metadata {
if ms.terminated { if ms.terminated {
zap.L().Panic("Trying to Drain() an already closed MetadataSink!") zap.L().Panic("Trying to Drain() an already closed Sink!")
} }
return ms.drain return ms.drain
} }
func (ms *MetadataSink) Terminate() { func (ms *Sink) Terminate() {
ms.terminated = true ms.terminated = true
close(ms.termination) close(ms.termination)
close(ms.drain) close(ms.drain)
} }
func (ms *MetadataSink) flush(result Metadata) { func (ms *Sink) flush(result Metadata) {
if !ms.terminated { if !ms.terminated {
ms.drain <- result ms.drain <- result
// Delete the infoHash from ms.incomingInfoHashes ONLY AFTER once we've flushed the // Delete the infoHash from ms.incomingInfoHashes ONLY AFTER once we've flushed the
@ -114,3 +116,8 @@ func (ms *MetadataSink) flush(result Metadata) {
delete(ms.incomingInfoHashes, infoHash) delete(ms.incomingInfoHashes, infoHash)
} }
} }
func (ms *Sink) onLeechError(infoHash [20]byte, err error) {
zap.L().Debug("leech error", zap.ByteString("infoHash", infoHash[:]), zap.Error(err))
delete(ms.incomingInfoHashes, infoHash)
}

View File

@ -1,480 +0,0 @@
package bittorrent
import (
"bytes"
"crypto/sha1"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
"time"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/torrent/metainfo"
"go.uber.org/zap"
"github.com/boramalper/magnetico/pkg/persistence"
)
const MAX_METADATA_SIZE = 10 * 1024 * 1024
type rootDict struct {
M mDict `bencode:"m"`
MetadataSize int `bencode:"metadata_size"`
}
type mDict struct {
UTMetadata int `bencode:"ut_metadata"`
}
type extDict struct {
MsgType int `bencode:"msg_type"`
Piece int `bencode:"piece"`
}
func (ms *MetadataSink) awaitMetadata(infoHash metainfo.Hash, peer Peer) {
conn, err := net.DialTCP("tcp", nil, peer.Addr)
if err != nil {
zap.L().Debug(
"awaitMetadata couldn't connect to the peer!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
defer conn.Close()
err = conn.SetNoDelay(true)
if err != nil {
zap.L().Panic(
"Couldn't set NODELAY!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
err = conn.SetDeadline(time.Now().Add(ms.deadline))
if err != nil {
zap.L().Panic(
"Couldn't set the deadline!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
// State Variables
var isExtHandshakeDone, done bool
var ut_metadata, metadataReceived, metadataSize int
var metadata []byte
lHandshake := []byte(fmt.Sprintf(
"\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s",
infoHash[:],
ms.clientID,
))
if len(lHandshake) != 68 {
zap.L().Panic(
"Generated BitTorrent handshake is not of length 68!",
zap.ByteString("infoHash", infoHash[:]),
zap.Int("len_lHandshake", len(lHandshake)),
)
}
err = writeAll(conn, lHandshake)
if err != nil {
zap.L().Debug(
"Couldn't write BitTorrent handshake!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
zap.L().Debug("BitTorrent handshake sent, waiting for the remote's...")
rHandshake, err := readExactly(conn, 68)
if err != nil {
zap.L().Debug(
"Couldn't read remote BitTorrent handshake!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) {
zap.L().Debug(
"Remote BitTorrent handshake is not what it is supposed to be!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.ByteString("rHandshake[:20]", rHandshake[:20]),
)
return
}
// __on_bt_handshake
// ================
if rHandshake[25] != 16 { // TODO (later): do *not* compare the whole byte, check the bit instead! (0x10)
zap.L().Debug(
"Peer does not support the extension protocol!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
)
return
}
writeAll(conn, []byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee"))
zap.L().Debug(
"Extension handshake sent, waiting for the remote's...",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
)
// the loop!
// =========
for !done {
rLengthB, err := readExactly(conn, 4)
if err != nil {
zap.L().Debug(
"Couldn't read the first 4 bytes from the remote peer in the loop!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
// The messages we are interested in have the length of AT LEAST two bytes
// (TODO: actually a bit more than that but SURELY when it's less than two bytes, the
// program panics)
rLength := bigEndianToInt(rLengthB)
if rLength < 2 {
continue
}
rMessage, err := readExactly(conn, rLength)
if err != nil {
zap.L().Debug(
"Couldn't read the rest of the message from the remote peer in the loop!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", peer.Addr.String()),
zap.Error(err),
)
return
}
// __on_message
// ------------
if rMessage[0] != 0x14 { // We are interested only in extension messages, whose first byte is always 0x14
zap.L().Debug(
"Ignoring the non-extension message.",
zap.ByteString("infoHash", infoHash[:]),
)
continue
}
if rMessage[1] == 0x00 { // Extension Handshake has the Extension Message ID = 0x00
// __on_ext_handshake_message(message[2:])
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// TODO: continue editing log messages from here
if isExtHandshakeDone {
return
}
rRootDict := new(rootDict)
err := bencode.Unmarshal(rMessage[2:], rRootDict)
if err != nil {
zap.L().Debug("Couldn't unmarshal extension handshake!", zap.Error(err))
return
}
if rRootDict.MetadataSize <= 0 || rRootDict.MetadataSize > MAX_METADATA_SIZE {
zap.L().Debug("Unacceptable metadata size!", zap.Int("metadata_size", rRootDict.MetadataSize))
return
}
ut_metadata = rRootDict.M.UTMetadata // Save the ut_metadata code the remote peer uses
metadataSize = rRootDict.MetadataSize
metadata = make([]byte, metadataSize)
isExtHandshakeDone = true
zap.L().Debug("GOT EXTENSION HANDSHAKE!", zap.Int("ut_metadata", ut_metadata), zap.Int("metadata_size", metadataSize))
// Request all the pieces of metadata
n_pieces := int(math.Ceil(float64(metadataSize) / math.Pow(2, 14)))
for piece := 0; piece < n_pieces; piece++ {
// __request_metadata_piece(piece)
// ...............................
extDictDump, err := bencode.Marshal(extDict{
MsgType: 0,
Piece: piece,
})
if err != nil {
zap.L().Warn("Couldn't marshal extDictDump!", zap.Error(err))
return
}
writeAll(conn, []byte(fmt.Sprintf(
"%s\x14%s%s",
intToBigEndian(2+len(extDictDump), 4),
intToBigEndian(ut_metadata, 1),
extDictDump,
)))
}
zap.L().Warn("requested all metadata pieces!")
} else if rMessage[1] == 0x01 {
// __on_ext_message(message[2:])
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Run TestDecoder() function in operations_test.go in case you have any doubts.
rMessageBuf := bytes.NewBuffer(rMessage[2:])
rExtDict := new(extDict)
err := bencode.NewDecoder(rMessageBuf).Decode(rExtDict)
if err != nil {
zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err))
return
}
if rExtDict.MsgType == 1 { // data
// Get the unread bytes!
metadataPiece := rMessageBuf.Bytes()
piece := rExtDict.Piece
// metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python
copy(metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece)
metadataReceived += len(metadataPiece)
done = metadataReceived == metadataSize
// BEP 9 explicitly states:
// > If the piece is the last piece of the metadata, it may be less than 16kiB. If
// > it is not the last piece of the metadata, it MUST be 16kiB.
//
// Hence...
// ... if the length of @metadataPiece is more than 16kiB, we err.
if len(metadataPiece) > 16*1024 {
zap.L().Debug(
"metadataPiece is bigger than 16kiB!",
zap.Int("len_metadataPiece", len(metadataPiece)),
zap.Int("metadataReceived", metadataReceived),
zap.Int("metadataSize", metadataSize),
zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)),
)
return
}
// ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT
// complete (!done) then we err.
if len(metadataPiece) < 16*1024 && !done {
zap.L().Debug(
"metadataPiece is less than 16kiB and metadata is incomplete!",
zap.Int("len_metadataPiece", len(metadataPiece)),
zap.Int("metadataReceived", metadataReceived),
zap.Int("metadataSize", metadataSize),
zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)),
)
return
}
if metadataReceived > metadataSize {
zap.L().Debug(
"metadataReceived is greater than metadataSize!",
zap.Int("len_metadataPiece", len(metadataPiece)),
zap.Int("metadataReceived", metadataReceived),
zap.Int("metadataSize", metadataSize),
zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)),
)
return
}
zap.L().Debug(
"Fetching...",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", conn.RemoteAddr().String()),
zap.Int("metadataReceived", metadataReceived),
zap.Int("metadataSize", metadataSize),
)
} else if rExtDict.MsgType == 2 { // reject
zap.L().Debug(
"Remote peer rejected sending metadata!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("remotePeerAddr", conn.RemoteAddr().String()),
)
return
}
} else {
zap.L().Debug(
"Message is not an ut_metadata message! (ignoring)",
zap.ByteString("msg", rMessage[:100]),
)
// no return!
}
}
zap.L().Debug(
"Metadata is complete, verifying the checksum...",
zap.ByteString("infoHash", infoHash[:]),
)
sha1Sum := sha1.Sum(metadata)
if !bytes.Equal(sha1Sum[:], infoHash[:]) {
zap.L().Debug(
"Info-hash mismatch!",
zap.ByteString("expectedInfoHash", infoHash[:]),
zap.ByteString("actualInfoHash", sha1Sum[:]),
)
return
}
zap.L().Debug(
"Checksum verified, checking the info dictionary...",
zap.ByteString("infoHash", infoHash[:]),
)
info := new(metainfo.Info)
err = bencode.Unmarshal(metadata, info)
if err != nil {
zap.L().Debug(
"Couldn't unmarshal info bytes!",
zap.ByteString("infoHash", infoHash[:]),
zap.Error(err),
)
return
}
err = validateInfo(info)
if err != nil {
zap.L().Debug(
"Bad info dictionary!",
zap.ByteString("infoHash", infoHash[:]),
zap.Error(err),
)
return
}
var files []persistence.File
// If there is only one file, there won't be a Files slice. That's why we need to add it here
if len(info.Files) == 0 {
files = append(files, persistence.File{
Size: info.Length,
Path: info.Name,
})
}
for _, file := range info.Files {
if file.Length < 0 {
zap.L().Debug(
"File size is less than zero!",
zap.ByteString("infoHash", infoHash[:]),
zap.String("filePath", file.DisplayPath(info)),
zap.Int64("fileSize", file.Length),
)
return
}
files = append(files, persistence.File{
Size: file.Length,
Path: file.DisplayPath(info),
})
}
var totalSize uint64
for _, file := range files {
totalSize += uint64(file.Size)
}
zap.L().Debug(
"Flushing metadata...",
zap.ByteString("infoHash", infoHash[:]),
)
ms.flush(Metadata{
InfoHash: infoHash[:],
Name: info.Name,
TotalSize: totalSize,
DiscoveredOn: time.Now().Unix(),
Files: files,
})
}
// COPIED FROM anacrolix/torrent
func validateInfo(info *metainfo.Info) error {
if len(info.Pieces)%20 != 0 {
return errors.New("pieces has invalid length")
}
if info.PieceLength == 0 {
if info.TotalLength() != 0 {
return errors.New("zero piece length")
}
} else {
if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() {
return errors.New("piece count and file lengths are at odds")
}
}
return nil
}
func writeAll(c *net.TCPConn, b []byte) error {
for len(b) != 0 {
n, err := c.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}
func readExactly(c *net.TCPConn, n int) ([]byte, error) {
b := make([]byte, n)
_, err := io.ReadFull(c, b)
return b, err
}
// TODO: add bounds checking!
func intToBigEndian(i int, n int) []byte {
b := make([]byte, n)
switch n {
case 1:
b = []byte{byte(i)}
case 2:
binary.BigEndian.PutUint16(b, uint16(i))
case 4:
binary.BigEndian.PutUint32(b, uint32(i))
default:
panic(fmt.Sprintf("n must be 1, 2, or 4!"))
}
if len(b) != n {
panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b))
}
return b
}
func bigEndianToInt(b []byte) int {
switch len(b) {
case 1:
return int(b[0])
case 2:
return int(binary.BigEndian.Uint16(b))
case 4:
return int(binary.BigEndian.Uint32(b))
default:
panic(fmt.Sprintf("bigEndianToInt: b is too long! (%d bytes)", len(b)))
}
}

View File

@ -26,11 +26,12 @@ type ProtocolEventHandlers struct {
OnGetPeersResponse func(*Message, net.Addr) OnGetPeersResponse func(*Message, net.Addr)
OnFindNodeResponse func(*Message, net.Addr) OnFindNodeResponse func(*Message, net.Addr)
OnPingORAnnouncePeerResponse func(*Message, net.Addr) OnPingORAnnouncePeerResponse func(*Message, net.Addr)
OnCongestion func()
} }
func NewProtocol(laddr string, eventHandlers ProtocolEventHandlers) (p *Protocol) { func NewProtocol(laddr string, eventHandlers ProtocolEventHandlers) (p *Protocol) {
p = new(Protocol) p = new(Protocol)
p.transport = NewTransport(laddr, p.onMessage) p.transport = NewTransport(laddr, p.onMessage, p.eventHandlers.OnCongestion)
p.eventHandlers = eventHandlers p.eventHandlers = eventHandlers
p.currentTokenSecret, p.previousTokenSecret = make([]byte, 20), make([]byte, 20) p.currentTokenSecret, p.previousTokenSecret = make([]byte, 20), make([]byte, 20)

View File

@ -31,13 +31,14 @@ type TrawlingService struct {
// ^~~~~~ // ^~~~~~
routingTable map[string]net.Addr routingTable map[string]net.Addr
routingTableMutex *sync.Mutex routingTableMutex *sync.Mutex
maxNeighbors uint
} }
type TrawlingServiceEventHandlers struct { type TrawlingServiceEventHandlers struct {
OnResult func(TrawlingResult) OnResult func(TrawlingResult)
} }
func NewTrawlingService(laddr string, eventHandlers TrawlingServiceEventHandlers) *TrawlingService { func NewTrawlingService(laddr string, initialMaxNeighbors uint, eventHandlers TrawlingServiceEventHandlers) *TrawlingService {
service := new(TrawlingService) service := new(TrawlingService)
service.protocol = NewProtocol( service.protocol = NewProtocol(
laddr, laddr,
@ -45,12 +46,14 @@ func NewTrawlingService(laddr string, eventHandlers TrawlingServiceEventHandlers
OnGetPeersQuery: service.onGetPeersQuery, OnGetPeersQuery: service.onGetPeersQuery,
OnAnnouncePeerQuery: service.onAnnouncePeerQuery, OnAnnouncePeerQuery: service.onAnnouncePeerQuery,
OnFindNodeResponse: service.onFindNodeResponse, OnFindNodeResponse: service.onFindNodeResponse,
OnCongestion: service.onCongestion,
}, },
) )
service.trueNodeID = make([]byte, 20) service.trueNodeID = make([]byte, 20)
service.routingTable = make(map[string]net.Addr) service.routingTable = make(map[string]net.Addr)
service.routingTableMutex = new(sync.Mutex) service.routingTableMutex = new(sync.Mutex)
service.eventHandlers = eventHandlers service.eventHandlers = eventHandlers
service.maxNeighbors = initialMaxNeighbors
_, err := rand.Read(service.trueNodeID) _, err := rand.Read(service.trueNodeID)
if err != nil { if err != nil {
@ -78,11 +81,14 @@ func (s *TrawlingService) Terminate() {
func (s *TrawlingService) trawl() { func (s *TrawlingService) trawl() {
for range time.Tick(3 * time.Second) { for range time.Tick(3 * time.Second) {
s.maxNeighbors = uint(float32(s.maxNeighbors) * 1.01)
s.routingTableMutex.Lock() s.routingTableMutex.Lock()
if len(s.routingTable) == 0 { if len(s.routingTable) == 0 {
s.bootstrap() s.bootstrap()
} else { } else {
zap.L().Debug("Latest status:", zap.Int("n", len(s.routingTable))) zap.L().Warn("Latest status:", zap.Int("n", len(s.routingTable)),
zap.Uint("maxNeighbors", s.maxNeighbors))
s.findNeighbors() s.findNeighbors()
s.routingTable = make(map[string]net.Addr) s.routingTable = make(map[string]net.Addr)
} }
@ -185,11 +191,34 @@ func (s *TrawlingService) onFindNodeResponse(response *Message, addr net.Addr) {
s.routingTableMutex.Lock() s.routingTableMutex.Lock()
defer s.routingTableMutex.Unlock() defer s.routingTableMutex.Unlock()
zap.L().Debug("find node response!!", zap.Uint("maxNeighbors", s.maxNeighbors),
zap.Int("response.R.Nodes length", len(response.R.Nodes)))
for _, node := range response.R.Nodes { for _, node := range response.R.Nodes {
if node.Addr.Port != 0 { // Ignore nodes who "use" port 0. if uint(len(s.routingTable)) >= s.maxNeighbors {
if len(s.routingTable) < 8000 { break
s.routingTable[string(node.ID)] = &node.Addr
}
} }
if node.Addr.Port == 0 { // Ignore nodes who "use" port 0.
zap.L().Debug("ignoring 0 port!!!")
continue
}
s.routingTable[string(node.ID)] = &node.Addr
} }
} }
func (s *TrawlingService) onCongestion() {
/* The Congestion Prevention Strategy:
*
* In case of congestion, decrease the maximum number of nodes to the 90% of the current value.
*/
if s.maxNeighbors < 200 {
zap.L().Warn("Max. number of neighbours are < 200 and there is still congestion!" +
"(check your network connection if this message recurs)")
return
}
s.maxNeighbors = uint(float32(s.maxNeighbors) * 0.9)
zap.L().Debug("Max. number of neighbours updated!",
zap.Uint("s.maxNeighbors", s.maxNeighbors))
}

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"github.com/anacrolix/torrent/bencode" "github.com/anacrolix/torrent/bencode"
sockaddr "github.com/libp2p/go-sockaddr/net"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -12,23 +13,43 @@ type Transport struct {
fd int fd int
laddr *net.UDPAddr laddr *net.UDPAddr
started bool started bool
buffer []byte
// OnMessage is the function that will be called when Transport receives a packet that is // OnMessage is the function that will be called when Transport receives a packet that is
// successfully unmarshalled as a syntactically correct Message (but -of course- the checking // successfully unmarshalled as a syntactically correct Message (but -of course- the checking
// the semantic correctness of the Message is left to Protocol). // the semantic correctness of the Message is left to Protocol).
onMessage func(*Message, net.Addr) onMessage func(*Message, net.Addr)
// OnCongestion
onCongestion func()
} }
func NewTransport(laddr string, onMessage func(*Message, net.Addr)) *Transport { func NewTransport(laddr string, onMessage func(*Message, net.Addr), onCongestion func()) *Transport {
transport := new(Transport) t := new(Transport)
transport.onMessage = onMessage /* The field size sets a theoretical limit of 65,535 bytes (8 byte header + 65,527 bytes of
* data) for a UDP datagram. However the actual limit for the data length, which is imposed by
* the underlying IPv4 protocol, is 65,507 bytes (65,535 8 byte UDP header 20 byte IP
* header).
*
* In IPv6 jumbograms it is possible to have UDP packets of size greater than 65,535 bytes.
* RFC 2675 specifies that the length field is set to zero if the length of the UDP header plus
* UDP data is greater than 65,535.
*
* https://en.wikipedia.org/wiki/User_Datagram_Protocol
*/
t.buffer = make([]byte, 65507)
t.onMessage = onMessage
t.onCongestion = onCongestion
var err error var err error
transport.laddr, err = net.ResolveUDPAddr("udp", laddr) t.laddr, err = net.ResolveUDPAddr("udp", laddr)
if err != nil { if err != nil {
zap.L().Panic("Could not resolve the UDP address for the trawler!", zap.Error(err)) zap.L().Panic("Could not resolve the UDP address for the trawler!", zap.Error(err))
} }
if t.laddr.IP.To4() == nil {
zap.L().Panic("IP address is not IPv4!")
}
return transport return t
} }
func (t *Transport) Start() { func (t *Transport) Start() {
@ -51,7 +72,12 @@ func (t *Transport) Start() {
zap.L().Fatal("Could NOT create a UDP socket!", zap.Error(err)) zap.L().Fatal("Could NOT create a UDP socket!", zap.Error(err))
} }
unix.Bind(t.fd, unix.SockaddrInet4{Addr: t.laddr.IP, Port: t.laddr.Port}) var ip [4]byte
copy(ip[:], t.laddr.IP.To4())
err = unix.Bind(t.fd, &unix.SockaddrInet4{Addr: ip, Port: t.laddr.Port})
if err != nil {
zap.L().Fatal("Could NOT bind the socket!", zap.Error(err))
}
go t.readMessages() go t.readMessages()
} }
@ -62,21 +88,33 @@ func (t *Transport) Terminate() {
// readMessages is a goroutine! // readMessages is a goroutine!
func (t *Transport) readMessages() { func (t *Transport) readMessages() {
buffer := make([]byte, 65536)
for { for {
n, from, err := unix.Recvfrom(t.fd, buffer, 0) n, fromSA, err := unix.Recvfrom(t.fd, t.buffer, 0)
if err != nil { if err == unix.EPERM || err == unix.ENOBUFS { // todo: are these errors possible for recvfrom?
zap.L().Warn("READ CONGESTION!", zap.Error(err))
t.onCongestion()
} else if err != nil {
// TODO: isn't there a more reliable way to detect if UDPConn is closed? // TODO: isn't there a more reliable way to detect if UDPConn is closed?
zap.L().Debug("Could NOT read an UDP packet!", zap.Error(err)) zap.L().Debug("Could NOT read an UDP packet!", zap.Error(err))
} }
if n == 0 {
/* Datagram sockets in various domains (e.g., the UNIX and Internet domains) permit
* zero-length datagrams. When such a datagram is received, the return value (n) is 0.
*/
zap.L().Debug("zero-length received!!")
continue
}
from := sockaddr.SockaddrToUDPAddr(fromSA)
var msg Message var msg Message
err = bencode.Unmarshal(buffer[:n], &msg) err = bencode.Unmarshal(t.buffer[:n], &msg)
if err != nil { if err != nil {
zap.L().Debug("Could NOT unmarshal packet data!", zap.Error(err)) zap.L().Debug("Could NOT unmarshal packet data!", zap.Error(err))
} }
zap.L().Debug("message read! (first 20...)", zap.ByteString("msg", t.buffer[:20]))
t.onMessage(&msg, from) t.onMessage(&msg, from)
} }
} }
@ -87,9 +125,29 @@ func (t *Transport) WriteMessages(msg *Message, addr net.Addr) {
zap.L().Panic("Could NOT marshal an outgoing message! (Programmer error.)") zap.L().Panic("Could NOT marshal an outgoing message! (Programmer error.)")
} }
err = unix.Sendto(t.fd, data, 0, addr) addrSA := sockaddr.NetAddrToSockaddr(addr)
// TODO: isn't there a more reliable way to detect if UDPConn is closed?
if err != nil { zap.L().Debug("sent message!!!")
err = unix.Sendto(t.fd, data, 0, addrSA)
if err == unix.EPERM || err == unix.ENOBUFS {
/* EPERM (errno: 1) is kernel's way of saying that "you are far too fast, chill". It is
* also likely that we have received a ICMP source quench packet (meaning, that we *really*
* need to slow down.
*
* Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto-amp-amp-EPERM.html
*
* > Note On BSD systems (OS X, FreeBSD, etc.) flow control is not supported for
* > DatagramProtocol, because send failures caused by writing too many packets cannot be
* > detected easily. The socket always appears ready and excess packets are dropped; an
* > OSError with errno set to errno.ENOBUFS may or may not be raised; if it is raised, it
* > will be reported to DatagramProtocol.error_received() but otherwise ignored.
*
* Source: https://docs.python.org/3/library/asyncio-protocol.html#flow-control-callbacks
*/
zap.L().Warn("WRITE CONGESTION!", zap.Error(err))
t.onCongestion()
} else if err != nil {
zap.L().Debug("Could NOT write an UDP packet!", zap.Error(err)) zap.L().Debug("Could NOT write an UDP packet!", zap.Error(err))
} }
} }

View File

@ -18,6 +18,7 @@ func NewTrawlingManager(mlAddrs []string) *TrawlingManager {
for _, addr := range mlAddrs { for _, addr := range mlAddrs {
manager.services = append(manager.services, mainline.NewTrawlingService( manager.services = append(manager.services, mainline.NewTrawlingService(
addr, addr,
2000,
mainline.TrawlingServiceEventHandlers{ mainline.TrawlingServiceEventHandlers{
OnResult: manager.onResult, OnResult: manager.onResult,
}, },

View File

@ -6,7 +6,6 @@ import (
"net" "net"
"os" "os"
"os/signal" "os/signal"
"path"
"runtime/pprof" "runtime/pprof"
"time" "time"
@ -14,7 +13,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
"github.com/boramalper/magnetico/cmd/magneticod/bittorrent" "github.com/boramalper/magnetico/cmd/magneticod/bittorrent/metadata"
"github.com/boramalper/magnetico/cmd/magneticod/dht" "github.com/boramalper/magnetico/cmd/magneticod/dht"
"github.com/Wessie/appdirs" "github.com/Wessie/appdirs"
@ -75,6 +74,8 @@ func main() {
zap.ReplaceGlobals(logger) zap.ReplaceGlobals(logger)
zap.L().Debug("debug message!")
switch opFlags.Profile { switch opFlags.Profile {
case "cpu": case "cpu":
file, err := os.OpenFile("magneticod_cpu.prof", os.O_CREATE | os.O_WRONLY, 0755) file, err := os.OpenFile("magneticod_cpu.prof", os.O_CREATE | os.O_WRONLY, 0755)
@ -96,19 +97,19 @@ func main() {
interruptChan := make(chan os.Signal) interruptChan := make(chan os.Signal)
signal.Notify(interruptChan, os.Interrupt) signal.Notify(interruptChan, os.Interrupt)
database, err := persistence.MakeDatabase(opFlags.DatabaseURL, false, logger) database, err := persistence.MakeDatabase(opFlags.DatabaseURL, logger)
if err != nil { if err != nil {
logger.Sugar().Fatalf("Could not open the database at `%s`: %s", opFlags.DatabaseURL, err.Error()) logger.Sugar().Fatalf("Could not open the database at `%s`: %s", opFlags.DatabaseURL, err.Error())
} }
trawlingManager := dht.NewTrawlingManager(opFlags.TrawlerMlAddrs) trawlingManager := dht.NewTrawlingManager(opFlags.TrawlerMlAddrs)
metadataSink := bittorrent.NewMetadataSink(2 * time.Minute) metadataSink := metadata.NewSink(2 * time.Minute)
// The Event Loop // The Event Loop
for stopped := false; !stopped; { for stopped := false; !stopped; {
select { select {
case result := <-trawlingManager.Output(): case result := <-trawlingManager.Output():
zap.L().Info("Trawled!", zap.String("infoHash", result.InfoHash.String())) zap.L().Debug("Trawled!", zap.String("infoHash", result.InfoHash.String()))
exists, err := database.DoesTorrentExist(result.InfoHash[:]) exists, err := database.DoesTorrentExist(result.InfoHash[:])
if err != nil { if err != nil {
zap.L().Fatal("Could not check whether torrent exists!", zap.Error(err)) zap.L().Fatal("Could not check whether torrent exists!", zap.Error(err))
@ -144,10 +145,11 @@ func parseFlags() (*opFlags, error) {
} }
if cmdF.DatabaseURL == "" { if cmdF.DatabaseURL == "" {
opF.DatabaseURL = "sqlite3://" + path.Join( opF.DatabaseURL =
appdirs.UserDataDir("magneticod", "", "", false), "sqlite3://" +
"database.sqlite3", appdirs.UserDataDir("magneticod", "", "", false) +
) "/database.sqlite3" +
"?_journal_mode=WAL" // https://github.com/mattn/go-sqlite3#connection-string
} else { } else {
opF.DatabaseURL = cmdF.DatabaseURL opF.DatabaseURL = cmdF.DatabaseURL
} }

View File

@ -14,6 +14,9 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// Close your rows lest you get "database table is locked" error(s)!
// See https://github.com/mattn/go-sqlite3/issues/2741
type sqlite3Database struct { type sqlite3Database struct {
conn *sql.DB conn *sql.DB
} }
@ -55,15 +58,12 @@ func (db *sqlite3Database) DoesTorrentExist(infoHash []byte) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
defer rows.Close()
// If rows.Next() returns true, meaning that the torrent is in the database, return true; else // If rows.Next() returns true, meaning that the torrent is in the database, return true; else
// return false. // return false.
exists := rows.Next() exists := rows.Next()
if !exists && rows.Err() != nil { if rows.Err() != nil {
return false, err
}
if err = rows.Close(); err != nil {
return false, err return false, err
} }
@ -143,6 +143,7 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer rows.Close()
if rows.Next() != true { if rows.Next() != true {
fmt.Errorf("No rows returned from `SELECT MAX(ROWID)`") fmt.Errorf("No rows returned from `SELECT MAX(ROWID)`")
@ -153,10 +154,6 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) {
return 0, err return 0, err
} }
if err = rows.Close(); err != nil {
return 0, err
}
return n, nil return n, nil
} }
@ -247,6 +244,7 @@ func (db *sqlite3Database) QueryTorrents(
queryArgs = append(queryArgs, limit) queryArgs = append(queryArgs, limit)
rows, err := db.conn.Query(sqlQuery, queryArgs...) rows, err := db.conn.Query(sqlQuery, queryArgs...)
defer rows.Close()
if err != nil { if err != nil {
return nil, fmt.Errorf("error while querying torrents: %s", err.Error()) return nil, fmt.Errorf("error while querying torrents: %s", err.Error())
} }
@ -269,10 +267,6 @@ func (db *sqlite3Database) QueryTorrents(
torrents = append(torrents, torrent) torrents = append(torrents, torrent)
} }
if err := rows.Close(); err != nil {
return nil, err
}
return torrents, nil return torrents, nil
} }
@ -307,6 +301,7 @@ func (db *sqlite3Database) GetTorrent(infoHash []byte) (*TorrentMetadata, error)
WHERE info_hash = ?`, WHERE info_hash = ?`,
infoHash, infoHash,
) )
defer rows.Close()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -320,10 +315,6 @@ func (db *sqlite3Database) GetTorrent(infoHash []byte) (*TorrentMetadata, error)
return nil, err return nil, err
} }
if err = rows.Close(); err != nil {
return nil, err
}
return &tm, nil return &tm, nil
} }
@ -331,6 +322,7 @@ func (db *sqlite3Database) GetFiles(infoHash []byte) ([]File, error) {
rows, err := db.conn.Query( rows, err := db.conn.Query(
"SELECT size, path FROM files, torrents WHERE files.torrent_id = torrents.id AND torrents.info_hash = ?;", "SELECT size, path FROM files, torrents WHERE files.torrent_id = torrents.id AND torrents.info_hash = ?;",
infoHash) infoHash)
defer rows.Close()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -344,10 +336,6 @@ func (db *sqlite3Database) GetFiles(infoHash []byte) ([]File, error) {
files = append(files, file) files = append(files, file)
} }
if err := rows.Close(); err != nil {
return nil, err
}
return files, nil return files, nil
} }
@ -391,6 +379,7 @@ func (db *sqlite3Database) GetStatistics(from string, n uint) (*Statistics, erro
GROUP BY dt;`, GROUP BY dt;`,
timef), timef),
fromTime.Unix(), toTime.Unix()) fromTime.Unix(), toTime.Unix())
defer rows.Close()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -478,6 +467,7 @@ func (db *sqlite3Database) setupDatabase() error {
if err != nil { if err != nil {
return fmt.Errorf("sql.Tx.Query (user_version): %s", err.Error()) return fmt.Errorf("sql.Tx.Query (user_version): %s", err.Error())
} }
defer rows.Close()
var userVersion int var userVersion int
if rows.Next() != true { if rows.Next() != true {
return fmt.Errorf("sql.Rows.Next (user_version): PRAGMA user_version did not return any rows!") return fmt.Errorf("sql.Rows.Next (user_version): PRAGMA user_version did not return any rows!")
@ -485,11 +475,6 @@ func (db *sqlite3Database) setupDatabase() error {
if err = rows.Scan(&userVersion); err != nil { if err = rows.Scan(&userVersion); err != nil {
return fmt.Errorf("sql.Rows.Scan (user_version): %s", err.Error()) return fmt.Errorf("sql.Rows.Scan (user_version): %s", err.Error())
} }
// Close your rows lest you get "database table is locked" error(s)!
// See https://github.com/mattn/go-sqlite3/issues/2741
if err = rows.Close(); err != nil {
return fmt.Errorf("sql.Rows.Close (user_version): %s", err.Error())
}
switch userVersion { switch userVersion {
case 0: // FROZEN. case 0: // FROZEN.