diff --git a/spotify/utils.go b/spotify/utils.go index 0b28e07..9336c36 100644 --- a/spotify/utils.go +++ b/spotify/utils.go @@ -54,16 +54,12 @@ func SongKeyExists(key string) (bool, error) { } defer db.Close() - song, err := db.GetSongByKey(key) - if err != nil && !strings.Contains(err.Error(), "song not found") { + _, songExists, err := db.GetSongByKey(key) + if err != nil { return false, err } - if song.Title == "" { - return false, nil - } - - return true, nil + return songExists, nil } func YtIDExists(ytID string) (bool, error) { @@ -73,16 +69,12 @@ func YtIDExists(ytID string) (bool, error) { } defer db.Close() - song, err := db.GetSongByYTID(ytID) - if err != nil && !strings.Contains(err.Error(), "song not found") { + _, songExits, err := db.GetSongByYTID(ytID) + if err != nil { return false, err } - if song.Title == "" { - return false, nil - } - - return true, nil + return songExits, nil } /* fixes some invalid file names (windows is the capricious one) */ diff --git a/utils/dbClient.go b/utils/dbClient.go index 8842012..552da9b 100644 --- a/utils/dbClient.go +++ b/utils/dbClient.go @@ -2,6 +2,7 @@ package utils import ( "context" + "errors" "fmt" "song-recognition/models" "strings" @@ -166,9 +167,14 @@ type Song struct { YouTubeID string } -func (db *DbClient) GetSong(filterKey string, value any) (Song, error) { - songsCollection := db.client.Database("song-recognition").Collection("songs") +const FILTER_KEYS = "_id | ytID | key" +func (db *DbClient) GetSong(filterKey string, value interface{}) (s Song, songExists bool, e error) { + if !strings.Contains(FILTER_KEYS, filterKey) { + return Song{}, false, errors.New("invalid filter key") + } + + songsCollection := db.client.Database("song-recognition").Collection("songs") var song bson.M filter := bson.M{filterKey: value} @@ -176,9 +182,9 @@ func (db *DbClient) GetSong(filterKey string, value any) (Song, error) { err := songsCollection.FindOne(context.Background(), filter).Decode(&song) if err != nil { if err == mongo.ErrNoDocuments { - return Song{}, fmt.Errorf("song (%v: %v) not found", filterKey, value) + return Song{}, false, nil } - return Song{}, fmt.Errorf("failed to retrieve song: %v", err) + return Song{}, false, fmt.Errorf("failed to retrieve song: %v", err) } ytID := song["ytID"].(string) @@ -187,18 +193,18 @@ func (db *DbClient) GetSong(filterKey string, value any) (Song, error) { songInstance := Song{title, artist, ytID} - return songInstance, nil + return songInstance, true, nil } -func (db *DbClient) GetSongByID(songID uint32) (Song, error) { +func (db *DbClient) GetSongByID(songID uint32) (Song, bool, error) { return db.GetSong("_id", songID) } -func (db *DbClient) GetSongByYTID(ytID string) (Song, error) { +func (db *DbClient) GetSongByYTID(ytID string) (Song, bool, error) { return db.GetSong("ytID", ytID) } -func (db *DbClient) GetSongByKey(key string) (Song, error) { +func (db *DbClient) GetSongByKey(key string) (Song, bool, error) { return db.GetSong("key", key) }