From 84a7cfe1517fecc404664cc9badb89c2ce65364b Mon Sep 17 00:00:00 2001 From: seth Date: Fri, 26 Apr 2024 20:53:13 -0400 Subject: [PATCH] use http client from context --- src/api/dadjoke.rs | 15 ++--- src/api/github.rs | 17 ++--- src/api/mod.rs | 62 +++++++------------ src/api/paste_gg.rs | 17 ++--- src/api/pluralkit.rs | 9 ++- src/api/prism_meta.rs | 12 ++-- src/api/rory.rs | 14 ++--- src/commands/general/joke.rs | 2 +- src/commands/general/rory.rs | 2 +- src/commands/general/stars.rs | 5 +- src/commands/mod.rs | 2 +- src/commands/moderation/set_welcome.rs | 9 ++- src/handlers/event/analyze_logs/issues.rs | 5 +- src/handlers/event/analyze_logs/mod.rs | 2 +- .../event/analyze_logs/providers/0x0.rs | 6 +- .../analyze_logs/providers/attachment.rs | 9 +-- .../event/analyze_logs/providers/haste.rs | 6 +- .../event/analyze_logs/providers/mclogs.rs | 6 +- .../event/analyze_logs/providers/mod.rs | 12 ++-- .../event/analyze_logs/providers/paste_gg.rs | 8 +-- .../event/analyze_logs/providers/pastebin.rs | 6 +- src/handlers/event/expand_link.rs | 8 +-- src/handlers/event/mod.rs | 16 ++--- src/handlers/event/pluralkit.rs | 19 ++++-- src/main.rs | 14 ++++- src/utils/messages.rs | 14 +++-- 26 files changed, 148 insertions(+), 149 deletions(-) diff --git a/src/api/dadjoke.rs b/src/api/dadjoke.rs index 642d656..ac6ca5f 100644 --- a/src/api/dadjoke.rs +++ b/src/api/dadjoke.rs @@ -1,18 +1,11 @@ +use super::{HttpClient, HttpClientExt}; + use eyre::Result; -use log::debug; const DADJOKE: &str = "https://icanhazdadjoke.com"; -pub async fn get_joke() -> Result { - debug!("Making request to {DADJOKE}"); +pub async fn get_joke(http: &HttpClient) -> Result { + let joke = http.get_request(DADJOKE).await?.text().await?; - let resp = super::client() - .get(DADJOKE) - .header("Accept", "text/plain") - .send() - .await?; - resp.error_for_status_ref()?; - - let joke = resp.text().await?; Ok(joke) } diff --git a/src/api/github.rs b/src/api/github.rs index f0a1fa4..18b533a 100644 --- a/src/api/github.rs +++ b/src/api/github.rs @@ -1,18 +1,11 @@ -use std::sync::OnceLock; - -use eyre::{Context, OptionExt, Result}; +use eyre::{OptionExt, Result, WrapErr}; use log::debug; use octocrab::Octocrab; -fn octocrab() -> &'static Octocrab { - static OCTOCRAB: OnceLock = OnceLock::new(); - OCTOCRAB.get_or_init(Octocrab::default) -} - -pub async fn get_latest_prism_version() -> Result { +pub async fn get_latest_prism_version(octocrab: &Octocrab) -> Result { debug!("Fetching the latest version of Prism Launcher"); - let version = octocrab() + let version = octocrab .repos("PrismLauncher", "PrismLauncher") .releases() .get_latest() @@ -22,10 +15,10 @@ pub async fn get_latest_prism_version() -> Result { Ok(version) } -pub async fn get_prism_stargazers_count() -> Result { +pub async fn get_prism_stargazers_count(octocrab: &Octocrab) -> Result { debug!("Fetching Prism Launcher's stargazer count"); - let stargazers_count = octocrab() + let stargazers_count = octocrab .repos("PrismLauncher", "PrismLauncher") .get() .await diff --git a/src/api/mod.rs b/src/api/mod.rs index 711334c..cb1c7bf 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,9 +1,5 @@ -use std::sync::OnceLock; - -use eyre::Result; -use log::debug; -use reqwest::{Client, Response}; -use serde::de::DeserializeOwned; +use log::trace; +use reqwest::Response; pub mod dadjoke; pub mod github; @@ -12,43 +8,29 @@ pub mod pluralkit; pub mod prism_meta; pub mod rory; -pub fn client() -> &'static reqwest::Client { - static CLIENT: OnceLock = OnceLock::new(); - CLIENT.get_or_init(|| { +pub type HttpClient = reqwest::Client; + +pub trait HttpClientExt { + // sadly i can't implement the actual Default trait :/ + fn default() -> Self; + async fn get_request(&self, url: &str) -> Result; +} + +impl HttpClientExt for HttpClient { + fn default() -> Self { let version = option_env!("CARGO_PKG_VERSION").unwrap_or("development"); let user_agent = format!("refraction/{version}"); - Client::builder() + reqwest::ClientBuilder::new() .user_agent(user_agent) .build() .unwrap_or_default() - }) -} - -pub async fn get_url(url: &str) -> Result { - debug!("Making request to {url}"); - let resp = client().get(url).send().await?; - resp.error_for_status_ref()?; - - Ok(resp) -} - -pub async fn text_from_url(url: &str) -> Result { - let resp = get_url(url).await?; - - let text = resp.text().await?; - Ok(text) -} - -pub async fn bytes_from_url(url: &str) -> Result> { - let resp = get_url(url).await?; - - let bytes = resp.bytes().await?; - Ok(bytes.to_vec()) -} - -pub async fn json_from_url(url: &str) -> Result { - let resp = get_url(url).await?; - - let json = resp.json().await?; - Ok(json) + } + + async fn get_request(&self, url: &str) -> Result { + trace!("Making request to {url}"); + let resp = self.get(url).send().await?; + resp.error_for_status_ref()?; + + Ok(resp) + } } diff --git a/src/api/paste_gg.rs b/src/api/paste_gg.rs index 01df669..a4004cb 100644 --- a/src/api/paste_gg.rs +++ b/src/api/paste_gg.rs @@ -1,5 +1,6 @@ +use super::{HttpClient, HttpClientExt}; + use eyre::{eyre, OptionExt, Result}; -use log::debug; use serde::{Deserialize, Serialize}; const PASTE_GG: &str = "https://api.paste.gg/v1"; @@ -27,11 +28,9 @@ pub struct Files { pub name: Option, } -pub async fn files_from(id: &str) -> Result> { +pub async fn files_from(http: &HttpClient, id: &str) -> Result> { let url = format!("{PASTE_GG}{PASTES}/{id}/files"); - debug!("Making request to {url}"); - - let resp: Response = super::json_from_url(&url).await?; + let resp: Response = http.get_request(&url).await?.json().await?; if resp.status == Status::Error { let message = resp @@ -44,9 +43,13 @@ pub async fn files_from(id: &str) -> Result> { } } -pub async fn get_raw_file(paste_id: &str, file_id: &str) -> eyre::Result { +pub async fn get_raw_file( + http: &HttpClient, + paste_id: &str, + file_id: &str, +) -> eyre::Result { let url = format!("{PASTE_GG}{PASTES}/{paste_id}/files/{file_id}/raw"); - let text = super::text_from_url(&url).await?; + let text = http.get_request(&url).await?.text().await?; Ok(text) } diff --git a/src/api/pluralkit.rs b/src/api/pluralkit.rs index 2fdb4cb..0b16793 100644 --- a/src/api/pluralkit.rs +++ b/src/api/pluralkit.rs @@ -1,5 +1,6 @@ +use super::{HttpClient, HttpClientExt}; + use eyre::{Context, Result}; -use log::debug; use poise::serenity_prelude::{MessageId, UserId}; use serde::{Deserialize, Serialize}; @@ -11,11 +12,9 @@ pub struct Message { const PLURAL_KIT: &str = "https://api.pluralkit.me/v2"; const MESSAGES: &str = "/messages"; -pub async fn sender_from(message_id: MessageId) -> Result { +pub async fn sender_from(http: &HttpClient, message_id: MessageId) -> Result { let url = format!("{PLURAL_KIT}{MESSAGES}/{message_id}"); - debug!("Making request to {url}"); - - let resp: Message = super::json_from_url(&url).await?; + let resp: Message = http.get_request(&url).await?.json().await?; let id: u64 = resp.sender.parse().wrap_err_with(|| { diff --git a/src/api/prism_meta.rs b/src/api/prism_meta.rs index e2114b7..f7efb0a 100644 --- a/src/api/prism_meta.rs +++ b/src/api/prism_meta.rs @@ -1,5 +1,6 @@ +use super::{HttpClient, HttpClientExt}; + use eyre::{OptionExt, Result}; -use log::debug; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize)] @@ -14,14 +15,9 @@ pub struct MinecraftPackageJson { const META: &str = "https://meta.prismlauncher.org/v1"; const MINECRAFT_PACKAGEJSON: &str = "/net.minecraft/package.json"; -pub async fn latest_minecraft_version() -> Result { +pub async fn latest_minecraft_version(http: &HttpClient) -> Result { let url = format!("{META}{MINECRAFT_PACKAGEJSON}"); - - debug!("Making request to {url}"); - let resp = super::client().get(url).send().await?; - resp.error_for_status_ref()?; - - let data: MinecraftPackageJson = resp.json().await?; + let data: MinecraftPackageJson = http.get_request(&url).await?.json().await?; let version = data .recommended diff --git a/src/api/rory.rs b/src/api/rory.rs index bb64a0c..3fffbbe 100644 --- a/src/api/rory.rs +++ b/src/api/rory.rs @@ -1,5 +1,6 @@ +use super::{HttpClient, HttpClientExt}; + use eyre::{Context, Result}; -use log::debug; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize)] @@ -12,16 +13,13 @@ pub struct Response { const RORY: &str = "https://rory.cat"; const PURR: &str = "/purr"; -pub async fn get(id: Option) -> Result { +pub async fn get(http: &HttpClient, id: Option) -> Result { let target = id.map(|id| id.to_string()).unwrap_or_default(); let url = format!("{RORY}{PURR}/{target}"); - debug!("Making request to {url}"); - - let resp = super::client().get(url).send().await?; - resp.error_for_status_ref()?; - - let data: Response = resp + let data: Response = http + .get_request(&url) + .await? .json() .await .wrap_err("Couldn't parse the rory response!")?; diff --git a/src/commands/general/joke.rs b/src/commands/general/joke.rs index a064997..e08282e 100644 --- a/src/commands/general/joke.rs +++ b/src/commands/general/joke.rs @@ -9,7 +9,7 @@ pub async fn joke(ctx: Context<'_>) -> Result<(), Error> { trace!("Running joke command"); ctx.defer().await?; - let joke = dadjoke::get_joke().await?; + let joke = dadjoke::get_joke(&ctx.data().http_client).await?; ctx.say(joke).await?; Ok(()) diff --git a/src/commands/general/rory.rs b/src/commands/general/rory.rs index bc8cd63..f02783d 100644 --- a/src/commands/general/rory.rs +++ b/src/commands/general/rory.rs @@ -14,7 +14,7 @@ pub async fn rory( ctx.defer().await?; - let rory = rory::get(id).await?; + let rory = rory::get(&ctx.data().http_client, id).await?; let embed = { let embed = CreateEmbed::new(); diff --git a/src/commands/general/stars.rs b/src/commands/general/stars.rs index 703586a..603f435 100644 --- a/src/commands/general/stars.rs +++ b/src/commands/general/stars.rs @@ -8,6 +8,7 @@ use poise::CreateReply; #[poise::command(slash_command, prefix_command, track_edits = true)] pub async fn stars(ctx: Context<'_>) -> Result<(), Error> { trace!("Running stars command"); + let octocrab = &ctx.data().octocrab; ctx.defer().await?; @@ -15,13 +16,13 @@ pub async fn stars(ctx: Context<'_>) -> Result<(), Error> { if let Ok(count) = storage.launcher_stargazer_count().await { count } else { - let count = api::github::get_prism_stargazers_count().await?; + let count = api::github::get_prism_stargazers_count(octocrab).await?; storage.cache_launcher_stargazer_count(count).await?; count } } else { trace!("Not caching launcher stargazer count, as we're running without a storage backend"); - api::github::get_prism_stargazers_count().await? + api::github::get_prism_stargazers_count(octocrab).await? }; let embed = CreateEmbed::new() diff --git a/src/commands/mod.rs b/src/commands/mod.rs index a849933..2292afd 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -32,7 +32,7 @@ module_macro!(moderation); pub type Command = poise::Command; -pub fn get() -> Vec { +pub fn all() -> Vec { vec![ general!(help), general!(joke), diff --git a/src/commands/moderation/set_welcome.rs b/src/commands/moderation/set_welcome.rs index 4882af4..be35ac0 100644 --- a/src/commands/moderation/set_welcome.rs +++ b/src/commands/moderation/set_welcome.rs @@ -1,6 +1,6 @@ use std::{fmt::Write, str::FromStr}; -use crate::{api, utils, Context, Error}; +use crate::{api::HttpClientExt, utils, Context, Error}; use eyre::Result; use log::trace; @@ -138,7 +138,12 @@ pub async fn set_welcome( let downloaded = attachment.download().await?; String::from_utf8(downloaded)? } else if let Some(url) = url { - api::text_from_url(&url).await? + ctx.data() + .http_client + .get_request(&url) + .await? + .text() + .await? } else { ctx.say("A text file or URL must be provided!").await?; return Ok(()); diff --git a/src/handlers/event/analyze_logs/issues.rs b/src/handlers/event/analyze_logs/issues.rs index 730c9e9..00b0e5e 100644 --- a/src/handlers/event/analyze_logs/issues.rs +++ b/src/handlers/event/analyze_logs/issues.rs @@ -192,19 +192,20 @@ async fn outdated_launcher(log: &str, data: &Data) -> Result { return Ok(None); }; + let octocrab = &data.octocrab; let version_from_log = captures[0].replace("Prism Launcher version: ", ""); let latest_version = if let Some(storage) = &data.storage { if let Ok(version) = storage.launcher_version().await { version } else { - let version = api::github::get_latest_prism_version().await?; + let version = api::github::get_latest_prism_version(octocrab).await?; storage.cache_launcher_version(&version).await?; version } } else { trace!("Not caching launcher version, as we're running without a storage backend"); - api::github::get_latest_prism_version().await? + api::github::get_latest_prism_version(octocrab).await? }; if version_from_log < latest_version { diff --git a/src/handlers/event/analyze_logs/mod.rs b/src/handlers/event/analyze_logs/mod.rs index cb2b0b4..6b9cf3e 100644 --- a/src/handlers/event/analyze_logs/mod.rs +++ b/src/handlers/event/analyze_logs/mod.rs @@ -19,7 +19,7 @@ pub async fn handle(ctx: &Context, message: &Message, data: &Data) -> Result<()> ); let channel = message.channel_id; - let log = find_log(message).await; + let log = find_log(&data.http_client, message).await; if log.is_err() { let embed = CreateEmbed::new() diff --git a/src/handlers/event/analyze_logs/providers/0x0.rs b/src/handlers/event/analyze_logs/providers/0x0.rs index 14f7feb..300ae0e 100644 --- a/src/handlers/event/analyze_logs/providers/0x0.rs +++ b/src/handlers/event/analyze_logs/providers/0x0.rs @@ -1,4 +1,4 @@ -use crate::api; +use crate::api::{HttpClient, HttpClientExt}; use std::sync::OnceLock; @@ -21,8 +21,8 @@ impl super::LogProvider for _0x0 { .nth(0) } - async fn fetch(&self, content: &str) -> Result { - let log = api::text_from_url(content).await?; + async fn fetch(&self, http: &HttpClient, content: &str) -> Result { + let log = http.get_request(content).await?.text().await?; Ok(log) } diff --git a/src/handlers/event/analyze_logs/providers/attachment.rs b/src/handlers/event/analyze_logs/providers/attachment.rs index 25ce3cc..4e09c67 100644 --- a/src/handlers/event/analyze_logs/providers/attachment.rs +++ b/src/handlers/event/analyze_logs/providers/attachment.rs @@ -1,9 +1,9 @@ +use crate::api::{HttpClient, HttpClientExt}; + use eyre::Result; use log::trace; use poise::serenity_prelude::Message; -use crate::api; - pub struct Attachment; impl super::LogProvider for Attachment { @@ -21,9 +21,10 @@ impl super::LogProvider for Attachment { .nth(0) } - async fn fetch(&self, content: &str) -> Result { - let attachment = api::bytes_from_url(content).await?; + async fn fetch(&self, http: &HttpClient, content: &str) -> Result { + let attachment = http.get_request(content).await?.bytes().await?.to_vec(); let log = String::from_utf8(attachment)?; + Ok(log) } } diff --git a/src/handlers/event/analyze_logs/providers/haste.rs b/src/handlers/event/analyze_logs/providers/haste.rs index c4113d5..3ad1c18 100644 --- a/src/handlers/event/analyze_logs/providers/haste.rs +++ b/src/handlers/event/analyze_logs/providers/haste.rs @@ -1,4 +1,4 @@ -use crate::api; +use crate::api::{HttpClient, HttpClientExt}; use std::sync::OnceLock; @@ -22,9 +22,9 @@ impl super::LogProvider for Haste { super::get_first_capture(regex, &message.content) } - async fn fetch(&self, content: &str) -> Result { + async fn fetch(&self, http: &HttpClient, content: &str) -> Result { let url = format!("{HASTE}{RAW}/{content}"); - let log = api::text_from_url(&url).await?; + let log = http.get_request(&url).await?.text().await?; Ok(log) } diff --git a/src/handlers/event/analyze_logs/providers/mclogs.rs b/src/handlers/event/analyze_logs/providers/mclogs.rs index b0f0c35..e89009a 100644 --- a/src/handlers/event/analyze_logs/providers/mclogs.rs +++ b/src/handlers/event/analyze_logs/providers/mclogs.rs @@ -1,4 +1,4 @@ -use crate::api; +use crate::api::{HttpClient, HttpClientExt}; use std::sync::OnceLock; @@ -21,9 +21,9 @@ impl super::LogProvider for MCLogs { super::get_first_capture(regex, &message.content) } - async fn fetch(&self, content: &str) -> Result { + async fn fetch(&self, http: &HttpClient, content: &str) -> Result { let url = format!("{MCLOGS}{RAW}/{content}"); - let log = api::text_from_url(&url).await?; + let log = http.get_request(&url).await?.text().await?; Ok(log) } diff --git a/src/handlers/event/analyze_logs/providers/mod.rs b/src/handlers/event/analyze_logs/providers/mod.rs index e20547f..bb9aa4e 100644 --- a/src/handlers/event/analyze_logs/providers/mod.rs +++ b/src/handlers/event/analyze_logs/providers/mod.rs @@ -1,3 +1,5 @@ +use crate::api::HttpClient; + use std::slice::Iter; use enum_dispatch::enum_dispatch; @@ -21,7 +23,7 @@ mod pastebin; #[enum_dispatch] pub trait LogProvider { async fn find_match(&self, message: &Message) -> Option; - async fn fetch(&self, content: &str) -> Result; + async fn fetch(&self, http: &HttpClient, content: &str) -> Result; } fn get_first_capture(regex: &Regex, string: &str) -> Option { @@ -41,7 +43,7 @@ enum Provider { } impl Provider { - pub fn interator() -> Iter<'static, Provider> { + pub fn iterator() -> Iter<'static, Provider> { static PROVIDERS: [Provider; 6] = [ Provider::_0x0st(_0x0st), Provider::Attachment(Attachment), @@ -54,12 +56,12 @@ impl Provider { } } -pub async fn find_log(message: &Message) -> Result> { - let providers = Provider::interator(); +pub async fn find_log(http: &HttpClient, message: &Message) -> Result> { + let providers = Provider::iterator(); for provider in providers { if let Some(found) = provider.find_match(message).await { - let log = provider.fetch(&found).await?; + let log = provider.fetch(http, &found).await?; return Ok(Some(log)); } } diff --git a/src/handlers/event/analyze_logs/providers/paste_gg.rs b/src/handlers/event/analyze_logs/providers/paste_gg.rs index 60f3c4b..8e69514 100644 --- a/src/handlers/event/analyze_logs/providers/paste_gg.rs +++ b/src/handlers/event/analyze_logs/providers/paste_gg.rs @@ -1,4 +1,4 @@ -use crate::api::paste_gg; +use crate::api::{paste_gg, HttpClient}; use std::sync::OnceLock; @@ -18,8 +18,8 @@ impl super::LogProvider for PasteGG { super::get_first_capture(regex, &message.content) } - async fn fetch(&self, content: &str) -> Result { - let files = paste_gg::files_from(content).await?; + async fn fetch(&self, http: &HttpClient, content: &str) -> Result { + let files = paste_gg::files_from(http, content).await?; let result = files .result .ok_or_eyre("Got an empty result from paste.gg!")?; @@ -30,7 +30,7 @@ impl super::LogProvider for PasteGG { .nth(0) .ok_or_eyre("Couldn't get file id from empty paste.gg response!")?; - let log = paste_gg::get_raw_file(content, file_id).await?; + let log = paste_gg::get_raw_file(http, content, file_id).await?; Ok(log) } diff --git a/src/handlers/event/analyze_logs/providers/pastebin.rs b/src/handlers/event/analyze_logs/providers/pastebin.rs index 66feecc..4207706 100644 --- a/src/handlers/event/analyze_logs/providers/pastebin.rs +++ b/src/handlers/event/analyze_logs/providers/pastebin.rs @@ -1,4 +1,4 @@ -use crate::api; +use crate::api::{HttpClient, HttpClientExt}; use std::sync::OnceLock; @@ -22,9 +22,9 @@ impl super::LogProvider for PasteBin { super::get_first_capture(regex, &message.content) } - async fn fetch(&self, content: &str) -> Result { + async fn fetch(&self, http: &HttpClient, content: &str) -> Result { let url = format!("{PASTEBIN}{RAW}/{content}"); - let log = api::text_from_url(&url).await?; + let log = http.get_request(&url).await?.text().await?; Ok(log) } diff --git a/src/handlers/event/expand_link.rs b/src/handlers/event/expand_link.rs index 8e3517f..b336616 100644 --- a/src/handlers/event/expand_link.rs +++ b/src/handlers/event/expand_link.rs @@ -1,10 +1,10 @@ +use crate::{api::HttpClient, utils}; + use eyre::Result; use poise::serenity_prelude::{Context, CreateAllowedMentions, CreateMessage, Message}; -use crate::utils; - -pub async fn handle(ctx: &Context, message: &Message) -> Result<()> { - let embeds = utils::messages::from_message(ctx, message).await?; +pub async fn handle(ctx: &Context, http: &HttpClient, message: &Message) -> Result<()> { + let embeds = utils::messages::from_message(ctx, http, message).await?; if !embeds.is_empty() { let allowed_mentions = CreateAllowedMentions::new().replied_user(false); diff --git a/src/handlers/event/mod.rs b/src/handlers/event/mod.rs index bf77748..82e279a 100644 --- a/src/handlers/event/mod.rs +++ b/src/handlers/event/mod.rs @@ -23,7 +23,8 @@ pub async fn handle( FullEvent::Ready { data_about_bot } => { info!("Logged in as {}!", data_about_bot.user.name); - let latest_minecraft_version = api::prism_meta::latest_minecraft_version().await?; + let latest_minecraft_version = + api::prism_meta::latest_minecraft_version(&data.http_client).await?; let activity = ActivityData::playing(format!("Minecraft {latest_minecraft_version}")); info!("Setting presence to activity {activity:#?}"); @@ -37,7 +38,7 @@ pub async fn handle( } FullEvent::Message { new_message } => { - trace!("Recieved message {}", new_message.content); + trace!("Received message {}", new_message.content); // ignore new messages from bots // note: the webhook_id check allows us to still respond to PK users @@ -49,11 +50,12 @@ pub async fn handle( } if let Some(storage) = &data.storage { + let http = &data.http_client; // detect PK users first to make sure we don't respond to unproxied messages - pluralkit::handle(ctx, new_message, storage).await?; + pluralkit::handle(ctx, http, storage, new_message).await?; if storage.is_user_plural(new_message.author.id).await? - && pluralkit::is_message_proxied(new_message).await? + && pluralkit::is_message_proxied(http, new_message).await? { debug!("Not replying to unproxied PluralKit message"); return Ok(()); @@ -61,13 +63,13 @@ pub async fn handle( } eta::handle(ctx, new_message).await?; - expand_link::handle(ctx, new_message).await?; + expand_link::handle(ctx, &data.http_client, new_message).await?; analyze_logs::handle(ctx, new_message, data).await?; } FullEvent::ReactionAdd { add_reaction } => { trace!( - "Recieved reaction {} on message {} from {}", + "Received reaction {} on message {} from {}", add_reaction.emoji, add_reaction.message_id.to_string(), add_reaction.user_id.unwrap_or_default().to_string() @@ -78,7 +80,7 @@ pub async fn handle( } FullEvent::ThreadCreate { thread } => { - trace!("Recieved thread {}", thread.id); + trace!("Received thread {}", thread.id); support_onboard::handle(ctx, thread).await?; } diff --git a/src/handlers/event/pluralkit.rs b/src/handlers/event/pluralkit.rs index dce3c37..a53434c 100644 --- a/src/handlers/event/pluralkit.rs +++ b/src/handlers/event/pluralkit.rs @@ -1,4 +1,8 @@ -use crate::{api, storage::Storage}; +use crate::{ + api::{self, HttpClient}, + storage::Storage, +}; + use std::time::Duration; use eyre::Result; @@ -8,19 +12,24 @@ use tokio::time::sleep; const PK_DELAY: Duration = Duration::from_secs(1); -pub async fn is_message_proxied(message: &Message) -> Result { +pub async fn is_message_proxied(http: &HttpClient, message: &Message) -> Result { trace!( "Waiting on PluralKit API for {} seconds", PK_DELAY.as_secs() ); sleep(PK_DELAY).await; - let proxied = api::pluralkit::sender_from(message.id).await.is_ok(); + let proxied = api::pluralkit::sender_from(http, message.id).await.is_ok(); Ok(proxied) } -pub async fn handle(_: &Context, msg: &Message, storage: &Storage) -> Result<()> { +pub async fn handle( + _: &Context, + http: &HttpClient, + storage: &Storage, + msg: &Message, +) -> Result<()> { if msg.webhook_id.is_none() { return Ok(()); } @@ -36,7 +45,7 @@ pub async fn handle(_: &Context, msg: &Message, storage: &Storage) -> Result<()> ); sleep(PK_DELAY).await; - if let Ok(sender) = api::pluralkit::sender_from(msg.id).await { + if let Ok(sender) = api::pluralkit::sender_from(http, msg.id).await { storage.store_user_plurality(sender).await?; } diff --git a/src/main.rs b/src/main.rs index 4034840..4cefb2e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,6 +30,8 @@ type Context<'a> = poise::Context<'a, Data, Error>; pub struct Data { config: Config, storage: Option, + http_client: api::HttpClient, + octocrab: Arc, } async fn setup( @@ -55,7 +57,15 @@ async fn setup( trace!("Redis connection looks good!"); } - let data = Data { config, storage }; + let http_client = api::HttpClient::default(); + let octocrab = octocrab::instance(); + + let data = Data { + config, + storage, + http_client, + octocrab, + }; poise::builtins::register_globally(ctx, &framework.options().commands).await?; info!("Registered global commands!"); @@ -82,7 +92,7 @@ async fn main() -> eyre::Result<()> { serenity::GatewayIntents::non_privileged() | serenity::GatewayIntents::MESSAGE_CONTENT; let options = FrameworkOptions { - commands: commands::get(), + commands: commands::all(), on_error: |error| Box::pin(handlers::handle_error(error)), diff --git a/src/utils/messages.rs b/src/utils/messages.rs index f82073a..830e49f 100644 --- a/src/utils/messages.rs +++ b/src/utils/messages.rs @@ -1,4 +1,4 @@ -use crate::api::pluralkit; +use crate::api::{pluralkit, HttpClient}; use std::{str::FromStr, sync::OnceLock}; @@ -23,8 +23,8 @@ fn find_first_image(message: &Message) -> Option { .map(|res| res.url.clone()) } -async fn find_real_author_id(message: &Message) -> UserId { - if let Ok(sender) = pluralkit::sender_from(message.id).await { +async fn find_real_author_id(http: &HttpClient, message: &Message) -> UserId { + if let Ok(sender) = pluralkit::sender_from(http, message.id).await { sender } else { message.author.id @@ -109,7 +109,11 @@ pub async fn to_embed( Ok(embed) } -pub async fn from_message(ctx: &Context, msg: &Message) -> Result> { +pub async fn from_message( + ctx: &Context, + http: &HttpClient, + msg: &Message, +) -> Result> { static MESSAGE_PATTERN: OnceLock = OnceLock::new(); let message_pattern = MESSAGE_PATTERN.get_or_init(|| Regex::new(r"(?:https?:\/\/)?(?:canary\.|ptb\.)?discord(?:app)?\.com\/channels\/(?\d+)\/(?\d+)\/(?\d+)").unwrap()); @@ -121,7 +125,7 @@ pub async fn from_message(ctx: &Context, msg: &Message) -> Result