diff --git a/.gitignore b/.gitignore index 6726bdd..dd9dce0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ release.tar.gz admin-setup-token.txt package-lock.json bun.lock +*.db* \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index bce7d52..a2a1704 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -837,6 +837,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "foreign-types" version = "0.3.2" @@ -861,6 +867,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -905,6 +926,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -923,8 +955,10 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -981,29 +1015,24 @@ dependencies = [ "tracing", ] -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", - "allocator-api2", -] - [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "hashlink" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown 0.14.5", + "hashbrown", ] [[package]] @@ -1249,7 +1278,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown", ] [[package]] @@ -1413,12 +1442,6 @@ dependencies = [ "unicase", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" version = "0.8.3" @@ -1457,16 +1480,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2109,6 +2122,7 @@ dependencies = [ "chrono", "clap", "dotenv", + "futures", "jsonwebtoken", "lazy_static", "mime_guess", @@ -2172,21 +2186,11 @@ dependencies = [ "der", ] -[[package]] -name = "sqlformat" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" -dependencies = [ - "nom", - "unicode_categories", -] - [[package]] name = "sqlx" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcfa89bea9500db4a0d038513d7a060566bfc51d46d1c014847049a45cce85e8" +checksum = "4410e73b3c0d8442c5f99b425d7a435b5ee0ae4167b3196771dd3f7a01be745f" dependencies = [ "sqlx-core", "sqlx-macros", @@ -2197,51 +2201,44 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d06e2f2bd861719b1f3f0c7dbe1d80c30bf59e76cf019f07d9014ed7eefb8e08" +checksum = "6a007b6936676aa9ab40207cde35daab0a04b823be8ae004368c0793b96a61e0" dependencies = [ - "atoi", - "byteorder", "bytes", "chrono", "crc", "crossbeam-queue", "either", "event-listener", - "futures-channel", "futures-core", "futures-intrusive", "futures-io", "futures-util", - "hashbrown 0.14.5", + "hashbrown", "hashlink", - "hex", "indexmap", "log", "memchr", "native-tls", "once_cell", - "paste", "percent-encoding", "serde", "serde_json", "sha2", "smallvec", - "sqlformat", - "thiserror 1.0.69", + "thiserror 2.0.11", "tokio", "tokio-stream", "tracing", "url", - "uuid", ] [[package]] name = "sqlx-macros" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f998a9defdbd48ed005a89362bd40dd2117502f15294f61c8d47034107dbbdc" +checksum = "3112e2ad78643fef903618d78cf0aec1cb3134b019730edb039b69eaf531f310" dependencies = [ "proc-macro2", "quote", @@ -2252,9 +2249,9 @@ dependencies = [ [[package]] name = "sqlx-macros-core" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d100558134176a2629d46cec0c8891ba0be8910f7896abfdb75ef4ab6f4e7ce" +checksum = "4e9f90acc5ab146a99bf5061a7eb4976b573f560bc898ef3bf8435448dd5e7ad" dependencies = [ "dotenvy", "either", @@ -2278,9 +2275,9 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936cac0ab331b14cb3921c62156d913e4c15b74fb6ec0f3146bd4ef6e4fb3c12" +checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233" dependencies = [ "atoi", "base64 0.22.1", @@ -2314,17 +2311,16 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 1.0.69", + "thiserror 2.0.11", "tracing", - "uuid", "whoami", ] [[package]] name = "sqlx-postgres" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9734dbce698c67ecf67c442f768a5e90a49b2a4d61a9f1d59f73874bd4cf0710" +checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613" dependencies = [ "atoi", "base64 0.22.1", @@ -2336,7 +2332,6 @@ dependencies = [ "etcetera", "futures-channel", "futures-core", - "futures-io", "futures-util", "hex", "hkdf", @@ -2354,17 +2349,16 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 1.0.69", + "thiserror 2.0.11", "tracing", - "uuid", "whoami", ] [[package]] name = "sqlx-sqlite" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75b419c3c1b1697833dd927bdc4c6545a620bc1bbafabd44e1efbe9afcd337e" +checksum = "f85ca71d3a5b24e64e1d08dd8fe36c6c95c339a896cc33068148906784620540" dependencies = [ "atoi", "chrono", @@ -2382,7 +2376,6 @@ dependencies = [ "sqlx-core", "tracing", "url", - "uuid", ] [[package]] @@ -2706,12 +2699,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" -[[package]] -name = "unicode_categories" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" - [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index cb8fbfd..0341831 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ actix-web = "4.4" actix-files = "0.6" actix-cors = "0.6" tokio = { version = "1.36", features = ["full"] } -sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono"] } +sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "sqlite", "chrono"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" anyhow = "1.0" @@ -31,3 +31,4 @@ lazy_static = "1.4" argon2 = "0.5.3" rand = { version = "0.8", features = ["std"] } mime_guess = "2.0.5" +futures = "0.3.31" diff --git a/README.md b/README.md index a2eac0a..595d71c 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ then check /target/release for the binary named `SimpleGit` ### From Docker ```bash docker build --build-arg API_URL=http://localhost:8080 -t simplelink . -docker run simplelink -p 8080:8080 \ +docker run -p 8080:8080 \ -e JWT_SECRET=change-me-in-production \ -e DATABASE_URL=postgres://user:password@host:port/database \ simplelink diff --git a/frontend/index.html b/frontend/index.html index e4b78ea..ffe5d61 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -1,13 +1,16 @@ - - - - - Vite + React + TS - - -
- - - + + + + + + SimpleLink + + + +
+ + + + \ No newline at end of file diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 6f2d224..1cf80b3 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -15,6 +15,20 @@ api.interceptors.request.use((config) => { return config; }); +api.interceptors.response.use( + (response) => response, + (error) => { + if (error.response?.status === 401) { + localStorage.removeItem('token'); + localStorage.removeItem('user'); + + window.dispatchEvent(new Event('unauthorized')); + } + return Promise.reject(error); + } +); + + // Auth endpoints export const login = async (email: string, password: string) => { const response = await api.post('/auth/login', { diff --git a/frontend/src/context/AuthContext.tsx b/frontend/src/context/AuthContext.tsx index a4f25db..cf5a1d8 100644 --- a/frontend/src/context/AuthContext.tsx +++ b/frontend/src/context/AuthContext.tsx @@ -23,6 +23,16 @@ export function AuthProvider({ children }: { children: React.ReactNode }) { setUser(userData); } setIsLoading(false); + + const handleUnauthorized = () => { + setUser(null); + }; + + window.addEventListener('unauthorized', handleUnauthorized); + + return () => { + window.removeEventListener('unauthorized', handleUnauthorized); + }; }, []); const login = async (email: string, password: string) => { diff --git a/migrations/sqlite/20250125000000_init.sql b/migrations/sqlite/20250125000000_init.sql new file mode 100644 index 0000000..ea15aa2 --- /dev/null +++ b/migrations/sqlite/20250125000000_init.sql @@ -0,0 +1,42 @@ +-- Enable foreign key support +PRAGMA foreign_keys = ON; + +-- Add Migration Version +CREATE TABLE IF NOT EXISTS _sqlx_migrations ( + version INTEGER PRIMARY KEY, + description TEXT NOT NULL, + installed_on TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Create users table +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email VARCHAR(255) NOT NULL UNIQUE, + password_hash TEXT NOT NULL +); + +-- Create links table +CREATE TABLE links ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + original_url TEXT NOT NULL, + short_code VARCHAR(8) NOT NULL UNIQUE, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + clicks INTEGER NOT NULL DEFAULT 0, + user_id INTEGER, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- Create clicks table +CREATE TABLE clicks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + link_id INTEGER, + source TEXT, + query_source TEXT, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (link_id) REFERENCES links(id) +); + +-- Create indexes +CREATE INDEX idx_short_code ON links(short_code); +CREATE INDEX idx_user_id ON links(user_id); +CREATE INDEX idx_link_id ON clicks(link_id); diff --git a/src/handlers.rs b/src/handlers.rs index 7e91540..2e7ba74 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -2,8 +2,8 @@ use crate::auth::AuthenticatedUser; use crate::{ error::AppError, models::{ - AuthResponse, Claims, ClickStats, CreateLink, Link, LoginRequest, RegisterRequest, - SourceStats, User, UserResponse, + AuthResponse, Claims, ClickStats, CreateLink, DatabasePool, Link, LoginRequest, + RegisterRequest, SourceStats, User, UserResponse, }, AppState, }; @@ -16,6 +16,7 @@ use argon2::{Argon2, PasswordHash, PasswordHasher}; use jsonwebtoken::{encode, EncodingKey, Header}; use lazy_static::lazy_static; use regex::Regex; +use sqlx::{Postgres, Sqlite}; lazy_static! { static ref VALID_CODE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9_-]{1,32}$").unwrap(); @@ -27,53 +28,88 @@ pub async fn create_short_url( payload: web::Json, ) -> Result { tracing::debug!("Creating short URL with user_id: {}", user.user_id); - validate_url(&payload.url)?; let short_code = if let Some(ref custom_code) = payload.custom_code { validate_custom_code(custom_code)?; - tracing::debug!("Checking if custom code {} exists", custom_code); - // Check if code is already taken - if let Some(_) = sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = $1") - .bind(custom_code) - .fetch_optional(&state.db) - .await? - { + // Check if code exists using match on pool type + let exists = match &state.db { + DatabasePool::Postgres(pool) => { + sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = $1") + .bind(custom_code) + .fetch_optional(pool) + .await? + } + DatabasePool::Sqlite(pool) => { + sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = ?1") + .bind(custom_code) + .fetch_optional(pool) + .await? + } + }; + + if exists.is_some() { return Err(AppError::InvalidInput( "Custom code already taken".to_string(), )); } - custom_code.clone() } else { generate_short_code() }; - // Start transaction - let mut tx = state.db.begin().await?; + // Start transaction based on pool type + let result = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; - tracing::debug!("Inserting new link with short_code: {}", short_code); - let link = sqlx::query_as::<_, Link>( - "INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *", - ) - .bind(&payload.url) - .bind(&short_code) - .bind(user.user_id) - .fetch_one(&mut *tx) - .await?; - - if let Some(ref source) = payload.source { - tracing::debug!("Adding click source: {}", source); - sqlx::query("INSERT INTO clicks (link_id, source) VALUES ($1, $2)") - .bind(link.id) - .bind(source) - .execute(&mut *tx) + let link = sqlx::query_as::<_, Link>( + "INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *" + ) + .bind(&payload.url) + .bind(&short_code) + .bind(user.user_id) + .fetch_one(&mut *tx) .await?; - } - tx.commit().await?; - Ok(HttpResponse::Created().json(link)) + if let Some(ref source) = payload.source { + sqlx::query("INSERT INTO clicks (link_id, source) VALUES ($1, $2)") + .bind(link.id) + .bind(source) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + link + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + + let link = sqlx::query_as::<_, Link>( + "INSERT INTO links (original_url, short_code, user_id) VALUES (?1, ?2, ?3) RETURNING *" + ) + .bind(&payload.url) + .bind(&short_code) + .bind(user.user_id) + .fetch_one(&mut *tx) + .await?; + + if let Some(ref source) = payload.source { + sqlx::query("INSERT INTO clicks (link_id, source) VALUES (?1, ?2)") + .bind(link.id) + .bind(source) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + link + } + }; + + Ok(HttpResponse::Created().json(result)) } fn validate_custom_code(code: &str) -> Result<(), AppError> { @@ -120,33 +156,76 @@ pub async fn redirect_to_url( .and_then(|q| web::Query::>::from_query(q).ok()) .and_then(|params| params.get("source").cloned()); - let mut tx = state.db.begin().await?; - - let link = sqlx::query_as::<_, Link>( - "UPDATE links SET clicks = clicks + 1 WHERE short_code = $1 RETURNING *", - ) - .bind(&short_code) - .fetch_optional(&mut *tx) - .await?; + let link = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let link = sqlx::query_as::<_, Link>( + "UPDATE links SET clicks = clicks + 1 WHERE short_code = $1 RETURNING *", + ) + .bind(&short_code) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + link + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let link = sqlx::query_as::<_, Link>( + "UPDATE links SET clicks = clicks + 1 WHERE short_code = ?1 RETURNING *", + ) + .bind(&short_code) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + link + } + }; match link { Some(link) => { - // Record click with both user agent and query source - let user_agent = req - .headers() - .get("user-agent") - .and_then(|h| h.to_str().ok()) - .unwrap_or("unknown") - .to_string(); + // Handle click recording based on database type + match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let user_agent = req + .headers() + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("unknown") + .to_string(); - sqlx::query("INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)") - .bind(link.id) - .bind(user_agent) - .bind(query_source) - .execute(&mut *tx) - .await?; + sqlx::query( + "INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)", + ) + .bind(link.id) + .bind(user_agent) + .bind(query_source) + .execute(&mut *tx) + .await?; - tx.commit().await?; + tx.commit().await?; + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let user_agent = req + .headers() + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("unknown") + .to_string(); + + sqlx::query( + "INSERT INTO clicks (link_id, source, query_source) VALUES (?1, ?2, ?3)", + ) + .bind(link.id) + .bind(user_agent) + .bind(query_source) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + } + }; Ok(HttpResponse::TemporaryRedirect() .append_header(("Location", link.original_url)) @@ -160,20 +239,38 @@ pub async fn get_all_links( state: web::Data, user: AuthenticatedUser, ) -> Result { - let links = sqlx::query_as::<_, Link>( - "SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC", - ) - .bind(user.user_id) - .fetch_all(&state.db) - .await?; + let links = match &state.db { + DatabasePool::Postgres(pool) => { + sqlx::query_as::<_, Link>( + "SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC", + ) + .bind(user.user_id) + .fetch_all(pool) + .await? + } + DatabasePool::Sqlite(pool) => { + sqlx::query_as::<_, Link>( + "SELECT * FROM links WHERE user_id = ?1 ORDER BY created_at DESC", + ) + .bind(user.user_id) + .fetch_all(pool) + .await? + } + }; Ok(HttpResponse::Ok().json(links)) } pub async fn health_check(state: web::Data) -> impl Responder { - match sqlx::query("SELECT 1").execute(&state.db).await { - Ok(_) => HttpResponse::Ok().json("Healthy"), - Err(_) => HttpResponse::ServiceUnavailable().json("Database unavailable"), + let is_healthy = match &state.db { + DatabasePool::Postgres(pool) => sqlx::query("SELECT 1").execute(pool).await.is_ok(), + DatabasePool::Sqlite(pool) => sqlx::query("SELECT 1").execute(pool).await.is_ok(), + }; + + if is_healthy { + HttpResponse::Ok().json("Healthy") + } else { + HttpResponse::ServiceUnavailable().json("Database unavailable") } } @@ -190,11 +287,26 @@ pub async fn register( payload: web::Json, ) -> Result { // Check if any users exist - let user_count = sqlx::query!("SELECT COUNT(*) as count FROM users") - .fetch_one(&state.db) - .await? - .count - .unwrap_or(0); + let user_count = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let count = sqlx::query_as::("SELECT COUNT(*)::bigint FROM users") + .fetch_one(&mut *tx) + .await? + .0; + tx.commit().await?; + count + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let count = sqlx::query_as::("SELECT COUNT(*) FROM users") + .fetch_one(&mut *tx) + .await? + .0; + tx.commit().await?; + count + } + }; // If users exist, registration is closed - no exceptions if user_count > 0 { @@ -210,9 +322,27 @@ pub async fn register( } // Check if email already exists - let exists = sqlx::query!("SELECT id FROM users WHERE email = $1", payload.email) - .fetch_optional(&state.db) - .await?; + let exists = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let exists = + sqlx::query_as::("SELECT id FROM users WHERE email = $1") + .bind(&payload.email) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + exists + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let exists = sqlx::query_as::("SELECT id FROM users WHERE email = ?") + .bind(&payload.email) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + exists + } + }; if exists.is_some() { return Err(AppError::Auth("Email already registered".to_string())); @@ -225,14 +355,33 @@ pub async fn register( .map_err(|e| AppError::Auth(e.to_string()))? .to_string(); - let user = sqlx::query_as!( - User, - "INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *", - payload.email, - password_hash - ) - .fetch_one(&state.db) - .await?; + // Insert new user + let user = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let user = sqlx::query_as::( + "INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *", + ) + .bind(&payload.email) + .bind(&password_hash) + .fetch_one(&mut *tx) + .await?; + tx.commit().await?; + user + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let user = sqlx::query_as::( + "INSERT INTO users (email, password_hash) VALUES (?, ?) RETURNING *", + ) + .bind(&payload.email) + .bind(&password_hash) + .fetch_one(&mut *tx) + .await?; + tx.commit().await?; + user + } + }; let claims = Claims::new(user.id); let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string()); @@ -256,10 +405,27 @@ pub async fn login( state: web::Data, payload: web::Json, ) -> Result { - let user = sqlx::query_as!(User, "SELECT * FROM users WHERE email = $1", payload.email) - .fetch_optional(&state.db) - .await? - .ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?; + let user = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let user = sqlx::query_as::("SELECT * FROM users WHERE email = $1") + .bind(&payload.email) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + user + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let user = sqlx::query_as::("SELECT * FROM users WHERE email = ?") + .bind(&payload.email) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + user + } + } + .ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?; let argon2 = Argon2::default(); let parsed_hash = @@ -297,34 +463,69 @@ pub async fn delete_link( ) -> Result { let link_id = path.into_inner(); - // Start transaction - let mut tx = state.db.begin().await?; + match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; - // Verify the link belongs to the user - let link = sqlx::query!( - "SELECT id FROM links WHERE id = $1 AND user_id = $2", - link_id, - user.user_id - ) - .fetch_optional(&mut *tx) - .await?; + // Verify the link belongs to the user + let link = sqlx::query_as::( + "SELECT id FROM links WHERE id = $1 AND user_id = $2", + ) + .bind(link_id) + .bind(user.user_id) + .fetch_optional(&mut *tx) + .await?; - if link.is_none() { - return Err(AppError::NotFound); + if link.is_none() { + return Err(AppError::NotFound); + } + + // Delete associated clicks first due to foreign key constraint + sqlx::query("DELETE FROM clicks WHERE link_id = $1") + .bind(link_id) + .execute(&mut *tx) + .await?; + + // Delete the link + sqlx::query("DELETE FROM links WHERE id = $1") + .bind(link_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + + // Verify the link belongs to the user + let link = sqlx::query_as::( + "SELECT id FROM links WHERE id = ? AND user_id = ?", + ) + .bind(link_id) + .bind(user.user_id) + .fetch_optional(&mut *tx) + .await?; + + if link.is_none() { + return Err(AppError::NotFound); + } + + // Delete associated clicks first due to foreign key constraint + sqlx::query("DELETE FROM clicks WHERE link_id = ?") + .bind(link_id) + .execute(&mut *tx) + .await?; + + // Delete the link + sqlx::query("DELETE FROM links WHERE id = ?") + .bind(link_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + } } - // Delete associated clicks first due to foreign key constraint - sqlx::query!("DELETE FROM clicks WHERE link_id = $1", link_id) - .execute(&mut *tx) - .await?; - - // Delete the link - sqlx::query!("DELETE FROM links WHERE id = $1", link_id) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - Ok(HttpResponse::NoContent().finish()) } @@ -336,34 +537,73 @@ pub async fn get_link_clicks( let link_id = path.into_inner(); // Verify the link belongs to the user - let link = sqlx::query!( - "SELECT id FROM links WHERE id = $1 AND user_id = $2", - link_id, - user.user_id - ) - .fetch_optional(&state.db) - .await?; + let link = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let link = sqlx::query_as::( + "SELECT id FROM links WHERE id = $1 AND user_id = $2", + ) + .bind(link_id) + .bind(user.user_id) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + link + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let link = sqlx::query_as::( + "SELECT id FROM links WHERE id = ? AND user_id = ?", + ) + .bind(link_id) + .bind(user.user_id) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + link + } + }; if link.is_none() { return Err(AppError::NotFound); } - let clicks = sqlx::query_as!( - ClickStats, - r#" - SELECT - DATE(created_at)::date as "date!", - COUNT(*)::bigint as "clicks!" - FROM clicks - WHERE link_id = $1 - GROUP BY DATE(created_at) - ORDER BY DATE(created_at) ASC -- Changed from DESC to ASC - LIMIT 30 - "#, - link_id - ) - .fetch_all(&state.db) - .await?; + let clicks = match &state.db { + DatabasePool::Postgres(pool) => { + sqlx::query_as::( + r#" + SELECT + DATE(created_at)::date as "date!", + COUNT(*)::bigint as "clicks!" + FROM clicks + WHERE link_id = $1 + GROUP BY DATE(created_at) + ORDER BY DATE(created_at) ASC + LIMIT 30 + "#, + ) + .bind(link_id) + .fetch_all(pool) + .await? + } + DatabasePool::Sqlite(pool) => { + sqlx::query_as::( + r#" + SELECT + DATE(created_at) as "date!", + COUNT(*) as "clicks!" + FROM clicks + WHERE link_id = ? + GROUP BY DATE(created_at) + ORDER BY DATE(created_at) ASC + LIMIT 30 + "#, + ) + .bind(link_id) + .fetch_all(pool) + .await? + } + }; Ok(HttpResponse::Ok().json(clicks)) } @@ -376,36 +616,77 @@ pub async fn get_link_sources( let link_id = path.into_inner(); // Verify the link belongs to the user - let link = sqlx::query!( - "SELECT id FROM links WHERE id = $1 AND user_id = $2", - link_id, - user.user_id - ) - .fetch_optional(&state.db) - .await?; + let link = match &state.db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let link = sqlx::query_as::( + "SELECT id FROM links WHERE id = $1 AND user_id = $2", + ) + .bind(link_id) + .bind(user.user_id) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + link + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let link = sqlx::query_as::( + "SELECT id FROM links WHERE id = ? AND user_id = ?", + ) + .bind(link_id) + .bind(user.user_id) + .fetch_optional(&mut *tx) + .await?; + tx.commit().await?; + link + } + }; if link.is_none() { return Err(AppError::NotFound); } - let sources = sqlx::query_as!( - SourceStats, - r#" - SELECT - query_source as "source!", - COUNT(*)::bigint as "count!" - FROM clicks - WHERE link_id = $1 - AND query_source IS NOT NULL - AND query_source != '' - GROUP BY query_source - ORDER BY COUNT(*) DESC - LIMIT 10 - "#, - link_id - ) - .fetch_all(&state.db) - .await?; + let sources = match &state.db { + DatabasePool::Postgres(pool) => { + sqlx::query_as::( + r#" + SELECT + query_source as "source!", + COUNT(*)::bigint as "count!" + FROM clicks + WHERE link_id = $1 + AND query_source IS NOT NULL + AND query_source != '' + GROUP BY query_source + ORDER BY COUNT(*) DESC + LIMIT 10 + "#, + ) + .bind(link_id) + .fetch_all(pool) + .await? + } + DatabasePool::Sqlite(pool) => { + sqlx::query_as::( + r#" + SELECT + query_source as "source!", + COUNT(*) as "count!" + FROM clicks + WHERE link_id = ? + AND query_source IS NOT NULL + AND query_source != '' + GROUP BY query_source + ORDER BY COUNT(*) DESC + LIMIT 10 + "#, + ) + .bind(link_id) + .fetch_all(pool) + .await? + } + }; Ok(HttpResponse::Ok().json(sources)) } diff --git a/src/lib.rs b/src/lib.rs index a169cb7..d8bc670 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,14 @@ +use anyhow::Result; use rand::Rng; -use sqlx::PgPool; +use sqlx::migrate::MigrateDatabase; +use sqlx::postgres::PgPoolOptions; +use sqlx::{Postgres, Sqlite}; use std::fs::File; use std::io::Write; use tracing::info; +use models::DatabasePool; + pub mod auth; pub mod error; pub mod handlers; @@ -11,17 +16,90 @@ pub mod models; #[derive(Clone)] pub struct AppState { - pub db: PgPool, + pub db: DatabasePool, pub admin_token: Option, } -pub async fn check_and_generate_admin_token(pool: &sqlx::PgPool) -> anyhow::Result> { +pub async fn create_db_pool() -> Result { + let database_url = std::env::var("DATABASE_URL").ok(); + + match database_url { + Some(url) if url.starts_with("postgres://") => { + info!("Using PostgreSQL database"); + let pool = PgPoolOptions::new() + .max_connections(5) + .acquire_timeout(std::time::Duration::from_secs(3)) + .connect(&url) + .await?; + + Ok(DatabasePool::Postgres(pool)) + } + _ => { + info!("No PostgreSQL connection string found, using SQLite"); + + // Create a data directory if it doesn't exist + let data_dir = std::path::Path::new("data"); + if !data_dir.exists() { + std::fs::create_dir_all(data_dir)?; + } + + let db_path = data_dir.join("simplelink.db"); + let sqlite_url = format!("sqlite://{}", db_path.display()); + + // Check if database exists and create it if it doesn't + if !Sqlite::database_exists(&sqlite_url).await.unwrap_or(false) { + info!("Creating new SQLite database at {}", db_path.display()); + Sqlite::create_database(&sqlite_url).await?; + info!("Database created successfully"); + } else { + info!("Database already exists"); + } + + let pool = sqlx::sqlite::SqlitePoolOptions::new() + .max_connections(5) + .connect(&sqlite_url) + .await?; + + Ok(DatabasePool::Sqlite(pool)) + } + } +} + +pub async fn run_migrations(pool: &DatabasePool) -> Result<()> { + match pool { + DatabasePool::Postgres(pool) => { + // Use the root migrations directory for postgres + sqlx::migrate!().run(pool).await?; + } + DatabasePool::Sqlite(pool) => { + sqlx::migrate!("./migrations/sqlite").run(pool).await?; + } + } + Ok(()) +} + +pub async fn check_and_generate_admin_token(db: &DatabasePool) -> anyhow::Result> { // Check if any users exist - let user_count = sqlx::query!("SELECT COUNT(*) as count FROM users") - .fetch_one(pool) - .await? - .count - .unwrap_or(0); + let user_count = match db { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let count = sqlx::query_as::("SELECT COUNT(*)::bigint FROM users") + .fetch_one(&mut *tx) + .await? + .0; + tx.commit().await?; + count + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let count = sqlx::query_as::("SELECT COUNT(*) FROM users") + .fetch_one(&mut *tx) + .await? + .0; + tx.commit().await?; + count + } + }; if user_count == 0 { // Generate a random token using simple characters diff --git a/src/main.rs b/src/main.rs index ca8bc16..fc60fc4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,8 +3,8 @@ use actix_web::{web, App, HttpResponse, HttpServer}; use anyhow::Result; use rust_embed::RustEmbed; use simplelink::check_and_generate_admin_token; +use simplelink::{create_db_pool, run_migrations}; use simplelink::{handlers, AppState}; -use sqlx::postgres::PgPoolOptions; use tracing::info; #[derive(RustEmbed)] @@ -31,18 +31,9 @@ async fn main() -> Result<()> { // Initialize logging tracing_subscriber::fmt::init(); - // Database connection string from environment - let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - // Create database connection pool - let pool = PgPoolOptions::new() - .max_connections(5) - .acquire_timeout(std::time::Duration::from_secs(3)) - .connect(&database_url) - .await?; - - // Run database migrations - sqlx::migrate!("./migrations").run(&pool).await?; + let pool = create_db_pool().await?; + run_migrations(&pool).await?; let admin_token = check_and_generate_admin_token(&pool).await?; diff --git a/src/models.rs b/src/models.rs index 05f60b6..1705dff 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,80 @@ +use anyhow::Result; +use chrono::NaiveDate; +use futures::future::BoxFuture; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgRow; +use sqlx::sqlite::SqliteRow; +use sqlx::FromRow; +use sqlx::Pool; +use sqlx::Postgres; +use sqlx::Sqlite; +use sqlx::Transaction; use std::time::{SystemTime, UNIX_EPOCH}; -use chrono::NaiveDate; -use serde::{Deserialize, Serialize}; -use sqlx::FromRow; +#[derive(Clone)] +pub enum DatabasePool { + Postgres(Pool), + Sqlite(Pool), +} + +impl DatabasePool { + pub async fn begin(&self) -> Result> { + match self { + DatabasePool::Postgres(pool) => Ok(Box::new(pool.begin().await?)), + DatabasePool::Sqlite(pool) => Ok(Box::new(pool.begin().await?)), + } + } + + pub async fn fetch_optional(&self, pg_query: &str, sqlite_query: &str) -> Result> + where + T: for<'r> FromRow<'r, PgRow> + for<'r> FromRow<'r, SqliteRow> + Send + Sync + Unpin, + { + match self { + DatabasePool::Postgres(pool) => { + Ok(sqlx::query_as(pg_query).fetch_optional(pool).await?) + } + DatabasePool::Sqlite(pool) => { + Ok(sqlx::query_as(sqlite_query).fetch_optional(pool).await?) + } + } + } + + pub async fn execute(&self, pg_query: &str, sqlite_query: &str) -> Result<()> { + match self { + DatabasePool::Postgres(pool) => { + sqlx::query(pg_query).execute(pool).await?; + Ok(()) + } + DatabasePool::Sqlite(pool) => { + sqlx::query(sqlite_query).execute(pool).await?; + Ok(()) + } + } + } + + pub async fn transaction<'a, F, R>(&'a self, f: F) -> Result + where + F: for<'c> Fn(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result> + + for<'c> Fn(&'c mut Transaction<'_, Sqlite>) -> BoxFuture<'c, Result> + + Copy, + R: Send + 'static, + { + match self { + DatabasePool::Postgres(pool) => { + let mut tx = pool.begin().await?; + let result = f(&mut tx).await?; + tx.commit().await?; + Ok(result) + } + DatabasePool::Sqlite(pool) => { + let mut tx = pool.begin().await?; + let result = f(&mut tx).await?; + tx.commit().await?; + Ok(result) + } + } + } +} #[derive(Debug, Serialize, Deserialize)] pub struct Claims {