diff --git a/shazam/shazam.go b/shazam/shazam.go index 77623cc..c8f84b2 100644 --- a/shazam/shazam.go +++ b/shazam/shazam.go @@ -3,6 +3,7 @@ package shazam import ( "fmt" "math" + "song-recognition/models" "song-recognition/utils" "sort" "time" @@ -17,7 +18,6 @@ type Match struct { Score float64 } -// FindMatches processes the audio samples and finds matches in the database func FindMatches(audioSamples []float64, audioDuration float64, sampleRate int) ([]Match, time.Duration, error) { startTime := time.Now() logger := utils.GetLogger() @@ -30,9 +30,11 @@ func FindMatches(audioSamples []float64, audioDuration float64, sampleRate int) peaks := ExtractPeaks(spectrogram, audioDuration) fingerprints := Fingerprint(peaks, utils.GenerateUniqueID()) + var sampleCouples []models.Couple addresses := make([]uint32, 0, len(fingerprints)) for address := range fingerprints { addresses = append(addresses, address) + sampleCouples = append(sampleCouples, fingerprints[address]) } db, err := utils.NewDbClient() @@ -41,61 +43,103 @@ func FindMatches(audioSamples []float64, audioDuration float64, sampleRate int) } defer db.Close() - m, err := db.GetCouples(addresses) + couplesMap, err := db.GetCouples(addresses) if err != nil { return nil, time.Since(startTime), err } - matches := map[uint32][][2]uint32{} // songID -> [(sampleTime, dbTime)] - timestamps := map[uint32]uint32{} - - for address, couples := range m { + // Count occurrences of each couple to derive potential target zones + coupleCounts := make(map[uint32]map[uint32]int) + for _, couples := range couplesMap { for _, couple := range couples { - matches[couple.SongID] = append(matches[couple.SongID], [2]uint32{fingerprints[address].AnchorTimeMs, couple.AnchorTimeMs}) - timestamps[couple.SongID] = couple.AnchorTimeMs + key := (couple.SongID << 32) | uint32(couple.AnchorTimeMs) + if _, exists := coupleCounts[couple.SongID]; !exists { + coupleCounts[couple.SongID] = make(map[uint32]int) + } + coupleCounts[couple.SongID][key]++ } } - scores := analyzeRelativeTiming(matches) - - var matchList []Match - for songID, points := range scores { - song, songExists, err := db.GetSongByID(songID) - if !songExists { - logger.Info(fmt.Sprintf("song with ID (%v) doesn't exist", songID)) - continue + // Filter target zones with targets (couples) meeting or exceeding the threshold + threshold := 4 + filteredCouples := make(map[uint32][]models.Couple) + for songID, counts := range coupleCounts { + for key, count := range counts { + if count >= threshold { + filteredCouples[songID] = append(filteredCouples[songID], models.Couple{ + AnchorTimeMs: key & 0xFFFFFFFF, + SongID: songID, + }) + } } + } + + // Score matches by calculating mean absolute difference + var matches []Match + for songID, songCouples := range filteredCouples { + song, songExists, err := db.GetSongByID(songID) if err != nil { logger.Info(fmt.Sprintf("failed to get song by ID (%v): %v", songID, err)) continue } + if !songExists { + logger.Info(fmt.Sprintf("song with ID (%v) doesn't exist", songID)) + continue + } - match := Match{songID, song.Title, song.Artist, song.YouTubeID, timestamps[songID], points} - matchList = append(matchList, match) + m_a_d := meanAbsoluteDifference(songCouples, sampleCouples) + + tstamp := songCouples[len(songCouples)-1].AnchorTimeMs + match := Match{songID, song.Title, song.Artist, song.YouTubeID, tstamp, m_a_d} + matches = append(matches, match) } - sort.Slice(matchList, func(i, j int) bool { - return matchList[i].Score > matchList[j].Score + sort.Slice(matches, func(i, j int) bool { + return matches[i].Score > matches[j].Score }) - return matchList, time.Since(startTime), nil + // TODO: hanld case when there's no match for cmdHandlers + + return matches, time.Since(startTime), nil } -// AnalyzeRelativeTiming checks for consistent relative timing and returns a score -func analyzeRelativeTiming(matches map[uint32][][2]uint32) map[uint32]float64 { - scores := make(map[uint32]float64) - for songID, times := range matches { - count := 0 - for i := 0; i < len(times); i++ { - for j := i + 1; j < len(times); j++ { - sampleDiff := math.Abs(float64(times[i][0] - times[j][0])) - dbDiff := math.Abs(float64(times[i][1] - times[j][1])) - if math.Abs(sampleDiff-dbDiff) < 100 { // Allow some tolerance - count++ - } - } - } - scores[songID] = float64(count) +func meanAbsoluteDifference(A, B []models.Couple) float64 { + minLen := len(A) + if len(B) < minLen { + minLen = len(B) } - return scores + + var sumDiff float64 + for i := 0; i < minLen; i++ { + diff := math.Abs(float64(A[i].AnchorTimeMs - B[i].AnchorTimeMs)) + sumDiff += diff + } + + meanAbsDiff := sumDiff / float64(minLen) + return meanAbsDiff +} + +// Function to calculate Dynamic Time Warping distance +func dynamicTimeWarping(A, B []models.Couple) float64 { + lenA := len(A) + lenB := len(B) + + // Create a 2D array to store DTW distances + dtw := make([][]float64, lenA+1) + for i := range dtw { + dtw[i] = make([]float64, lenB+1) + for j := range dtw[i] { + dtw[i][j] = math.Inf(1) + } + } + dtw[0][0] = 0 + + for i := 1; i <= lenA; i++ { + for j := 1; j <= lenB; j++ { + cost := math.Abs(float64(A[i-1].AnchorTimeMs - B[j-1].AnchorTimeMs)) + dtw[i][j] = cost + math.Min(math.Min(dtw[i-1][j], dtw[i][j-1]), dtw[i-1][j-1]) + } + } + + return dtw[lenA][lenB] }