package metadata import ( "bytes" "crypto/rand" "crypto/sha1" "encoding/binary" "fmt" "io" "math" "net" "time" "github.com/anacrolix/torrent/bencode" "github.com/anacrolix/torrent/metainfo" "github.com/pkg/errors" "go.uber.org/zap" "github.com/boramalper/magnetico/pkg/persistence" ) const MAX_METADATA_SIZE = 10 * 1024 * 1024 type rootDict struct { M mDict `bencode:"m"` MetadataSize int `bencode:"metadata_size"` } type mDict struct { UTMetadata int `bencode:"ut_metadata"` } type extDict struct { MsgType int `bencode:"msg_type"` Piece int `bencode:"piece"` } type Leech struct { infoHash [20]byte peerAddr *net.TCPAddr ev LeechEventHandlers conn *net.TCPConn clientID [20]byte ut_metadata uint8 metadataReceived, metadataSize uint metadata []byte } type LeechEventHandlers struct { OnSuccess func(Metadata) // must be supplied. args: metadata OnError func([20]byte, error) // must be supplied. args: infohash, error } func NewLeech(infoHash [20]byte, peerAddr *net.TCPAddr, ev LeechEventHandlers) *Leech { l := new(Leech) l.infoHash = infoHash l.peerAddr = peerAddr l.ev = ev if _, err := rand.Read(l.clientID[:]); err != nil { panic(err.Error()) } return l } func (l *Leech) writeAll(b []byte) error { for len(b) != 0 { n, err := l.conn.Write(b) if err != nil { return err } b = b[n:] } return nil } func (l *Leech) doBtHandshake() error { lHandshake := []byte(fmt.Sprintf( "\x13BitTorrent protocol\x00\x00\x00\x00\x00\x10\x00\x01%s%s", l.infoHash, l.clientID, )) // ASSERTION if len(lHandshake) != 68 { panic(fmt.Sprintf("len(lHandshake) == %d", len(lHandshake))) } err := l.writeAll(lHandshake) if err != nil { return errors.Wrap(err, "writeAll lHandshake") } rHandshake, err := l.readExactly(68) if err != nil { return errors.Wrap(err, "readExactly rHandshake") } if !bytes.HasPrefix(rHandshake, []byte("\x13BitTorrent protocol")) { return fmt.Errorf("corrupt BitTorrent handshake received") } // TODO: maybe check for the infohash sent by the remote peer to double check? if (rHandshake[25] & 0x10) == 0 { return fmt.Errorf("peer does not support the extension protocol") } return nil } func (l *Leech) doExHandshake() error { err := l.writeAll([]byte("\x00\x00\x00\x1a\x14\x00d1:md11:ut_metadatai1eee")) if err != nil { return errors.Wrap(err, "writeAll lHandshake") } rExMessage, err := l.readExMessage() if err != nil { return errors.Wrap(err, "readExMessage") } // Extension Handshake has the Extension Message ID = 0x00 if rExMessage[1] != 0 { return errors.Wrap(err, "first extension message is not an extension handshake") } rRootDict := new(rootDict) err = bencode.Unmarshal(rExMessage[2:], rRootDict) if err != nil { return errors.Wrap(err, "unmarshal rExMessage") } if !(0 < rRootDict.MetadataSize && rRootDict.MetadataSize < MAX_METADATA_SIZE) { return fmt.Errorf("metadata too big or its size is less than or equal zero") } if !(0 < rRootDict.M.UTMetadata && rRootDict.M.UTMetadata < 255) { return fmt.Errorf("ut_metadata is not an uint8") } l.ut_metadata = uint8(rRootDict.M.UTMetadata) // Save the ut_metadata code the remote peer uses l.metadataSize = uint(rRootDict.MetadataSize) l.metadata = make([]byte, l.metadataSize) return nil } func (l *Leech) requestAllPieces() error { // Request all the pieces of metadata nPieces := int(math.Ceil(float64(l.metadataSize) / math.Pow(2, 14))) for piece := 0; piece < nPieces; piece++ { // __request_metadata_piece(piece) // ............................... extDictDump, err := bencode.Marshal(extDict{ MsgType: 0, Piece: piece, }) if err != nil { // ASSERT panic(errors.Wrap(err, "marshal extDict")) } err = l.writeAll([]byte(fmt.Sprintf( "%s\x14%s%s", toBigEndian(uint(2+len(extDictDump)), 4), toBigEndian(uint(l.ut_metadata), 1), extDictDump, ))) if err != nil { return errors.Wrap(err, "writeAll piece request") } } return nil } // readMessage returns a BitTorrent message, sans the first 4 bytes indicating its length. func (l *Leech) readMessage() ([]byte, error) { rLengthB, err := l.readExactly(4) if err != nil { return nil, errors.Wrap(err, "readExactly rLengthB") } rLength := uint(binary.BigEndian.Uint32(rLengthB)) rMessage, err := l.readExactly(rLength) if err != nil { return nil, errors.Wrap(err, "readExactly rMessage") } return rMessage, nil } // readExMessage returns an *extension* message, sans the first 4 bytes indicating its length. // // It will IGNORE all non-extension messages! func (l *Leech) readExMessage() ([]byte, error) { for { rMessage, err := l.readMessage() if err != nil { return nil, errors.Wrap(err, "readMessage") } // Every extension message has at least 2 bytes. if len(rMessage) < 2 { continue } // We are interested only in extension messages, whose first byte is always 20 if rMessage[0] == 20 { return rMessage, nil } } } // readUmMessage returns an ut_metadata extension message, sans the first 4 bytes indicating its // length. // // It will IGNORE all non-"ut_metadata extension" messages! func (l *Leech) readUmMessage() ([]byte, error) { for { rExMessage, err := l.readExMessage() if err != nil { return nil, errors.Wrap(err, "readExMessage") } if rExMessage[1] == 0x01 { return rExMessage, nil } } } func (l *Leech) connect(deadline time.Time) error { var err error l.conn, err = net.DialTCP("tcp4", nil, l.peerAddr) if err != nil { return errors.Wrap(err, "dial") } // > If sec == 0, operating system discards any unsent or unacknowledged data [after Close() // > has been called]. err = l.conn.SetLinger(0) if err != nil { if err := l.conn.Close(); err != nil { zap.L().Panic("couldn't close leech connection!", zap.Error(err)) } return errors.Wrap(err, "SetLinger") } err = l.conn.SetKeepAlive(true) if err != nil { if err := l.conn.Close(); err != nil { zap.L().Panic("couldn't close leech connection!", zap.Error(err)) } return errors.Wrap(err, "SetKeepAlive") } err = l.conn.SetKeepAlivePeriod(10 * time.Second) if err != nil { if err := l.conn.Close(); err != nil { zap.L().Panic("couldn't close leech connection!", zap.Error(err)) } return errors.Wrap(err, "SetKeepAlivePeriod") } err = l.conn.SetNoDelay(true) if err != nil { if err := l.conn.Close(); err != nil { zap.L().Panic("couldn't close leech connection!", zap.Error(err)) } return errors.Wrap(err, "NODELAY") } err = l.conn.SetDeadline(deadline) if err != nil { if err := l.conn.Close(); err != nil { zap.L().Panic("couldn't close leech connection!", zap.Error(err)) } return errors.Wrap(err, "SetDeadline") } return nil } func (l *Leech) Do(deadline time.Time) { err := l.connect(deadline) if err != nil { l.OnError(errors.Wrap(err, "connect")) return } defer func() { if err := l.conn.Close(); err != nil { zap.L().Panic("couldn't close leech connection!", zap.Error(err)) } }() err = l.doBtHandshake() if err != nil { l.OnError(errors.Wrap(err, "doBtHandshake")) return } err = l.doExHandshake() if err != nil { l.OnError(errors.Wrap(err, "doExHandshake")) return } err = l.requestAllPieces() if err != nil { l.OnError(errors.Wrap(err, "requestAllPieces")) return } for l.metadataReceived < l.metadataSize { rUmMessage, err := l.readUmMessage() if err != nil { l.OnError(errors.Wrap(err, "readUmMessage")) return } // Run TestDecoder() function in leech_test.go in case you have any doubts. rMessageBuf := bytes.NewBuffer(rUmMessage[2:]) rExtDict := new(extDict) err = bencode.NewDecoder(rMessageBuf).Decode(rExtDict) if err != nil { zap.L().Warn("Couldn't decode extension message in the loop!", zap.Error(err)) return } if rExtDict.MsgType == 2 { // reject l.OnError(fmt.Errorf("remote peer rejected sending metadata")) return } if rExtDict.MsgType == 1 { // data // Get the unread bytes! metadataPiece := rMessageBuf.Bytes() // BEP 9 explicitly states: // > If the piece is the last piece of the metadata, it may be less than 16kiB. If // > it is not the last piece of the metadata, it MUST be 16kiB. // // Hence... // ... if the length of @metadataPiece is more than 16kiB, we err. if len(metadataPiece) > 16*1024 { l.OnError(fmt.Errorf("metadataPiece > 16kiB")) return } piece := rExtDict.Piece // metadata[piece * 2**14: piece * 2**14 + len(metadataPiece)] = metadataPiece is how it'd be done in Python copy(l.metadata[piece*int(math.Pow(2, 14)):piece*int(math.Pow(2, 14))+len(metadataPiece)], metadataPiece) l.metadataReceived += uint(len(metadataPiece)) // ... if the length of @metadataPiece is less than 16kiB AND metadata is NOT // complete then we err. if len(metadataPiece) < 16*1024 && l.metadataReceived != l.metadataSize { l.OnError(fmt.Errorf("metadataPiece < 16 kiB but incomplete")) return } if l.metadataReceived > l.metadataSize { l.OnError(fmt.Errorf("metadataReceived > metadataSize")) return } } } // Verify the checksum sha1Sum := sha1.Sum(l.metadata) if !bytes.Equal(sha1Sum[:], l.infoHash[:]) { l.OnError(fmt.Errorf("infohash mismatch")) return } // Check the info dictionary info := new(metainfo.Info) err = bencode.Unmarshal(l.metadata, info) if err != nil { l.OnError(errors.Wrap(err, "unmarshal info")) return } err = validateInfo(info) if err != nil { l.OnError(errors.Wrap(err, "validateInfo")) return } var files []persistence.File // If there is only one file, there won't be a Files slice. That's why we need to add it here if len(info.Files) == 0 { files = append(files, persistence.File{ Size: info.Length, Path: info.Name, }) } else { for _, file := range info.Files { files = append(files, persistence.File{ Size: file.Length, Path: file.DisplayPath(info), }) } } var totalSize uint64 for _, file := range files { if file.Size < 0 { l.OnError(fmt.Errorf("file size less than zero")) return } totalSize += uint64(file.Size) } l.ev.OnSuccess(Metadata{ InfoHash: l.infoHash[:], Name: info.Name, TotalSize: totalSize, DiscoveredOn: time.Now().Unix(), Files: files, }) } // COPIED FROM anacrolix/torrent func validateInfo(info *metainfo.Info) error { if len(info.Pieces)%20 != 0 { return errors.New("pieces has invalid length") } if info.PieceLength == 0 { if info.TotalLength() != 0 { return errors.New("zero piece length") } } else { if int((info.TotalLength()+info.PieceLength-1)/info.PieceLength) != info.NumPieces() { return errors.New("piece count and file lengths are at odds") } } return nil } func (l *Leech) readExactly(n uint) ([]byte, error) { b := make([]byte, n) _, err := io.ReadFull(l.conn, b) return b, err } func (l *Leech) OnError(err error) { l.ev.OnError(l.infoHash, err) } // TODO: add bounds checking! func toBigEndian(i uint, n int) []byte { b := make([]byte, n) switch n { case 1: b = []byte{byte(i)} case 2: binary.BigEndian.PutUint16(b, uint16(i)) case 4: binary.BigEndian.PutUint32(b, uint32(i)) default: panic(fmt.Sprintf("n must be 1, 2, or 4!")) } if len(b) != n { panic(fmt.Sprintf("postcondition failed: len(b) != n in intToBigEndian (i %d, n %d, len b %d, b %s)", i, n, len(b), b)) } return b }