From c07daa3eca376935200ec43cb20052da66f7f521 Mon Sep 17 00:00:00 2001 From: Bora Alper Date: Tue, 24 Jul 2018 15:41:13 +0300 Subject: [PATCH] magneticod: metadata leech refactored heavily, now much more readable <3 + persistence: we now make sure that rows are always closed using `defer` --- Gopkg.toml | 4 + cmd/magneticod/bittorrent/metadata/leech.go | 439 ++++++++++++++++ .../leech_test.go} | 2 +- .../{sinkMetadata.go => metadata/sink.go} | 55 +- cmd/magneticod/bittorrent/operations.go | 480 ------------------ cmd/magneticod/dht/mainline/protocol.go | 3 +- cmd/magneticod/dht/mainline/service.go | 41 +- cmd/magneticod/dht/mainline/transport.go | 86 +++- cmd/magneticod/dht/managers.go | 1 + cmd/magneticod/main.go | 20 +- pkg/persistence/sqlite3.go | 37 +- 11 files changed, 607 insertions(+), 561 deletions(-) create mode 100644 cmd/magneticod/bittorrent/metadata/leech.go rename cmd/magneticod/bittorrent/{operations_test.go => metadata/leech_test.go} (98%) rename cmd/magneticod/bittorrent/{sinkMetadata.go => metadata/sink.go} (63%) delete mode 100644 cmd/magneticod/bittorrent/operations.go diff --git a/Gopkg.toml b/Gopkg.toml index 980964b..2a98fca 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -56,3 +56,7 @@ [[constraint]] name = "go.uber.org/zap" version = "1.7.1" + +[[constraint]] + name = "github.com/libp2p/go-sockaddr" + version = "1.0.3" diff --git a/cmd/magneticod/bittorrent/metadata/leech.go b/cmd/magneticod/bittorrent/metadata/leech.go new file mode 100644 index 0000000..ff1c18c --- /dev/null +++ b/cmd/magneticod/bittorrent/metadata/leech.go @@ -0,0 +1,439 @@ +package metadata + +import ( + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/binary" + "fmt" + "io" + "math" + "net" + "time" + + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/metainfo" + "github.com/pkg/errors" + + "go.uber.org/zap" + + "github.com/boramalper/magnetico/pkg/persistence" +) + +const MAX_METADATA_SIZE = 10 * 1024 * 1024 + +type rootDict struct { + M mDict `bencode:"m"` + MetadataSize int `bencode:"metadata_size"` +} + +type mDict struct { + UTMetadata int `bencode:"ut_metadata"` +} + +type extDict struct { + MsgType int `bencode:"msg_type"` + Piece int `bencode:"piece"` +} + +type Leech struct { + infoHash [20]byte + peerAddr *net.TCPAddr + ev LeechEventHandlers + + conn *net.TCPConn + clientID [20]byte + + ut_metadata uint8 + metadataReceived, metadataSize uint + metadata []byte +} + +type LeechEventHandlers struct { + OnSuccess func(Metadata) // must be supplied. args: metadata + OnError func([20]byte, error) // must be supplied. args: infohash, error +} + +func NewLeech(infoHash [20]byte, peerAddr *net.TCPAddr, ev LeechEventHandlers) *Leech { + l := new(Leech) + l.infoHash = infoHash + l.peerAddr = peerAddr + l.ev = ev + + if _, err := rand.Read(l.clientID[:]); err != nil { + panic(err.Error()) + } + + return l +} + +func (l *Leech) writeAll(b []byte) error { + for len(b) != 0 { + n, err := l.conn.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil +} + +func (l *Leech) doBtHandshake() error { + lHandshake := []byte(fmt.Sprintf( + "\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s", + l.infoHash[:], + l.clientID, + )) + + // ASSERTION + if len(lHandshake) != 68 { panic(fmt.Sprintf("len(lHandshake) == %d", len(lHandshake))) } + + + err := l.writeAll(lHandshake) + if err != nil { + return errors.Wrap(err, "writeAll lHandshake") + } + + zap.L().Debug("BitTorrent handshake sent, waiting for the remote's...") + + rHandshake, err := l.readExactly(68) + if err != nil { + return errors.Wrap(err, "readExactly rHandshake") + } + if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) { + return fmt.Errorf("corrupt BitTorrent handshake received") + } + + // TODO: maybe check for the infohash sent by the remote peer to double check? + + if (rHandshake[25] & 0x10) == 0 { + return fmt.Errorf("peer does not support the extension protocol") + } + + return nil +} + +func (l *Leech) doExHandshake() error { + err := l.writeAll([]byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee")) + if err != nil { + return errors.Wrap(err, "writeAll lHandshake") + } + + rExMessage, err := l.readExMessage() + if err != nil { + return errors.Wrap(err, "readExMessage") + } + + // Extension Handshake has the Extension Message ID = 0x00 + if rExMessage[1] != 0 { + return errors.Wrap(err, "first extension message is not an extension handshake") + } + + rRootDict := new(rootDict) + err = bencode.Unmarshal(rExMessage[2:], rRootDict) + if err != nil { + return errors.Wrap(err, "unmarshal rExMessage") + } + + if !(0 < rRootDict.MetadataSize && rRootDict.MetadataSize < MAX_METADATA_SIZE) { + return fmt.Errorf("metadata too big or its size is less than or equal zero") + } + + if !(0 < rRootDict.M.UTMetadata && rRootDict.M.UTMetadata < 255) { + return fmt.Errorf("ut_metadata is not an uint8") + } + + l.ut_metadata = uint8(rRootDict.M.UTMetadata) // Save the ut_metadata code the remote peer uses + l.metadataSize = uint(rRootDict.MetadataSize) + l.metadata = make([]byte, l.metadataSize) + + return nil +} + +func (l *Leech) requestAllPieces() error { + // Request all the pieces of metadata + nPieces := int(math.Ceil(float64(l.metadataSize) / math.Pow(2, 14))) + for piece := 0; piece < nPieces; piece++ { + // __request_metadata_piece(piece) + // ............................... + extDictDump, err := bencode.Marshal(extDict{ + MsgType: 0, + Piece: piece, + }) + if err != nil { // ASSERT + panic(errors.Wrap(err, "marshal extDict")) + } + + err = l.writeAll([]byte(fmt.Sprintf( + "%s\x14%s%s", + toBigEndian(uint(2 + len(extDictDump)), 4), + toBigEndian(uint(l.ut_metadata), 1), + extDictDump, + ))) + if err != nil { + return errors.Wrap(err, "writeAll piece request") + } + } + + return nil +} + +// readMessage returns a BitTorrent message, sans the first 4 bytes indicating its length. +func (l *Leech) readMessage() ([]byte, error) { + rLengthB, err := l.readExactly(4) + if err != nil { + return nil, errors.Wrap(err, "readExactly rLengthB") + } + + rLength := uint(binary.BigEndian.Uint32(rLengthB)) + + rMessage, err := l.readExactly(rLength) + if err != nil { + return nil, errors.Wrap(err, "readExactly rMessage") + } + + return rMessage, nil +} + +// readExMessage returns an *extension* message, sans the first 4 bytes indicating its length. +// +// It will IGNORE all non-extension messages! +func (l *Leech) readExMessage() ([]byte, error) { + for { + rMessage, err := l.readMessage() + if err != nil { + return nil, errors.Wrap(err, "readMessage") + } + + // We are interested only in extension messages, whose first byte is always 20 + if rMessage[0] == 20 { + return rMessage, nil + } + } +} + +// readUmMessage returns an ut_metadata extension message, sans the first 4 bytes indicating its +// length. +// +// It will IGNORE all non-"ut_metadata extension" messages! +func (l *Leech) readUmMessage() ([]byte, error) { + for { + rExMessage, err := l.readExMessage() + if err != nil { + return nil, errors.Wrap(err, "readExMessage") + } + + if rExMessage[1] == 0x01 { + return rExMessage, nil + } + } +} + +func (l *Leech) connect(deadline time.Time) error { + var err error + + l.conn, err = net.DialTCP("tcp", nil, l.peerAddr) + if err != nil { + return errors.Wrap(err, "dial") + } + defer l.conn.Close() + + err = l.conn.SetNoDelay(true) + if err != nil { + return errors.Wrap(err, "NODELAY") + } + + err = l.conn.SetDeadline(deadline) + if err != nil { + return errors.Wrap(err, "SetDeadline") + } + + return nil +} + +func (l *Leech) Do(deadline time.Time) { + err := l.connect(deadline) + if err != nil { + l.OnError(errors.Wrap(err, "connect")) + return + } + + err = l.doBtHandshake() + if err != nil { + l.OnError(errors.Wrap(err, "doBtHandshake")) + return + } + + err = l.doExHandshake() + if err != nil { + l.OnError(errors.Wrap(err, "doExHandshake")) + return + } + + err = l.requestAllPieces() + if err != nil { + l.OnError(errors.Wrap(err, "requestAllPieces")) + return + } + + for l.metadataReceived < l.metadataSize { + rUmMessage, err := l.readUmMessage() + if err != nil { + l.OnError(errors.Wrap(err, "readUmMessage")) + return + } + + // Run TestDecoder() function in leech_test.go in case you have any doubts. + rMessageBuf := bytes.NewBuffer(rUmMessage[2:]) + rExtDict := new(extDict) + err = bencode.NewDecoder(rMessageBuf).Decode(rExtDict) + if err != nil { + zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err)) + return + } + + if rExtDict.MsgType == 2 { // reject + l.OnError(fmt.Errorf("remote peer rejected sending metadata")) + return + } + + if rExtDict.MsgType == 1 { // data + // Get the unread bytes! + metadataPiece := rMessageBuf.Bytes() + + // BEP 9 explicitly states: + // > If the piece is the last piece of the metadata, it may be less than 16kiB. If + // > it is not the last piece of the metadata, it MUST be 16kiB. + // + // Hence... + // ... if the length of @metadataPiece is more than 16kiB, we err. + if len(metadataPiece) > 16*1024 { + l.OnError(fmt.Errorf("metadataPiece > 16kiB")) + return + } + + piece := rExtDict.Piece + // metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python + copy(l.metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece) + l.metadataReceived += uint(len(metadataPiece)) + + // ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT + // complete then we err. + if len(metadataPiece) < 16*1024 && l.metadataReceived != l.metadataSize { + l.OnError(fmt.Errorf("metadataPiece < 16 kiB but incomplete")) + return + } + + if l.metadataReceived > l.metadataSize { + l.OnError(fmt.Errorf("metadataReceived > metadataSize")) + return + } + } + } + + // Verify the checksum + sha1Sum := sha1.Sum(l.metadata) + if !bytes.Equal(sha1Sum[:], l.infoHash[:]) { + l.OnError(fmt.Errorf("infohash mismatch")) + return + } + + // Check the info dictionary + info := new(metainfo.Info) + err = bencode.Unmarshal(l.metadata, info) + if err != nil { + l.OnError(errors.Wrap(err, "unmarshal info")) + return + } + err = validateInfo(info) + if err != nil { + l.OnError(errors.Wrap(err, "validateInfo")) + return + } + + var files []persistence.File + // If there is only one file, there won't be a Files slice. That's why we need to add it here + if len(info.Files) == 0 { + files = append(files, persistence.File{ + Size: info.Length, + Path: info.Name, + }) + } else { + for _, file := range info.Files { + files = append(files, persistence.File{ + Size: file.Length, + Path: file.DisplayPath(info), + }) + } + } + + var totalSize uint64 + for _, file := range files { + if file.Size < 0 { + l.OnError(fmt.Errorf("file size less than zero")) + return + } + + totalSize += uint64(file.Size) + } + + l.ev.OnSuccess(Metadata{ + InfoHash: l.infoHash[:], + Name: info.Name, + TotalSize: totalSize, + DiscoveredOn: time.Now().Unix(), + Files: files, + }) +} + +// COPIED FROM anacrolix/torrent +func validateInfo(info *metainfo.Info) error { + if len(info.Pieces)%20 != 0 { + return errors.New("pieces has invalid length") + } + if info.PieceLength == 0 { + if info.TotalLength() != 0 { + return errors.New("zero piece length") + } + } else { + if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() { + return errors.New("piece count and file lengths are at odds") + } + } + return nil +} + +func (l *Leech) readExactly(n uint) ([]byte, error) { + b := make([]byte, n) + _, err := io.ReadFull(l.conn, b) + return b, err +} + +func (l *Leech) OnError(err error) { + l.ev.OnError(l.infoHash, err) +} + +// TODO: add bounds checking! +func toBigEndian(i uint, n int) []byte { + b := make([]byte, n) + switch n { + case 1: + b = []byte{byte(i)} + + case 2: + binary.BigEndian.PutUint16(b, uint16(i)) + + case 4: + binary.BigEndian.PutUint32(b, uint32(i)) + + default: + panic(fmt.Sprintf("n must be 1, 2, or 4!")) + } + + if len(b) != n { + panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b)) + } + + return b +} + diff --git a/cmd/magneticod/bittorrent/operations_test.go b/cmd/magneticod/bittorrent/metadata/leech_test.go similarity index 98% rename from cmd/magneticod/bittorrent/operations_test.go rename to cmd/magneticod/bittorrent/metadata/leech_test.go index 8c24fa6..9e35686 100644 --- a/cmd/magneticod/bittorrent/operations_test.go +++ b/cmd/magneticod/bittorrent/metadata/leech_test.go @@ -1,4 +1,4 @@ -package bittorrent +package metadata import ( "bytes" diff --git a/cmd/magneticod/bittorrent/sinkMetadata.go b/cmd/magneticod/bittorrent/metadata/sink.go similarity index 63% rename from cmd/magneticod/bittorrent/sinkMetadata.go rename to cmd/magneticod/bittorrent/metadata/sink.go index 9e5049e..fef4572 100644 --- a/cmd/magneticod/bittorrent/sinkMetadata.go +++ b/cmd/magneticod/bittorrent/metadata/sink.go @@ -1,4 +1,4 @@ -package bittorrent +package metadata import ( "crypto/rand" @@ -24,21 +24,17 @@ type Metadata struct { Files []persistence.File } -type Peer struct { - Addr *net.TCPAddr -} - -type MetadataSink struct { - clientID []byte - deadline time.Duration - drain chan Metadata +type Sink struct { + clientID []byte + deadline time.Duration + drain chan Metadata incomingInfoHashes map[[20]byte]struct{} - terminated bool - termination chan interface{} + terminated bool + termination chan interface{} } -func NewMetadataSink(deadline time.Duration) *MetadataSink { - ms := new(MetadataSink) +func NewSink(deadline time.Duration) *Sink { + ms := new(Sink) ms.clientID = make([]byte, 20) _, err := rand.Read(ms.clientID) @@ -52,27 +48,29 @@ func NewMetadataSink(deadline time.Duration) *MetadataSink { return ms } -func (ms *MetadataSink) Sink(res mainline.TrawlingResult) { +func (ms *Sink) Sink(res mainline.TrawlingResult) { if ms.terminated { - zap.L().Panic("Trying to Sink() an already closed MetadataSink!") + zap.L().Panic("Trying to Sink() an already closed Sink!") } if _, exists := ms.incomingInfoHashes[res.InfoHash]; exists { return } // BEWARE! - // Although not crucial, the assumption is that MetadataSink.Sink() will be called by only one + // Although not crucial, the assumption is that Sink.Sink() will be called by only one // goroutine (i.e. it's not thread-safe), lest there might be a race condition between where we // check whether res.infoHash exists in the ms.incomingInfoHashes, and where we add the infoHash // to the incomingInfoHashes at the end of this function. + zap.L().Info("Sunk!", zap.String("infoHash", res.InfoHash.String())) + IPs := res.PeerIP.String() var rhostport string if IPs == "" { - zap.L().Debug("MetadataSink.Sink: Peer IP is nil!") + zap.L().Debug("Sink.Sink: Peer IP is nil!") return } else if IPs[0] == '?' { - zap.L().Debug("MetadataSink.Sink: Peer IP is invalid!") + zap.L().Debug("Sink.Sink: Peer IP is invalid!") return } else if strings.ContainsRune(IPs, ':') { // IPv6 rhostport = fmt.Sprintf("[%s]:%d", IPs, res.PeerPort) @@ -82,29 +80,33 @@ func (ms *MetadataSink) Sink(res mainline.TrawlingResult) { raddr, err := net.ResolveTCPAddr("tcp", rhostport) if err != nil { - zap.L().Debug("MetadataSink.Sink: Couldn't resolve peer address!", zap.Error(err)) + zap.L().Debug("Sink.Sink: Couldn't resolve peer address!", zap.Error(err)) return } - go ms.awaitMetadata(res.InfoHash, Peer{Addr: raddr}) + leech := NewLeech(res.InfoHash, raddr, LeechEventHandlers{ + OnSuccess: ms.flush, + OnError: ms.onLeechError, + }) + go leech.Do(time.Now().Add(ms.deadline)) ms.incomingInfoHashes[res.InfoHash] = struct{}{} } -func (ms *MetadataSink) Drain() <-chan Metadata { +func (ms *Sink) Drain() <-chan Metadata { if ms.terminated { - zap.L().Panic("Trying to Drain() an already closed MetadataSink!") + zap.L().Panic("Trying to Drain() an already closed Sink!") } return ms.drain } -func (ms *MetadataSink) Terminate() { +func (ms *Sink) Terminate() { ms.terminated = true close(ms.termination) close(ms.drain) } -func (ms *MetadataSink) flush(result Metadata) { +func (ms *Sink) flush(result Metadata) { if !ms.terminated { ms.drain <- result // Delete the infoHash from ms.incomingInfoHashes ONLY AFTER once we've flushed the @@ -114,3 +116,8 @@ func (ms *MetadataSink) flush(result Metadata) { delete(ms.incomingInfoHashes, infoHash) } } + +func (ms *Sink) onLeechError(infoHash [20]byte, err error) { + zap.L().Debug("leech error", zap.ByteString("infoHash", infoHash[:]), zap.Error(err)) + delete(ms.incomingInfoHashes, infoHash) +} diff --git a/cmd/magneticod/bittorrent/operations.go b/cmd/magneticod/bittorrent/operations.go deleted file mode 100644 index d0e65df..0000000 --- a/cmd/magneticod/bittorrent/operations.go +++ /dev/null @@ -1,480 +0,0 @@ -package bittorrent - -import ( - "bytes" - "crypto/sha1" - "encoding/binary" - "errors" - "fmt" - "io" - "math" - "net" - "time" - - "github.com/anacrolix/torrent/bencode" - "github.com/anacrolix/torrent/metainfo" - - "go.uber.org/zap" - - "github.com/boramalper/magnetico/pkg/persistence" -) - -const MAX_METADATA_SIZE = 10 * 1024 * 1024 - -type rootDict struct { - M mDict `bencode:"m"` - MetadataSize int `bencode:"metadata_size"` -} - -type mDict struct { - UTMetadata int `bencode:"ut_metadata"` -} - -type extDict struct { - MsgType int `bencode:"msg_type"` - Piece int `bencode:"piece"` -} - -func (ms *MetadataSink) awaitMetadata(infoHash metainfo.Hash, peer Peer) { - conn, err := net.DialTCP("tcp", nil, peer.Addr) - if err != nil { - zap.L().Debug( - "awaitMetadata couldn't connect to the peer!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - defer conn.Close() - - err = conn.SetNoDelay(true) - if err != nil { - zap.L().Panic( - "Couldn't set NODELAY!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - err = conn.SetDeadline(time.Now().Add(ms.deadline)) - if err != nil { - zap.L().Panic( - "Couldn't set the deadline!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - - // State Variables - var isExtHandshakeDone, done bool - var ut_metadata, metadataReceived, metadataSize int - var metadata []byte - - lHandshake := []byte(fmt.Sprintf( - "\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s", - infoHash[:], - ms.clientID, - )) - if len(lHandshake) != 68 { - zap.L().Panic( - "Generated BitTorrent handshake is not of length 68!", - zap.ByteString("infoHash", infoHash[:]), - zap.Int("len_lHandshake", len(lHandshake)), - ) - } - err = writeAll(conn, lHandshake) - if err != nil { - zap.L().Debug( - "Couldn't write BitTorrent handshake!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - - zap.L().Debug("BitTorrent handshake sent, waiting for the remote's...") - - rHandshake, err := readExactly(conn, 68) - if err != nil { - zap.L().Debug( - "Couldn't read remote BitTorrent handshake!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) { - zap.L().Debug( - "Remote BitTorrent handshake is not what it is supposed to be!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.ByteString("rHandshake[:20]", rHandshake[:20]), - ) - return - } - - // __on_bt_handshake - // ================ - if rHandshake[25] != 16 { // TODO (later): do *not* compare the whole byte, check the bit instead! (0x10) - zap.L().Debug( - "Peer does not support the extension protocol!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - ) - return - } - - writeAll(conn, []byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee")) - zap.L().Debug( - "Extension handshake sent, waiting for the remote's...", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - ) - - // the loop! - // ========= - for !done { - rLengthB, err := readExactly(conn, 4) - if err != nil { - zap.L().Debug( - "Couldn't read the first 4 bytes from the remote peer in the loop!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - - // The messages we are interested in have the length of AT LEAST two bytes - // (TODO: actually a bit more than that but SURELY when it's less than two bytes, the - // program panics) - rLength := bigEndianToInt(rLengthB) - if rLength < 2 { - continue - } - - rMessage, err := readExactly(conn, rLength) - if err != nil { - zap.L().Debug( - "Couldn't read the rest of the message from the remote peer in the loop!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", peer.Addr.String()), - zap.Error(err), - ) - return - } - - // __on_message - // ------------ - if rMessage[0] != 0x14 { // We are interested only in extension messages, whose first byte is always 0x14 - zap.L().Debug( - "Ignoring the non-extension message.", - zap.ByteString("infoHash", infoHash[:]), - ) - continue - } - - if rMessage[1] == 0x00 { // Extension Handshake has the Extension Message ID = 0x00 - // __on_ext_handshake_message(message[2:]) - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - // TODO: continue editing log messages from here - - if isExtHandshakeDone { - return - } - - rRootDict := new(rootDict) - err := bencode.Unmarshal(rMessage[2:], rRootDict) - if err != nil { - zap.L().Debug("Couldn't unmarshal extension handshake!", zap.Error(err)) - return - } - - if rRootDict.MetadataSize <= 0 || rRootDict.MetadataSize > MAX_METADATA_SIZE { - zap.L().Debug("Unacceptable metadata size!", zap.Int("metadata_size", rRootDict.MetadataSize)) - return - } - - ut_metadata = rRootDict.M.UTMetadata // Save the ut_metadata code the remote peer uses - metadataSize = rRootDict.MetadataSize - metadata = make([]byte, metadataSize) - isExtHandshakeDone = true - - zap.L().Debug("GOT EXTENSION HANDSHAKE!", zap.Int("ut_metadata", ut_metadata), zap.Int("metadata_size", metadataSize)) - - // Request all the pieces of metadata - n_pieces := int(math.Ceil(float64(metadataSize) / math.Pow(2, 14))) - for piece := 0; piece < n_pieces; piece++ { - // __request_metadata_piece(piece) - // ............................... - extDictDump, err := bencode.Marshal(extDict{ - MsgType: 0, - Piece: piece, - }) - if err != nil { - zap.L().Warn("Couldn't marshal extDictDump!", zap.Error(err)) - return - } - writeAll(conn, []byte(fmt.Sprintf( - "%s\x14%s%s", - intToBigEndian(2+len(extDictDump), 4), - intToBigEndian(ut_metadata, 1), - extDictDump, - ))) - } - - zap.L().Warn("requested all metadata pieces!") - - } else if rMessage[1] == 0x01 { - // __on_ext_message(message[2:]) - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - // Run TestDecoder() function in operations_test.go in case you have any doubts. - rMessageBuf := bytes.NewBuffer(rMessage[2:]) - rExtDict := new(extDict) - err := bencode.NewDecoder(rMessageBuf).Decode(rExtDict) - if err != nil { - zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err)) - return - } - - if rExtDict.MsgType == 1 { // data - // Get the unread bytes! - metadataPiece := rMessageBuf.Bytes() - piece := rExtDict.Piece - // metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python - copy(metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece) - metadataReceived += len(metadataPiece) - done = metadataReceived == metadataSize - - // BEP 9 explicitly states: - // > If the piece is the last piece of the metadata, it may be less than 16kiB. If - // > it is not the last piece of the metadata, it MUST be 16kiB. - // - // Hence... - // ... if the length of @metadataPiece is more than 16kiB, we err. - if len(metadataPiece) > 16*1024 { - zap.L().Debug( - "metadataPiece is bigger than 16kiB!", - zap.Int("len_metadataPiece", len(metadataPiece)), - zap.Int("metadataReceived", metadataReceived), - zap.Int("metadataSize", metadataSize), - zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)), - ) - return - } - - // ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT - // complete (!done) then we err. - if len(metadataPiece) < 16*1024 && !done { - zap.L().Debug( - "metadataPiece is less than 16kiB and metadata is incomplete!", - zap.Int("len_metadataPiece", len(metadataPiece)), - zap.Int("metadataReceived", metadataReceived), - zap.Int("metadataSize", metadataSize), - zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)), - ) - return - } - - if metadataReceived > metadataSize { - zap.L().Debug( - "metadataReceived is greater than metadataSize!", - zap.Int("len_metadataPiece", len(metadataPiece)), - zap.Int("metadataReceived", metadataReceived), - zap.Int("metadataSize", metadataSize), - zap.Int("metadataPieceIndex", bytes.Index(rMessage, metadataPiece)), - ) - return - } - - zap.L().Debug( - "Fetching...", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", conn.RemoteAddr().String()), - zap.Int("metadataReceived", metadataReceived), - zap.Int("metadataSize", metadataSize), - ) - } else if rExtDict.MsgType == 2 { // reject - zap.L().Debug( - "Remote peer rejected sending metadata!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("remotePeerAddr", conn.RemoteAddr().String()), - ) - return - } - - } else { - zap.L().Debug( - "Message is not an ut_metadata message! (ignoring)", - zap.ByteString("msg", rMessage[:100]), - ) - // no return! - } - } - - zap.L().Debug( - "Metadata is complete, verifying the checksum...", - zap.ByteString("infoHash", infoHash[:]), - ) - - sha1Sum := sha1.Sum(metadata) - if !bytes.Equal(sha1Sum[:], infoHash[:]) { - zap.L().Debug( - "Info-hash mismatch!", - zap.ByteString("expectedInfoHash", infoHash[:]), - zap.ByteString("actualInfoHash", sha1Sum[:]), - ) - return - } - - zap.L().Debug( - "Checksum verified, checking the info dictionary...", - zap.ByteString("infoHash", infoHash[:]), - ) - - info := new(metainfo.Info) - err = bencode.Unmarshal(metadata, info) - if err != nil { - zap.L().Debug( - "Couldn't unmarshal info bytes!", - zap.ByteString("infoHash", infoHash[:]), - zap.Error(err), - ) - return - } - err = validateInfo(info) - if err != nil { - zap.L().Debug( - "Bad info dictionary!", - zap.ByteString("infoHash", infoHash[:]), - zap.Error(err), - ) - return - } - - var files []persistence.File - // If there is only one file, there won't be a Files slice. That's why we need to add it here - if len(info.Files) == 0 { - files = append(files, persistence.File{ - Size: info.Length, - Path: info.Name, - }) - } - - for _, file := range info.Files { - if file.Length < 0 { - zap.L().Debug( - "File size is less than zero!", - zap.ByteString("infoHash", infoHash[:]), - zap.String("filePath", file.DisplayPath(info)), - zap.Int64("fileSize", file.Length), - ) - return - } - - files = append(files, persistence.File{ - Size: file.Length, - Path: file.DisplayPath(info), - }) - } - - var totalSize uint64 - for _, file := range files { - totalSize += uint64(file.Size) - } - - zap.L().Debug( - "Flushing metadata...", - zap.ByteString("infoHash", infoHash[:]), - ) - - ms.flush(Metadata{ - InfoHash: infoHash[:], - Name: info.Name, - TotalSize: totalSize, - DiscoveredOn: time.Now().Unix(), - Files: files, - }) -} - -// COPIED FROM anacrolix/torrent -func validateInfo(info *metainfo.Info) error { - if len(info.Pieces)%20 != 0 { - return errors.New("pieces has invalid length") - } - if info.PieceLength == 0 { - if info.TotalLength() != 0 { - return errors.New("zero piece length") - } - } else { - if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() { - return errors.New("piece count and file lengths are at odds") - } - } - return nil -} - -func writeAll(c *net.TCPConn, b []byte) error { - for len(b) != 0 { - n, err := c.Write(b) - if err != nil { - return err - } - b = b[n:] - } - return nil -} - -func readExactly(c *net.TCPConn, n int) ([]byte, error) { - b := make([]byte, n) - _, err := io.ReadFull(c, b) - return b, err -} - -// TODO: add bounds checking! -func intToBigEndian(i int, n int) []byte { - b := make([]byte, n) - switch n { - case 1: - b = []byte{byte(i)} - - case 2: - binary.BigEndian.PutUint16(b, uint16(i)) - - case 4: - binary.BigEndian.PutUint32(b, uint32(i)) - - default: - panic(fmt.Sprintf("n must be 1, 2, or 4!")) - } - - if len(b) != n { - panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b)) - } - - return b -} - -func bigEndianToInt(b []byte) int { - switch len(b) { - case 1: - return int(b[0]) - - case 2: - return int(binary.BigEndian.Uint16(b)) - - case 4: - return int(binary.BigEndian.Uint32(b)) - - default: - panic(fmt.Sprintf("bigEndianToInt: b is too long! (%d bytes)", len(b))) - } -} diff --git a/cmd/magneticod/dht/mainline/protocol.go b/cmd/magneticod/dht/mainline/protocol.go index 53459fe..9ec2849 100644 --- a/cmd/magneticod/dht/mainline/protocol.go +++ b/cmd/magneticod/dht/mainline/protocol.go @@ -26,11 +26,12 @@ type ProtocolEventHandlers struct { OnGetPeersResponse func(*Message, net.Addr) OnFindNodeResponse func(*Message, net.Addr) OnPingORAnnouncePeerResponse func(*Message, net.Addr) + OnCongestion func() } func NewProtocol(laddr string, eventHandlers ProtocolEventHandlers) (p *Protocol) { p = new(Protocol) - p.transport = NewTransport(laddr, p.onMessage) + p.transport = NewTransport(laddr, p.onMessage, p.eventHandlers.OnCongestion) p.eventHandlers = eventHandlers p.currentTokenSecret, p.previousTokenSecret = make([]byte, 20), make([]byte, 20) diff --git a/cmd/magneticod/dht/mainline/service.go b/cmd/magneticod/dht/mainline/service.go index c5c8c0a..e0e754e 100644 --- a/cmd/magneticod/dht/mainline/service.go +++ b/cmd/magneticod/dht/mainline/service.go @@ -31,13 +31,14 @@ type TrawlingService struct { // ^~~~~~ routingTable map[string]net.Addr routingTableMutex *sync.Mutex + maxNeighbors uint } type TrawlingServiceEventHandlers struct { OnResult func(TrawlingResult) } -func NewTrawlingService(laddr string, eventHandlers TrawlingServiceEventHandlers) *TrawlingService { +func NewTrawlingService(laddr string, initialMaxNeighbors uint, eventHandlers TrawlingServiceEventHandlers) *TrawlingService { service := new(TrawlingService) service.protocol = NewProtocol( laddr, @@ -45,12 +46,14 @@ func NewTrawlingService(laddr string, eventHandlers TrawlingServiceEventHandlers OnGetPeersQuery: service.onGetPeersQuery, OnAnnouncePeerQuery: service.onAnnouncePeerQuery, OnFindNodeResponse: service.onFindNodeResponse, + OnCongestion: service.onCongestion, }, ) service.trueNodeID = make([]byte, 20) service.routingTable = make(map[string]net.Addr) service.routingTableMutex = new(sync.Mutex) service.eventHandlers = eventHandlers + service.maxNeighbors = initialMaxNeighbors _, err := rand.Read(service.trueNodeID) if err != nil { @@ -78,11 +81,14 @@ func (s *TrawlingService) Terminate() { func (s *TrawlingService) trawl() { for range time.Tick(3 * time.Second) { + s.maxNeighbors = uint(float32(s.maxNeighbors) * 1.01) + s.routingTableMutex.Lock() if len(s.routingTable) == 0 { s.bootstrap() } else { - zap.L().Debug("Latest status:", zap.Int("n", len(s.routingTable))) + zap.L().Warn("Latest status:", zap.Int("n", len(s.routingTable)), + zap.Uint("maxNeighbors", s.maxNeighbors)) s.findNeighbors() s.routingTable = make(map[string]net.Addr) } @@ -185,11 +191,34 @@ func (s *TrawlingService) onFindNodeResponse(response *Message, addr net.Addr) { s.routingTableMutex.Lock() defer s.routingTableMutex.Unlock() + zap.L().Debug("find node response!!", zap.Uint("maxNeighbors", s.maxNeighbors), + zap.Int("response.R.Nodes length", len(response.R.Nodes))) + for _, node := range response.R.Nodes { - if node.Addr.Port != 0 { // Ignore nodes who "use" port 0. - if len(s.routingTable) < 8000 { - s.routingTable[string(node.ID)] = &node.Addr - } + if uint(len(s.routingTable)) >= s.maxNeighbors { + break } + if node.Addr.Port == 0 { // Ignore nodes who "use" port 0. + zap.L().Debug("ignoring 0 port!!!") + continue + } + + s.routingTable[string(node.ID)] = &node.Addr } } + +func (s *TrawlingService) onCongestion() { + /* The Congestion Prevention Strategy: + * + * In case of congestion, decrease the maximum number of nodes to the 90% of the current value. + */ + if s.maxNeighbors < 200 { + zap.L().Warn("Max. number of neighbours are < 200 and there is still congestion!" + + "(check your network connection if this message recurs)") + return + } + + s.maxNeighbors = uint(float32(s.maxNeighbors) * 0.9) + zap.L().Debug("Max. number of neighbours updated!", + zap.Uint("s.maxNeighbors", s.maxNeighbors)) +} diff --git a/cmd/magneticod/dht/mainline/transport.go b/cmd/magneticod/dht/mainline/transport.go index 2d97187..60f37e7 100644 --- a/cmd/magneticod/dht/mainline/transport.go +++ b/cmd/magneticod/dht/mainline/transport.go @@ -4,6 +4,7 @@ import ( "net" "github.com/anacrolix/torrent/bencode" + sockaddr "github.com/libp2p/go-sockaddr/net" "go.uber.org/zap" "golang.org/x/sys/unix" ) @@ -12,23 +13,43 @@ type Transport struct { fd int laddr *net.UDPAddr started bool + buffer []byte // OnMessage is the function that will be called when Transport receives a packet that is // successfully unmarshalled as a syntactically correct Message (but -of course- the checking // the semantic correctness of the Message is left to Protocol). onMessage func(*Message, net.Addr) + // OnCongestion + onCongestion func() } -func NewTransport(laddr string, onMessage func(*Message, net.Addr)) *Transport { - transport := new(Transport) - transport.onMessage = onMessage +func NewTransport(laddr string, onMessage func(*Message, net.Addr), onCongestion func()) *Transport { + t := new(Transport) + /* The field size sets a theoretical limit of 65,535 bytes (8 byte header + 65,527 bytes of + * data) for a UDP datagram. However the actual limit for the data length, which is imposed by + * the underlying IPv4 protocol, is 65,507 bytes (65,535 − 8 byte UDP header − 20 byte IP + * header). + * + * In IPv6 jumbograms it is possible to have UDP packets of size greater than 65,535 bytes. + * RFC 2675 specifies that the length field is set to zero if the length of the UDP header plus + * UDP data is greater than 65,535. + * + * https://en.wikipedia.org/wiki/User_Datagram_Protocol + */ + t.buffer = make([]byte, 65507) + t.onMessage = onMessage + t.onCongestion = onCongestion + var err error - transport.laddr, err = net.ResolveUDPAddr("udp", laddr) + t.laddr, err = net.ResolveUDPAddr("udp", laddr) if err != nil { zap.L().Panic("Could not resolve the UDP address for the trawler!", zap.Error(err)) } + if t.laddr.IP.To4() == nil { + zap.L().Panic("IP address is not IPv4!") + } - return transport + return t } func (t *Transport) Start() { @@ -51,7 +72,12 @@ func (t *Transport) Start() { zap.L().Fatal("Could NOT create a UDP socket!", zap.Error(err)) } - unix.Bind(t.fd, unix.SockaddrInet4{Addr: t.laddr.IP, Port: t.laddr.Port}) + var ip [4]byte + copy(ip[:], t.laddr.IP.To4()) + err = unix.Bind(t.fd, &unix.SockaddrInet4{Addr: ip, Port: t.laddr.Port}) + if err != nil { + zap.L().Fatal("Could NOT bind the socket!", zap.Error(err)) + } go t.readMessages() } @@ -62,21 +88,33 @@ func (t *Transport) Terminate() { // readMessages is a goroutine! func (t *Transport) readMessages() { - buffer := make([]byte, 65536) - for { - n, from, err := unix.Recvfrom(t.fd, buffer, 0) - if err != nil { + n, fromSA, err := unix.Recvfrom(t.fd, t.buffer, 0) + if err == unix.EPERM || err == unix.ENOBUFS { // todo: are these errors possible for recvfrom? + zap.L().Warn("READ CONGESTION!", zap.Error(err)) + t.onCongestion() + } else if err != nil { // TODO: isn't there a more reliable way to detect if UDPConn is closed? zap.L().Debug("Could NOT read an UDP packet!", zap.Error(err)) } + if n == 0 { + /* Datagram sockets in various domains (e.g., the UNIX and Internet domains) permit + * zero-length datagrams. When such a datagram is received, the return value (n) is 0. + */ + zap.L().Debug("zero-length received!!") + continue + } + + from := sockaddr.SockaddrToUDPAddr(fromSA) + var msg Message - err = bencode.Unmarshal(buffer[:n], &msg) + err = bencode.Unmarshal(t.buffer[:n], &msg) if err != nil { zap.L().Debug("Could NOT unmarshal packet data!", zap.Error(err)) } + zap.L().Debug("message read! (first 20...)", zap.ByteString("msg", t.buffer[:20])) t.onMessage(&msg, from) } } @@ -87,9 +125,29 @@ func (t *Transport) WriteMessages(msg *Message, addr net.Addr) { zap.L().Panic("Could NOT marshal an outgoing message! (Programmer error.)") } - err = unix.Sendto(t.fd, data, 0, addr) - // TODO: isn't there a more reliable way to detect if UDPConn is closed? - if err != nil { + addrSA := sockaddr.NetAddrToSockaddr(addr) + + zap.L().Debug("sent message!!!") + + err = unix.Sendto(t.fd, data, 0, addrSA) + if err == unix.EPERM || err == unix.ENOBUFS { + /* EPERM (errno: 1) is kernel's way of saying that "you are far too fast, chill". It is + * also likely that we have received a ICMP source quench packet (meaning, that we *really* + * need to slow down. + * + * Read more here: http://www.archivum.info/comp.protocols.tcp-ip/2009-05/00088/UDP-socket-amp-amp-sendto-amp-amp-EPERM.html + * + * > Note On BSD systems (OS X, FreeBSD, etc.) flow control is not supported for + * > DatagramProtocol, because send failures caused by writing too many packets cannot be + * > detected easily. The socket always appears ‘ready’ and excess packets are dropped; an + * > OSError with errno set to errno.ENOBUFS may or may not be raised; if it is raised, it + * > will be reported to DatagramProtocol.error_received() but otherwise ignored. + * + * Source: https://docs.python.org/3/library/asyncio-protocol.html#flow-control-callbacks + */ + zap.L().Warn("WRITE CONGESTION!", zap.Error(err)) + t.onCongestion() + } else if err != nil { zap.L().Debug("Could NOT write an UDP packet!", zap.Error(err)) } } diff --git a/cmd/magneticod/dht/managers.go b/cmd/magneticod/dht/managers.go index c26ccaf..95990d0 100644 --- a/cmd/magneticod/dht/managers.go +++ b/cmd/magneticod/dht/managers.go @@ -18,6 +18,7 @@ func NewTrawlingManager(mlAddrs []string) *TrawlingManager { for _, addr := range mlAddrs { manager.services = append(manager.services, mainline.NewTrawlingService( addr, + 2000, mainline.TrawlingServiceEventHandlers{ OnResult: manager.onResult, }, diff --git a/cmd/magneticod/main.go b/cmd/magneticod/main.go index 0a35e4f..9d92ff1 100644 --- a/cmd/magneticod/main.go +++ b/cmd/magneticod/main.go @@ -6,7 +6,6 @@ import ( "net" "os" "os/signal" - "path" "runtime/pprof" "time" @@ -14,7 +13,7 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" - "github.com/boramalper/magnetico/cmd/magneticod/bittorrent" + "github.com/boramalper/magnetico/cmd/magneticod/bittorrent/metadata" "github.com/boramalper/magnetico/cmd/magneticod/dht" "github.com/Wessie/appdirs" @@ -75,6 +74,8 @@ func main() { zap.ReplaceGlobals(logger) + zap.L().Debug("debug message!") + switch opFlags.Profile { case "cpu": file, err := os.OpenFile("magneticod_cpu.prof", os.O_CREATE | os.O_WRONLY, 0755) @@ -96,19 +97,19 @@ func main() { interruptChan := make(chan os.Signal) signal.Notify(interruptChan, os.Interrupt) - database, err := persistence.MakeDatabase(opFlags.DatabaseURL, false, logger) + database, err := persistence.MakeDatabase(opFlags.DatabaseURL, logger) if err != nil { logger.Sugar().Fatalf("Could not open the database at `%s`: %s", opFlags.DatabaseURL, err.Error()) } trawlingManager := dht.NewTrawlingManager(opFlags.TrawlerMlAddrs) - metadataSink := bittorrent.NewMetadataSink(2 * time.Minute) + metadataSink := metadata.NewSink(2 * time.Minute) // The Event Loop for stopped := false; !stopped; { select { case result := <-trawlingManager.Output(): - zap.L().Info("Trawled!", zap.String("infoHash", result.InfoHash.String())) + zap.L().Debug("Trawled!", zap.String("infoHash", result.InfoHash.String())) exists, err := database.DoesTorrentExist(result.InfoHash[:]) if err != nil { zap.L().Fatal("Could not check whether torrent exists!", zap.Error(err)) @@ -144,10 +145,11 @@ func parseFlags() (*opFlags, error) { } if cmdF.DatabaseURL == "" { - opF.DatabaseURL = "sqlite3://" + path.Join( - appdirs.UserDataDir("magneticod", "", "", false), - "database.sqlite3", - ) + opF.DatabaseURL = + "sqlite3://" + + appdirs.UserDataDir("magneticod", "", "", false) + + "/database.sqlite3" + + "?_journal_mode=WAL" // https://github.com/mattn/go-sqlite3#connection-string } else { opF.DatabaseURL = cmdF.DatabaseURL } diff --git a/pkg/persistence/sqlite3.go b/pkg/persistence/sqlite3.go index 1579eda..131d89a 100644 --- a/pkg/persistence/sqlite3.go +++ b/pkg/persistence/sqlite3.go @@ -14,6 +14,9 @@ import ( "go.uber.org/zap" ) +// Close your rows lest you get "database table is locked" error(s)! +// See https://github.com/mattn/go-sqlite3/issues/2741 + type sqlite3Database struct { conn *sql.DB } @@ -55,15 +58,12 @@ func (db *sqlite3Database) DoesTorrentExist(infoHash []byte) (bool, error) { if err != nil { return false, err } + defer rows.Close() // If rows.Next() returns true, meaning that the torrent is in the database, return true; else // return false. exists := rows.Next() - if !exists && rows.Err() != nil { - return false, err - } - - if err = rows.Close(); err != nil { + if rows.Err() != nil { return false, err } @@ -143,6 +143,7 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) { if err != nil { return 0, err } + defer rows.Close() if rows.Next() != true { fmt.Errorf("No rows returned from `SELECT MAX(ROWID)`") @@ -153,10 +154,6 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) { return 0, err } - if err = rows.Close(); err != nil { - return 0, err - } - return n, nil } @@ -247,6 +244,7 @@ func (db *sqlite3Database) QueryTorrents( queryArgs = append(queryArgs, limit) rows, err := db.conn.Query(sqlQuery, queryArgs...) + defer rows.Close() if err != nil { return nil, fmt.Errorf("error while querying torrents: %s", err.Error()) } @@ -269,10 +267,6 @@ func (db *sqlite3Database) QueryTorrents( torrents = append(torrents, torrent) } - if err := rows.Close(); err != nil { - return nil, err - } - return torrents, nil } @@ -307,6 +301,7 @@ func (db *sqlite3Database) GetTorrent(infoHash []byte) (*TorrentMetadata, error) WHERE info_hash = ?`, infoHash, ) + defer rows.Close() if err != nil { return nil, err } @@ -320,10 +315,6 @@ func (db *sqlite3Database) GetTorrent(infoHash []byte) (*TorrentMetadata, error) return nil, err } - if err = rows.Close(); err != nil { - return nil, err - } - return &tm, nil } @@ -331,6 +322,7 @@ func (db *sqlite3Database) GetFiles(infoHash []byte) ([]File, error) { rows, err := db.conn.Query( "SELECT size, path FROM files, torrents WHERE files.torrent_id = torrents.id AND torrents.info_hash = ?;", infoHash) + defer rows.Close() if err != nil { return nil, err } @@ -344,10 +336,6 @@ func (db *sqlite3Database) GetFiles(infoHash []byte) ([]File, error) { files = append(files, file) } - if err := rows.Close(); err != nil { - return nil, err - } - return files, nil } @@ -391,6 +379,7 @@ func (db *sqlite3Database) GetStatistics(from string, n uint) (*Statistics, erro GROUP BY dt;`, timef), fromTime.Unix(), toTime.Unix()) + defer rows.Close() if err != nil { return nil, err } @@ -478,6 +467,7 @@ func (db *sqlite3Database) setupDatabase() error { if err != nil { return fmt.Errorf("sql.Tx.Query (user_version): %s", err.Error()) } + defer rows.Close() var userVersion int if rows.Next() != true { return fmt.Errorf("sql.Rows.Next (user_version): PRAGMA user_version did not return any rows!") @@ -485,11 +475,6 @@ func (db *sqlite3Database) setupDatabase() error { if err = rows.Scan(&userVersion); err != nil { return fmt.Errorf("sql.Rows.Scan (user_version): %s", err.Error()) } - // Close your rows lest you get "database table is locked" error(s)! - // See https://github.com/mattn/go-sqlite3/issues/2741 - if err = rows.Close(); err != nil { - return fmt.Errorf("sql.Rows.Close (user_version): %s", err.Error()) - } switch userVersion { case 0: // FROZEN.