diff --git a/src/magneticod/bittorrent/operations.go b/src/magneticod/bittorrent/operations.go index fec2768..c9bcf78 100644 --- a/src/magneticod/bittorrent/operations.go +++ b/src/magneticod/bittorrent/operations.go @@ -4,6 +4,7 @@ import ( "time" "strings" + "github.com/anacrolix/missinggo" "github.com/anacrolix/torrent" "github.com/anacrolix/torrent/metainfo" "go.uber.org/zap" @@ -18,19 +19,18 @@ func (ms *MetadataSink) awaitMetadata(infoHash metainfo.Hash, peer torrent.Peer) // fetched. t.AddPeers([]torrent.Peer{peer}) if !isNew { - // If the recently added torrent is not new, then quit as we do not want multiple - // awaitMetadata goroutines waiting on the same torrent. + // Return immediately if we are trying to await on an ongoing metadata-fetching operation. + // Each ongoing operation should have one and only one "await*" function waiting on it. return - } else { - // Drop the torrent once we return from this function, whether we got the metadata or an - // error. - defer t.Drop() } // Wait for the torrent client to receive the metadata for the torrent, meanwhile allowing // termination to be handled gracefully. + var info *metainfo.Info select { case <- t.GotInfo(): + info = t.Info() + t.Drop() case <-time.After(5 * time.Minute): zap.L().Sugar().Debugf("Fetcher timeout! %x", infoHash) @@ -40,7 +40,6 @@ func (ms *MetadataSink) awaitMetadata(infoHash metainfo.Hash, peer torrent.Peer) return } - info := t.Info() var files []metainfo.FileInfo if len(info.Files) == 0 { if strings.ContainsRune(info.Name, '/') { @@ -75,3 +74,111 @@ func (ms *MetadataSink) awaitMetadata(infoHash metainfo.Hash, peer torrent.Peer) Files: files, }) } + + +func (fs *FileSink) awaitFile(infoHash []byte, filePath string, peer *torrent.Peer) { + var infoHash_ [20]byte + copy(infoHash_[:], infoHash) + t, isNew := fs.client.AddTorrentInfoHash(infoHash_) + if peer != nil { + t.AddPeers([]torrent.Peer{*peer}) + } + if !isNew { + // Return immediately if we are trying to await on an ongoing file-downloading operation. + // Each ongoing operation should have one and only one "await*" function waiting on it. + return + } + + // Setup & start the timeout timer. + timeout := time.After(fs.timeout) + + // Once we return from this function, drop the torrent from the client. + // TODO: Check if dropping a torrent also cancels any outstanding read operations? + defer t.Drop() + + select { + case <-t.GotInfo(): + + case <- timeout: + return + } + + var match *torrent.File + for _, file := range t.Files() { + if file.Path() == filePath { + match = &file + } else { + file.Cancel() + } + } + if match == nil { + var filePaths []string + for _, file := range t.Files() { filePaths = append(filePaths, file.Path()) } + + zap.L().Warn( + "The leech (FileSink) has been requested to download a file which does not exist!", + zap.ByteString("torrent", infoHash), + zap.String("requestedFile", filePath), + zap.Strings("allFiles", filePaths), + ) + } + + + reader := t.NewReader() + defer reader.Close() + + fileDataChan := make(chan []byte) + go downloadFile(*match, reader, fileDataChan) + + select { + case fileData := <-fileDataChan: + if fileData != nil { + fs.flush(File{ + torrentInfoHash: infoHash, + path: match.Path(), + data: fileData, + }) + } + + case <- timeout: + zap.L().Debug( + "Timeout while downloading a file!", + zap.ByteString("torrent", infoHash), + zap.String("file", filePath), + ) + } +} + + +func downloadFile(file torrent.File, reader *torrent.Reader, fileDataChan chan<- []byte) { + readSeeker := missinggo.NewSectionReadSeeker(reader, file.Offset(), file.Length()) + + fileData := make([]byte, file.Length()) + n, err := readSeeker.Read(fileData) + if int64(n) != file.Length() { + zap.L().Debug( + "Not all of a file could be read!", + zap.ByteString("torrent", file.Torrent().InfoHash()[:]), + zap.String("file", file.Path()), + zap.Int64("fileLength", file.Length()), + zap.Int("n", n), + ) + fileDataChan <- nil + return + } + if err != nil { + zap.L().Debug( + "Error while downloading a file!", + zap.Error(err), + zap.ByteString("torrent", file.Torrent().InfoHash()[:]), + zap.String("file", file.Path()), + zap.Int64("fileLength", file.Length()), + zap.Int("n", n), + ) + fileDataChan <- nil + return + } + + fileDataChan <- fileData +} + diff --git a/src/magneticod/bittorrent/sinkFile.go b/src/magneticod/bittorrent/sinkFile.go new file mode 100644 index 0000000..2592fea --- /dev/null +++ b/src/magneticod/bittorrent/sinkFile.go @@ -0,0 +1,105 @@ +package bittorrent + +import ( + "net" + "path" + "time" + + "github.com/anacrolix/dht" + "github.com/anacrolix/torrent" + "github.com/anacrolix/torrent/storage" + "github.com/Wessie/appdirs" + "go.uber.org/zap" +) + + +type File struct{ + torrentInfoHash []byte + path string + data []byte +} + + +type FileSink struct { + client *torrent.Client + drain chan File + terminated bool + termination chan interface{} + + timeout time.Duration +} + +// NewFileSink creates a new FileSink. +// +// cAddr : client address +// mlAddr: mainline DHT node address +func NewFileSink(cAddr, mlAddr string, timeout time.Duration) *FileSink { + fs := new(FileSink) + + mlUDPAddr, err := net.ResolveUDPAddr("udp", mlAddr) + if err != nil { + zap.L().Fatal("Could NOT resolve UDP addr!", zap.Error(err)) + return nil + } + + // Make sure to close the mlUDPConn before returning from this function in case of an error. + mlUDPConn, err := net.ListenUDP("udp", mlUDPAddr) + if err != nil { + zap.L().Fatal("Could NOT listen UDP (file sink)!", zap.Error(err)) + return nil + } + + fs.client, err = torrent.NewClient(&torrent.Config{ + ListenAddr: cAddr, + DisableTrackers: true, + DHTConfig: dht.ServerConfig{ + Conn: mlUDPConn, + Passive: true, + NoSecurity: true, + }, + DefaultStorage: storage.NewFileByInfoHash(path.Join( + appdirs.UserCacheDir("magneticod", "", "", true), + "downloads", + )), + }) + if err != nil { + zap.L().Fatal("Leech could NOT create a new torrent client!", zap.Error(err)) + mlUDPConn.Close() + return nil + } + + fs.drain = make(chan File) + fs.termination = make(chan interface{}) + fs.timeout = timeout + + return fs +} + + +// peer might be nil +func (fs *FileSink) Sink(infoHash []byte, filePath string, peer *torrent.Peer) { + go fs.awaitFile(infoHash, filePath, peer) +} + + +func (fs *FileSink) Drain() <-chan File { + if fs.terminated { + zap.L().Panic("Trying to Drain() an already closed FileSink!") + } + return fs.drain +} + + +func (fs *FileSink) Terminate() { + fs.terminated = true + close(fs.termination) + fs.client.Close() + close(fs.drain) +} + + +func (fs *FileSink) flush(result File) { + if !fs.terminated { + fs.drain <- result + } +} diff --git a/src/magneticod/bittorrent/sink.go b/src/magneticod/bittorrent/sinkMetadata.go similarity index 90% rename from src/magneticod/bittorrent/sink.go rename to src/magneticod/bittorrent/sinkMetadata.go index 09f7192..a01c9fe 100644 --- a/src/magneticod/bittorrent/sink.go +++ b/src/magneticod/bittorrent/sinkMetadata.go @@ -1,12 +1,13 @@ package bittorrent import ( + "net" + "go.uber.org/zap" "github.com/anacrolix/torrent" "github.com/anacrolix/torrent/metainfo" "magneticod/dht/mainline" - "net" ) @@ -23,7 +24,6 @@ type Metadata struct { type MetadataSink struct { - activeInfoHashes []metainfo.Hash client *torrent.Client drain chan Metadata terminated bool @@ -58,7 +58,6 @@ func (ms *MetadataSink) Sink(res mainline.TrawlingResult) { zap.L().Panic("Trying to Sink() an already closed MetadataSink!") } - ms.activeInfoHashes = append(ms.activeInfoHashes, res.InfoHash) go ms.awaitMetadata(res.InfoHash, res.Peer) } @@ -67,7 +66,6 @@ func (ms *MetadataSink) Drain() <-chan Metadata { if ms.terminated { zap.L().Panic("Trying to Drain() an already closed MetadataSink!") } - return ms.drain } @@ -80,8 +78,8 @@ func (ms *MetadataSink) Terminate() { } -func (ms *MetadataSink) flush(metadata Metadata) { +func (ms *MetadataSink) flush(result Metadata) { if !ms.terminated { - ms.drain <- metadata + ms.drain <- result } } diff --git a/src/magneticod/dht/mainline/service.go b/src/magneticod/dht/mainline/service.go index 0533e14..305371b 100644 --- a/src/magneticod/dht/mainline/service.go +++ b/src/magneticod/dht/mainline/service.go @@ -194,7 +194,9 @@ func (s *TrawlingService) onFindNodeResponse(response *Message, addr net.Addr) { for _, node := range response.R.Nodes { if node.Addr.Port != 0 { // Ignore nodes who "use" port 0. - s.routingTable[string(node.ID)] = &node.Addr + if len(s.routingTable) < 10000 { + s.routingTable[string(node.ID)] = &node.Addr + } } } } diff --git a/src/magneticod/main.go b/src/magneticod/main.go index 2c743c2..c7ab3f3 100644 --- a/src/magneticod/main.go +++ b/src/magneticod/main.go @@ -6,37 +6,40 @@ import ( "os/signal" "regexp" + "github.com/jessevdk/go-flags" + "github.com/pkg/profile" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "github.com/jessevdk/go-flags" "magneticod/bittorrent" "magneticod/dht" + "fmt" + "time" ) - type cmdFlags struct { - Database string `long:"database" description:"URL of the database."` + DatabaseURL string `long:"database" description:"URL of the database."` - MlTrawlerAddrs []string `long:"ml-trawler-addrs" description:"Address(es) to be used by trawling DHT (Mainline) nodes." default:"0.0.0.0:0"` - TrawlingInterval uint `long:"trawling-interval" description:"Trawling interval in integer seconds."` + TrawlerMlAddrs []string `long:"trawler-ml-addr" description:"Address(es) to be used by trawling DHT (Mainline) nodes." default:"0.0.0.0:0"` + TrawlerMlInterval uint `long:"trawler-ml-interval" description:"Trawling interval in integer deciseconds (one tenth of a second)."` // TODO: is this even supported by anacrolix/torrent? - FetcherAddr string `long:"fetcher-addr" description:"Address(es) to be used by ephemeral peers fetching torrent metadata." default:"0.0.0.0:0"` - FetcherTimeout uint `long:"fetcher-timeout" description:"Number of integer seconds before a fetcher timeouts."` - // TODO: is this even supported by anacrolix/torrent? - MaxMetadataSize uint `long:"max-metadata-size" description:"Maximum metadata size -which must be greater than zero- in bytes."` + FetcherAddr string `long:"fetcher-addr" description:"Address(es) to be used by ephemeral peers fetching torrent metadata." default:"0.0.0.0:0"` + FetcherTimeout uint `long:"fetcher-timeout" description:"Number of integer seconds before a fetcher timeouts."` - MlStatisticianAddrs []string `long:"ml-statistician-addrs" description:"Address(es) to be used by ephemeral nodes fetching latest statistics about individual torrents." default:"0.0.0.0:0"` - StatisticianTimeout uint `long:"statistician-timeout" description:"Number of integer seconds before a statistician timeouts."` + StatistMlAddrs []string `long:"statist-ml-addr" description:"Address(es) to be used by ephemeral nodes fetching latest statistics about individual torrents." default:"0.0.0.0:0"` + StatistMlTimeout uint `long:"statist-ml-timeout" description:"Number of integer seconds before a statist timeouts."` // TODO: is this even supported by anacrolix/torrent? - LeechAddr string `long:"leech-addr" description:"Address(es) to be used by ephemeral peers fetching README files." default:"0.0.0.0:0"` - LeechTimeout uint `long:"leech-timeout" description:"Number of integer seconds before a leech timeouts."` - MaxDescriptionSize uint `long:"max-description-size" description:"Maximum size -which must be greater than zero- of a description file in bytes"` - DescriptionNames []string `long:"description-names" description:"Regular expression(s) which will be tested against the name of the description files, in the supplied order."` + LeechClAddr string `long:"leech-cl-addr" description:"Address to be used by the peer fetching README files." default:"0.0.0.0:0"` + LeechMlAddr string `long:"leech-ml-addr" descrition:"Address to be used by the mainline DHT node for fetching README files." default:"0.0.0.0:0"` + LeechTimeout uint `long:"leech-timeout" description:"Number of integer seconds to pass before a leech timeouts." default:"300"` + ReadmeMaxSize uint `long:"readme-max-size" description:"Maximum size -which must be greater than zero- of a description file in bytes." default:"20480"` + ReadmeRegexes []string `long:"readme-regex" description:"Regular expression(s) which will be tested against the name of the README files, in the supplied order."` - Verbose []bool `short:"v" long:"verbose" description:"Increases verbosity."` + Verbose []bool `short:"v" long:"verbose" description:"Increases verbosity."` + + Profile string `long:"profile" description:"Enable profiling." default:""` // ==== Deprecated Flags ==== // TODO: don't even support deprecated flags! @@ -53,30 +56,39 @@ type cmdFlags struct { // DatabaseFile string } +const ( + PROFILE_BLOCK = 1 + PROFILE_CPU + PROFILE_MEM + PROFILE_MUTEX + PROFILE_A +) type opFlags struct { - Database string + DatabaseURL string - MlTrawlerAddrs []net.UDPAddr - TrawlingInterval uint + TrawlerMlAddrs []string + TrawlerMlInterval time.Duration - FetcherAddr net.TCPAddr - FetcherTimeout uint // TODO: is this even supported by anacrolix/torrent? - MaxMetadataSize uint + FetcherAddr string + FetcherTimeout time.Duration - MlStatisticianAddrs []net.UDPAddr - StatisticianTimeout uint + StatistMlAddrs []string + StatistMlTimeout time.Duration - LeechAddr net.TCPAddr - LeechTimeout uint - MaxDescriptionSize uint - DescriptionNames []regexp.Regexp + // TODO: is this even supported by anacrolix/torrent? + LeechClAddr string + LeechMlAddr string + LeechTimeout time.Duration + ReadmeMaxSize uint + ReadmeRegexes []*regexp.Regexp - Verbosity uint + Verbosity int + + Profile string } - func main() { atom := zap.NewAtomicLevel() // Logging levels: ("debug", "info", "warn", "error", "dpanic", "panic", and "fatal"). @@ -88,6 +100,8 @@ func main() { defer logger.Sync() zap.ReplaceGlobals(logger) + defer profile.Start(profile.CPUProfile, profile.ProfilePath(".")).Stop() + zap.L().Info("magneticod v0.7.0 has been started.") zap.L().Info("Copyright (C) 2017 Mert Bora ALPER .") zap.L().Info("Dedicated to Cemile Binay, in whose hands I thrived.") @@ -95,8 +109,6 @@ func main() { // opFlags is the "operational flags" opFlags := parseFlags() - logger.Sugar().Warn(">>>", opFlags.MlTrawlerAddrs) - switch opFlags.Verbosity { case 0: atom.SetLevel(zap.WarnLevel) @@ -110,10 +122,10 @@ func main() { zap.ReplaceGlobals(logger) /* - updating_manager := nil - statistics_sink := nil - completing_manager := nil - file_sink := nil + updating_manager := nil + statistics_sink := nil + completing_manager := nil + file_sink := nil */ // Handle Ctrl-C gracefully. interrupt_chan := make(chan os.Signal) @@ -124,14 +136,15 @@ func main() { logger.Sugar().Fatalf("Could not open the database at `%s`: %s", opFlags.Database, err.Error()) } - go func() { - trawlingManager := dht.NewTrawlingManager(opFlags.MlTrawlerAddrs) - metadataSink := bittorrent.NewMetadataSink(opFlags.FetcherAddr) + trawlingManager := dht.NewTrawlingManager(opFlags.MlTrawlerAddrs) + metadataSink := bittorrent.NewMetadataSink(opFlags.FetcherAddr) + fileSink := bittorrent.NewFileSink() + go func() { for { select { case result := <-trawlingManager.Output(): - logger.Info("result: ", zap.String("hash", result.InfoHash.String())) + logger.Debug("result: ", zap.String("hash", result.InfoHash.String())) if !database.DoesExist(result.InfoHash[:]) { metadataSink.Sink(result) } @@ -160,105 +173,116 @@ func main() { }() /* - for { - select { + for { + select { - case updating_manager.Output(): + case updating_manager.Output(): - case statistics_sink.Sink(): + case statistics_sink.Sink(): - case completing_manager.Output(): + case completing_manager.Output(): - case file_sink.Sink(): + case file_sink.Sink(): */ <-interrupt_chan } - -func parseFlags() (opFlags) { +func parseFlags() (opF opFlags) { var cmdF cmdFlags _, err := flags.Parse(&cmdF) if err != nil { - zap.L().Fatal("Error while parsing command-line flags: ", zap.Error(err)) + zap.S().Fatalf("Could not parse command-line flags! %s", err.Error()) } - mlTrawlerAddrs, err := hostPortsToUDPAddrs(cmdF.MlTrawlerAddrs) - if err != nil { - zap.L().Fatal("Erroneous ml-trawler-addrs argument supplied: ", zap.Error(err)) + // TODO: Check Database URL here + opF.DatabaseURL = cmdF.DatabaseURL + + if err = checkAddrs(cmdF.TrawlerMlAddrs); err != nil { + zap.S().Fatalf("Of argument (list) `trawler-ml-addr` %s", err.Error()) + } else { + opF.TrawlerMlAddrs = cmdF.TrawlerMlAddrs } - fetcherAddr, err := hostPortsToTCPAddr(cmdF.FetcherAddr) - if err != nil { - zap.L().Fatal("Erroneous fetcher-addr argument supplied: ", zap.Error(err)) + if cmdF.TrawlerMlInterval <= 0 { + zap.L().Fatal("Argument `trawler-ml-interval` must be greater than zero, if supplied.") + } else { + // 1 decisecond = 100 milliseconds = 0.1 seconds + opF.TrawlerMlInterval = time.Duration(cmdF.TrawlerMlInterval) * 100 * time.Millisecond } - mlStatisticianAddrs, err := hostPortsToUDPAddrs(cmdF.MlStatisticianAddrs) - if err != nil { - zap.L().Fatal("Erroneous ml-statistician-addrs argument supplied: ", zap.Error(err)) + if err = checkAddrs([]string{cmdF.FetcherAddr}); err != nil { + zap.S().Fatalf("Of argument `fetcher-addr` %s", err.Error()) + } else { + opF.FetcherAddr = cmdF.FetcherAddr } - leechAddr, err := hostPortsToTCPAddr(cmdF.LeechAddr) - if err != nil { - zap.L().Fatal("Erroneous leech-addrs argument supplied: ", zap.Error(err)) + if cmdF.FetcherTimeout <= 0 { + zap.L().Fatal("Argument `fetcher-timeout` must be greater than zero, if supplied.") + } else { + opF.FetcherTimeout = time.Duration(cmdF.FetcherTimeout) * time.Second } - var descriptionNames []regexp.Regexp - for _, expr := range cmdF.DescriptionNames { - regex, err := regexp.Compile(expr) + if err = checkAddrs(cmdF.StatistMlAddrs); err != nil { + zap.S().Fatalf("Of argument (list) `statist-ml-addr` %s", err.Error()) + } else { + opF.StatistMlAddrs = cmdF.StatistMlAddrs + } + + if cmdF.StatistMlTimeout <= 0 { + zap.L().Fatal("Argument `statist-ml-timeout` must be greater than zero, if supplied.") + } else { + opF.StatistMlTimeout = time.Duration(cmdF.StatistMlTimeout) * time.Second + } + + if err = checkAddrs([]string{cmdF.LeechClAddr}); err != nil { + zap.S().Fatal("Of argument `leech-cl-addr` %s", err.Error()) + } else { + opF.LeechClAddr = cmdF.LeechClAddr + } + + if err = checkAddrs([]string{cmdF.LeechMlAddr}); err != nil { + zap.S().Fatal("Of argument `leech-ml-addr` %s", err.Error()) + } else { + opF.LeechMlAddr = cmdF.LeechMlAddr + } + + if cmdF.LeechTimeout <= 0 { + zap.L().Fatal("Argument `leech-timeout` must be greater than zero, if supplied.") + } else { + opF.LeechTimeout = time.Duration(cmdF.LeechTimeout) * time.Second + } + + if cmdF.ReadmeMaxSize <= 0 { + zap.L().Fatal("Argument `readme-max-size` must be greater than zero, if supplied.") + } else { + opF.ReadmeMaxSize = cmdF.ReadmeMaxSize + } + + for i, s := range cmdF.ReadmeRegexes { + regex, err := regexp.Compile(s) if err != nil { - zap.L().Fatal("Erroneous description-names argument supplied: ", zap.Error(err)) + zap.S().Fatalf("Of argument `readme-regex` with %d(th) regex `%s`: %s", i + 1, s, err.Error()) + } else { + opF.ReadmeRegexes = append(opF.ReadmeRegexes, regex) } - descriptionNames = append(descriptionNames, *regex) } + opF.Verbosity = len(cmdF.Verbose) + opF.Profile = cmdF.Profile - opF := opFlags{ - Database: cmdF.Database, - - MlTrawlerAddrs: mlTrawlerAddrs, - TrawlingInterval: cmdF.TrawlingInterval, - - FetcherAddr: fetcherAddr, - FetcherTimeout: cmdF.FetcherTimeout, - MaxMetadataSize: cmdF.MaxMetadataSize, - - MlStatisticianAddrs: mlStatisticianAddrs, - StatisticianTimeout: cmdF.StatisticianTimeout, - - LeechAddr: leechAddr, - LeechTimeout: cmdF.LeechTimeout, - MaxDescriptionSize: cmdF.MaxDescriptionSize, - DescriptionNames: descriptionNames, - - Verbosity: uint(len(cmdF.Verbose)), - } - - return opF + return } - -func hostPortsToUDPAddrs(hostport []string) ([]net.UDPAddr, error) { - udpAddrs := make([]net.UDPAddr, len(hostport)) - - for i, hp := range hostport { - udpAddr, err := net.ResolveUDPAddr("udp", hp) +func checkAddrs(addrs []string) error { + for i, addr := range addrs { + // We are using ResolveUDPAddr but it works equally well for checking TCPAddr(esses) as + // well. + _, err := net.ResolveUDPAddr("udp", addr) if err != nil { - return nil, err + return fmt.Errorf("with %d(th) address `%s`: %s", i + 1, addr, err.Error()) } - udpAddrs[i] = *udpAddr } - - return udpAddrs, nil -} - - -func hostPortsToTCPAddr(hostport string) (net.TCPAddr, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", hostport) - if err != nil { - return net.TCPAddr{}, err - } - - return *tcpAddr, nil + return nil } diff --git a/src/magneticod/main_test.go b/src/magneticod/main_test.go new file mode 100644 index 0000000..d67e761 --- /dev/null +++ b/src/magneticod/main_test.go @@ -0,0 +1,24 @@ +package main + +import ( + "testing" + + "github.com/Wessie/appdirs" +) + +func TestAppdirs(t *testing.T) { + var expected, returned string + + returned = appdirs.UserDataDir("magneticod", "", "", false) + expected = appdirs.ExpandUser("~/.local/share/magneticod") + if returned != expected { + t.Errorf("UserDataDir returned an unexpected value! `%s`", returned) + } + + returned = appdirs.UserCacheDir("magneticod", "", "", true) + expected = appdirs.ExpandUser("~/.cache/magneticod") + if returned != expected { + t.Errorf("UserCacheDir returned an unexpected value! `%s`", returned) + } +} + diff --git a/src/magneticod/persistence.go b/src/magneticod/persistence.go index ade719f..1a938fa 100644 --- a/src/magneticod/persistence.go +++ b/src/magneticod/persistence.go @@ -1,19 +1,18 @@ package main import ( - "fmt" + "bytes" "database/sql" + "fmt" "net/url" + "path" + "os" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "go.uber.org/zap" "magneticod/bittorrent" - - "path" - "os" - "bytes" ) type engineType uint8 @@ -226,7 +225,13 @@ func setupSqliteDatabase(database *sql.DB) error { return err } - _, err = database.Exec( + tx, err := database.Begin() + if err != nil { + return err + } + + // Essential, and valid for all user_version`s: + _, err = tx.Exec( `CREATE TABLE IF NOT EXISTS torrents ( id INTEGER PRIMARY KEY, info_hash BLOB NOT NULL UNIQUE, @@ -242,12 +247,42 @@ func setupSqliteDatabase(database *sql.DB) error { torrent_id INTEGER REFERENCES torrents ON DELETE CASCADE ON UPDATE RESTRICT, size INTEGER NOT NULL, path TEXT NOT NULL - );`, + ); + `, ) if err != nil { return err } + // Get the user_version: + res, err := tx.Query( + `PRAGMA user_version;`, + ) + if err != nil { + return err + } + var userVersion int; + res.Next() + res.Scan(&userVersion) + + // Upgrade to the latest schema: + switch userVersion { + // Upgrade from user_version 0 to 1 + case 0: + _, err = tx.Exec( + `ALTER TABLE torrents ADD COLUMN readme TEXT; + PRAGMA user_version = 1;`, + ) + if err != nil { + return err + } + // Add `fallthrough`s as needed to keep upgrading... + } + + if err = tx.Commit(); err != nil { + return err + } + return nil } diff --git a/src/magneticod/persistence_test.go b/src/magneticod/persistence_test.go new file mode 100644 index 0000000..716f216 --- /dev/null +++ b/src/magneticod/persistence_test.go @@ -0,0 +1,20 @@ +package main + +import ( + "path" + "testing" +) + + +// TestPathJoin tests the assumption we made in flushNewTorrents() function where we assumed path +// separator to be the `/` (slash), and not `\` (backslash) character (which is used by Windows). +// +// Golang seems to use slash character on both platforms but we need to check that slash character +// is used in all cases. As a rule of thumb in secure programming, always check ONLY for the valid +// case AND IGNORE THE REST (e.g. do not check for backslashes but check for slashes). +func TestPathJoin(t *testing.T) { + if path.Join("a", "b", "c") != "a/b/c" { + t.Errorf("path.Join uses a different character than `/` (slash) character as path separator! (path: `%s`)", + path.Join("a", "b", "c")) + } +}