diff --git a/Cargo.toml b/Cargo.toml index f8dd5bc..886be79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ anyhow = "1.0.75" dotenv = "0.15.0" tokio = { version = "1.32.0", features = ["rt-multi-thread", "macros", "process"] } teloxide = { version = "0.12.2", git ="https://github.com/teloxide/teloxide", features = ["macros"] } -sqlx = { version = "0.7.3", features = [ "runtime-tokio", "tls-native-tls" ] } +sqlx = { version = "0.7.3", features = [ "runtime-tokio", "tls-native-tls", "sqlite", "sqlx-sqlite" ] } serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.113" ordered-float = "4.2.0" diff --git a/src/bot.rs b/src/bot.rs index 283a1a5..0701399 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,2 +1,3 @@ pub mod bot; pub mod sanitize; +pub mod util; diff --git a/src/bot/bot.rs b/src/bot/bot.rs index ba95dc6..14703a5 100644 --- a/src/bot/bot.rs +++ b/src/bot/bot.rs @@ -9,6 +9,9 @@ use teloxide::dispatching::dialogue::InMemStorage; use teloxide::dispatching::UpdateHandler; use teloxide::types::InputFile; use teloxide::{prelude::*, update_listeners::Polling, utils::command::BotCommands}; +use sqlx::SqlitePool; + +use super::util::make_database_url; use crate::dl::delete_if_exists; use crate::dl::download; @@ -33,6 +36,9 @@ where } pub async fn bot_main() -> anyhow::Result<()> { + let db_path = make_database_url(); + let db = SqlitePool::connect(&db_path).await?; + let bot = Bot::new(env::var("BOT_TOKEN")?); let listener = Polling::builder(bot.clone()) .timeout(Duration::from_secs(parse_env("POLLING_TIMEOUT"))) @@ -41,7 +47,7 @@ pub async fn bot_main() -> anyhow::Result<()> { .build(); Dispatcher::builder(bot, schema()) - .dependencies(dptree::deps![InMemStorage::::new()]) + .dependencies(dptree::deps![db]) .enable_ctrlc_handler() .build() .dispatch_with_listener( diff --git a/src/bot/sanitize.rs b/src/bot/sanitize.rs index 7954af7..ecab59a 100644 --- a/src/bot/sanitize.rs +++ b/src/bot/sanitize.rs @@ -9,7 +9,7 @@ pub fn extract_url(text: &str) -> Option<&str> { let re = Regex::new(RE_URL).unwrap(); match re.find(text) { Some(m) => Some(m.as_str()), - None => None + None => None, } } @@ -23,8 +23,14 @@ mod tests { #[test] fn test_extract_url() { - assert_eq!(extract_url("test http://www.test.com/id/1"), Some("http://www.test.com/id/1")); - assert_eq!(extract_url("https://www.test.com 3"), Some("https://www.test.com")); + assert_eq!( + extract_url("test http://www.test.com/id/1"), + Some("http://www.test.com/id/1") + ); + assert_eq!( + extract_url("https://www.test.com 3"), + Some("https://www.test.com") + ); assert_eq!(extract_url("there is no any url"), None); } @@ -36,4 +42,4 @@ mod tests { let url = parse_url("https://youtu.be/00000000000").unwrap(); assert_eq!(url.host_str().unwrap(), "youtu.be"); } -} \ No newline at end of file +} diff --git a/src/bot/util.rs b/src/bot/util.rs new file mode 100644 index 0000000..3b99502 --- /dev/null +++ b/src/bot/util.rs @@ -0,0 +1,18 @@ +use std::path::Path; + +#[cfg(debug_assertions)] +const VAR_LIB: &str = "."; + +#[cfg(not(debug_assertions))] +const VAR_LIB: &str = "/var/lib/mk-dl-bot"; + +#[cfg(debug_assertions)] +const VAR_LOG: &str = "."; + +#[cfg(not(debug_assertions))] +const VAR_LOG: &str = "/var/log/mk-dl-bot"; + +pub fn make_database_url() -> String { + let path = Path::new(VAR_LIB).join("mk-dl-bot.db"); + format!("sqlite://{}", path.as_os_str().to_str().unwrap()).to_string() +} diff --git a/src/dl/spawn.rs b/src/dl/spawn.rs index dd0a5bd..0c058fa 100644 --- a/src/dl/spawn.rs +++ b/src/dl/spawn.rs @@ -35,7 +35,7 @@ impl fmt::Display for SpawnError { } /* !!! The argument list could be exploited in a way to inject malicious arguments !!! - !!! and alter the way program executes and/or gain access to system !!! */ +!!! and alter the way program executes and/or gain access to system !!! */ pub async fn spawn(program: &str, args: I) -> Result where I: IntoIterator,