From 24bdf0593ea70dc069c05ea9b550c27c9bf16ccc Mon Sep 17 00:00:00 2001 From: mykola2312 Date: Sat, 2 Mar 2024 08:39:15 +0200 Subject: [PATCH] implement user opping (elevating permissions to admin), use Arc because dptree seems to clone deps everytime a dispatch happens --- src/bot/bot.rs | 6 ++++-- src/bot/op.rs | 37 +++++++++++++++++++++++++++++++------ src/db.rs | 9 ++++++++- src/db/user.rs | 22 +++++++++++++++++++--- src/main.rs | 4 +++- 5 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/bot/bot.rs b/src/bot/bot.rs index 0ed28c1..6312438 100644 --- a/src/bot/bot.rs +++ b/src/bot/bot.rs @@ -4,8 +4,10 @@ use std::env; use std::fmt; use std::str; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use teloxide::dispatching::{dialogue, dialogue::InMemStorage, UpdateHandler}; +use teloxide::types::Recipient; use teloxide::{prelude::*, update_listeners::Polling, utils::command::BotCommands}; use tracing::{event, Level}; @@ -28,7 +30,7 @@ where .expect(format!("env '{}' parse error", name).as_str()) } -pub async fn bot_main(db: SqlitePool) -> anyhow::Result<()> { +pub async fn bot_main(db: DbPool) -> anyhow::Result<()> { event!(Level::INFO, "start"); let bot = Bot::new(env::var("BOT_TOKEN")?); @@ -76,7 +78,7 @@ enum Command { Download(String), #[command(alias = "op")] - OP + OP, } async fn cmd_test(bot: Bot, msg: Message, _db: DbPool) -> HandlerResult { diff --git a/src/bot/op.rs b/src/bot/op.rs index 975c0da..e4ba59f 100644 --- a/src/bot/op.rs +++ b/src/bot/op.rs @@ -3,26 +3,51 @@ use teloxide::prelude::*; use tracing::{event, Level}; use super::types::HandlerResult; -use crate::db::{user::create_user, DbPool}; +use crate::db::user::{create_user, find_or_create_user}; +use crate::db::DbPool; pub async fn cmd_op(bot: Bot, msg: Message, db: DbPool) -> HandlerResult { let admins: i64 = sqlx::query("SELECT COUNT(*) FROM user WHERE is_admin = 1") - .fetch_one(&db) + .fetch_one(db.as_ref()) .await? .get(0); - if let Some(user) = msg.from() { + if let Some(tg_user) = msg.from() { if admins == 0 { - let user = create_user(db, user, true, true).await?; + let user = create_user(&db, tg_user, true, true).await?; + event!( Level::INFO, "opped {} - {}", user.tg_id, - user.username.unwrap_or(user.first_name) + user.username_or_name() ); bot.send_message(msg.chat.id, "Now you're an admin").await?; } else { - bot.send_message(msg.chat.id, "You can't do that anymore").await?; + let user = find_or_create_user(&db, tg_user).await?; + if user.is_admin == 1 { + if let Some(target) = msg.reply_to_message().and_then(|m| m.from()) { + let target = find_or_create_user(&db, target).await?; + sqlx::query("UPDATE user SET can_download = 1, is_admin = 1 WHERE id = $1;") + .bind(target.id) + .execute(db.as_ref()) + .await?; + + event!( + Level::INFO, + "opped {} - {}", + target.tg_id, + target.username_or_name() + ); + bot.send_message(msg.chat.id, "opped").await?; + } else { + bot.send_message(msg.chat.id, "You have to reply on target's message") + .await?; + } + } else { + bot.send_message(msg.chat.id, "You can't do that bruh") + .await?; + } } } diff --git a/src/db.rs b/src/db.rs index 4462801..5e9bc0c 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,9 +1,10 @@ use sqlx::migrate::MigrateDatabase; use sqlx::{Sqlite, SqlitePool}; +use std::sync::Arc; use super::util::make_database_url; -pub type DbPool = SqlitePool; +pub type DbPool = Arc; #[derive(sqlx::FromRow)] pub struct User { @@ -16,6 +17,12 @@ pub struct User { pub is_admin: i64, } +impl User { + pub fn username_or_name(&self) -> &String { + self.username.as_ref().unwrap_or(&self.first_name) + } +} + pub mod user; #[derive(sqlx::FromRow)] diff --git a/src/db/user.rs b/src/db/user.rs index d9413df..5a8bc33 100644 --- a/src/db/user.rs +++ b/src/db/user.rs @@ -3,7 +3,7 @@ use teloxide::types; use super::{DbPool, User}; pub async fn create_user( - db: DbPool, + db: &DbPool, user: &types::User, can_download: bool, is_admin: bool, @@ -19,12 +19,28 @@ pub async fn create_user( .bind(&user.last_name) .bind(can_download as i64) .bind(is_admin as i64) - .execute(&db) + .execute(db.as_ref()) .await?; let user: User = sqlx::query_as("SELECT * FROM user WHERE tg_id = $1 LIMIT 1;") .bind(user.id.0 as i64) - .fetch_one(&db) + .fetch_one(db.as_ref()) .await?; Ok(user) } + +pub async fn find_or_create_user(db: &DbPool, user: &types::User) -> Result { + let res: Result = + sqlx::query_as("SELECT * FROM user WHERE tg_id = $1 LIMIT 1;") + .bind(user.id.0 as i64) + .fetch_one(db.as_ref()) + .await; + + match res { + Ok(user) => return Ok(user), + Err(e) => match e { + sqlx::Error::RowNotFound => create_user(db, user, false, false).await, + _ => Err(e), + }, + } +} diff --git a/src/main.rs b/src/main.rs index 4bb61dc..18e411d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use dotenv::dotenv; mod bot; @@ -20,6 +22,6 @@ async fn main() -> anyhow::Result<()> { log_init(); let db = db_init().await; - bot_main(db).await?; + bot_main(Arc::from(db)).await?; Ok(()) }