diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..b54e408 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +.gitignore +**/data +**/docs +.env? +**/migrations +**/scripts diff --git a/.env b/.env new file mode 100644 index 0000000..36aba1e --- /dev/null +++ b/.env @@ -0,0 +1,13 @@ +# for sqlx offline to run properly +POSTGRES_USER=myapp +POSTGRES_PASSWORD=password +DATABASE_URL=postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@localhost:5432/myapp + +# for configuration.rs +APP_ENVIRONMENT=local +APP_DATABASE__USERNAME=${POSTGRES_USER} +APP_DATABASE__PASSWORD=${POSTGRES_PASSWORD} +APP_SECURITY__SECRET_KEY=secret + +APP_database__username=${POSTGRES_USER} +APP_DATABASE__PASSWORD=${POSTGRES_PASSWORD} diff --git a/.env.sample b/.env.sample new file mode 100644 index 0000000..36aba1e --- /dev/null +++ b/.env.sample @@ -0,0 +1,13 @@ +# for sqlx offline to run properly +POSTGRES_USER=myapp +POSTGRES_PASSWORD=password +DATABASE_URL=postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@localhost:5432/myapp + +# for configuration.rs +APP_ENVIRONMENT=local +APP_DATABASE__USERNAME=${POSTGRES_USER} +APP_DATABASE__PASSWORD=${POSTGRES_PASSWORD} +APP_SECURITY__SECRET_KEY=secret + +APP_database__username=${POSTGRES_USER} +APP_DATABASE__PASSWORD=${POSTGRES_PASSWORD} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e69de29 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..64bf0eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# RustRover +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.sqlx/query-0239c4a304b887f116249a2aa7ffb390c83bbfd13e6174f7153b52bed6a2c553.json b/.sqlx/query-0239c4a304b887f116249a2aa7ffb390c83bbfd13e6174f7153b52bed6a2c553.json new file mode 100644 index 0000000..a103cea --- /dev/null +++ b/.sqlx/query-0239c4a304b887f116249a2aa7ffb390c83bbfd13e6174f7153b52bed6a2c553.json @@ -0,0 +1,58 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT users.id, users.name, users.email, users.preferred_name,\n users.is_active, users.is_verified, users.is_superuser\n FROM users\n WHERE users.id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "preferred_name", + "type_info": "Varchar" + }, + { + "ordinal": 4, + "name": "is_active", + "type_info": "Bool" + }, + { + "ordinal": 5, + "name": "is_verified", + "type_info": "Bool" + }, + { + "ordinal": 6, + "name": "is_superuser", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "0239c4a304b887f116249a2aa7ffb390c83bbfd13e6174f7153b52bed6a2c553" +} diff --git a/.sqlx/query-2d5c7904eb9eb959e7d85e0bc19ed21cf61c8472dcfef4e5cf5a8ed88339c4a0.json b/.sqlx/query-2d5c7904eb9eb959e7d85e0bc19ed21cf61c8472dcfef4e5cf5a8ed88339c4a0.json new file mode 100644 index 0000000..950e9d2 --- /dev/null +++ b/.sqlx/query-2d5c7904eb9eb959e7d85e0bc19ed21cf61c8472dcfef4e5cf5a8ed88339c4a0.json @@ -0,0 +1,25 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO users (email, hashed_password, name, preferred_name) VALUES ($1, $2, $3, $4) RETURNING id", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Varchar", + "Varchar", + "Varchar", + "Varchar" + ] + }, + "nullable": [ + false + ] + }, + "hash": "2d5c7904eb9eb959e7d85e0bc19ed21cf61c8472dcfef4e5cf5a8ed88339c4a0" +} diff --git a/.sqlx/query-5327eaad266e3c258806e25e376ad951c8cd6c1303d9afa4f60537756b520ffe.json b/.sqlx/query-5327eaad266e3c258806e25e376ad951c8cd6c1303d9afa4f60537756b520ffe.json new file mode 100644 index 0000000..2338848 --- /dev/null +++ b/.sqlx/query-5327eaad266e3c258806e25e376ad951c8cd6c1303d9afa4f60537756b520ffe.json @@ -0,0 +1,56 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT users.id, users.name, users.email, users.preferred_name,\n users.is_active, users.is_verified, users.is_superuser\n FROM users", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "preferred_name", + "type_info": "Varchar" + }, + { + "ordinal": 4, + "name": "is_active", + "type_info": "Bool" + }, + { + "ordinal": 5, + "name": "is_verified", + "type_info": "Bool" + }, + { + "ordinal": 6, + "name": "is_superuser", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "5327eaad266e3c258806e25e376ad951c8cd6c1303d9afa4f60537756b520ffe" +} diff --git a/.sqlx/query-87fcf318ae48bb33d6c484c26d251ad991048342ad460bbb78c54898bbfff49a.json b/.sqlx/query-87fcf318ae48bb33d6c484c26d251ad991048342ad460bbb78c54898bbfff49a.json new file mode 100644 index 0000000..e0fdf0e --- /dev/null +++ b/.sqlx/query-87fcf318ae48bb33d6c484c26d251ad991048342ad460bbb78c54898bbfff49a.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, hashed_password, is_superuser, is_verified from users\n WHERE email = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "hashed_password", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "is_superuser", + "type_info": "Bool" + }, + { + "ordinal": 3, + "name": "is_verified", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "87fcf318ae48bb33d6c484c26d251ad991048342ad460bbb78c54898bbfff49a" +} diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..27e490e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "myapp" +version = "0.1.0" +edition = "2021" +default-run = "myapp" + +[dependencies] +serde = { version = "1.0.204", features = ["derive"] } +serde_json = "1.0.122" +tokio = { version = "1.39.2", features = ["full"] } +axum = { version = "0.7.5", features = [ + "macros", + "form", + "multipart", + "query", +] } +rayon = "1.10.0" +thiserror = "1.0.63" +tracing = { version = "0.1.40", features = ["attributes"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +jsonwebtoken = "9.3.0" +once_cell = "1.19.0" +bcrypt = "0.15.1" +chrono = { version = "0.4.38", features = ["clock", "serde"] } +tower-http = { version = "0.5.2", features = [ + "cors", + "trace", + "timeout", + "limit", +] } +axum-extra = { version = "0.9.3", features = ["cookie", "typed-header"] } +sqlx = { version = "0.8.0", features = [ + "runtime-tokio-rustls", + "postgres", + "macros", + "migrate", + "uuid", + "chrono", + "json", +] } +tower = "0.4.13" +uuid = { version = "1.10.0", features = ["serde", "v4"] } +async-trait = "0.1.81" +axum-macros = "0.4.1" +anyhow = "1.0.86" +dotenvy = "0.15.7" +# to generate random number +rand = "0.8.5" +itertools = "0.13.0" +cookie = "0.18.1" +config = "0.14.0" +# to opt-in password instead of opt-out +secrecy = { version = "0.8", features = ["serde"] } +# environment variables are strings for the config crate and it will fail to pick up integers +serde-aux = "4.5.0" +tracing-log = "0.2.0" +num = "0.4.3" +num-format = "0.4.4" +regex = "1.10.6" +rand_pcg = "0.3.1" +tracing-appender = "0.2.3" +log = "0.4.22" +env_logger = "0.11.5" +time = { version = "0.3.36", features = ["macros"] } +serde_with = "3.9.0" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..85fffb8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +# syntax=docker/dockerfile:experimental +# The above step enabled specific experimental syntax, namely the +# "--mount-type=cache parameter, a named cache volume managed by +# Docker BuildKit +# https://juliankrieger.dev/blog/combining-docker-buildkit-cargo-chef + +FROM lukemathwalker/cargo-chef:latest as chef + +FROM chef as planner +WORKDIR /app +COPY . . +# Compute a lock-like file for our project +RUN cargo chef prepare --recipe-path recipe.json + + +FROM chef as cacher +WORKDIR /app +# Get the recipe file +COPY --from=planner /app/recipe.json recipe.json +# Build our project dependencies, not our application! +# this is the caching Docker layer! +# Cache dependencies +# I don't know why target=/app/target doesn't work +# but target=/target does +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/target \ + cargo chef cook --release --recipe-path recipe.json + + +FROM chef as builder +WORKDIR /app +# Copy built dependencies over cache +COPY --from=cacher /app/target target +# Copy cargo folder from cache. This includes the package registry and downloaded sources +COPY --from=cacher $CARGO_HOME $CARGO_HOME +COPY . . +# needa run cargo sqlx prepare +ENV SQLX_OFFLINE true +# Build the binary +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + cargo build --release --bin myapp + + +# We do not need the Rust toolchain to run the binary! +FROM debian:bookworm-slim as runtime +WORKDIR /app +# Install OpenSSL - it is dynamically linked by some of our dependencies +# Install ca-certificates - it is needed to verify TLS certificates +# when establishing HTTPS connections +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends openssl ca-certificates \ + # Clean up + && apt-get autoremove -y \ + && apt-get clean -y \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=builder /app/target/release/myapp myapp +# Custom config +COPY configuration configuration +ENV APP_ENVIRONMENT development +ENTRYPOINT ["./myapp"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..32de7bc --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# fullstack-crud diff --git a/configuration/base.toml b/configuration/base.toml new file mode 100644 index 0000000..de54b1c --- /dev/null +++ b/configuration/base.toml @@ -0,0 +1,18 @@ +[application] +# We are using 127.0.0.1 as our host in address, +# we are instructing our application to only accept connec-ions coming from the same machine +host = "127.0.0.1" +port = 3000 +rust_log = "rust_axum=debug,axum=debug,tower_http=debug,bullai=debug" + +[database] +# for development using docker compose, we need to use the service name as host +# for api in docker-compose, host need to be service name, e.g. "postgres" +host = "localhost" +username = "it_should_be_override_by_.env" +password = "it_should_be_override_by_.env" +port = 5432 +require_ssl = false + +[security] +secret_key = "it_should_be_override_by_.env" diff --git a/configuration/development.toml b/configuration/development.toml new file mode 100644 index 0000000..345ca07 --- /dev/null +++ b/configuration/development.toml @@ -0,0 +1,6 @@ +[application] +host = "0.0.0.0" +port = "80" + +[database] +database_name = "myapp_dev" diff --git a/configuration/local.toml b/configuration/local.toml new file mode 100644 index 0000000..90c649f --- /dev/null +++ b/configuration/local.toml @@ -0,0 +1,6 @@ +[application] +# IPV6 localhost [::] is for macos +host = "[::]" + +[database] +database_name = "myapp" diff --git a/configuration/production.toml b/configuration/production.toml new file mode 100644 index 0000000..179e63d --- /dev/null +++ b/configuration/production.toml @@ -0,0 +1,6 @@ +[application] +host = "0.0.0.0" +port = "80" + +[database] +database_name = "myapp_prod" diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..e1e13a9 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,16 @@ +version: "3.9" + +services: + postgres: + image: postgres + restart: always + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data/ + env_file: ".env" + environment: + POSTGRES_USER: ${POSTGRES_USER} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} +volumes: + postgres_data: diff --git a/migrations/20240902113425_create_user_table.down.sql b/migrations/20240902113425_create_user_table.down.sql new file mode 100644 index 0000000..cdbb0cc --- /dev/null +++ b/migrations/20240902113425_create_user_table.down.sql @@ -0,0 +1,2 @@ +-- Add down migration script here +DROP TABLE IF EXISTS user; diff --git a/migrations/20240902113425_create_user_table.up.sql b/migrations/20240902113425_create_user_table.up.sql new file mode 100644 index 0000000..0e10b68 --- /dev/null +++ b/migrations/20240902113425_create_user_table.up.sql @@ -0,0 +1,15 @@ +-- Add up migration script here +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +CREATE TABLE users ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + email varchar(255) UNIQUE NOT NULL, + hashed_password varchar(255) NOT NULL, + name varchar(255) NOT NULL, + is_active boolean NOT NULL DEFAULT true, + is_verified boolean NOT NULL DEFAULT false, + is_superuser boolean NOT NULL DEFAULT false, + created_at timestamp NOT NULL DEFAULT current_timestamp, + updated_at timestamp NOT NULL DEFAULT current_timestamp, + last_login timestamp +); diff --git a/scripts/build.sh b/scripts/build.sh new file mode 100755 index 0000000..1b3ab81 --- /dev/null +++ b/scripts/build.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +docker compose build && docker image prune -f && docker compose up --remove-orphans \ No newline at end of file diff --git a/sqlx-data.json b/sqlx-data.json new file mode 100644 index 0000000..95c8c85 --- /dev/null +++ b/sqlx-data.json @@ -0,0 +1,3 @@ +{ + "db": "PostgreSQL" +} \ No newline at end of file diff --git a/src/catalog.rs b/src/catalog.rs new file mode 100644 index 0000000..885d887 --- /dev/null +++ b/src/catalog.rs @@ -0,0 +1,5 @@ +pub mod entity; +pub mod error; +pub mod handler; +pub mod pokemon; +pub mod service; diff --git a/src/catalog/entity.rs b/src/catalog/entity.rs new file mode 100644 index 0000000..e90320e --- /dev/null +++ b/src/catalog/entity.rs @@ -0,0 +1,41 @@ +use num::{Bounded, Num}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Range { + pub min: T, + pub q25: T, + pub q50: T, + pub q75: T, + pub max: T, +} + +pub fn get_range(values: &mut [T]) -> Range { + let len = values.len(); + if len == 0 { + tracing::error!("Creating range from zero len array, it is possibly a bug, returning a invalid range with all zeros"); + return Range { + min: T::zero(), + q25: T::zero(), + q50: T::zero(), + q75: T::zero(), + max: T::zero(), + }; + } + values.sort(); + // we needa clone the quantiles + // because we need the values being stored as Range + // while all of them should be remaining in the vector + let min = *values.iter().min().unwrap_or(&T::zero()); + let max = *values.iter().max().unwrap_or(&T::max_value()); + let q25 = *values.get(len / 4).unwrap_or(&T::zero()); + let q50 = *values.get(len / 2).unwrap_or(&T::zero()); + let q75 = *values.get(len * 3 / 4).unwrap_or(&T::zero()); + Range { + min, + q25, + q50, + q75, + max, + } +} diff --git a/src/catalog/error.rs b/src/catalog/error.rs new file mode 100644 index 0000000..2893818 --- /dev/null +++ b/src/catalog/error.rs @@ -0,0 +1,30 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CatalogError { + #[error("Resource not found")] + NotFound, + #[error("Resource not ready")] + NotImplemented, + #[error(transparent)] + UnexpectedError(#[from] anyhow::Error), +} + +impl IntoResponse for CatalogError { + fn into_response(self) -> Response { + let status = match self { + CatalogError::NotFound | CatalogError::NotImplemented => StatusCode::NOT_FOUND, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + let body = Json(json!({ + "message": self.to_string(), + })); + (status, body).into_response() + } +} diff --git a/src/catalog/handler.rs b/src/catalog/handler.rs new file mode 100644 index 0000000..94ccf5b --- /dev/null +++ b/src/catalog/handler.rs @@ -0,0 +1,57 @@ +use crate::catalog::error::CatalogError; +use crate::catalog::service::{CatalogService, HasCatalogService}; +use crate::common::entity::Legal; +use crate::common::entity::{AppState, QueryName}; +use anyhow::Context; +use async_trait::async_trait; +use axum::{ + extract::{Json, Path, State}, + routing::{get, put}, + Router, +}; +use sqlx::PgPool; + +pub struct CatalogHandlers { + _service: std::marker::PhantomData, +} + +impl HasCatalogHandlers for CatalogHandlers { + type Service = CatalogService; +} + +#[async_trait] +pub trait HasCatalogHandlers: 'static + Send + Sync { + type Service: HasCatalogService + Send; + + fn create_router() -> Router { + Router::new() + .route("/professionals", get(Self::show_professionals)) + .route("/professionals/:id", put(Self::show_professional)) + } + + async fn show_professionals( + State(pool): State, + q_name: Option, + ) -> Result::Item>>, CatalogError> { + let professionals = if let Some(QueryName { name }) = q_name { + Self::Service::query_items_by_name(&pool, &name) + .await + .context("Failed to get professionals")? + } else { + Self::Service::query_items(&pool) + .await + .context("Failed to get professionals")? + }; + Ok(Json(professionals)) + } + + async fn show_professional( + State(pool): State, + Path(id): Path, + ) -> Result::Item>, CatalogError> { + let professional = Self::Service::query_item(&pool, id) + .await + .context("Failed to get professional")?; + Ok(Json(professional)) + } +} diff --git a/src/catalog/pokemon.rs b/src/catalog/pokemon.rs new file mode 100644 index 0000000..b997965 --- /dev/null +++ b/src/catalog/pokemon.rs @@ -0,0 +1,2 @@ +mod entity; +mod service; diff --git a/src/catalog/pokemon/entity.rs b/src/catalog/pokemon/entity.rs new file mode 100644 index 0000000..4835dfa --- /dev/null +++ b/src/catalog/pokemon/entity.rs @@ -0,0 +1,4 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Item {} diff --git a/src/catalog/pokemon/service.rs b/src/catalog/pokemon/service.rs new file mode 100644 index 0000000..ec85bb7 --- /dev/null +++ b/src/catalog/pokemon/service.rs @@ -0,0 +1,42 @@ +use super::entity::Item; +use crate::catalog::service::CatalogService; +use crate::catalog::service::HasCatalogService; +use crate::common::entity::Legal; +use async_trait::async_trait; +use sqlx::PgPool; + +#[async_trait] +impl HasCatalogService for CatalogService { + type CreateItem = Item; + type Item = Item; + + async fn query_items(pool: &PgPool) -> Result, sqlx::Error> { + unimplemented!(); + } + + async fn query_items_by_name( + pool: &PgPool, + name: &str, + ) -> Result, sqlx::Error> { + unimplemented!(); + } + + async fn query_item(pool: &PgPool, id: uuid::Uuid) -> Result { + unimplemented!(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::db::postgres::get_postgres_pool; + + #[tokio::test] + async fn get_professionals() { + let pool = get_postgres_pool().await; + let professionals = CatalogService::::query_items(pool) + .await + .expect("Failed to get firms"); + println!("{:?}", professionals); + } +} diff --git a/src/catalog/service.rs b/src/catalog/service.rs new file mode 100644 index 0000000..e41a1f9 --- /dev/null +++ b/src/catalog/service.rs @@ -0,0 +1,43 @@ +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use serde::Serialize; +use sqlx::PgPool; + +pub struct CatalogService { + _service: std::marker::PhantomData, +} + +#[async_trait] +pub trait HasCatalogService: 'static { + // Send is required for async future to be pass around + + type CreateItem: Send + DeserializeOwned; + type Item: Send + DeserializeOwned + Serialize; + + #[allow(unused_variables)] + async fn insert_professional(pool: &PgPool, firm: Self::CreateItem) -> Result<(), sqlx::Error> { + Err(sqlx::Error::RowNotFound) + } + #[allow(unused_variables)] + async fn query_items(pool: &PgPool) -> Result, sqlx::Error> { + Err(sqlx::Error::RowNotFound) + } + #[allow(unused_variables)] + async fn query_items_by_name( + pool: &PgPool, + name: &str, + ) -> Result, sqlx::Error> { + Err(sqlx::Error::RowNotFound) + } + #[allow(unused_variables)] + async fn query_professionals_by_ids( + pool: &PgPool, + ids: &[uuid::Uuid], + ) -> Result, sqlx::Error> { + Err(sqlx::Error::RowNotFound) + } + #[allow(unused_variables)] + async fn query_item(pool: &PgPool, id: uuid::Uuid) -> Result { + Err(sqlx::Error::RowNotFound) + } +} diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..3bb815a --- /dev/null +++ b/src/common.rs @@ -0,0 +1,4 @@ +pub mod date; +pub mod db; +pub mod entity; +pub mod error; diff --git a/src/common/date.rs b/src/common/date.rs new file mode 100644 index 0000000..0363581 --- /dev/null +++ b/src/common/date.rs @@ -0,0 +1,43 @@ +use chrono::Datelike; +use chrono::{Duration, NaiveDate, Weekday}; + +pub fn get_next_working_day(date: NaiveDate) -> NaiveDate { + match date.weekday() { + Weekday::Sat => date + Duration::days(2), + Weekday::Sun => date + Duration::days(1), + _ => date, + } +} + +pub fn get_end_date(start: NaiveDate, days: i64) -> NaiveDate { + let start = get_next_working_day(start); + + let num_weeks = days / 5; + let mut end = start + Duration::weeks(num_weeks); + + end = get_next_working_day(end); + + let remaining_needed_working_days = days % 5; + for _ in 0..remaining_needed_working_days { + end += Duration::days(1); + end = get_next_working_day(end); + } + + // TODO: check holidays, add back number of holidays contained in the period + end +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn check_end_date() { + let start = NaiveDate::from_ymd_opt(2023, 7, 18).expect("Invalid date"); + + let computed_end_date = get_end_date(start, 7); + let proper_end_date = NaiveDate::from_ymd_opt(2023, 7, 27).expect("Invalid date"); + + assert_eq!(computed_end_date, proper_end_date); + } +} diff --git a/src/common/db.rs b/src/common/db.rs new file mode 100644 index 0000000..26e9103 --- /dev/null +++ b/src/common/db.rs @@ -0,0 +1 @@ +pub mod postgres; diff --git a/src/common/db/postgres.rs b/src/common/db/postgres.rs new file mode 100644 index 0000000..66c2cf4 --- /dev/null +++ b/src/common/db/postgres.rs @@ -0,0 +1,38 @@ +use crate::configuration::get_configuration; +use sqlx::postgres::{PgPool, PgPoolOptions}; +use tokio::sync::OnceCell; + +static POOL: OnceCell = OnceCell::const_new(); + +pub async fn get_postgres_pool() -> &'static PgPool { + POOL.get_or_init(|| async { + let configuration = get_configuration().expect("Failed to read configuration."); + PgPoolOptions::new() + .acquire_timeout(std::time::Duration::from_secs(2)) + .connect_lazy_with(configuration.database.with_db()) + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlx::Connection; + + #[sqlx::test] + async fn test_postgres_connection(pool: PgPool) { + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = conn.ping().await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_get_postgres_pool() { + let pool = get_postgres_pool().await; + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = conn.ping().await; + + assert!(result.is_ok()); + } +} diff --git a/src/common/entity.rs b/src/common/entity.rs new file mode 100644 index 0000000..f491287 --- /dev/null +++ b/src/common/entity.rs @@ -0,0 +1,74 @@ +use async_trait::async_trait; +use axum::extract::{FromRef, FromRequestParts, Query}; +use axum::http::request::Parts; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; + +use super::error::CommonError; + +pub struct Legal {} + +pub trait HasService { + const SERVICE: Service; +} + +// `#[derive(FromRef)]` makes them sub states so they can be extracted independently +#[derive(Clone, FromRef)] +pub struct AppState { + pub pool: PgPool, +} + +#[derive(Debug, Serialize, Deserialize, sqlx::Type, PartialEq, Eq)] +// sql type +#[sqlx(type_name = "service")] +pub enum Service { + #[serde(rename = "legal-advisory")] + #[sqlx(rename = "legal-advisory")] + Legal, +} + +impl std::fmt::Display for Service { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Service::Legal => write!(f, "legal-advisory"), + } + } +} + +impl std::str::FromStr for Service { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "legal-advisory" => Ok(Service::Legal), + _ => Err(()), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct QueryName { + pub name: String, +} + +#[async_trait] +impl FromRequestParts for QueryName +where + S: Send + Sync, +{ + type Rejection = CommonError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let query = Query::::from_request_parts(parts, state) + .await + .map_err(|_| CommonError::ValidationError("Cannot get query param".into()))?; + + let name = query.name.clone(); + if name.len() < 2 { + return Err(CommonError::ValidationError( + "Name must be at least 2 characters long".into(), + )); + } + Ok(Self { name }) + } +} diff --git a/src/common/error.rs b/src/common/error.rs new file mode 100644 index 0000000..bbbc053 --- /dev/null +++ b/src/common/error.rs @@ -0,0 +1,57 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CommonError { + // ref: zero2prod 8.4.2 anyhow Or thiserror? + #[error("Resource not found")] + NotFound, + #[error("Validation error: {0}")] + ValidationError(String), + #[error(transparent)] + UnexpectedError(#[from] anyhow::Error), +} + +impl IntoResponse for CommonError { + fn into_response(self) -> Response { + let status = match self { + CommonError::NotFound => StatusCode::NOT_FOUND, + CommonError::ValidationError(_) => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + let body = Json(json!({ + "message": self.to_string(), + })); + (status, body).into_response() + } +} + +// Make our own error that wraps `anyhow::Error`. +pub struct AnyError(anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AnyError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {}", self.0), + ) + .into_response() + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl From for AnyError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} diff --git a/src/common/time.rs b/src/common/time.rs new file mode 100644 index 0000000..c3efaf2 --- /dev/null +++ b/src/common/time.rs @@ -0,0 +1,12 @@ +use time::Duration; +use time::OffsetDateTime; + +pub fn get_expiry(minutes: i64) -> OffsetDateTime { + let now = get_now(); + let ttl = minutes + 5; + now + Duration::minutes(ttl) +} + +pub fn get_now() -> OffsetDateTime { + OffsetDateTime::now_utc() - Duration::minutes(5) +} diff --git a/src/configuration.rs b/src/configuration.rs new file mode 100644 index 0000000..f29f343 --- /dev/null +++ b/src/configuration.rs @@ -0,0 +1,137 @@ +use dotenvy::dotenv; +use secrecy::{ExposeSecret, Secret}; +use serde::Deserialize; +use serde_aux::field_attributes::deserialize_number_from_string; +use sqlx::postgres::{PgConnectOptions, PgSslMode}; +use sqlx::ConnectOptions; + +#[derive(Deserialize)] +pub struct Settings { + pub application: ApplicationSettings, + pub database: DatabaseSettings, + pub security: SecuritySettings, +} + +#[derive(Deserialize)] +pub struct ApplicationSettings { + pub host: String, + #[serde(deserialize_with = "deserialize_number_from_string")] + pub port: u16, + pub rust_log: String, +} + +#[derive(Deserialize)] +pub struct DatabaseSettings { + pub username: String, + pub password: Secret, + #[serde(deserialize_with = "deserialize_number_from_string")] + pub port: u16, + pub host: String, + pub database_name: String, + pub require_ssl: bool, +} + +#[derive(Deserialize)] +pub struct SecuritySettings { + pub secret_key: String, +} + +pub fn get_environment() -> Environment { + let environment: Environment = std::env::var("APP_ENVIRONMENT") + .expect("APP_ENVIRONMENT not set in .env.") + .try_into() + .expect("Failed to parse APP_ENVIRONMENT."); + environment +} + +pub fn get_configuration() -> Result { + // - A base configuration file, for values that are shared across our local and production environment + // (e.g. database name); + // - A collection of environment-specific configuration files, specifying values for fields that require cus- + // tomisation on a per-environment basis (e.g. host); + // - An environment variable, APP_ENVIRONMENT, to determine the running environment (e.g. produc- + // tion or local). + // - All configuration files will live in the same top-level directory, configuration. + + let base_path = std::env::current_dir().expect("Failed to determine the current directory"); + let configuration_directory = base_path.join("configuration"); + // Detect the running environment. + // load dotenv + dotenv().expect("Failed to load .env file"); + let environment = get_environment(); + let environment_filename = format!("{}.toml", environment.as_str()); + let settings = config::Config::builder() + .add_source(config::File::from( + configuration_directory.join("base.toml"), + )) + .add_source(config::File::from( + configuration_directory.join(environment_filename), + )) + // Add in settings from environment variables (with a prefix of APP and + // '__' as separator) + // E.g. `APP_APPLICATION__PORT=5001 would set `Settings.application.port` + .add_source( + config::Environment::with_prefix("APP") + .prefix_separator("_") + .separator("__"), + ) + .build()?; + + // Try to convert the configuration values it read into + // our Settings type + settings.try_deserialize::() +} + +impl DatabaseSettings { + pub fn with_db(&self) -> PgConnectOptions { + let mut options = self.without_db().database(&self.database_name); + options = options.log_statements(tracing_log::log::LevelFilter::Trace); + options + } + pub fn without_db(&self) -> PgConnectOptions { + let ssl_mode = if self.require_ssl { + PgSslMode::Require + } else { + // Try an encrypted connection, fallback to unencrypted if it fails + PgSslMode::Prefer + }; + PgConnectOptions::new() + .host(&self.host) + .username(&self.username) + .password(self.password.expose_secret()) + .port(self.port) + .ssl_mode(ssl_mode) + } +} + +/// The possible runtime environment for our application. +#[derive(PartialEq)] +pub enum Environment { + Local, + Development, + Production, +} +impl Environment { + pub fn as_str(&self) -> &'static str { + match self { + Environment::Local => "local", + Environment::Development => "development", + Environment::Production => "production", + } + } +} +impl TryFrom for Environment { + type Error = String; + fn try_from(s: String) -> Result { + match s.to_lowercase().as_str() { + "local" => Ok(Self::Local), + "development" => Ok(Self::Development), + "production" => Ok(Self::Production), + other => Err(format!( + "{} is not a supported environment. \ + Use either `local`, `development`, or `production`.", + other + )), + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..2a3c9b2 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,4 @@ +pub mod catalog; +pub mod common; +pub mod configuration; +pub mod user_mgmt; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..01ffbf9 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,125 @@ +use axum::http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT}; +use axum::http::{Method, StatusCode, Uri}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::Router; +use myapp::{ + catalog::handler::{CatalogHandlers, HasCatalogHandlers}, + common::{ + db::postgres::get_postgres_pool, + entity::{AppState, Legal, Service}, + }, + configuration::get_configuration, + user_mgmt::{ + auth::{login, logout, me_handler}, + handler::{create_user, show_users}, + }, +}; +use std::time::Duration; +use tokio::net::TcpListener; +use tower_http::cors::CorsLayer; +use tower_http::limit::RequestBodyLimitLayer; +use tower_http::{timeout::TimeoutLayer, trace::TraceLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + let configuration = get_configuration().expect("Failed to read configuration."); + + // initialize tracing + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::new( + configuration.application.rust_log, + )) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let pool = get_postgres_pool().await.clone(); + let ping = sqlx::query("SELECT 1").fetch_one(&pool).await; + + match ping { + Ok(_) => { + tracing::info!("Connected to postgres on {}", configuration.database.host); + } + Err(e) => { + tracing::error!("Failed to connect to postgres: {}", e); + panic!("Failed to connect to postgres"); + } + } + + // Setup app state for the entire app + let state = AppState { pool }; + + let origins = ["http://localhost:5173".parse().unwrap()]; + + // Setup CORS + let cors = CorsLayer::new() + .allow_methods([ + Method::OPTIONS, + Method::GET, + Method::PUT, + Method::POST, + Method::PATCH, + Method::DELETE, + ]) + .allow_headers([ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT]) + .allow_origin(origins) + .allow_credentials(true); + + let user_routes = Router::new() + .route("/", post(create_user)) + .route("/", get(show_users)); + + // Note that the middleware is only applied to existing routes. + // So you have to first add your routes (and / or fallback) + // and then call layer afterwards. + // Additional routes added after layer is called will not have the middleware added. + + let legal_handlers = create_module_router::(); + + let base_app = Router::new() + .route("/", get(root)) + .route("/auth/login", post(login)) + .route("/auth/logout", post(logout)) + .route("/me", get(me_handler)) + .nest("/users", user_routes) + .nest(format!("/{}", Service::Legal).as_str(), legal_handlers) + .layer(cors) + .layer(TraceLayer::new_for_http()) + // timeout requests after 10 secs, returning 408 status code + .layer(TimeoutLayer::new(Duration::from_secs(20))) + .layer(RequestBodyLimitLayer::new(4096)) + .with_state(state); + + let app = Router::new().nest("/api/v1", base_app).fallback(fallback); + + // For macos, listen to IPV4 and IPV6 + let addr_str = format!( + "{}:{}", + configuration.application.host, configuration.application.port + ); + let addr = addr_str.parse::().unwrap(); + tracing::debug!("Listening on {}", addr); + let listener = TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app) + .await + .expect("Cannot start the server"); +} + +fn create_module_router() -> Router +where + CatalogHandlers: HasCatalogHandlers, +{ + Router::new().merge(CatalogHandlers::::create_router()) +} + +// basic handler that responds with a static string +async fn root() -> &'static str { + "Hello, World!" +} + +async fn fallback(uri: Uri) -> impl IntoResponse { + let message = format!("No route for {}", uri); + tracing::debug!(message); + (StatusCode::NOT_FOUND, message) +} diff --git a/src/user_mgmt.rs b/src/user_mgmt.rs new file mode 100644 index 0000000..858f457 --- /dev/null +++ b/src/user_mgmt.rs @@ -0,0 +1,6 @@ +pub mod auth; +mod encryption; +pub mod entity; +mod error; +pub mod handler; +mod jwt; diff --git a/src/user_mgmt/auth.rs b/src/user_mgmt/auth.rs new file mode 100644 index 0000000..9cd81e7 --- /dev/null +++ b/src/user_mgmt/auth.rs @@ -0,0 +1,246 @@ +use crate::configuration::{get_environment, Environment}; + +use super::encryption::verify; +use super::error::AuthError; +use super::handler::query_user; +use super::jwt::{decode, encode}; +pub use super::jwt::{Claims, Role}; +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts, State}, + http::request::Parts, + Json, RequestPartsExt, +}; +use axum_extra::{ + extract::cookie::{Cookie, CookieJar, SameSite}, + headers::{authorization::Bearer, Authorization}, + TypedHeader, +}; +use cookie::time::Duration; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; + +pub async fn me_handler(user: CurrentUser) -> Result { + Ok(user.to_string()) +} + +// TODO: should it be in auth.rs? or handler.rs? messy code separation +// tracing::instrument is a wrapper +// it shows only if there are logs inside. +#[tracing::instrument(name="Logging in", skip(jar, pool, payload), fields(username = %payload.username))] +pub async fn login( + jar: CookieJar, + State(pool): State, + // Json must be placed at the end of the parameters + Json(payload): Json, + // Json must be placed at the end of the Result tuple +) -> Result<(CookieJar, Json), AuthError> { + // Check if the user sent the credentials + if payload.username.is_empty() || payload.password.is_empty() { + return Err(AuthError::MissingCredentials); + } + // Here you can check the user credentials from a database + // debug payload + let (user_id, role) = validate_user(&pool, payload).await?; + // Validated, now create jwt claims + let claims = Claims::new(user_id, role); + // Create the authorization token + let token = encode(&claims).map_err(|_| AuthError::TokenCreation)?; + + // check env for local client to bypass the secure flag, cuz we don't need https on localhost + let env = get_environment(); + // Create a http_only cookie to store the token + let cookie = Cookie::build(("access_token", token.clone())) + .http_only(true) + .secure(env != Environment::Local) + .same_site(SameSite::None) + .max_age(Duration::hours(1)) + .path("/") + .build(); + // Store and Send the authorized token + Ok((jar.add(cookie), Json(AuthBody::new(token)))) +} + +pub async fn logout(jar: CookieJar) -> CookieJar { + jar.remove(Cookie::from("access_token")) +} + +async fn validate_user( + pool: &PgPool, + credentials: AuthPayload, +) -> Result<(uuid::Uuid, Role), AuthError> { + let user = sqlx::query_as!( + UserToClaim, + "SELECT id, hashed_password, is_superuser, is_verified from users + WHERE email = $1 + ", + credentials.username + ) + .fetch_optional(pool) + .await? + .ok_or(AuthError::WrongCredentials)?; + + let role = if user.is_superuser { + Role::Admin + } else { + Role::User + }; + + if !verify(credentials.password, user.hashed_password).await? { + return Err(AuthError::WrongCredentials); + } + + if !user.is_verified { + return Err(AuthError::UnverifiedUser); + } + + Ok((user.id, role)) +} + +// Extract the Claim from a request body +#[async_trait] +impl FromRequestParts for Claims +where + S: Send + Sync, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + // The token might come from a cookie or from the authorization header + // Note: It is invalid if the token from cookie is correct + // but the token from the header is not + let token = if let Ok(TypedHeader(Authorization(bearer))) = + parts.extract::>>().await + { + // Extract the token from the authorization header + bearer.token().to_string() + } else if let Ok(jar) = parts.extract::().await { + // Extract the token from the cookie + jar.get("access_token") + .ok_or(AuthError::MissingCredentials)? + .value() + .to_string() + } else { + return Err(AuthError::MissingCredentials); + }; + + let claims = decode(&token)?.claims; + Ok(claims) + } +} + +// Extract the Claim from a request body +#[async_trait] +impl FromRequestParts for CurrentUser +where + S: Send + Sync, + // From AppState, which implements FromRef + PgPool: FromRef, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let claims = Claims::from_request_parts(parts, state).await?; + let pool = PgPool::from_ref(state); + + let user = query_user(&pool, claims.sub) + .await + .inspect_err(|e| tracing::error!("Failed to query current user from jwt: {e}"))?; + + let current_user = Self { + id: user.id, + name: user.name, + email: user.email, + role: claims.role, + }; + + Ok(current_user) + } +} + +#[derive(Debug, Serialize)] +pub struct CurrentUser { + pub id: uuid::Uuid, + pub name: String, + pub email: String, + pub role: Role, +} + +impl std::fmt::Display for CurrentUser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "User: {},\n Email: {}", self.name, self.email) + } +} + +#[derive(Debug, Serialize)] +pub struct AuthBody { + access_token: String, + token_type: String, +} + +impl AuthBody { + fn new(access_token: String) -> Self { + Self { + access_token, + token_type: "Bearer".to_string(), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct AuthPayload { + username: String, + password: String, +} + +#[derive(Deserialize)] +struct UserToClaim { + id: uuid::Uuid, + hashed_password: String, + is_superuser: bool, + is_verified: bool, +} + +pub async fn get_current_user_from_id(pool: &PgPool, user_id: &uuid::Uuid) -> CurrentUser { + let user = query_user(pool, *user_id) + .await + .inspect_err(|e| tracing::error!("Failed to query current user from jwt: {e}")) + .expect("Failed to query current user"); + + let role = if user.is_superuser { + Role::Admin + } else { + Role::User + }; + + CurrentUser { + id: user.id, + name: user.name, + email: user.email, + role, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::db::postgres::get_postgres_pool; + + #[tokio::test] + async fn test_validate_user_from_db() { + let pool = get_postgres_pool().await; + let payload = AuthPayload { + username: "admin@example.com".to_string(), + password: "password".to_string(), + }; + let result = validate_user(pool, payload).await; + + assert!(result.is_ok()) + } + + #[tokio::test] + async fn test_authorize() { + // https://github.com/tokio-rs/axum/blob/main/examples/testing/src/main.rs + todo!("write the test to test the authorize function") + } +} diff --git a/src/user_mgmt/encryption.rs b/src/user_mgmt/encryption.rs new file mode 100644 index 0000000..f647eea --- /dev/null +++ b/src/user_mgmt/encryption.rs @@ -0,0 +1,34 @@ +use bcrypt::DEFAULT_COST; + +use super::error::AuthError; + +// consume password value to make it unusable +pub async fn hash(password: String) -> Result { + let (send, recv) = tokio::sync::oneshot::channel(); + rayon::spawn(move || { + let result = bcrypt::hash(password, DEFAULT_COST); + let _ = send.send(result); + }); + Ok(recv.await??) +} + +pub async fn verify(password: String, hash: String) -> Result { + let (send, recv) = tokio::sync::oneshot::channel(); + rayon::spawn(move || { + let result = bcrypt::verify(password, &hash); + let _ = send.send(result); + }); + Ok(recv.await??) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn it_works() { + let hashed = hash("hunter2".to_string()).await.unwrap(); + let valid = verify("hunter2".to_string(), hashed).await.unwrap(); + assert!(valid) + } +} diff --git a/src/user_mgmt/entity.rs b/src/user_mgmt/entity.rs new file mode 100644 index 0000000..e2d38b0 --- /dev/null +++ b/src/user_mgmt/entity.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize)] +pub struct User { + pub id: uuid::Uuid, + pub email: String, + pub name: String, + pub is_active: bool, + pub is_verified: bool, + pub is_superuser: bool, +} + +#[derive(Deserialize)] +pub struct CreateUser { + pub email: String, + pub password: String, + pub name: String, +} diff --git a/src/user_mgmt/error.rs b/src/user_mgmt/error.rs new file mode 100644 index 0000000..964ab75 --- /dev/null +++ b/src/user_mgmt/error.rs @@ -0,0 +1,53 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error("Wrong credentials")] + WrongCredentials, + #[error("Missing credentials")] + MissingCredentials, + #[error("Unverified user")] + UnverifiedUser, + #[error("Invalid token")] + InvalidToken, + #[error("Token creation error")] + TokenCreation, + #[error(transparent)] + CommonError(#[from] crate::common::error::CommonError), + #[error(transparent)] + BcryptError(#[from] bcrypt::BcryptError), + #[error(transparent)] + JwtError(#[from] jsonwebtoken::errors::Error), + #[error(transparent)] + TokioRecvError(#[from] tokio::sync::oneshot::error::RecvError), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), + #[error(transparent)] + UnexpectedError(#[from] anyhow::Error), +} + +impl IntoResponse for AuthError { + fn into_response(self) -> Response { + let status = match self { + AuthError::WrongCredentials + | AuthError::InvalidToken + | AuthError::MissingCredentials + | AuthError::UnverifiedUser => StatusCode::UNAUTHORIZED, + AuthError::JwtError(ref e) => match e.kind() { + jsonwebtoken::errors::ErrorKind::ExpiredSignature => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + let body = Json(json!({ + "message": self.to_string(), + })); + (status, body).into_response() + } +} diff --git a/src/user_mgmt/handler.rs b/src/user_mgmt/handler.rs new file mode 100644 index 0000000..4853d86 --- /dev/null +++ b/src/user_mgmt/handler.rs @@ -0,0 +1,83 @@ +use super::encryption::hash; +use super::entity::{CreateUser, User}; +use super::error::AuthError; +use axum::extract::{Json, State}; +use sqlx::PgConnection; +use sqlx::PgPool; + +pub async fn create_user( + State(pool): State, + Json(user): Json, +) -> Result<(), AuthError> { + let mut tx = pool.begin().await?; + + insert_user(&mut tx, user) + .await + .inspect_err(|e| tracing::error!("Failed to create user: {e}"))?; + tx.commit().await?; + + Ok(()) +} + +pub async fn query_users(pool: &PgPool) -> Result, sqlx::Error> { + sqlx::query_as!( + User, + r#"SELECT users.id, users.name, users.email, + users.is_active, users.is_verified, users.is_superuser + FROM users"# + ) + .fetch_all(pool) + .await +} + +pub async fn query_user(pool: &PgPool, user_id: uuid::Uuid) -> Result { + sqlx::query_as!( + User, + r#"SELECT users.id, users.name, users.email, + users.is_active, users.is_verified, users.is_superuser + FROM users + WHERE users.id = $1"#, + user_id + ) + .fetch_one(pool) + .await +} + +pub async fn show_users(State(pool): State) -> Result>, AuthError> { + let users = query_users(&pool).await?; + + Ok(Json(users)) +} + +pub async fn insert_user( + tx: &mut PgConnection, + payload: CreateUser, +) -> Result { + let hashed_password = hash(payload.password).await?; + + let row = sqlx::query!( + "INSERT INTO users (email, hashed_password, name) VALUES ($1, $2, $3) RETURNING id", + payload.email, + hashed_password, + payload.name, + ) + .fetch_one(tx) + .await + .inspect_err(|e| tracing::error!("Failed to insert user: {e}"))?; + + Ok(row.id) +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlx::Connection; + + #[sqlx::test] + async fn test_postgres_connection(pool: PgPool) { + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = conn.ping().await; + + assert!(result.is_ok()); + } +} diff --git a/src/user_mgmt/jwt.rs b/src/user_mgmt/jwt.rs new file mode 100644 index 0000000..03e4a6d --- /dev/null +++ b/src/user_mgmt/jwt.rs @@ -0,0 +1,84 @@ +use crate::configuration::get_configuration; +use chrono::{Duration, Utc}; +use jsonwebtoken::{errors::Result, DecodingKey, EncodingKey, Header, TokenData, Validation}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + Admin, + User, +} + +/// Our claims struct, it needs to derive `Serialize` and/or `Deserialize` +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct Claims { + // aud: String, // Optional. Audience + exp: i64, // Required (validate_exp defaults to true in validation). Expiration time (as UTC timestamp) + iat: i64, // Optional. Issued at (as UTC timestamp) + // iss: String, // Optional. Issuer + // nbf: usize, // Optional. Not Before (as UTC timestamp) + pub sub: uuid::Uuid, // Optional. Subject (whom token refers to) + pub role: Role, +} + +impl Display for Claims { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "User ID: {}", self.sub) + } +} + +impl Claims { + pub fn new(sub: uuid::Uuid, role: Role) -> Self { + let now = Utc::now(); + Self { + exp: (now + Duration::hours(1)).timestamp(), + iat: now.timestamp(), + sub, + role, + } + } +} + +pub fn encode(claims: &Claims) -> Result { + jsonwebtoken::encode(&Header::default(), claims, &KEYS.encoding) +} + +pub fn decode(token: &str) -> Result> { + jsonwebtoken::decode::(token, &KEYS.decoding, &Validation::default()) +} + +static KEYS: Lazy = Lazy::new(|| { + let configuration = get_configuration().expect("Failed to load configuration."); + let secret = configuration.security.secret_key; + Keys::new(secret.as_bytes()) +}); + +struct Keys { + encoding: EncodingKey, + decoding: DecodingKey, +} + +impl Keys { + fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn check_jwt_encoding_and_decoding() { + let my_claims = Claims::new(uuid::Uuid::new_v4(), Role::User); + let token = encode(&my_claims).unwrap(); + let claims = decode(&token).unwrap().claims; + assert_eq!(my_claims, claims) + } +}