add sqlite

This commit is contained in:
Wavering Ana 2025-01-29 18:52:49 -05:00
parent 8a3ee20a9d
commit daa1323b88
8 changed files with 707 additions and 254 deletions

1
.gitignore vendored
View file

@ -11,3 +11,4 @@ release.tar.gz
admin-setup-token.txt
package-lock.json
bun.lock
*.db*

137
Cargo.lock generated
View file

@ -837,6 +837,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]]
name = "foreign-types"
version = "0.3.2"
@ -861,6 +867,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -905,6 +926,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@ -923,8 +955,10 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@ -981,29 +1015,24 @@ dependencies = [
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]]
name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
]
[[package]]
name = "hashlink"
version = "0.9.1"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af"
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [
"hashbrown 0.14.5",
"hashbrown",
]
[[package]]
@ -1249,7 +1278,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
dependencies = [
"equivalent",
"hashbrown 0.15.2",
"hashbrown",
]
[[package]]
@ -1413,12 +1442,6 @@ dependencies = [
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.3"
@ -1457,16 +1480,6 @@ dependencies = [
"tempfile",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@ -2109,6 +2122,7 @@ dependencies = [
"chrono",
"clap",
"dotenv",
"futures",
"jsonwebtoken",
"lazy_static",
"mime_guess",
@ -2172,21 +2186,11 @@ dependencies = [
"der",
]
[[package]]
name = "sqlformat"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790"
dependencies = [
"nom",
"unicode_categories",
]
[[package]]
name = "sqlx"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcfa89bea9500db4a0d038513d7a060566bfc51d46d1c014847049a45cce85e8"
checksum = "4410e73b3c0d8442c5f99b425d7a435b5ee0ae4167b3196771dd3f7a01be745f"
dependencies = [
"sqlx-core",
"sqlx-macros",
@ -2197,51 +2201,44 @@ dependencies = [
[[package]]
name = "sqlx-core"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d06e2f2bd861719b1f3f0c7dbe1d80c30bf59e76cf019f07d9014ed7eefb8e08"
checksum = "6a007b6936676aa9ab40207cde35daab0a04b823be8ae004368c0793b96a61e0"
dependencies = [
"atoi",
"byteorder",
"bytes",
"chrono",
"crc",
"crossbeam-queue",
"either",
"event-listener",
"futures-channel",
"futures-core",
"futures-intrusive",
"futures-io",
"futures-util",
"hashbrown 0.14.5",
"hashbrown",
"hashlink",
"hex",
"indexmap",
"log",
"memchr",
"native-tls",
"once_cell",
"paste",
"percent-encoding",
"serde",
"serde_json",
"sha2",
"smallvec",
"sqlformat",
"thiserror 1.0.69",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tracing",
"url",
"uuid",
]
[[package]]
name = "sqlx-macros"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f998a9defdbd48ed005a89362bd40dd2117502f15294f61c8d47034107dbbdc"
checksum = "3112e2ad78643fef903618d78cf0aec1cb3134b019730edb039b69eaf531f310"
dependencies = [
"proc-macro2",
"quote",
@ -2252,9 +2249,9 @@ dependencies = [
[[package]]
name = "sqlx-macros-core"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d100558134176a2629d46cec0c8891ba0be8910f7896abfdb75ef4ab6f4e7ce"
checksum = "4e9f90acc5ab146a99bf5061a7eb4976b573f560bc898ef3bf8435448dd5e7ad"
dependencies = [
"dotenvy",
"either",
@ -2278,9 +2275,9 @@ dependencies = [
[[package]]
name = "sqlx-mysql"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "936cac0ab331b14cb3921c62156d913e4c15b74fb6ec0f3146bd4ef6e4fb3c12"
checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233"
dependencies = [
"atoi",
"base64 0.22.1",
@ -2314,17 +2311,16 @@ dependencies = [
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 1.0.69",
"thiserror 2.0.11",
"tracing",
"uuid",
"whoami",
]
[[package]]
name = "sqlx-postgres"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9734dbce698c67ecf67c442f768a5e90a49b2a4d61a9f1d59f73874bd4cf0710"
checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613"
dependencies = [
"atoi",
"base64 0.22.1",
@ -2336,7 +2332,6 @@ dependencies = [
"etcetera",
"futures-channel",
"futures-core",
"futures-io",
"futures-util",
"hex",
"hkdf",
@ -2354,17 +2349,16 @@ dependencies = [
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 1.0.69",
"thiserror 2.0.11",
"tracing",
"uuid",
"whoami",
]
[[package]]
name = "sqlx-sqlite"
version = "0.8.1"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75b419c3c1b1697833dd927bdc4c6545a620bc1bbafabd44e1efbe9afcd337e"
checksum = "f85ca71d3a5b24e64e1d08dd8fe36c6c95c339a896cc33068148906784620540"
dependencies = [
"atoi",
"chrono",
@ -2382,7 +2376,6 @@ dependencies = [
"sqlx-core",
"tracing",
"url",
"uuid",
]
[[package]]
@ -2706,12 +2699,6 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "untrusted"
version = "0.9.0"

View file

@ -14,7 +14,7 @@ actix-web = "4.4"
actix-files = "0.6"
actix-cors = "0.6"
tokio = { version = "1.36", features = ["full"] }
sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono"] }
sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "sqlite", "chrono"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
anyhow = "1.0"
@ -31,3 +31,4 @@ lazy_static = "1.4"
argon2 = "0.5.3"
rand = { version = "0.8", features = ["std"] }
mime_guess = "2.0.5"
futures = "0.3.31"

View file

@ -0,0 +1,42 @@
-- Enable foreign key support
PRAGMA foreign_keys = ON;
-- Add Migration Version
CREATE TABLE IF NOT EXISTS _sqlx_migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
installed_on TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Create users table
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash TEXT NOT NULL
);
-- Create links table
CREATE TABLE links (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_url TEXT NOT NULL,
short_code VARCHAR(8) NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
clicks INTEGER NOT NULL DEFAULT 0,
user_id INTEGER,
FOREIGN KEY (user_id) REFERENCES users(id)
);
-- Create clicks table
CREATE TABLE clicks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
link_id INTEGER,
source TEXT,
query_source TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (link_id) REFERENCES links(id)
);
-- Create indexes
CREATE INDEX idx_short_code ON links(short_code);
CREATE INDEX idx_user_id ON links(user_id);
CREATE INDEX idx_link_id ON clicks(link_id);

View file

@ -2,8 +2,8 @@ use crate::auth::AuthenticatedUser;
use crate::{
error::AppError,
models::{
AuthResponse, Claims, ClickStats, CreateLink, Link, LoginRequest, RegisterRequest,
SourceStats, User, UserResponse,
AuthResponse, Claims, ClickStats, CreateLink, DatabasePool, Link, LoginRequest,
RegisterRequest, SourceStats, User, UserResponse,
},
AppState,
};
@ -16,6 +16,7 @@ use argon2::{Argon2, PasswordHash, PasswordHasher};
use jsonwebtoken::{encode, EncodingKey, Header};
use lazy_static::lazy_static;
use regex::Regex;
use sqlx::{Postgres, Sqlite};
lazy_static! {
static ref VALID_CODE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9_-]{1,32}$").unwrap();
@ -27,53 +28,88 @@ pub async fn create_short_url(
payload: web::Json<CreateLink>,
) -> Result<impl Responder, AppError> {
tracing::debug!("Creating short URL with user_id: {}", user.user_id);
validate_url(&payload.url)?;
let short_code = if let Some(ref custom_code) = payload.custom_code {
validate_custom_code(custom_code)?;
tracing::debug!("Checking if custom code {} exists", custom_code);
// Check if code is already taken
if let Some(_) = sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = $1")
.bind(custom_code)
.fetch_optional(&state.db)
.await?
{
// Check if code exists using match on pool type
let exists = match &state.db {
DatabasePool::Postgres(pool) => {
sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = $1")
.bind(custom_code)
.fetch_optional(pool)
.await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = ?1")
.bind(custom_code)
.fetch_optional(pool)
.await?
}
};
if exists.is_some() {
return Err(AppError::InvalidInput(
"Custom code already taken".to_string(),
));
}
custom_code.clone()
} else {
generate_short_code()
};
// Start transaction
let mut tx = state.db.begin().await?;
// Start transaction based on pool type
let result = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
tracing::debug!("Inserting new link with short_code: {}", short_code);
let link = sqlx::query_as::<_, Link>(
"INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *",
)
.bind(&payload.url)
.bind(&short_code)
.bind(user.user_id)
.fetch_one(&mut *tx)
.await?;
if let Some(ref source) = payload.source {
tracing::debug!("Adding click source: {}", source);
sqlx::query("INSERT INTO clicks (link_id, source) VALUES ($1, $2)")
.bind(link.id)
.bind(source)
.execute(&mut *tx)
let link = sqlx::query_as::<_, Link>(
"INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *"
)
.bind(&payload.url)
.bind(&short_code)
.bind(user.user_id)
.fetch_one(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(HttpResponse::Created().json(link))
if let Some(ref source) = payload.source {
sqlx::query("INSERT INTO clicks (link_id, source) VALUES ($1, $2)")
.bind(link.id)
.bind(source)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<_, Link>(
"INSERT INTO links (original_url, short_code, user_id) VALUES (?1, ?2, ?3) RETURNING *"
)
.bind(&payload.url)
.bind(&short_code)
.bind(user.user_id)
.fetch_one(&mut *tx)
.await?;
if let Some(ref source) = payload.source {
sqlx::query("INSERT INTO clicks (link_id, source) VALUES (?1, ?2)")
.bind(link.id)
.bind(source)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
link
}
};
Ok(HttpResponse::Created().json(result))
}
fn validate_custom_code(code: &str) -> Result<(), AppError> {
@ -120,33 +156,76 @@ pub async fn redirect_to_url(
.and_then(|q| web::Query::<std::collections::HashMap<String, String>>::from_query(q).ok())
.and_then(|params| params.get("source").cloned());
let mut tx = state.db.begin().await?;
let link = sqlx::query_as::<_, Link>(
"UPDATE links SET clicks = clicks + 1 WHERE short_code = $1 RETURNING *",
)
.bind(&short_code)
.fetch_optional(&mut *tx)
.await?;
let link = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<_, Link>(
"UPDATE links SET clicks = clicks + 1 WHERE short_code = $1 RETURNING *",
)
.bind(&short_code)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<_, Link>(
"UPDATE links SET clicks = clicks + 1 WHERE short_code = ?1 RETURNING *",
)
.bind(&short_code)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
};
match link {
Some(link) => {
// Record click with both user agent and query source
let user_agent = req
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string();
// Handle click recording based on database type
match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let user_agent = req
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string();
sqlx::query("INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)")
.bind(link.id)
.bind(user_agent)
.bind(query_source)
.execute(&mut *tx)
.await?;
sqlx::query(
"INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)",
)
.bind(link.id)
.bind(user_agent)
.bind(query_source)
.execute(&mut *tx)
.await?;
tx.commit().await?;
tx.commit().await?;
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let user_agent = req
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string();
sqlx::query(
"INSERT INTO clicks (link_id, source, query_source) VALUES (?1, ?2, ?3)",
)
.bind(link.id)
.bind(user_agent)
.bind(query_source)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
};
Ok(HttpResponse::TemporaryRedirect()
.append_header(("Location", link.original_url))
@ -160,20 +239,38 @@ pub async fn get_all_links(
state: web::Data<AppState>,
user: AuthenticatedUser,
) -> Result<impl Responder, AppError> {
let links = sqlx::query_as::<_, Link>(
"SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC",
)
.bind(user.user_id)
.fetch_all(&state.db)
.await?;
let links = match &state.db {
DatabasePool::Postgres(pool) => {
sqlx::query_as::<_, Link>(
"SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC",
)
.bind(user.user_id)
.fetch_all(pool)
.await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<_, Link>(
"SELECT * FROM links WHERE user_id = ?1 ORDER BY created_at DESC",
)
.bind(user.user_id)
.fetch_all(pool)
.await?
}
};
Ok(HttpResponse::Ok().json(links))
}
pub async fn health_check(state: web::Data<AppState>) -> impl Responder {
match sqlx::query("SELECT 1").execute(&state.db).await {
Ok(_) => HttpResponse::Ok().json("Healthy"),
Err(_) => HttpResponse::ServiceUnavailable().json("Database unavailable"),
let is_healthy = match &state.db {
DatabasePool::Postgres(pool) => sqlx::query("SELECT 1").execute(pool).await.is_ok(),
DatabasePool::Sqlite(pool) => sqlx::query("SELECT 1").execute(pool).await.is_ok(),
};
if is_healthy {
HttpResponse::Ok().json("Healthy")
} else {
HttpResponse::ServiceUnavailable().json("Database unavailable")
}
}
@ -190,11 +287,26 @@ pub async fn register(
payload: web::Json<RegisterRequest>,
) -> Result<impl Responder, AppError> {
// Check if any users exist
let user_count = sqlx::query!("SELECT COUNT(*) as count FROM users")
.fetch_one(&state.db)
.await?
.count
.unwrap_or(0);
let user_count = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Postgres, (i64,)>("SELECT COUNT(*)::bigint FROM users")
.fetch_one(&mut *tx)
.await?
.0;
tx.commit().await?;
count
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Sqlite, (i64,)>("SELECT COUNT(*) FROM users")
.fetch_one(&mut *tx)
.await?
.0;
tx.commit().await?;
count
}
};
// If users exist, registration is closed - no exceptions
if user_count > 0 {
@ -210,9 +322,27 @@ pub async fn register(
}
// Check if email already exists
let exists = sqlx::query!("SELECT id FROM users WHERE email = $1", payload.email)
.fetch_optional(&state.db)
.await?;
let exists = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let exists =
sqlx::query_as::<Postgres, (i32,)>("SELECT id FROM users WHERE email = $1")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
exists
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let exists = sqlx::query_as::<Sqlite, (i32,)>("SELECT id FROM users WHERE email = ?")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
exists
}
};
if exists.is_some() {
return Err(AppError::Auth("Email already registered".to_string()));
@ -225,14 +355,33 @@ pub async fn register(
.map_err(|e| AppError::Auth(e.to_string()))?
.to_string();
let user = sqlx::query_as!(
User,
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *",
payload.email,
password_hash
)
.fetch_one(&state.db)
.await?;
// Insert new user
let user = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Postgres, User>(
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *",
)
.bind(&payload.email)
.bind(&password_hash)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
user
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Sqlite, User>(
"INSERT INTO users (email, password_hash) VALUES (?, ?) RETURNING *",
)
.bind(&payload.email)
.bind(&password_hash)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
user
}
};
let claims = Claims::new(user.id);
let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string());
@ -256,10 +405,27 @@ pub async fn login(
state: web::Data<AppState>,
payload: web::Json<LoginRequest>,
) -> Result<impl Responder, AppError> {
let user = sqlx::query_as!(User, "SELECT * FROM users WHERE email = $1", payload.email)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?;
let user = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Postgres, User>("SELECT * FROM users WHERE email = $1")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
user
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Sqlite, User>("SELECT * FROM users WHERE email = ?")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
user
}
}
.ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?;
let argon2 = Argon2::default();
let parsed_hash =
@ -297,34 +463,69 @@ pub async fn delete_link(
) -> Result<impl Responder, AppError> {
let link_id = path.into_inner();
// Start transaction
let mut tx = state.db.begin().await?;
match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
// Verify the link belongs to the user
let link = sqlx::query!(
"SELECT id FROM links WHERE id = $1 AND user_id = $2",
link_id,
user.user_id
)
.fetch_optional(&mut *tx)
.await?;
// Verify the link belongs to the user
let link = sqlx::query_as::<Postgres, (i32,)>(
"SELECT id FROM links WHERE id = $1 AND user_id = $2",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
if link.is_none() {
return Err(AppError::NotFound);
if link.is_none() {
return Err(AppError::NotFound);
}
// Delete associated clicks first due to foreign key constraint
sqlx::query("DELETE FROM clicks WHERE link_id = $1")
.bind(link_id)
.execute(&mut *tx)
.await?;
// Delete the link
sqlx::query("DELETE FROM links WHERE id = $1")
.bind(link_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
// Verify the link belongs to the user
let link = sqlx::query_as::<Sqlite, (i32,)>(
"SELECT id FROM links WHERE id = ? AND user_id = ?",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
if link.is_none() {
return Err(AppError::NotFound);
}
// Delete associated clicks first due to foreign key constraint
sqlx::query("DELETE FROM clicks WHERE link_id = ?")
.bind(link_id)
.execute(&mut *tx)
.await?;
// Delete the link
sqlx::query("DELETE FROM links WHERE id = ?")
.bind(link_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
}
// Delete associated clicks first due to foreign key constraint
sqlx::query!("DELETE FROM clicks WHERE link_id = $1", link_id)
.execute(&mut *tx)
.await?;
// Delete the link
sqlx::query!("DELETE FROM links WHERE id = $1", link_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(HttpResponse::NoContent().finish())
}
@ -336,34 +537,73 @@ pub async fn get_link_clicks(
let link_id = path.into_inner();
// Verify the link belongs to the user
let link = sqlx::query!(
"SELECT id FROM links WHERE id = $1 AND user_id = $2",
link_id,
user.user_id
)
.fetch_optional(&state.db)
.await?;
let link = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Postgres, (i32,)>(
"SELECT id FROM links WHERE id = $1 AND user_id = $2",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Sqlite, (i32,)>(
"SELECT id FROM links WHERE id = ? AND user_id = ?",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
};
if link.is_none() {
return Err(AppError::NotFound);
}
let clicks = sqlx::query_as!(
ClickStats,
r#"
SELECT
DATE(created_at)::date as "date!",
COUNT(*)::bigint as "clicks!"
FROM clicks
WHERE link_id = $1
GROUP BY DATE(created_at)
ORDER BY DATE(created_at) ASC -- Changed from DESC to ASC
LIMIT 30
"#,
link_id
)
.fetch_all(&state.db)
.await?;
let clicks = match &state.db {
DatabasePool::Postgres(pool) => {
sqlx::query_as::<Postgres, ClickStats>(
r#"
SELECT
DATE(created_at)::date as "date!",
COUNT(*)::bigint as "clicks!"
FROM clicks
WHERE link_id = $1
GROUP BY DATE(created_at)
ORDER BY DATE(created_at) ASC
LIMIT 30
"#,
)
.bind(link_id)
.fetch_all(pool)
.await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<Sqlite, ClickStats>(
r#"
SELECT
DATE(created_at) as "date!",
COUNT(*) as "clicks!"
FROM clicks
WHERE link_id = ?
GROUP BY DATE(created_at)
ORDER BY DATE(created_at) ASC
LIMIT 30
"#,
)
.bind(link_id)
.fetch_all(pool)
.await?
}
};
Ok(HttpResponse::Ok().json(clicks))
}
@ -376,36 +616,77 @@ pub async fn get_link_sources(
let link_id = path.into_inner();
// Verify the link belongs to the user
let link = sqlx::query!(
"SELECT id FROM links WHERE id = $1 AND user_id = $2",
link_id,
user.user_id
)
.fetch_optional(&state.db)
.await?;
let link = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Postgres, (i32,)>(
"SELECT id FROM links WHERE id = $1 AND user_id = $2",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Sqlite, (i32,)>(
"SELECT id FROM links WHERE id = ? AND user_id = ?",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
};
if link.is_none() {
return Err(AppError::NotFound);
}
let sources = sqlx::query_as!(
SourceStats,
r#"
SELECT
query_source as "source!",
COUNT(*)::bigint as "count!"
FROM clicks
WHERE link_id = $1
AND query_source IS NOT NULL
AND query_source != ''
GROUP BY query_source
ORDER BY COUNT(*) DESC
LIMIT 10
"#,
link_id
)
.fetch_all(&state.db)
.await?;
let sources = match &state.db {
DatabasePool::Postgres(pool) => {
sqlx::query_as::<Postgres, SourceStats>(
r#"
SELECT
query_source as "source!",
COUNT(*)::bigint as "count!"
FROM clicks
WHERE link_id = $1
AND query_source IS NOT NULL
AND query_source != ''
GROUP BY query_source
ORDER BY COUNT(*) DESC
LIMIT 10
"#,
)
.bind(link_id)
.fetch_all(pool)
.await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<Sqlite, SourceStats>(
r#"
SELECT
query_source as "source!",
COUNT(*) as "count!"
FROM clicks
WHERE link_id = ?
AND query_source IS NOT NULL
AND query_source != ''
GROUP BY query_source
ORDER BY COUNT(*) DESC
LIMIT 10
"#,
)
.bind(link_id)
.fetch_all(pool)
.await?
}
};
Ok(HttpResponse::Ok().json(sources))
}

View file

@ -1,9 +1,14 @@
use anyhow::Result;
use rand::Rng;
use sqlx::PgPool;
use sqlx::migrate::MigrateDatabase;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Postgres, Sqlite};
use std::fs::File;
use std::io::Write;
use tracing::info;
use models::DatabasePool;
pub mod auth;
pub mod error;
pub mod handlers;
@ -11,17 +16,90 @@ pub mod models;
#[derive(Clone)]
pub struct AppState {
pub db: PgPool,
pub db: DatabasePool,
pub admin_token: Option<String>,
}
pub async fn check_and_generate_admin_token(pool: &sqlx::PgPool) -> anyhow::Result<Option<String>> {
pub async fn create_db_pool() -> Result<DatabasePool> {
let database_url = std::env::var("DATABASE_URL").ok();
match database_url {
Some(url) if url.starts_with("postgres://") => {
info!("Using PostgreSQL database");
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&url)
.await?;
Ok(DatabasePool::Postgres(pool))
}
_ => {
info!("No PostgreSQL connection string found, using SQLite");
// Create a data directory if it doesn't exist
let data_dir = std::path::Path::new("data");
if !data_dir.exists() {
std::fs::create_dir_all(data_dir)?;
}
let db_path = data_dir.join("simplelink.db");
let sqlite_url = format!("sqlite://{}", db_path.display());
// Check if database exists and create it if it doesn't
if !Sqlite::database_exists(&sqlite_url).await.unwrap_or(false) {
info!("Creating new SQLite database at {}", db_path.display());
Sqlite::create_database(&sqlite_url).await?;
info!("Database created successfully");
} else {
info!("Database already exists");
}
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(5)
.connect(&sqlite_url)
.await?;
Ok(DatabasePool::Sqlite(pool))
}
}
}
pub async fn run_migrations(pool: &DatabasePool) -> Result<()> {
match pool {
DatabasePool::Postgres(pool) => {
// Use the root migrations directory for postgres
sqlx::migrate!().run(pool).await?;
}
DatabasePool::Sqlite(pool) => {
sqlx::migrate!("./migrations/sqlite").run(pool).await?;
}
}
Ok(())
}
pub async fn check_and_generate_admin_token(db: &DatabasePool) -> anyhow::Result<Option<String>> {
// Check if any users exist
let user_count = sqlx::query!("SELECT COUNT(*) as count FROM users")
.fetch_one(pool)
.await?
.count
.unwrap_or(0);
let user_count = match db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Postgres, (i64,)>("SELECT COUNT(*)::bigint FROM users")
.fetch_one(&mut *tx)
.await?
.0;
tx.commit().await?;
count
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Sqlite, (i64,)>("SELECT COUNT(*) FROM users")
.fetch_one(&mut *tx)
.await?
.0;
tx.commit().await?;
count
}
};
if user_count == 0 {
// Generate a random token using simple characters

View file

@ -3,8 +3,8 @@ use actix_web::{web, App, HttpResponse, HttpServer};
use anyhow::Result;
use rust_embed::RustEmbed;
use simplelink::check_and_generate_admin_token;
use simplelink::{create_db_pool, run_migrations};
use simplelink::{handlers, AppState};
use sqlx::postgres::PgPoolOptions;
use tracing::info;
#[derive(RustEmbed)]
@ -31,18 +31,9 @@ async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt::init();
// Database connection string from environment
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
// Create database connection pool
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&database_url)
.await?;
// Run database migrations
sqlx::migrate!("./migrations").run(&pool).await?;
let pool = create_db_pool().await?;
run_migrations(&pool).await?;
let admin_token = check_and_generate_admin_token(&pool).await?;

View file

@ -1,8 +1,80 @@
use anyhow::Result;
use chrono::NaiveDate;
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use sqlx::postgres::PgRow;
use sqlx::sqlite::SqliteRow;
use sqlx::FromRow;
use sqlx::Pool;
use sqlx::Postgres;
use sqlx::Sqlite;
use sqlx::Transaction;
use std::time::{SystemTime, UNIX_EPOCH};
use chrono::NaiveDate;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
#[derive(Clone)]
pub enum DatabasePool {
Postgres(Pool<Postgres>),
Sqlite(Pool<Sqlite>),
}
impl DatabasePool {
pub async fn begin(&self) -> Result<Box<dyn std::any::Any + Send>> {
match self {
DatabasePool::Postgres(pool) => Ok(Box::new(pool.begin().await?)),
DatabasePool::Sqlite(pool) => Ok(Box::new(pool.begin().await?)),
}
}
pub async fn fetch_optional<T>(&self, pg_query: &str, sqlite_query: &str) -> Result<Option<T>>
where
T: for<'r> FromRow<'r, PgRow> + for<'r> FromRow<'r, SqliteRow> + Send + Sync + Unpin,
{
match self {
DatabasePool::Postgres(pool) => {
Ok(sqlx::query_as(pg_query).fetch_optional(pool).await?)
}
DatabasePool::Sqlite(pool) => {
Ok(sqlx::query_as(sqlite_query).fetch_optional(pool).await?)
}
}
}
pub async fn execute(&self, pg_query: &str, sqlite_query: &str) -> Result<()> {
match self {
DatabasePool::Postgres(pool) => {
sqlx::query(pg_query).execute(pool).await?;
Ok(())
}
DatabasePool::Sqlite(pool) => {
sqlx::query(sqlite_query).execute(pool).await?;
Ok(())
}
}
}
pub async fn transaction<'a, F, R>(&'a self, f: F) -> Result<R>
where
F: for<'c> Fn(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<R>>
+ for<'c> Fn(&'c mut Transaction<'_, Sqlite>) -> BoxFuture<'c, Result<R>>
+ Copy,
R: Send + 'static,
{
match self {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let result = f(&mut tx).await?;
tx.commit().await?;
Ok(result)
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let result = f(&mut tx).await?;
tx.commit().await?;
Ok(result)
}
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {