diff --git a/cmdHandlers.go b/cmdHandlers.go index df43683..71e64d0 100644 --- a/cmdHandlers.go +++ b/cmdHandlers.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path/filepath" + "song-recognition/db" "song-recognition/shazam" "song-recognition/spotify" "song-recognition/utils" @@ -195,7 +196,7 @@ func erase(songsDir string) { ctx := context.Background() // wipe db - dbClient, err := utils.NewDbClient() + dbClient, err := db.NewDBClient() if err != nil { msg := fmt.Sprintf("Error creating DB client: %v\n", err) logger.ErrorContext(ctx, msg, slog.Any("error", err)) diff --git a/db/client.go b/db/client.go new file mode 100644 index 0000000..00fb725 --- /dev/null +++ b/db/client.go @@ -0,0 +1,54 @@ +package db + +import ( + "fmt" + "song-recognition/models" + "song-recognition/utils" +) + +type DBClient interface { + Close() error + StoreFingerprints(fingerprints map[uint32]models.Couple) error + GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) + TotalSongs() (int, error) + RegisterSong(songTitle, songArtist, ytID string) (uint32, error) + GetSong(filterKey string, value interface{}) (Song, bool, error) + GetSongByID(songID uint32) (Song, bool, error) + GetSongByYTID(ytID string) (Song, bool, error) + GetSongByKey(key string) (Song, bool, error) + DeleteSongByID(songID uint32) error + DeleteCollection(collectionName string) error +} + +type Song struct { + Title string + Artist string + YouTubeID string +} + +var DBtype = utils.GetEnv("DB_TYPE", "sqlite") // Can be "sqlite" or "mongo" + +func NewDBClient() (DBClient, error) { + switch DBtype { + case "mongo": + var ( + dbUsername = utils.GetEnv("DB_USER") + dbPassword = utils.GetEnv("DB_PASS") + dbName = utils.GetEnv("DB_NAME") + dbHost = utils.GetEnv("DB_HOST") + dbPort = utils.GetEnv("DB_PORT") + + dbUri = "mongodb://" + dbUsername + ":" + dbPassword + "@" + dbHost + ":" + dbPort + "/" + dbName + ) + if dbUsername == "" || dbPassword == "" { + dbUri = "mongodb://localhost:27017" + } + return NewMongoClient(dbUri) + + case "sqlite": + return NewSQLiteClient("db.sqlite3") + + default: + return nil, fmt.Errorf("unsupported database type: %s", DBtype) + } +} diff --git a/utils/dbClient.go b/db/mongo.go similarity index 71% rename from utils/dbClient.go rename to db/mongo.go index 0dde6ed..16f4177 100644 --- a/utils/dbClient.go +++ b/db/mongo.go @@ -1,10 +1,11 @@ -package utils +package db import ( "context" "errors" "fmt" "song-recognition/models" + "song-recognition/utils" "strings" "go.mongodb.org/mongo-driver/bson" @@ -13,46 +14,27 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// godotenv.Load(".env") - -var ( - dbUsername = GetEnv("DB_USER") - dbPassword = GetEnv("DB_PASS") - dbName = GetEnv("DB_NAME") - dbHost = GetEnv("DB_HOST") - dbPort = GetEnv("DB_PORT") - - dbUri = "mongodb://" + dbUsername + ":" + dbPassword + "@" + dbHost + ":" + dbPort + "/" + dbName -) - -// DbClient represents a MongoDB client -type DbClient struct { +type MongoClient struct { client *mongo.Client } -// NewDbClient creates a new instance of DbClient -func NewDbClient() (*DbClient, error) { - if dbUsername == "" || dbPassword == "" { - dbUri = "mongodb://localhost:27017" - } - - clientOptions := options.Client().ApplyURI(dbUri) +func NewMongoClient(uri string) (*MongoClient, error) { + clientOptions := options.Client().ApplyURI(uri) client, err := mongo.Connect(context.Background(), clientOptions) if err != nil { - return nil, fmt.Errorf("error connecting to MongoDB: %d", err) + return nil, fmt.Errorf("error connecting to MongoDB: %s", err) } - return &DbClient{client: client}, nil + return &MongoClient{client: client}, nil } -// Close closes the underlying MongoDB client -func (db *DbClient) Close() error { +func (db *MongoClient) Close() error { if db.client != nil { return db.client.Disconnect(context.Background()) } return nil } -func (db *DbClient) StoreFingerprints(fingerprints map[uint32]models.Couple) error { +func (db *MongoClient) StoreFingerprints(fingerprints map[uint32]models.Couple) error { collection := db.client.Database("song-recognition").Collection("fingerprints") for address, couple := range fingerprints { @@ -76,7 +58,7 @@ func (db *DbClient) StoreFingerprints(fingerprints map[uint32]models.Couple) err return nil } -func (db *DbClient) GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) { +func (db *MongoClient) GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) { collection := db.client.Database("song-recognition").Collection("fingerprints") couples := make(map[uint32][]models.Couple) @@ -117,7 +99,7 @@ func (db *DbClient) GetCouples(addresses []uint32) (map[uint32][]models.Couple, return couples, nil } -func (db *DbClient) TotalSongs() (int, error) { +func (db *MongoClient) TotalSongs() (int, error) { existingSongsCollection := db.client.Database("song-recognition").Collection("songs") total, err := existingSongsCollection.CountDocuments(context.Background(), bson.D{}) if err != nil { @@ -127,7 +109,7 @@ func (db *DbClient) TotalSongs() (int, error) { return int(total), nil } -func (db *DbClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, error) { +func (db *MongoClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, error) { existingSongsCollection := db.client.Database("song-recognition").Collection("songs") // Create a compound unique index on ytID and key, if it doesn't already exist @@ -141,8 +123,8 @@ func (db *DbClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, er } // Attempt to insert the song with ytID and key - songID := GenerateUniqueID() - key := GenerateSongKey(songTitle, songArtist) + songID := utils.GenerateUniqueID() + key := utils.GenerateSongKey(songTitle, songArtist) _, err = existingSongsCollection.InsertOne(context.Background(), bson.M{"_id": songID, "key": key, "ytID": ytID}) if err != nil { if mongo.IsDuplicateKeyError(err) { @@ -155,16 +137,10 @@ func (db *DbClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, er return songID, nil } -type Song struct { - Title string - Artist string - YouTubeID string -} +var mongofilterKeys = "_id | ytID | key" -const FILTER_KEYS = "_id | ytID | key" - -func (db *DbClient) GetSong(filterKey string, value interface{}) (s Song, songExists bool, e error) { - if !strings.Contains(FILTER_KEYS, filterKey) { +func (db *MongoClient) GetSong(filterKey string, value interface{}) (s Song, songExists bool, e error) { + if !strings.Contains(mongofilterKeys, filterKey) { return Song{}, false, errors.New("invalid filter key") } @@ -190,19 +166,19 @@ func (db *DbClient) GetSong(filterKey string, value interface{}) (s Song, songEx return songInstance, true, nil } -func (db *DbClient) GetSongByID(songID uint32) (Song, bool, error) { +func (db *MongoClient) GetSongByID(songID uint32) (Song, bool, error) { return db.GetSong("_id", songID) } -func (db *DbClient) GetSongByYTID(ytID string) (Song, bool, error) { +func (db *MongoClient) GetSongByYTID(ytID string) (Song, bool, error) { return db.GetSong("ytID", ytID) } -func (db *DbClient) GetSongByKey(key string) (Song, bool, error) { +func (db *MongoClient) GetSongByKey(key string) (Song, bool, error) { return db.GetSong("key", key) } -func (db *DbClient) DeleteSongByID(songID uint32) error { +func (db *MongoClient) DeleteSongByID(songID uint32) error { songsCollection := db.client.Database("song-recognition").Collection("songs") filter := bson.M{"_id": songID} @@ -215,7 +191,7 @@ func (db *DbClient) DeleteSongByID(songID uint32) error { return nil } -func (db *DbClient) DeleteCollection(collectionName string) error { +func (db *MongoClient) DeleteCollection(collectionName string) error { collection := db.client.Database("song-recognition").Collection(collectionName) err := collection.Drop(context.Background()) if err != nil { diff --git a/db/sqlite.go b/db/sqlite.go new file mode 100644 index 0000000..d2d03c6 --- /dev/null +++ b/db/sqlite.go @@ -0,0 +1,207 @@ +package db + +import ( + "database/sql" + "fmt" + "song-recognition/models" + "song-recognition/utils" + "strings" + + "github.com/mattn/go-sqlite3" +) + +type SQLiteClient struct { + db *sql.DB +} + +func NewSQLiteClient(dataSourceName string) (*SQLiteClient, error) { + db, err := sql.Open("sqlite3", dataSourceName) + if err != nil { + return nil, fmt.Errorf("error connecting to SQLite: %s", err) + } + + err = createTables(db) + if err != nil { + return nil, fmt.Errorf("error creating tables: %s", err) + } + + return &SQLiteClient{db: db}, nil +} + +// createTables creates the required tables if they don't exist +func createTables(db *sql.DB) error { + createSongsTable := ` + CREATE TABLE IF NOT EXISTS songs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + artist TEXT NOT NULL, + ytID TEXT NOT NULL UNIQUE, + key TEXT NOT NULL UNIQUE + ); + ` + + createFingerprintsTable := ` + CREATE TABLE IF NOT EXISTS fingerprints ( + address INTEGER NOT NULL, + anchorTimeMs INTEGER NOT NULL, + songID INTEGER NOT NULL, + PRIMARY KEY (address, anchorTimeMs, songID) + ); + ` + + _, err := db.Exec(createSongsTable) + if err != nil { + return fmt.Errorf("error creating songs table: %s", err) + } + + _, err = db.Exec(createFingerprintsTable) + if err != nil { + return fmt.Errorf("error creating fingerprints table: %s", err) + } + + return nil +} + +func (db *SQLiteClient) Close() error { + if db.db != nil { + return db.db.Close() + } + return nil +} + +func (db *SQLiteClient) StoreFingerprints(fingerprints map[uint32]models.Couple) error { + tx, err := db.db.Begin() + if err != nil { + return fmt.Errorf("error starting transaction: %s", err) + } + + stmt, err := tx.Prepare("INSERT OR REPLACE INTO fingerprints (address, anchorTimeMs, songID) VALUES (?, ?, ?)") + if err != nil { + tx.Rollback() + return fmt.Errorf("error preparing statement: %s", err) + } + defer stmt.Close() + + for address, couple := range fingerprints { + if _, err := stmt.Exec(address, couple.AnchorTimeMs, couple.SongID); err != nil { + tx.Rollback() + return fmt.Errorf("error executing statement: %s", err) + } + } + + return tx.Commit() +} + +func (db *SQLiteClient) GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) { + couples := make(map[uint32][]models.Couple) + + for _, address := range addresses { + rows, err := db.db.Query("SELECT anchorTimeMs, songID FROM fingerprints WHERE address = ?", address) + if err != nil { + return nil, fmt.Errorf("error querying database: %s", err) + } + defer rows.Close() + + var docCouples []models.Couple + for rows.Next() { + var couple models.Couple + if err := rows.Scan(&couple.AnchorTimeMs, &couple.SongID); err != nil { + return nil, fmt.Errorf("error scanning row: %s", err) + } + docCouples = append(docCouples, couple) + } + couples[address] = docCouples + } + + return couples, nil +} + +func (db *SQLiteClient) TotalSongs() (int, error) { + var count int + err := db.db.QueryRow("SELECT COUNT(*) FROM songs").Scan(&count) + if err != nil { + return 0, fmt.Errorf("error counting songs: %s", err) + } + return count, nil +} + +func (db *SQLiteClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, error) { + tx, err := db.db.Begin() + if err != nil { + return 0, fmt.Errorf("error starting transaction: %s", err) + } + + stmt, err := tx.Prepare("INSERT INTO songs (id, title, artist, ytID, key) VALUES (?, ?, ?, ?, ?)") + if err != nil { + tx.Rollback() + return 0, fmt.Errorf("error preparing statement: %s", err) + } + defer stmt.Close() + + songID := utils.GenerateUniqueID() + songKey := utils.GenerateSongKey(songTitle, songArtist) + if _, err := stmt.Exec(songID, songTitle, songArtist, ytID, songKey); err != nil { + tx.Rollback() + if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.Code == sqlite3.ErrConstraint { + return 0, fmt.Errorf("song with ytID or key already exists: %v", err) + } + return 0, fmt.Errorf("failed to register song: %v", err) + } + + return songID, tx.Commit() +} + +var sqlitefilterKeys = "id | ytID | key" + +// GetSong retrieves a song by filter key +func (s *SQLiteClient) GetSong(filterKey string, value interface{}) (Song, bool, error) { + + if !strings.Contains(sqlitefilterKeys, filterKey) { + return Song{}, false, fmt.Errorf("invalid filter key") + } + + query := fmt.Sprintf("SELECT title, artist, ytID FROM songs WHERE %s = ?", filterKey) + + row := s.db.QueryRow(query, value) + + var song Song + err := row.Scan(&song.Title, &song.Artist, &song.YouTubeID) + if err != nil { + if err == sql.ErrNoRows { + return Song{}, false, nil + } + return Song{}, false, fmt.Errorf("failed to retrieve song: %s", err) + } + + return song, true, nil +} + +func (db *SQLiteClient) GetSongByID(songID uint32) (Song, bool, error) { + return db.GetSong("id", songID) +} + +func (db *SQLiteClient) GetSongByYTID(ytID string) (Song, bool, error) { + return db.GetSong("ytID", ytID) +} + +func (db *SQLiteClient) GetSongByKey(key string) (Song, bool, error) { + return db.GetSong("key", key) +} + +// DeleteSongByID deletes a song by ID +func (db *SQLiteClient) DeleteSongByID(songID uint32) error { + _, err := db.db.Exec("DELETE FROM songs WHERE id = ?", songID) + if err != nil { + return fmt.Errorf("failed to delete song: %v", err) + } + return nil +} + +// DeleteCollection deletes a collection (table) from the database +func (db *SQLiteClient) DeleteCollection(collectionName string) error { + _, err := db.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", collectionName)) + if err != nil { + return fmt.Errorf("error deleting collection: %v", err) + } + return nil +} diff --git a/go.mod b/go.mod index 9b0fc22..7478c63 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/klauspost/compress v1.17.6 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mjibson/go-dsp v0.0.0-20180508042940-11479a337f12 // indirect github.com/montanaflynn/stats v0.7.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index c5e63f8..0dd0434 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdobak/go-xerrors v0.3.1 h1:XfqaLMNN5T4qsHSlLHGJ35f6YlDTVeINSYYeeuK4VpQ= github.com/mdobak/go-xerrors v0.3.1/go.mod h1:nIR+HMAJuj/uNqyp5+MTN6PJ7ymuIJq3UVs9QCgAHbY= github.com/mjibson/go-dsp v0.0.0-20180508042940-11479a337f12 h1:dd7vnTDfjtwCETZDrRe+GPYNLA1jBtbZeyfyE8eZCyk= diff --git a/shazam/shazam.go b/shazam/shazam.go index 12e7d3a..a3e5f68 100644 --- a/shazam/shazam.go +++ b/shazam/shazam.go @@ -3,6 +3,7 @@ package shazam import ( "fmt" "math" + "song-recognition/db" "song-recognition/utils" "sort" "time" @@ -35,7 +36,7 @@ func FindMatches(audioSamples []float64, audioDuration float64, sampleRate int) addresses = append(addresses, address) } - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { return nil, time.Since(startTime), err } diff --git a/shazam/shazamInit.go b/shazam/shazamInit.go index 75655be..f162663 100644 --- a/shazam/shazamInit.go +++ b/shazam/shazamInit.go @@ -2,6 +2,7 @@ package shazam import ( "fmt" + "song-recognition/db" "song-recognition/models" "song-recognition/utils" "sort" @@ -30,7 +31,7 @@ func Search(audioSamples []float64, audioDuration float64, sampleRate int) ([]Ma addresses = append(addresses, address) } - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { return nil, err } diff --git a/socketHandlers.go b/socketHandlers.go index 66f3105..75ce0c6 100644 --- a/socketHandlers.go +++ b/socketHandlers.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log/slog" + "song-recognition/db" "song-recognition/models" "song-recognition/shazam" "song-recognition/spotify" @@ -32,7 +33,7 @@ func handleTotalSongs(socket socketio.Conn) { logger := utils.GetLogger() ctx := context.Background() - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { err := xerrors.New(err) logger.ErrorContext(ctx, "error connecting to DB", slog.Any("error", err)) @@ -130,7 +131,7 @@ func handleSongDownload(socket socketio.Conn, spotifyURL string) { } // check if track already exist - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { fmt.Errorf("Log - error connecting to DB: %d", err) } diff --git a/spotify/downloader.go b/spotify/downloader.go index fac7471..5601909 100644 --- a/spotify/downloader.go +++ b/spotify/downloader.go @@ -10,6 +10,7 @@ import ( "os/exec" "path/filepath" "runtime" + "song-recognition/db" "song-recognition/shazam" "song-recognition/utils" "song-recognition/wav" @@ -88,7 +89,7 @@ func dlTrack(tracks []Track, path string) (int, error) { ctx := context.Background() - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { return 0, err } @@ -269,11 +270,11 @@ func addTags(file string, track Track) error { } func ProcessAndSaveSong(songFilePath, songTitle, songArtist, ytID string) error { - db, err := utils.NewDbClient() + dbclient, err := db.NewDBClient() if err != nil { return err } - defer db.Close() + defer dbclient.Close() wavFilePath, err := wav.ConvertToWAV(songFilePath, 1) if err != nil { @@ -295,7 +296,7 @@ func ProcessAndSaveSong(songFilePath, songTitle, songArtist, ytID string) error return fmt.Errorf("error creating spectrogram: %v", err) } - songID, err := db.RegisterSong(songTitle, songArtist, ytID) + songID, err := dbclient.RegisterSong(songTitle, songArtist, ytID) if err != nil { return err } @@ -303,13 +304,13 @@ func ProcessAndSaveSong(songFilePath, songTitle, songArtist, ytID string) error peaks := shazam.ExtractPeaks(spectro, wavInfo.Duration) fingerprints := shazam.Fingerprint(peaks, songID) - err = db.StoreFingerprints(fingerprints) + err = dbclient.StoreFingerprints(fingerprints) if err != nil { - db.DeleteSongByID(songID) + dbclient.DeleteSongByID(songID) return fmt.Errorf("error to storing fingerpring: %v", err) } - fmt.Println("Fingerprints saved in MongoDB successfully") + fmt.Printf("Fingerprint for %v by %v saved in DB successfully\n", songTitle, songArtist) return nil } diff --git a/spotify/utils.go b/spotify/utils.go index e5a21dd..e9ff497 100644 --- a/spotify/utils.go +++ b/spotify/utils.go @@ -8,7 +8,7 @@ import ( "os/exec" "path/filepath" "runtime" - "song-recognition/utils" + "song-recognition/db" "strings" ) @@ -40,7 +40,7 @@ func GetFileSize(file string) (int64, error) { } func SongKeyExists(key string) (bool, error) { - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { return false, err } @@ -55,7 +55,7 @@ func SongKeyExists(key string) (bool, error) { } func YtIDExists(ytID string) (bool, error) { - db, err := utils.NewDbClient() + db, err := db.NewDBClient() if err != nil { return false, err }