diff --git a/src/api/admin.rs b/src/api/admin.rs index 681cf14..18178d8 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -1,7 +1,7 @@ use rocket_contrib::json::Json; use serde_json::Value; -use rocket::http::{Cookie, Cookies}; +use rocket::http::{Cookie, Cookies, SameSite}; use rocket::request::{self, FlashMessage, Form, FromRequest, Request}; use rocket::response::{content::Html, Flash, Redirect}; use rocket::{Outcome, Route}; @@ -85,6 +85,8 @@ fn post_admin_login(data: Form, mut cookies: Cookies, ip: ClientIp) - let cookie = Cookie::build(COOKIE_NAME, jwt) .path(ADMIN_PATH) + .max_age(chrono::Duration::minutes(20)) + .same_site(SameSite::Strict) .http_only(true) .finish(); diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs index 11b4e4a..08e3f97 100644 --- a/src/api/core/mod.rs +++ b/src/api/core/mod.rs @@ -11,6 +11,7 @@ pub fn routes() -> Vec { get_eq_domains, post_eq_domains, put_eq_domains, + hibp_breach, ]; let mut routes = Vec::new(); @@ -128,3 +129,20 @@ fn post_eq_domains(data: JsonUpcase, headers: Headers, conn: Db fn put_eq_domains(data: JsonUpcase, headers: Headers, conn: DbConn) -> JsonResult { post_eq_domains(data, headers, conn) } + +#[get("/hibp/breach?")] +fn hibp_breach(username: String) -> JsonResult { + let url = format!("https://haveibeenpwned.com/api/v2/breachedaccount/{}", username); + let user_agent = "Bitwarden_RS"; + + use reqwest::{header::USER_AGENT, Client}; + + let value: Value = Client::new() + .get(&url) + .header(USER_AGENT, user_agent) + .send()? + .error_for_status()? + .json()?; + + Ok(Json(value)) +} diff --git a/src/api/icons.rs b/src/api/icons.rs index 8680600..1232110 100644 --- a/src/api/icons.rs +++ b/src/api/icons.rs @@ -1,4 +1,3 @@ -use std::error::Error; use std::fs::{create_dir_all, remove_file, symlink_metadata, File}; use std::io::prelude::*; use std::time::SystemTime; @@ -9,6 +8,7 @@ use rocket::Route; use reqwest; +use crate::error::Error; use crate::CONFIG; pub fn routes() -> Vec { @@ -77,7 +77,7 @@ fn get_cached_icon(path: &str) -> Option> { None } -fn file_is_expired(path: &str, ttl: u64) -> Result> { +fn file_is_expired(path: &str, ttl: u64) -> Result { let meta = symlink_metadata(path)?; let modified = meta.modified()?; let age = SystemTime::now().duration_since(modified)?; @@ -122,7 +122,7 @@ fn get_icon_url(domain: &str) -> String { } } -fn download_icon(url: &str) -> Result, reqwest::Error> { +fn download_icon(url: &str) -> Result, Error> { info!("Downloading icon for {}...", url); let mut res = reqwest::get(url)?; diff --git a/src/error.rs b/src/error.rs index 98a7d0e..741a742 100644 --- a/src/error.rs +++ b/src/error.rs @@ -32,12 +32,14 @@ macro_rules! make_error { }; } -use diesel::result::Error as DieselError; -use jsonwebtoken::errors::Error as JwtError; -use serde_json::{Error as SerError, Value}; -use std::io::Error as IOError; +use diesel::result::Error as DieselErr; +use handlebars::RenderError as HbErr; +use jsonwebtoken::errors::Error as JWTErr; +use reqwest::Error as ReqErr; +use serde_json::{Error as SerdeErr, Value}; +use std::io::Error as IOErr; +use std::time::SystemTimeError as TimeErr; use u2f::u2ferror::U2fError as U2fErr; -use handlebars::RenderError as HbError; // Error struct // Contains a String error message, meant for the user and an enum variant, with an error of different types. @@ -49,13 +51,15 @@ make_error! { SimpleError(String): _no_source, _api_error, // Used for special return values, like 2FA errors JsonError(Value): _no_source, _serialize, - DbError(DieselError): _has_source, _api_error, + DbError(DieselErr): _has_source, _api_error, U2fError(U2fErr): _has_source, _api_error, - SerdeError(SerError): _has_source, _api_error, - JWTError(JwtError): _has_source, _api_error, - IoErrror(IOError): _has_source, _api_error, - TemplErrror(HbError): _has_source, _api_error, + SerdeError(SerdeErr): _has_source, _api_error, + JWTError(JWTErr): _has_source, _api_error, + TemplError(HbErr): _has_source, _api_error, //WsError(ws::Error): _has_source, _api_error, + IOError(IOErr): _has_source, _api_error, + TimeError(TimeErr): _has_source, _api_error, + ReqError(ReqErr): _has_source, _api_error, } impl std::fmt::Debug for Error {