diff --git a/db/client.go b/db/client.go new file mode 100644 index 0000000..00fb725 --- /dev/null +++ b/db/client.go @@ -0,0 +1,54 @@ +package db + +import ( + "fmt" + "song-recognition/models" + "song-recognition/utils" +) + +type DBClient interface { + Close() error + StoreFingerprints(fingerprints map[uint32]models.Couple) error + GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) + TotalSongs() (int, error) + RegisterSong(songTitle, songArtist, ytID string) (uint32, error) + GetSong(filterKey string, value interface{}) (Song, bool, error) + GetSongByID(songID uint32) (Song, bool, error) + GetSongByYTID(ytID string) (Song, bool, error) + GetSongByKey(key string) (Song, bool, error) + DeleteSongByID(songID uint32) error + DeleteCollection(collectionName string) error +} + +type Song struct { + Title string + Artist string + YouTubeID string +} + +var DBtype = utils.GetEnv("DB_TYPE", "sqlite") // Can be "sqlite" or "mongo" + +func NewDBClient() (DBClient, error) { + switch DBtype { + case "mongo": + var ( + dbUsername = utils.GetEnv("DB_USER") + dbPassword = utils.GetEnv("DB_PASS") + dbName = utils.GetEnv("DB_NAME") + dbHost = utils.GetEnv("DB_HOST") + dbPort = utils.GetEnv("DB_PORT") + + dbUri = "mongodb://" + dbUsername + ":" + dbPassword + "@" + dbHost + ":" + dbPort + "/" + dbName + ) + if dbUsername == "" || dbPassword == "" { + dbUri = "mongodb://localhost:27017" + } + return NewMongoClient(dbUri) + + case "sqlite": + return NewSQLiteClient("db.sqlite3") + + default: + return nil, fmt.Errorf("unsupported database type: %s", DBtype) + } +} diff --git a/db/mongo.go b/db/mongo.go new file mode 100644 index 0000000..16f4177 --- /dev/null +++ b/db/mongo.go @@ -0,0 +1,201 @@ +package db + +import ( + "context" + "errors" + "fmt" + "song-recognition/models" + "song-recognition/utils" + "strings" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type MongoClient struct { + client *mongo.Client +} + +func NewMongoClient(uri string) (*MongoClient, error) { + clientOptions := options.Client().ApplyURI(uri) + client, err := mongo.Connect(context.Background(), clientOptions) + if err != nil { + return nil, fmt.Errorf("error connecting to MongoDB: %s", err) + } + return &MongoClient{client: client}, nil +} + +func (db *MongoClient) Close() error { + if db.client != nil { + return db.client.Disconnect(context.Background()) + } + return nil +} + +func (db *MongoClient) StoreFingerprints(fingerprints map[uint32]models.Couple) error { + collection := db.client.Database("song-recognition").Collection("fingerprints") + + for address, couple := range fingerprints { + filter := bson.M{"_id": address} + update := bson.M{ + "$push": bson.M{ + "couples": bson.M{ + "anchorTimeMs": couple.AnchorTimeMs, + "songID": couple.SongID, + }, + }, + } + opts := options.Update().SetUpsert(true) + + _, err := collection.UpdateOne(context.Background(), filter, update, opts) + if err != nil { + return fmt.Errorf("error upserting document: %s", err) + } + } + + return nil +} + +func (db *MongoClient) GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) { + collection := db.client.Database("song-recognition").Collection("fingerprints") + + couples := make(map[uint32][]models.Couple) + + for _, address := range addresses { + // Find the document corresponding to the address + var result bson.M + err := collection.FindOne(context.Background(), bson.M{"_id": address}).Decode(&result) + if err != nil { + if err == mongo.ErrNoDocuments { + continue + } + return nil, fmt.Errorf("error retrieving document for address %d: %s", address, err) + } + + // Extract couples from the document and append them to the couples map + var docCouples []models.Couple + couplesList, ok := result["couples"].(primitive.A) + if !ok { + return nil, fmt.Errorf("couples field in document for address %d is not valid", address) + } + + for _, item := range couplesList { + itemMap, ok := item.(primitive.M) + if !ok { + return nil, fmt.Errorf("invalid couple format in document for address %d", address) + } + + couple := models.Couple{ + AnchorTimeMs: uint32(itemMap["anchorTimeMs"].(int64)), + SongID: uint32(itemMap["songID"].(int64)), + } + docCouples = append(docCouples, couple) + } + couples[address] = docCouples + } + + return couples, nil +} + +func (db *MongoClient) TotalSongs() (int, error) { + existingSongsCollection := db.client.Database("song-recognition").Collection("songs") + total, err := existingSongsCollection.CountDocuments(context.Background(), bson.D{}) + if err != nil { + return 0, err + } + + return int(total), nil +} + +func (db *MongoClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, error) { + existingSongsCollection := db.client.Database("song-recognition").Collection("songs") + + // Create a compound unique index on ytID and key, if it doesn't already exist + indexModel := mongo.IndexModel{ + Keys: bson.D{{"ytID", 1}, {"key", 1}}, + Options: options.Index().SetUnique(true), + } + _, err := existingSongsCollection.Indexes().CreateOne(context.Background(), indexModel) + if err != nil { + return 0, fmt.Errorf("failed to create unique index: %v", err) + } + + // Attempt to insert the song with ytID and key + songID := utils.GenerateUniqueID() + key := utils.GenerateSongKey(songTitle, songArtist) + _, err = existingSongsCollection.InsertOne(context.Background(), bson.M{"_id": songID, "key": key, "ytID": ytID}) + if err != nil { + if mongo.IsDuplicateKeyError(err) { + return 0, fmt.Errorf("song with ytID or key already exists: %v", err) + } else { + return 0, fmt.Errorf("failed to register song: %v", err) + } + } + + return songID, nil +} + +var mongofilterKeys = "_id | ytID | key" + +func (db *MongoClient) GetSong(filterKey string, value interface{}) (s Song, songExists bool, e error) { + if !strings.Contains(mongofilterKeys, filterKey) { + return Song{}, false, errors.New("invalid filter key") + } + + songsCollection := db.client.Database("song-recognition").Collection("songs") + var song bson.M + + filter := bson.M{filterKey: value} + + err := songsCollection.FindOne(context.Background(), filter).Decode(&song) + if err != nil { + if err == mongo.ErrNoDocuments { + return Song{}, false, nil + } + return Song{}, false, fmt.Errorf("failed to retrieve song: %v", err) + } + + ytID := song["ytID"].(string) + title := strings.Split(song["key"].(string), "---")[0] + artist := strings.Split(song["key"].(string), "---")[1] + + songInstance := Song{title, artist, ytID} + + return songInstance, true, nil +} + +func (db *MongoClient) GetSongByID(songID uint32) (Song, bool, error) { + return db.GetSong("_id", songID) +} + +func (db *MongoClient) GetSongByYTID(ytID string) (Song, bool, error) { + return db.GetSong("ytID", ytID) +} + +func (db *MongoClient) GetSongByKey(key string) (Song, bool, error) { + return db.GetSong("key", key) +} + +func (db *MongoClient) DeleteSongByID(songID uint32) error { + songsCollection := db.client.Database("song-recognition").Collection("songs") + + filter := bson.M{"_id": songID} + + _, err := songsCollection.DeleteOne(context.Background(), filter) + if err != nil { + return fmt.Errorf("failed to delete song: %v", err) + } + + return nil +} + +func (db *MongoClient) DeleteCollection(collectionName string) error { + collection := db.client.Database("song-recognition").Collection(collectionName) + err := collection.Drop(context.Background()) + if err != nil { + return fmt.Errorf("error deleting collection: %v", err) + } + return nil +} diff --git a/db/sqlite.go b/db/sqlite.go new file mode 100644 index 0000000..d2d03c6 --- /dev/null +++ b/db/sqlite.go @@ -0,0 +1,207 @@ +package db + +import ( + "database/sql" + "fmt" + "song-recognition/models" + "song-recognition/utils" + "strings" + + "github.com/mattn/go-sqlite3" +) + +type SQLiteClient struct { + db *sql.DB +} + +func NewSQLiteClient(dataSourceName string) (*SQLiteClient, error) { + db, err := sql.Open("sqlite3", dataSourceName) + if err != nil { + return nil, fmt.Errorf("error connecting to SQLite: %s", err) + } + + err = createTables(db) + if err != nil { + return nil, fmt.Errorf("error creating tables: %s", err) + } + + return &SQLiteClient{db: db}, nil +} + +// createTables creates the required tables if they don't exist +func createTables(db *sql.DB) error { + createSongsTable := ` + CREATE TABLE IF NOT EXISTS songs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + artist TEXT NOT NULL, + ytID TEXT NOT NULL UNIQUE, + key TEXT NOT NULL UNIQUE + ); + ` + + createFingerprintsTable := ` + CREATE TABLE IF NOT EXISTS fingerprints ( + address INTEGER NOT NULL, + anchorTimeMs INTEGER NOT NULL, + songID INTEGER NOT NULL, + PRIMARY KEY (address, anchorTimeMs, songID) + ); + ` + + _, err := db.Exec(createSongsTable) + if err != nil { + return fmt.Errorf("error creating songs table: %s", err) + } + + _, err = db.Exec(createFingerprintsTable) + if err != nil { + return fmt.Errorf("error creating fingerprints table: %s", err) + } + + return nil +} + +func (db *SQLiteClient) Close() error { + if db.db != nil { + return db.db.Close() + } + return nil +} + +func (db *SQLiteClient) StoreFingerprints(fingerprints map[uint32]models.Couple) error { + tx, err := db.db.Begin() + if err != nil { + return fmt.Errorf("error starting transaction: %s", err) + } + + stmt, err := tx.Prepare("INSERT OR REPLACE INTO fingerprints (address, anchorTimeMs, songID) VALUES (?, ?, ?)") + if err != nil { + tx.Rollback() + return fmt.Errorf("error preparing statement: %s", err) + } + defer stmt.Close() + + for address, couple := range fingerprints { + if _, err := stmt.Exec(address, couple.AnchorTimeMs, couple.SongID); err != nil { + tx.Rollback() + return fmt.Errorf("error executing statement: %s", err) + } + } + + return tx.Commit() +} + +func (db *SQLiteClient) GetCouples(addresses []uint32) (map[uint32][]models.Couple, error) { + couples := make(map[uint32][]models.Couple) + + for _, address := range addresses { + rows, err := db.db.Query("SELECT anchorTimeMs, songID FROM fingerprints WHERE address = ?", address) + if err != nil { + return nil, fmt.Errorf("error querying database: %s", err) + } + defer rows.Close() + + var docCouples []models.Couple + for rows.Next() { + var couple models.Couple + if err := rows.Scan(&couple.AnchorTimeMs, &couple.SongID); err != nil { + return nil, fmt.Errorf("error scanning row: %s", err) + } + docCouples = append(docCouples, couple) + } + couples[address] = docCouples + } + + return couples, nil +} + +func (db *SQLiteClient) TotalSongs() (int, error) { + var count int + err := db.db.QueryRow("SELECT COUNT(*) FROM songs").Scan(&count) + if err != nil { + return 0, fmt.Errorf("error counting songs: %s", err) + } + return count, nil +} + +func (db *SQLiteClient) RegisterSong(songTitle, songArtist, ytID string) (uint32, error) { + tx, err := db.db.Begin() + if err != nil { + return 0, fmt.Errorf("error starting transaction: %s", err) + } + + stmt, err := tx.Prepare("INSERT INTO songs (id, title, artist, ytID, key) VALUES (?, ?, ?, ?, ?)") + if err != nil { + tx.Rollback() + return 0, fmt.Errorf("error preparing statement: %s", err) + } + defer stmt.Close() + + songID := utils.GenerateUniqueID() + songKey := utils.GenerateSongKey(songTitle, songArtist) + if _, err := stmt.Exec(songID, songTitle, songArtist, ytID, songKey); err != nil { + tx.Rollback() + if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.Code == sqlite3.ErrConstraint { + return 0, fmt.Errorf("song with ytID or key already exists: %v", err) + } + return 0, fmt.Errorf("failed to register song: %v", err) + } + + return songID, tx.Commit() +} + +var sqlitefilterKeys = "id | ytID | key" + +// GetSong retrieves a song by filter key +func (s *SQLiteClient) GetSong(filterKey string, value interface{}) (Song, bool, error) { + + if !strings.Contains(sqlitefilterKeys, filterKey) { + return Song{}, false, fmt.Errorf("invalid filter key") + } + + query := fmt.Sprintf("SELECT title, artist, ytID FROM songs WHERE %s = ?", filterKey) + + row := s.db.QueryRow(query, value) + + var song Song + err := row.Scan(&song.Title, &song.Artist, &song.YouTubeID) + if err != nil { + if err == sql.ErrNoRows { + return Song{}, false, nil + } + return Song{}, false, fmt.Errorf("failed to retrieve song: %s", err) + } + + return song, true, nil +} + +func (db *SQLiteClient) GetSongByID(songID uint32) (Song, bool, error) { + return db.GetSong("id", songID) +} + +func (db *SQLiteClient) GetSongByYTID(ytID string) (Song, bool, error) { + return db.GetSong("ytID", ytID) +} + +func (db *SQLiteClient) GetSongByKey(key string) (Song, bool, error) { + return db.GetSong("key", key) +} + +// DeleteSongByID deletes a song by ID +func (db *SQLiteClient) DeleteSongByID(songID uint32) error { + _, err := db.db.Exec("DELETE FROM songs WHERE id = ?", songID) + if err != nil { + return fmt.Errorf("failed to delete song: %v", err) + } + return nil +} + +// DeleteCollection deletes a collection (table) from the database +func (db *SQLiteClient) DeleteCollection(collectionName string) error { + _, err := db.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", collectionName)) + if err != nil { + return fmt.Errorf("error deleting collection: %v", err) + } + return nil +}