From 7dcc7b1b52ac2f7823b27386dea635c2a0db0fca Mon Sep 17 00:00:00 2001 From: Chigozirim Igweamaka Date: Fri, 26 Apr 2024 05:39:40 +0100 Subject: [PATCH] Update GitHub --- server.go | 136 ++++++++++++++++++++++++--------------- shazam/fft.go | 13 ++-- shazam/shazam.go | 143 ++++++++++++++++-------------------------- shazam/spectrogram.go | 91 +++++++++++++++++++++++++++ signal/webrtc.go | 13 ++-- spotify/downloader.go | 33 +++++++--- spotify/youtube.go | 3 +- utils/wav.go | 2 +- 8 files changed, 276 insertions(+), 158 deletions(-) create mode 100644 shazam/spectrogram.go diff --git a/server.go b/server.go index 431bae6..f727102 100644 --- a/server.go +++ b/server.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io/ioutil" "log" "net/http" "song-recognition/shazam" @@ -15,7 +14,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pion/webrtc/v4" - "github.com/pion/webrtc/v4/pkg/media/oggwriter" + "go.mongodb.org/mongo-driver/bson/primitive" socketio "github.com/googollee/go-socket.io" ) @@ -42,6 +41,21 @@ func GinMiddleware(allowOrigin string) gin.HandlerFunc { } } +type DownloadStatus struct { + Type string + Message string +} + +func downloadStatus(msgType, message string) string { + data := map[string]interface{}{"type": msgType, "message": message} + jsonData, err := json.Marshal(data) + if err != nil { + fmt.Println("Error marshalling JSON:", err) + return "" + } + return string(jsonData) +} + func main() { router := gin.New() @@ -54,13 +68,6 @@ func main() { return nil }) - server.OnEvent("/", "initOffer", func(s socketio.Conn, initEncodedOffer string) { - log.Println("initOffer: ", initEncodedOffer) - - peerConnection := signal.SetupWebRTC(initEncodedOffer) - s.Emit("initAnswer", signal.Encode(*peerConnection.LocalDescription())) - }) - server.OnEvent("/", "totalSongs", func(socket socketio.Conn) { db, err := utils.NewDbClient() if err != nil { @@ -100,41 +107,54 @@ func main() { tracksInAlbum, err := spotify.AlbumInfo(spotifyURL) if err != nil { fmt.Println("log error: ", err) + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + } return } - socket.Emit("albumStat", fmt.Sprintf("%v songs found in album.", len(tracksInAlbum))) + statusMsg := fmt.Sprintf("%v songs found in album.", len(tracksInAlbum)) + socket.Emit("downloadStatus", downloadStatus("info", statusMsg)) totalTracksDownloaded, err := spotify.DlAlbum(spotifyURL, tmpSongDir) if err != nil { - socket.Emit("downloadStatus", fmt.Sprintf("Failed to download album.")) + socket.Emit("downloadStatus", downloadStatus("error", "Couldn't to download album.")) return } - socket.Emit("downloadStatus", fmt.Sprintf("%d songs downloaded from album", totalTracksDownloaded)) + statusMsg = fmt.Sprintf("%d songs downloaded from album", totalTracksDownloaded) + socket.Emit("downloadStatus", downloadStatus("success", statusMsg)) } else if strings.Contains(spotifyURL, "playlist") { tracksInPL, err := spotify.PlaylistInfo(spotifyURL) if err != nil { fmt.Println("log error: ", err) + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + } return } - socket.Emit("playlistStat", fmt.Sprintf("%v songs found in playlist.", len(tracksInPL))) + statusMsg := fmt.Sprintf("%v songs found in playlist.", len(tracksInPL)) + socket.Emit("downloadStatus", downloadStatus("info", statusMsg)) totalTracksDownloaded, err := spotify.DlPlaylist(spotifyURL, tmpSongDir) if err != nil { fmt.Println("log errorr: ", err) - socket.Emit("downloadStatus", fmt.Sprintf("Failed to download playlist.")) + socket.Emit("downloadStatus", downloadStatus("error", "Couldn't download playlist.")) return } - socket.Emit("downloadStatus", fmt.Sprintf("%d songs downloaded from playlist", totalTracksDownloaded)) + statusMsg = fmt.Sprintf("%d songs downloaded from playlist.", totalTracksDownloaded) + socket.Emit("downloadStatus", downloadStatus("success", statusMsg)) } else if strings.Contains(spotifyURL, "track") { trackInfo, err := spotify.TrackInfo(spotifyURL) if err != nil { fmt.Println("log error: ", err) + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + } return } @@ -151,22 +171,33 @@ func main() { } if chunkTag != nil { - socket.Emit("downloadStatus", fmt.Sprintf( + statusMsg := fmt.Sprintf( "'%s' by '%s' already exists in the database (https://www.youtube.com/watch?v=%s)", - trackInfo.Title, trackInfo.Artist, chunkTag["youtubeid"])) + trackInfo.Title, trackInfo.Artist, chunkTag["youtubeid"]) + + fmt.Println("Emitting1") + + socket.Emit("downloadStatus", downloadStatus("error", statusMsg)) return } totalDownloads, err := spotify.DlSingleTrack(spotifyURL, tmpSongDir) if err != nil { - socket.Emit("downloadStatus", fmt.Sprintf("Failed to download '%s' by '%s'", trackInfo.Title, trackInfo.Artist)) + statusMsg := fmt.Sprintf("Couldn't download '%s' by '%s'", trackInfo.Title, trackInfo.Artist) + fmt.Println("Emitting2") + socket.Emit("downloadStatus", downloadStatus("error", statusMsg)) return } + statusMsg := "" if totalDownloads != 1 { - socket.Emit("downloadStatus", fmt.Sprintf("'%s' by '%s' failed to download", trackInfo.Title, trackInfo.Artist)) + statusMsg = fmt.Sprintf("'%s' by '%s' failed to download", trackInfo.Title, trackInfo.Artist) + fmt.Println("Emitting2") + socket.Emit("downloadStatus", downloadStatus("error", statusMsg)) } else { - socket.Emit("downloadStatus", fmt.Sprintf("'%s' by '%s' was downloaded", trackInfo.Title, trackInfo.Artist)) + statusMsg = fmt.Sprintf("'%s' by '%s' was downloaded", trackInfo.Title, trackInfo.Artist) + fmt.Println("Emitting3") + socket.Emit("downloadStatus", downloadStatus("success", statusMsg)) } } else { @@ -184,24 +215,32 @@ func main() { } // Save the decoded data to a file - err = ioutil.WriteFile("recorded_audio.ogg", decodedData, 0644) + sampleRate := 44100 + channels := 1 + bitsPerSample := 16 + + err = utils.WriteWavFile("blob.wav", decodedData, sampleRate, channels, bitsPerSample) if err != nil { - fmt.Println("Error: Failed to write file to disk:", err) - return + fmt.Println("Error: Failed to write wav file: ", err) } fmt.Println("Audio saved successfully.") - matches, err := shazam.Match(decodedData) + matches, err := shazam.FindMatches(decodedData) if err != nil { fmt.Println("Error: Failed to match:", err) return } - jsonData, err := json.Marshal(matches) + var matchesChunkTags []primitive.M + for _, match := range matches { + matchesChunkTags = append(matchesChunkTags, match.ChunkTag) + } - if len(matches) > 5 { - jsonData, err = json.Marshal(matches[:5]) + jsonData, err := json.Marshal(matchesChunkTags) + + if len(matchesChunkTags) > 5 { + jsonData, err = json.Marshal(matchesChunkTags[:5]) } if err != nil { @@ -210,39 +249,35 @@ func main() { } socket.Emit("matches", string(jsonData)) - - fmt.Println("BLOB: ", matches) }) - server.OnEvent("/", "engage", func(s socketio.Conn, encodedOffer string) { - log.Println("engage: ", encodedOffer) + server.OnEvent("/", "engage", func(socket socketio.Conn, encodedOffer string) { + log.Println("Offer received from client ", socket.ID()) peerConnection := signal.SetupWebRTC(encodedOffer) // Allow us to receive 1 audio track if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio); err != nil { - panic(err) - } - - // Set a handler for when a new remote track starts, this handler saves buffers to disk as - // an Ogg file. - oggFile, err := oggwriter.New("output.ogg", 48000, 1) - if err != nil { + fmt.Println("AAAAA") panic(err) } peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { codec := track.Codec() if strings.EqualFold(codec.MimeType, webrtc.MimeTypeOpus) { - // fmt.Println("Got Opus track, saving to disk as output.opus (44.1 kHz, 1 channel)") - // signal.SaveToDisk(oggFile, track) - + fmt.Println("Getting tracks") matches, err := signal.MatchSampleAudio(track) if err != nil { - panic(err) + fmt.Println("CCCCC") + fmt.Println("Error getting matches: ", err) + return } jsonData, err := json.Marshal(matches) + if err != nil { + fmt.Println("Log error: ", err) + return + } if len(matches) > 5 { jsonData, err = json.Marshal(matches[:5]) @@ -255,8 +290,8 @@ func main() { fmt.Println(string(jsonData)) - s.Emit("matches", string(jsonData)) - peerConnection.Close() + socket.Emit("matches", string(jsonData)) + // peerConnection.Close() } }) @@ -266,16 +301,17 @@ func main() { fmt.Printf("Connection State has changed %s \n", connectionState.String()) if connectionState == webrtc.ICEConnectionStateConnected { - fmt.Println("Ctrl+C the remote client to stop the demo") + fmt.Println("WebRTC Connected. Client: ", socket.ID()) } else if connectionState == webrtc.ICEConnectionStateFailed || connectionState == webrtc.ICEConnectionStateClosed { - if closeErr := oggFile.Close(); closeErr != nil { - panic(closeErr) - } - fmt.Println("Done writing media files") + if connectionState == webrtc.ICEConnectionStateFailed { + fmt.Println("WebRTC connection failed. Client: ", socket.ID()) + socket.Emit("failedToEngage", "") + } // Gracefully shutdown the peer connection if closeErr := peerConnection.Close(); closeErr != nil { + fmt.Println("Gracefully shutdown the peer connection") panic(closeErr) } @@ -284,7 +320,7 @@ func main() { }) // Emit answer in base64 - s.Emit("serverEngaged", signal.Encode(*peerConnection.LocalDescription())) + socket.Emit("serverEngaged", signal.Encode(*peerConnection.LocalDescription())) }) server.OnError("/", func(s socketio.Conn, e error) { diff --git a/shazam/fft.go b/shazam/fft.go index 22d9b3d..a3d3198 100644 --- a/shazam/fft.go +++ b/shazam/fft.go @@ -2,11 +2,16 @@ package shazam import ( "math" - "math/cmplx" ) -// fft performs the Fast Fourier Transform on the input signal. -func Fft(complexArray []complex128) []complex128 { +// Fft performs the Fast Fourier Transform on the input signal. +func FFT(input []float64) []complex128 { + // Convert input to complex128 + complexArray := make([]complex128, len(input)) + for i, v := range input { + complexArray[i] = complex(v, 0) + } + fftResult := make([]complex128, len(complexArray)) copy(fftResult, complexArray) // Copy input to result buffer return recursiveFFT(fftResult) @@ -31,7 +36,7 @@ func recursiveFFT(complexArray []complex128) []complex128 { fftResult := make([]complex128, N) for k := 0; k < N/2; k++ { - t := cmplx.Exp(-2i * math.Pi * complex(float64(k), 0) / complex(float64(N), 0)) + t := complex(math.Cos(-2*math.Pi*float64(k)/float64(N)), math.Sin(-2*math.Pi*float64(k)/float64(N))) fftResult[k] = even[k] + t*odd[k] fftResult[k+N/2] = even[k] - t*odd[k] } diff --git a/shazam/shazam.go b/shazam/shazam.go index 65f88a6..55326aa 100644 --- a/shazam/shazam.go +++ b/shazam/shazam.go @@ -32,7 +32,13 @@ type ChunkTag struct { TimeStamp string } -func Match(sampleAudio []byte) ([]primitive.M, error) { +type Match struct { + songKey string + ChunkTag primitive.M + WeightedScore float64 +} + +func FindMatches(sampleAudio []byte) ([]Match, error) { sampleChunks := Chunkify(sampleAudio) chunkFingerprints, _ := FingerprintChunks(sampleChunks, nil) @@ -63,51 +69,42 @@ func Match(sampleAudio []byte) ([]primitive.M, error) { } } - maxMatchCount := 0 - var maxMatch string - - matches := make(map[string][]int) - + var matches []Match for songKey, timestamps := range songsTimestamps { - timestampsInSeconds, err := timestampsInSeconds(timestamps) - if err != nil && err.Error() == "insufficient timestamps" { - continue - } else if err != nil { + if err != nil { return nil, err } maxPeak, differenceSum, err := getMaxPeak(timestampsInSeconds) if err != nil { - return nil, err - } - fmt.Printf("%s MaxPeak: %v, DifferenceSum: %d\n", songKey, maxPeak, differenceSum) - fmt.Println("=====================================================\n") - - differences, err := timeDifference(timestamps) - if err != nil && err.Error() == "insufficient timestamps" { - continue - } else if err != nil { - return nil, err - } - - // fmt.Printf("%s DIFFERENCES: %d\n", songKey, differences) - if len(differences) >= 2 { - matches[songKey] = differences - if len(differences) > maxMatchCount { - maxMatchCount = len(differences) - maxMatch = songKey + if err.Error() == "insufficient timestamps" || err.Error() == "no peak was identified" { + continue + } else { + return nil, err } } + + weightedScore := float64(differenceSum) / float64(len(maxPeak)) + matches = append(matches, Match{songKey, chunkTags[songKey], weightedScore}) + + fmt.Printf("%s MaxPeak: %v, DifferenceSum: %d\n", songKey, maxPeak, differenceSum) + fmt.Println("=====================================================\n") } - sortedChunkTags := sortMatchesByTimeDifference(matches, chunkTags) + sort.Slice(matches, func(i, j int) bool { + return matches[i].WeightedScore < matches[j].WeightedScore + }) - // fmt.Println("SORTED CHUNK TAGS: ", sortedChunkTags) - // fmt.Println("MATCHES: ", matches) - fmt.Println("MATCH: ", maxMatch) - // fmt.Println() - return sortedChunkTags, nil + display := make(map[string]float64) + for _, match := range matches { + key := match.songKey + display[key] = match.WeightedScore + } + + fmt.Println("New Matches: ", display) + fmt.Println("Matches: ", matches) + return matches, nil } func sortMatchesByTimeDifference(matches map[string][]int, chunkTags map[string]primitive.M) []primitive.M { @@ -174,8 +171,10 @@ func getMaxPeak(timestamps []int) ([]int, int, error) { // Ensure timestamps are in ascending order if minuend > subtrahend { - peaks = append(peaks, cluster) - cluster = nil + if len(cluster) > 0 { + peaks = append(peaks, cluster) + cluster = nil + } continue } @@ -187,11 +186,17 @@ func getMaxPeak(timestamps []int) ([]int, int, error) { } else if difference <= maxDifference { cluster = append(cluster, subtrahend) } else if difference > maxDifference { - peaks = append(peaks, cluster) - cluster = nil + if len(cluster) > 0 { + peaks = append(peaks, cluster) + cluster = nil + } } } + if len(peaks) < 1 { + return nil, 0, fmt.Errorf("no peak was identified") + } + // Identify the largest peak(s) largestPeak := [][]int{peaks[0]} for _, peak := range peaks[1:] { @@ -208,16 +213,19 @@ func getMaxPeak(timestamps []int) ([]int, int, error) { if len(largestPeak) > 1 { fmt.Println("Largest Peak > 1: ", largestPeak) - // Deduplicate largest peaks in order to get accurate sum of difference - var largestPeakDeDuplicated [][]int + // Deduplicate largest peaks to get accurate result. + // How? Consider two peaks: A: [53, 53, 53] and B: [14, 15]. + // Peak A has only one unique value (53) repeated three times, while peak B has two unique values (14 and 15). + // In this case, peak B would be prioritized over peak A + var largestPeakDeduplicated [][]int for _, peak := range largestPeak { - largestPeakDeDuplicated = append(largestPeakDeDuplicated, deduplicate(peak)) + largestPeakDeduplicated = append(largestPeakDeduplicated, deduplicate(peak)) } - fmt.Println("Largest Peak deduplicated: ", largestPeakDeDuplicated) + fmt.Println("Largest Peak deduplicated: ", largestPeakDeduplicated) minDifferenceSum := math.Inf(1) var peakWithMinDifferenceSum []int - for idx, peak := range largestPeakDeDuplicated { + for idx, peak := range largestPeakDeduplicated { if len(peak) <= 1 { continue } @@ -228,15 +236,15 @@ func getMaxPeak(timestamps []int) ([]int, int, error) { } if differenceSum < minDifferenceSum { minDifferenceSum = differenceSum - fmt.Printf("%v vs %v\n", largestPeak[idx], peak) peakWithMinDifferenceSum = largestPeak[idx] } } // In the case where no peak with the min difference sum was identified, - // probably because they were all duplicates, return the first from the largestspeaks + // probably because they are all duplicates, return the first from the largestspeaks if len(peakWithMinDifferenceSum) == 0 { peakWithMinDifferenceSum = largestPeak[0] + minDifferenceSum = 0 } return peakWithMinDifferenceSum, int(minDifferenceSum), nil @@ -252,49 +260,6 @@ func getMaxPeak(timestamps []int) ([]int, int, error) { return maxPeak, differenceSum, nil } -func timeDifference(timestamps []string) ([]int, error) { - if len(timestamps) < 2 { - return nil, fmt.Errorf("insufficient timestamps") - } - - layout := "15:04:05" - - timestampsInSeconds := make([]int, len(timestamps)) - for i, ts := range timestamps { - parsedTime, err := time.Parse(layout, ts) - if err != nil { - return nil, fmt.Errorf("error parsing timestamp %q: %w", ts, err) - } - hours := parsedTime.Hour() - minutes := parsedTime.Minute() - seconds := parsedTime.Second() - timestampsInSeconds[i] = (hours * 3600) + (minutes * 60) + seconds - } - - // sort.Ints(timestampsInSeconds) - - differencesSet := map[int]struct{}{} - var differences []int - - for i := len(timestampsInSeconds) - 1; i >= 1; i-- { - difference := timestampsInSeconds[i] - timestampsInSeconds[i-1] - // maxSeconds = 15 - if difference > 0 && difference <= 15 { - differencesSet[difference] = struct{}{} - differences = append(differences, difference) - } - } - - differencesList := []int{} - if len(differencesSet) > 0 { - for k := range differencesSet { - differencesList = append(differencesList, k) - } - } - - return timestampsInSeconds, nil -} - // Chunkify divides the input audio signal into chunks and calculates the Short-Time Fourier Transform (STFT) for each chunk. // The function returns a 2D slice containing the STFT coefficients for each chunk. func Chunkify(audio []byte) [][]complex128 { @@ -360,7 +325,7 @@ func FingerprintChunks(chunks [][]complex128, chunkTag *ChunkTag) ([]int64, map[ if chunkCount == chunksPerSecond { chunkCount = 0 chunkTime = chunkTime.Add(1 * time.Second) - fmt.Println(chunkTime.Format("15:04:05")) + // fmt.Println(chunkTime.Format("15:04:05")) } } diff --git a/shazam/spectrogram.go b/shazam/spectrogram.go new file mode 100644 index 0000000..d129da0 --- /dev/null +++ b/shazam/spectrogram.go @@ -0,0 +1,91 @@ +package shazam + +import ( + "errors" + "fmt" + "math" +) + +const ( + dspRatio = 4 + lowPassFilter = 5000.0 // 5kHz + samplesPerWindow = 1024 +) + +func Spectrogram(samples []float64, channels, sampleRate int) [][]complex128 { + lpf := NewLowPassFilter(lowPassFilter, float64(sampleRate)) + filteredSamples := lpf.Filter(samples) + + downsampledSamples, err := downsample(filteredSamples, dspRatio) + if err != nil { + fmt.Println("Couldn't downsample audio samples: ", err) + } + + hopSize := samplesPerWindow / 32 + numOfWindows := len(downsampledSamples) / (samplesPerWindow - hopSize) + spectrogram := make([][]complex128, numOfWindows) + + // Apply Hamming window function + windowSize := len(samples) + for i := 0; i < len(downsampledSamples); i++ { + downsampledSamples[i] = 0.54 - 0.46*math.Cos(2*math.Pi*float64(i)/(float64(windowSize)-1)) + } + + // Perform STFT + for i := 0; i < numOfWindows; i++ { + start := i * hopSize + end := start + samplesPerWindow + if end > len(downsampledSamples) { + end = len(downsampledSamples) + } + + spec := make([]float64, samplesPerWindow) + for j := start; j < end; j++ { + spec[j-start] = downsampledSamples[j] + } + + applyHammingWindow(spec) + spectrogram[i] = FFT(spec) + } + + return spectrogram +} + +func applyHammingWindow(samples []float64) { + windowSize := len(samples) + + for i := 0; i < windowSize; i++ { + samples[i] *= 0.54 - 0.46*math.Cos(2*math.Pi*float64(i)/(float64(windowSize)-1)) + } +} + +// Downsample downsamples a list of float64 values from 44100 Hz to a specified ratio by averaging groups of samples +func downsample(input []float64, ratio int) ([]float64, error) { + // Ensure the ratio is valid and compatible with the input length + if ratio <= 0 || len(input)%ratio != 0 { + return nil, errors.New("invalid or incompatible ratio") + } + + // Calculate the size of the output slice + outputSize := len(input) / ratio + + // Create the output slice + output := make([]float64, outputSize) + + // Iterate over the input and calculate averages for each group of samples + for i := 0; i < outputSize; i++ { + startIndex := i * ratio + endIndex := startIndex + ratio + sum := 0.0 + + // Sum up the values in the current group of samples + for j := startIndex; j < endIndex; j++ { + sum += input[j] + } + + // Calculate the average for the current group + output[i] = sum / float64(ratio) + } + + return output, nil +} diff --git a/signal/webrtc.go b/signal/webrtc.go index f770ad3..86c8a1e 100644 --- a/signal/webrtc.go +++ b/signal/webrtc.go @@ -73,17 +73,17 @@ func MatchSampleAudio(track *webrtc.TrackRemote) ([]primitive.M, error) { defer ticker.Stop() var sampleAudio []byte - var matches []primitive.M + var matches []shazam.Match for { select { case <-ticker.C: // Process sampleAudio every 2 seconds if len(sampleAudio) > 0 { - matchess, err := shazam.Match(sampleAudio) + matchess, err := shazam.FindMatches(sampleAudio) matches = matchess if err != nil { - fmt.Println(err) + fmt.Println("An Error: ", err) return nil, nil } @@ -103,7 +103,12 @@ func MatchSampleAudio(track *webrtc.TrackRemote) ([]primitive.M, error) { case <-stop: // Stop after 15 seconds fmt.Println("Stopped after 15 seconds") - return matches, nil + var matchesChunkTags []primitive.M + for _, match := range matches { + matchesChunkTags = append(matchesChunkTags, match.ChunkTag) + } + return matchesChunkTags, nil + default: // Read RTP packets and accumulate sampleAudio rtpPacket, _, err := track.ReadRTP() diff --git a/spotify/downloader.go b/spotify/downloader.go index 2a81d75..eb5077e 100644 --- a/spotify/downloader.go +++ b/spotify/downloader.go @@ -104,7 +104,11 @@ func dlTrack(tracks []Track, path string) (int, error) { } // check if song exists - songExists, _ := db.SongExists(trackCopy.Title, trackCopy.Artist, "") + songExists, err := db.SongExists(trackCopy.Title, trackCopy.Artist, "") + if err != nil { + logMessage := fmt.Sprintln("error checking song existence: ", err) + slog.Error(logMessage) + } if songExists { logMessage := fmt.Sprintf("'%s' by '%s' already downloaded\n", trackCopy.Title, trackCopy.Artist) slog.Info(logMessage) @@ -120,12 +124,19 @@ func dlTrack(tracks []Track, path string) (int, error) { } // Check if YouTube ID exists - ytIdExists, _ := db.SongExists("", "", ytID) + ytIdExists, err := db.SongExists("", "", ytID) + fmt.Printf("%s exists? = %v\n", ytID, ytIdExists) + if err != nil { + logMessage := fmt.Sprintln("error checking song existence: ", err) + slog.Error(logMessage) + } + if ytIdExists { // try to get the YouTube ID again logMessage := fmt.Sprintf("YouTube ID exists. Trying again: %s\n", ytID) + fmt.Println("WARN: ", logMessage) slog.Warn(logMessage) - ytID, err := GetYoutubeId(*trackCopy) + ytID, err = GetYoutubeId(*trackCopy) if ytID == "" || err != nil { logMessage := fmt.Sprintf("Error (1): '%s' by '%s' could not be downloaded: %s\n", trackCopy.Title, trackCopy.Artist, err) slog.Info(logMessage) @@ -133,7 +144,11 @@ func dlTrack(tracks []Track, path string) (int, error) { return } - ytIdExists, _ := db.SongExists("", "", ytID) + ytIdExists, err := db.SongExists("", "", ytID) + if err != nil { + logMessage := fmt.Sprintln("error checking song existence: ", err) + slog.Error(logMessage) + } if ytIdExists { logMessage := fmt.Sprintf("'%s' by '%s' could not be downloaded: YouTube ID (%s) exists\n", trackCopy.Title, trackCopy.Artist, ytID) slog.Error(logMessage) @@ -285,6 +300,11 @@ func processAndSaveSong(songFilePath, songTitle, songArtist, ytID string) error } defer db.Close() + err = db.RegisterSong(songTitle, songArtist, ytID) + if err != nil { + return err + } + audioBytes, err := convertStereoToMono(songFilePath) if err != nil { return fmt.Errorf("error converting song to mono: %v", err) @@ -308,11 +328,6 @@ func processAndSaveSong(songFilePath, songTitle, songArtist, ytID string) error } } - err = db.RegisterSong(songTitle, songArtist, ytID) - if err != nil { - return err - } - fmt.Println("Fingerprints saved in MongoDB successfully") return nil } diff --git a/spotify/youtube.go b/spotify/youtube.go index f2342ed..e54822c 100644 --- a/spotify/youtube.go +++ b/spotify/youtube.go @@ -79,7 +79,8 @@ func convertStringDurationToSeconds(durationStr string) int { // GetYoutubeId takes the query as string and returns the search results video ID's func GetYoutubeId(track Track) (string, error) { songDurationInSeconds := track.Duration - searchQuery := fmt.Sprintf("'%s' %s %s", track.Title, track.Artist, track.Album) + // searchQuery := fmt.Sprintf("'%s' %s %s", track.Title, track.Artist, track.Album) + searchQuery := fmt.Sprintf("'%s' %s", track.Title, track.Artist) searchResults, err := ytSearch(searchQuery, 10) if err != nil { diff --git a/utils/wav.go b/utils/wav.go index d2a2ae2..3aa215a 100644 --- a/utils/wav.go +++ b/utils/wav.go @@ -113,7 +113,7 @@ func ReadWavInfo(filename string) (*WavInfo, error) { } // WavBytesToFloat64 converts a slice of bytes from a .wav file to a slice of float64 samples -func WavBytesToFloat64(input []byte) ([]float64, error) { +func WavBytesToSamples(input []byte) ([]float64, error) { if len(input)%2 != 0 { return nil, errors.New("invalid input length") }