diff --git a/main.go b/main.go index c477d22..c9c07a8 100644 --- a/main.go +++ b/main.go @@ -1,251 +1,319 @@ -// SPDX-FileCopyrightText: 2023 The Pion community -// SPDX-License-Identifier: MIT - -//go:build !js -// +build !js - -// save-to-disk is a simple application that shows how to record your webcam/microphone using Pion WebRTC and save VP8/Opus to disk. package main import ( + "context" + "encoding/base64" + "encoding/json" "fmt" - "io" - "os" - "strings" - "time" - - "github.com/pion/interceptor" - "github.com/pion/interceptor/pkg/intervalpli" - "github.com/pion/webrtc/v4" - + "log" + "log/slog" + "net/http" "song-recognition/shazam" - "song-recognition/signal" + "song-recognition/spotify" + "song-recognition/utils" + "song-recognition/wav" + "strings" - "github.com/pion/webrtc/v4/pkg/media" - "github.com/pion/webrtc/v4/pkg/media/oggwriter" + "github.com/gin-gonic/gin" + "github.com/mdobak/go-xerrors" + + socketio "github.com/googollee/go-socket.io" ) -func saveToDisk(i media.Writer, track *webrtc.TrackRemote) { - defer func() { - if err := i.Close(); err != nil { - panic(err) +const ( + tmpSongDir = "/home/chigozirim/Documents/my-docs/song-recognition/songs/" +) + +func GinMiddleware(allowOrigin string) gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin) + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, Content-Length, X-CSRF-Token, Token, session, Origin, Host, Connection, Accept-Encoding, Accept-Language, X-Requested-With") + + if c.Request.Method == http.MethodOptions { + c.AbortWithStatus(http.StatusNoContent) + return + } + + c.Request.Header.Del("Origin") + + c.Next() + } +} + +func downloadStatus(statusType, message string) string { + data := map[string]interface{}{"type": statusType, "message": message} + jsonData, err := json.Marshal(data) + if err != nil { + logger := utils.GetLogger() + ctx := context.Background() + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to marshal data.", slog.Any("error", err)) + return "" + } + return string(jsonData) +} + +type RecordData struct { + Audio string `json:"audio"` + Channels int `json:"channels"` + SampleRate int `json:"sampleRate"` + SampleSize int `json:"sampleSize"` +} + +func main() { + router := gin.New() + + server := socketio.NewServer(nil) + + logger := utils.GetLogger() + ctx := context.Background() + + server.OnConnect("/", func(socket socketio.Conn) error { + socket.SetContext("") + log.Println("CONNECTED: ", socket.ID()) + + return nil + }) + + server.OnEvent("/", "totalSongs", func(socket socketio.Conn) { + db, err := utils.NewDbClient() + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "error connecting to DB", slog.Any("error", err)) + return + } + defer db.Close() + + totalSongs, err := db.TotalSongs() + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "Log error getting total songs", slog.Any("error", err)) + return + } + + socket.Emit("totalSongs", totalSongs) + }) + + server.OnEvent("/", "newDownload", func(socket socketio.Conn, spotifyURL string) { + if len(spotifyURL) == 0 { + logger.Debug("Spotify URL required.") + return + } + + splitURL := strings.Split(spotifyURL, "/") + + if len(splitURL) < 2 { + logger.Debug("invalid Spotify URL.") + return + } + + spotifyID := splitURL[len(splitURL)-1] + if strings.Contains(spotifyID, "?") { + spotifyID = strings.Split(spotifyID, "?")[0] + } + + // Handle album download + if strings.Contains(spotifyURL, "album") { + tracksInAlbum, err := spotify.AlbumInfo(spotifyURL) + if err != nil { + fmt.Println("log error: ", err) + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + logger.Info(err.Error()) + } else { + err := xerrors.New(err) + logger.ErrorContext(ctx, "error getting album info", slog.Any("error", err)) + } + return + } + + statusMsg := fmt.Sprintf("%v songs found in album.", len(tracksInAlbum)) + socket.Emit("downloadStatus", downloadStatus("info", statusMsg)) + + totalTracksDownloaded, err := spotify.DlAlbum(spotifyURL, tmpSongDir) + if err != nil { + socket.Emit("downloadStatus", downloadStatus("error", "Couldn't to download album.")) + + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to download album.", slog.Any("error", err)) + return + } + + statusMsg = fmt.Sprintf("%d songs downloaded from album", totalTracksDownloaded) + socket.Emit("downloadStatus", downloadStatus("success", statusMsg)) + + } + + // Handle playlist download + if strings.Contains(spotifyURL, "playlist") { + tracksInPL, err := spotify.PlaylistInfo(spotifyURL) + if err != nil { + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + logger.Info(err.Error()) + } else { + err := xerrors.New(err) + logger.ErrorContext(ctx, "error getting album info", slog.Any("error", err)) + } + return + } + + statusMsg := fmt.Sprintf("%v songs found in playlist.", len(tracksInPL)) + socket.Emit("downloadStatus", downloadStatus("info", statusMsg)) + + totalTracksDownloaded, err := spotify.DlPlaylist(spotifyURL, tmpSongDir) + if err != nil { + socket.Emit("downloadStatus", downloadStatus("error", "Couldn't download playlist.")) + + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to download playlist.", slog.Any("error", err)) + return + } + + statusMsg = fmt.Sprintf("%d songs downloaded from playlist.", totalTracksDownloaded) + socket.Emit("downloadStatus", downloadStatus("success", statusMsg)) + + } + + // Handle track download + if strings.Contains(spotifyURL, "track") { + trackInfo, err := spotify.TrackInfo(spotifyURL) + if err != nil { + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + logger.Info(err.Error()) + } else { + err := xerrors.New(err) + logger.ErrorContext(ctx, "error getting album info", slog.Any("error", err)) + } + return + } + + // check if track already exist + db, err := utils.NewDbClient() + if err != nil { + fmt.Errorf("Log - error connecting to DB: %d", err) + } + defer db.Close() + + song, songExists, err := db.GetSongByKey(utils.GenerateSongKey(trackInfo.Title, trackInfo.Artist)) + if err == nil { + if songExists { + statusMsg := fmt.Sprintf( + "'%s' by '%s' already exists in the database (https://www.youtube.com/watch?v=%s)", + song.Title, song.Artist, song.YouTubeID) + + socket.Emit("downloadStatus", downloadStatus("error", statusMsg)) + return + } + } else { + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to get song by key.", slog.Any("error", err)) + } + + totalDownloads, err := spotify.DlSingleTrack(spotifyURL, tmpSongDir) + if err != nil { + if len(err.Error()) <= 25 { + socket.Emit("downloadStatus", downloadStatus("error", err.Error())) + logger.Info(err.Error()) + } else { + err := xerrors.New(err) + logger.ErrorContext(ctx, "error getting album info", slog.Any("error", err)) + } + return + } + + statusMsg := "" + if totalDownloads != 1 { + statusMsg = fmt.Sprintf("'%s' by '%s' failed to download", trackInfo.Title, trackInfo.Artist) + socket.Emit("downloadStatus", downloadStatus("error", statusMsg)) + } else { + statusMsg = fmt.Sprintf("'%s' by '%s' was downloaded", trackInfo.Title, trackInfo.Artist) + socket.Emit("downloadStatus", downloadStatus("success", statusMsg)) + } + + } + + return + }) + + server.OnEvent("/", "record", func(socket socketio.Conn, recordData string) { + var recData RecordData + if err := json.Unmarshal([]byte(recordData), &recData); err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "Failed to unmarshal record data.", slog.Any("error", err)) + return + } + + // Decode base64 data + decodedAudioData, err := base64.StdEncoding.DecodeString(recData.Audio) + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to decode base64 data.", slog.Any("error", err)) + return + } + + // Save the decoded data to a file + channels := recData.Channels + sampleRate := recData.SampleRate + bitsPerSample := recData.SampleSize + fmt.Println(channels, sampleRate, bitsPerSample) + + err = wav.WriteWavFile("blob.wav", decodedAudioData, sampleRate, channels, bitsPerSample) + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to write wav file.", slog.Any("error", err)) + } + + samples, err := wav.WavBytesToSamples(decodedAudioData) + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to convert decodedData to samples.", slog.Any("error", err)) + } + + matches, err := shazam.FindMatches(samples, 10.0, sampleRate) + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to get matches.", slog.Any("error", err)) + } + fmt.Println("Matches! : ", matches) + + jsonData, err := json.Marshal(matches) + if len(matches) > 10 { + jsonData, _ = json.Marshal(matches[:10]) + } + + if err != nil { + err := xerrors.New(err) + logger.ErrorContext(ctx, "failed to marshal matches.", slog.Any("error", err)) + return + } + + socket.Emit("matches", string(jsonData)) + }) + + server.OnError("/", func(s socketio.Conn, e error) { + log.Println("meet error:", e) + }) + + server.OnDisconnect("/", func(s socketio.Conn, reason string) { + log.Println("closed", reason) + }) + + go func() { + if err := server.Serve(); err != nil { + log.Fatalf("socketio listen error: %s\n", err) } }() + defer server.Close() - for { - rtpPacket, _, err := track.ReadRTP() - if err != nil { - fmt.Println(err) - return - } - if err := i.WriteRTP(rtpPacket); err != nil { - fmt.Println(err) - return - } + router.Use(GinMiddleware("http://localhost:3000")) + router.GET("/socket.io/*any", gin.WrapH(server)) + router.POST("/socket.io/*any", gin.WrapH(server)) + + if err := router.Run(":5000"); err != nil { + log.Fatal("failed run app: ", err) } } - -func saveToBytes(track *webrtc.TrackRemote) ([]byte, error) { - var audioData []byte - - for { - rtpPacket, _, err := track.ReadRTP() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - // Extract audio payload from RTP packet - payload := rtpPacket.Payload - - // Append audio payload to audioData - audioData = append(audioData, payload...) - // fmt.Println("ByteArray: ", audioData) - } - - return audioData, nil -} - -func MatchSampleAudio(track *webrtc.TrackRemote) (string, error) { - // Use time.After to stop after 15 seconds - stop := time.After(50 * time.Second) - - // Use a ticker to process sampleAudio every 2 seconds - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - - var sampleAudio []byte - - for { - select { - case <-ticker.C: - // Process sampleAudio every 2 seconds - if len(sampleAudio) > 0 { - match, err := shazam.Match(sampleAudio) - if err != nil { - fmt.Println(err) - return "", nil - } - - // Reset sampleAudio for fresh input - // sampleAudio = nil - if len(match) > 0 { - fmt.Println("FOUND A MATCH! - ", match) - // return match, nil - } - } - case <-stop: - // Stop after 15 seconds - fmt.Println("Stopped after 15 seconds") - return "", nil - default: - // Read RTP packets and accumulate sampleAudio - rtpPacket, _, err := track.ReadRTP() - if err != nil { - if err != io.EOF { - return "", fmt.Errorf("error reading RTP packet: %d", err) - } - return "", err - } - - // Extract audio payload from RTP packet - payload := rtpPacket.Payload - - sampleAudio = append(sampleAudio, payload...) - } - } -} - -// nolint:gocognit -func main() { - // Everything below is the Pion WebRTC API! Thanks for using it ❤️. - - // Create a MediaEngine object to configure the supported codec - m := &webrtc.MediaEngine{} - - // Setup the codecs you want to use. - // We'll use Opus, but you can also define your own - if err := m.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus, ClockRate: 44100, Channels: 1, SDPFmtpLine: "", RTCPFeedback: nil}, - PayloadType: 111, - }, webrtc.RTPCodecTypeAudio); err != nil { - panic(err) - } - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. - i := &interceptor.Registry{} - - // Register a intervalpli factory - // This interceptor sends a PLI every 3 seconds. A PLI causes a keyframe to be generated by the sender. - intervalPliFactory, err := intervalpli.NewReceiverInterceptor() - if err != nil { - panic(err) - } - i.Add(intervalPliFactory) - - // Use the default set of Interceptors - if err = webrtc.RegisterDefaultInterceptors(m, i); err != nil { - panic(err) - } - - // Create the API object with the MediaEngine - api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(i)) - - // Prepare the configuration - config := webrtc.Configuration{ - ICEServers: []webrtc.ICEServer{ - { - URLs: []string{"stun:stun.l.google.com:19302"}, - }, - }, - } - - // Create a new RTCPeerConnection - peerConnection, err := api.NewPeerConnection(config) - if err != nil { - panic(err) - } - - // Allow us to receive 1 audio track - if _, err = peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio); err != nil { - panic(err) - } - - // Create an Ogg file for audio output - oggFile, err := oggwriter.New("output.ogg", 44100, 1) - if err != nil { - panic(err) - } - - // Set a handler for when a new remote track starts, this handler saves buffers to disk as - // an Ogg file. - peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - codec := track.Codec() - if strings.EqualFold(codec.MimeType, webrtc.MimeTypeOpus) { - fmt.Println("Got Opus track, saving to disk as output.opus (44.1 kHz, 1 channel)") - saveToDisk(oggFile, track) - } - }) - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - fmt.Printf("Connection State has changed %s \n", connectionState.String()) - - if connectionState == webrtc.ICEConnectionStateConnected { - fmt.Println("Ctrl+C the remote client to stop the demo") - } else if connectionState == webrtc.ICEConnectionStateFailed || connectionState == webrtc.ICEConnectionStateClosed { - if closeErr := oggFile.Close(); closeErr != nil { - panic(closeErr) - } - - fmt.Println("Done writing media files") - - // Gracefully shutdown the peer connection - if closeErr := peerConnection.Close(); closeErr != nil { - panic(closeErr) - } - - os.Exit(0) - } - }) - - // Wait for the offer to be pasted - offer := webrtc.SessionDescription{} - signal.Decode(signal.MustReadStdin(), &offer) - - // Set the remote SessionDescription - err = peerConnection.SetRemoteDescription(offer) - if err != nil { - panic(err) - } - - // Create answer - answer, err := peerConnection.CreateAnswer(nil) - if err != nil { - panic(err) - } - - // Create channel that is blocked until ICE Gathering is complete - gatherComplete := webrtc.GatheringCompletePromise(peerConnection) - - // Sets the LocalDescription, and starts our UDP listeners - err = peerConnection.SetLocalDescription(answer) - if err != nil { - panic(err) - } - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - <-gatherComplete - - // Output the answer in base64 so we can paste it in browser - fmt.Println(signal.Encode(*peerConnection.LocalDescription())) - - // Block forever - select {} -}