Merge pull request #7 from WaveringAna/v0.1

fix session expiration, add sqlite
This commit is contained in:
Wavering Ana 2025-01-29 19:34:22 -05:00 committed by GitHub
commit fd36858e20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 746 additions and 266 deletions

1
.gitignore vendored
View file

@ -11,3 +11,4 @@ release.tar.gz
admin-setup-token.txt admin-setup-token.txt
package-lock.json package-lock.json
bun.lock bun.lock
*.db*

137
Cargo.lock generated
View file

@ -837,6 +837,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]] [[package]]
name = "foreign-types" name = "foreign-types"
version = "0.3.2" version = "0.3.2"
@ -861,6 +867,21 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.31"
@ -905,6 +926,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.31" version = "0.3.31"
@ -923,8 +955,10 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr", "memchr",
@ -981,29 +1015,24 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.2" version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
]
[[package]] [[package]]
name = "hashlink" name = "hashlink"
version = "0.9.1" version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [ dependencies = [
"hashbrown 0.14.5", "hashbrown",
] ]
[[package]] [[package]]
@ -1249,7 +1278,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown 0.15.2", "hashbrown",
] ]
[[package]] [[package]]
@ -1413,12 +1442,6 @@ dependencies = [
"unicase", "unicase",
] ]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.8.3" version = "0.8.3"
@ -1457,16 +1480,6 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.46.0" version = "0.46.0"
@ -2109,6 +2122,7 @@ dependencies = [
"chrono", "chrono",
"clap", "clap",
"dotenv", "dotenv",
"futures",
"jsonwebtoken", "jsonwebtoken",
"lazy_static", "lazy_static",
"mime_guess", "mime_guess",
@ -2172,21 +2186,11 @@ dependencies = [
"der", "der",
] ]
[[package]]
name = "sqlformat"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790"
dependencies = [
"nom",
"unicode_categories",
]
[[package]] [[package]]
name = "sqlx" name = "sqlx"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcfa89bea9500db4a0d038513d7a060566bfc51d46d1c014847049a45cce85e8" checksum = "4410e73b3c0d8442c5f99b425d7a435b5ee0ae4167b3196771dd3f7a01be745f"
dependencies = [ dependencies = [
"sqlx-core", "sqlx-core",
"sqlx-macros", "sqlx-macros",
@ -2197,51 +2201,44 @@ dependencies = [
[[package]] [[package]]
name = "sqlx-core" name = "sqlx-core"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d06e2f2bd861719b1f3f0c7dbe1d80c30bf59e76cf019f07d9014ed7eefb8e08" checksum = "6a007b6936676aa9ab40207cde35daab0a04b823be8ae004368c0793b96a61e0"
dependencies = [ dependencies = [
"atoi",
"byteorder",
"bytes", "bytes",
"chrono", "chrono",
"crc", "crc",
"crossbeam-queue", "crossbeam-queue",
"either", "either",
"event-listener", "event-listener",
"futures-channel",
"futures-core", "futures-core",
"futures-intrusive", "futures-intrusive",
"futures-io", "futures-io",
"futures-util", "futures-util",
"hashbrown 0.14.5", "hashbrown",
"hashlink", "hashlink",
"hex",
"indexmap", "indexmap",
"log", "log",
"memchr", "memchr",
"native-tls", "native-tls",
"once_cell", "once_cell",
"paste",
"percent-encoding", "percent-encoding",
"serde", "serde",
"serde_json", "serde_json",
"sha2", "sha2",
"smallvec", "smallvec",
"sqlformat", "thiserror 2.0.11",
"thiserror 1.0.69",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"url", "url",
"uuid",
] ]
[[package]] [[package]]
name = "sqlx-macros" name = "sqlx-macros"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f998a9defdbd48ed005a89362bd40dd2117502f15294f61c8d47034107dbbdc" checksum = "3112e2ad78643fef903618d78cf0aec1cb3134b019730edb039b69eaf531f310"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2252,9 +2249,9 @@ dependencies = [
[[package]] [[package]]
name = "sqlx-macros-core" name = "sqlx-macros-core"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d100558134176a2629d46cec0c8891ba0be8910f7896abfdb75ef4ab6f4e7ce" checksum = "4e9f90acc5ab146a99bf5061a7eb4976b573f560bc898ef3bf8435448dd5e7ad"
dependencies = [ dependencies = [
"dotenvy", "dotenvy",
"either", "either",
@ -2278,9 +2275,9 @@ dependencies = [
[[package]] [[package]]
name = "sqlx-mysql" name = "sqlx-mysql"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "936cac0ab331b14cb3921c62156d913e4c15b74fb6ec0f3146bd4ef6e4fb3c12" checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233"
dependencies = [ dependencies = [
"atoi", "atoi",
"base64 0.22.1", "base64 0.22.1",
@ -2314,17 +2311,16 @@ dependencies = [
"smallvec", "smallvec",
"sqlx-core", "sqlx-core",
"stringprep", "stringprep",
"thiserror 1.0.69", "thiserror 2.0.11",
"tracing", "tracing",
"uuid",
"whoami", "whoami",
] ]
[[package]] [[package]]
name = "sqlx-postgres" name = "sqlx-postgres"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9734dbce698c67ecf67c442f768a5e90a49b2a4d61a9f1d59f73874bd4cf0710" checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613"
dependencies = [ dependencies = [
"atoi", "atoi",
"base64 0.22.1", "base64 0.22.1",
@ -2336,7 +2332,6 @@ dependencies = [
"etcetera", "etcetera",
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-io",
"futures-util", "futures-util",
"hex", "hex",
"hkdf", "hkdf",
@ -2354,17 +2349,16 @@ dependencies = [
"smallvec", "smallvec",
"sqlx-core", "sqlx-core",
"stringprep", "stringprep",
"thiserror 1.0.69", "thiserror 2.0.11",
"tracing", "tracing",
"uuid",
"whoami", "whoami",
] ]
[[package]] [[package]]
name = "sqlx-sqlite" name = "sqlx-sqlite"
version = "0.8.1" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75b419c3c1b1697833dd927bdc4c6545a620bc1bbafabd44e1efbe9afcd337e" checksum = "f85ca71d3a5b24e64e1d08dd8fe36c6c95c339a896cc33068148906784620540"
dependencies = [ dependencies = [
"atoi", "atoi",
"chrono", "chrono",
@ -2382,7 +2376,6 @@ dependencies = [
"sqlx-core", "sqlx-core",
"tracing", "tracing",
"url", "url",
"uuid",
] ]
[[package]] [[package]]
@ -2706,12 +2699,6 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"

View file

@ -14,7 +14,7 @@ actix-web = "4.4"
actix-files = "0.6" actix-files = "0.6"
actix-cors = "0.6" actix-cors = "0.6"
tokio = { version = "1.36", features = ["full"] } tokio = { version = "1.36", features = ["full"] }
sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono"] } sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "sqlite", "chrono"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
anyhow = "1.0" anyhow = "1.0"
@ -31,3 +31,4 @@ lazy_static = "1.4"
argon2 = "0.5.3" argon2 = "0.5.3"
rand = { version = "0.8", features = ["std"] } rand = { version = "0.8", features = ["std"] }
mime_guess = "2.0.5" mime_guess = "2.0.5"
futures = "0.3.31"

View file

@ -30,7 +30,7 @@ then check /target/release for the binary named `SimpleGit`
### From Docker ### From Docker
```bash ```bash
docker build --build-arg API_URL=http://localhost:8080 -t simplelink . docker build --build-arg API_URL=http://localhost:8080 -t simplelink .
docker run simplelink -p 8080:8080 \ docker run -p 8080:8080 \
-e JWT_SECRET=change-me-in-production \ -e JWT_SECRET=change-me-in-production \
-e DATABASE_URL=postgres://user:password@host:port/database \ -e DATABASE_URL=postgres://user:password@host:port/database \
simplelink simplelink

View file

@ -1,13 +1,16 @@
<!doctype html> <!doctype html>
<html lang="en"> <html lang="en">
<head>
<head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vite + React + TS</title> <title>SimpleLink</title>
</head> </head>
<body>
<body>
<div id="root"></div> <div id="root"></div>
<script type="module" src="/src/main.tsx"></script> <script type="module" src="/src/main.tsx"></script>
</body> </body>
</html> </html>

View file

@ -15,6 +15,20 @@ api.interceptors.request.use((config) => {
return config; return config;
}); });
api.interceptors.response.use(
(response) => response,
(error) => {
if (error.response?.status === 401) {
localStorage.removeItem('token');
localStorage.removeItem('user');
window.dispatchEvent(new Event('unauthorized'));
}
return Promise.reject(error);
}
);
// Auth endpoints // Auth endpoints
export const login = async (email: string, password: string) => { export const login = async (email: string, password: string) => {
const response = await api.post<AuthResponse>('/auth/login', { const response = await api.post<AuthResponse>('/auth/login', {

View file

@ -23,6 +23,16 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
setUser(userData); setUser(userData);
} }
setIsLoading(false); setIsLoading(false);
const handleUnauthorized = () => {
setUser(null);
};
window.addEventListener('unauthorized', handleUnauthorized);
return () => {
window.removeEventListener('unauthorized', handleUnauthorized);
};
}, []); }, []);
const login = async (email: string, password: string) => { const login = async (email: string, password: string) => {

View file

@ -0,0 +1,42 @@
-- Enable foreign key support
PRAGMA foreign_keys = ON;
-- Add Migration Version
CREATE TABLE IF NOT EXISTS _sqlx_migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
installed_on TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Create users table
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash TEXT NOT NULL
);
-- Create links table
CREATE TABLE links (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_url TEXT NOT NULL,
short_code VARCHAR(8) NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
clicks INTEGER NOT NULL DEFAULT 0,
user_id INTEGER,
FOREIGN KEY (user_id) REFERENCES users(id)
);
-- Create clicks table
CREATE TABLE clicks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
link_id INTEGER,
source TEXT,
query_source TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (link_id) REFERENCES links(id)
);
-- Create indexes
CREATE INDEX idx_short_code ON links(short_code);
CREATE INDEX idx_user_id ON links(user_id);
CREATE INDEX idx_link_id ON clicks(link_id);

View file

@ -2,8 +2,8 @@ use crate::auth::AuthenticatedUser;
use crate::{ use crate::{
error::AppError, error::AppError,
models::{ models::{
AuthResponse, Claims, ClickStats, CreateLink, Link, LoginRequest, RegisterRequest, AuthResponse, Claims, ClickStats, CreateLink, DatabasePool, Link, LoginRequest,
SourceStats, User, UserResponse, RegisterRequest, SourceStats, User, UserResponse,
}, },
AppState, AppState,
}; };
@ -16,6 +16,7 @@ use argon2::{Argon2, PasswordHash, PasswordHasher};
use jsonwebtoken::{encode, EncodingKey, Header}; use jsonwebtoken::{encode, EncodingKey, Header};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use sqlx::{Postgres, Sqlite};
lazy_static! { lazy_static! {
static ref VALID_CODE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9_-]{1,32}$").unwrap(); static ref VALID_CODE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9_-]{1,32}$").unwrap();
@ -27,35 +28,44 @@ pub async fn create_short_url(
payload: web::Json<CreateLink>, payload: web::Json<CreateLink>,
) -> Result<impl Responder, AppError> { ) -> Result<impl Responder, AppError> {
tracing::debug!("Creating short URL with user_id: {}", user.user_id); tracing::debug!("Creating short URL with user_id: {}", user.user_id);
validate_url(&payload.url)?; validate_url(&payload.url)?;
let short_code = if let Some(ref custom_code) = payload.custom_code { let short_code = if let Some(ref custom_code) = payload.custom_code {
validate_custom_code(custom_code)?; validate_custom_code(custom_code)?;
tracing::debug!("Checking if custom code {} exists", custom_code); // Check if code exists using match on pool type
// Check if code is already taken let exists = match &state.db {
if let Some(_) = sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = $1") DatabasePool::Postgres(pool) => {
sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = $1")
.bind(custom_code) .bind(custom_code)
.fetch_optional(&state.db) .fetch_optional(pool)
.await? .await?
{ }
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<_, Link>("SELECT * FROM links WHERE short_code = ?1")
.bind(custom_code)
.fetch_optional(pool)
.await?
}
};
if exists.is_some() {
return Err(AppError::InvalidInput( return Err(AppError::InvalidInput(
"Custom code already taken".to_string(), "Custom code already taken".to_string(),
)); ));
} }
custom_code.clone() custom_code.clone()
} else { } else {
generate_short_code() generate_short_code()
}; };
// Start transaction // Start transaction based on pool type
let mut tx = state.db.begin().await?; let result = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
tracing::debug!("Inserting new link with short_code: {}", short_code);
let link = sqlx::query_as::<_, Link>( let link = sqlx::query_as::<_, Link>(
"INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *", "INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *"
) )
.bind(&payload.url) .bind(&payload.url)
.bind(&short_code) .bind(&short_code)
@ -64,7 +74,6 @@ pub async fn create_short_url(
.await?; .await?;
if let Some(ref source) = payload.source { if let Some(ref source) = payload.source {
tracing::debug!("Adding click source: {}", source);
sqlx::query("INSERT INTO clicks (link_id, source) VALUES ($1, $2)") sqlx::query("INSERT INTO clicks (link_id, source) VALUES ($1, $2)")
.bind(link.id) .bind(link.id)
.bind(source) .bind(source)
@ -73,7 +82,34 @@ pub async fn create_short_url(
} }
tx.commit().await?; tx.commit().await?;
Ok(HttpResponse::Created().json(link)) link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<_, Link>(
"INSERT INTO links (original_url, short_code, user_id) VALUES (?1, ?2, ?3) RETURNING *"
)
.bind(&payload.url)
.bind(&short_code)
.bind(user.user_id)
.fetch_one(&mut *tx)
.await?;
if let Some(ref source) = payload.source {
sqlx::query("INSERT INTO clicks (link_id, source) VALUES (?1, ?2)")
.bind(link.id)
.bind(source)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
link
}
};
Ok(HttpResponse::Created().json(result))
} }
fn validate_custom_code(code: &str) -> Result<(), AppError> { fn validate_custom_code(code: &str) -> Result<(), AppError> {
@ -120,18 +156,37 @@ pub async fn redirect_to_url(
.and_then(|q| web::Query::<std::collections::HashMap<String, String>>::from_query(q).ok()) .and_then(|q| web::Query::<std::collections::HashMap<String, String>>::from_query(q).ok())
.and_then(|params| params.get("source").cloned()); .and_then(|params| params.get("source").cloned());
let mut tx = state.db.begin().await?; let link = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<_, Link>( let link = sqlx::query_as::<_, Link>(
"UPDATE links SET clicks = clicks + 1 WHERE short_code = $1 RETURNING *", "UPDATE links SET clicks = clicks + 1 WHERE short_code = $1 RETURNING *",
) )
.bind(&short_code) .bind(&short_code)
.fetch_optional(&mut *tx) .fetch_optional(&mut *tx)
.await?; .await?;
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<_, Link>(
"UPDATE links SET clicks = clicks + 1 WHERE short_code = ?1 RETURNING *",
)
.bind(&short_code)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
};
match link { match link {
Some(link) => { Some(link) => {
// Record click with both user agent and query source // Handle click recording based on database type
match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let user_agent = req let user_agent = req
.headers() .headers()
.get("user-agent") .get("user-agent")
@ -139,7 +194,9 @@ pub async fn redirect_to_url(
.unwrap_or("unknown") .unwrap_or("unknown")
.to_string(); .to_string();
sqlx::query("INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)") sqlx::query(
"INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)",
)
.bind(link.id) .bind(link.id)
.bind(user_agent) .bind(user_agent)
.bind(query_source) .bind(query_source)
@ -147,6 +204,28 @@ pub async fn redirect_to_url(
.await?; .await?;
tx.commit().await?; tx.commit().await?;
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let user_agent = req
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string();
sqlx::query(
"INSERT INTO clicks (link_id, source, query_source) VALUES (?1, ?2, ?3)",
)
.bind(link.id)
.bind(user_agent)
.bind(query_source)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
};
Ok(HttpResponse::TemporaryRedirect() Ok(HttpResponse::TemporaryRedirect()
.append_header(("Location", link.original_url)) .append_header(("Location", link.original_url))
@ -160,20 +239,38 @@ pub async fn get_all_links(
state: web::Data<AppState>, state: web::Data<AppState>,
user: AuthenticatedUser, user: AuthenticatedUser,
) -> Result<impl Responder, AppError> { ) -> Result<impl Responder, AppError> {
let links = sqlx::query_as::<_, Link>( let links = match &state.db {
DatabasePool::Postgres(pool) => {
sqlx::query_as::<_, Link>(
"SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC", "SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC",
) )
.bind(user.user_id) .bind(user.user_id)
.fetch_all(&state.db) .fetch_all(pool)
.await?; .await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<_, Link>(
"SELECT * FROM links WHERE user_id = ?1 ORDER BY created_at DESC",
)
.bind(user.user_id)
.fetch_all(pool)
.await?
}
};
Ok(HttpResponse::Ok().json(links)) Ok(HttpResponse::Ok().json(links))
} }
pub async fn health_check(state: web::Data<AppState>) -> impl Responder { pub async fn health_check(state: web::Data<AppState>) -> impl Responder {
match sqlx::query("SELECT 1").execute(&state.db).await { let is_healthy = match &state.db {
Ok(_) => HttpResponse::Ok().json("Healthy"), DatabasePool::Postgres(pool) => sqlx::query("SELECT 1").execute(pool).await.is_ok(),
Err(_) => HttpResponse::ServiceUnavailable().json("Database unavailable"), DatabasePool::Sqlite(pool) => sqlx::query("SELECT 1").execute(pool).await.is_ok(),
};
if is_healthy {
HttpResponse::Ok().json("Healthy")
} else {
HttpResponse::ServiceUnavailable().json("Database unavailable")
} }
} }
@ -190,11 +287,26 @@ pub async fn register(
payload: web::Json<RegisterRequest>, payload: web::Json<RegisterRequest>,
) -> Result<impl Responder, AppError> { ) -> Result<impl Responder, AppError> {
// Check if any users exist // Check if any users exist
let user_count = sqlx::query!("SELECT COUNT(*) as count FROM users") let user_count = match &state.db {
.fetch_one(&state.db) DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Postgres, (i64,)>("SELECT COUNT(*)::bigint FROM users")
.fetch_one(&mut *tx)
.await? .await?
.count .0;
.unwrap_or(0); tx.commit().await?;
count
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Sqlite, (i64,)>("SELECT COUNT(*) FROM users")
.fetch_one(&mut *tx)
.await?
.0;
tx.commit().await?;
count
}
};
// If users exist, registration is closed - no exceptions // If users exist, registration is closed - no exceptions
if user_count > 0 { if user_count > 0 {
@ -210,9 +322,27 @@ pub async fn register(
} }
// Check if email already exists // Check if email already exists
let exists = sqlx::query!("SELECT id FROM users WHERE email = $1", payload.email) let exists = match &state.db {
.fetch_optional(&state.db) DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let exists =
sqlx::query_as::<Postgres, (i32,)>("SELECT id FROM users WHERE email = $1")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?; .await?;
tx.commit().await?;
exists
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let exists = sqlx::query_as::<Sqlite, (i32,)>("SELECT id FROM users WHERE email = ?")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
exists
}
};
if exists.is_some() { if exists.is_some() {
return Err(AppError::Auth("Email already registered".to_string())); return Err(AppError::Auth("Email already registered".to_string()));
@ -225,14 +355,33 @@ pub async fn register(
.map_err(|e| AppError::Auth(e.to_string()))? .map_err(|e| AppError::Auth(e.to_string()))?
.to_string(); .to_string();
let user = sqlx::query_as!( // Insert new user
User, let user = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Postgres, User>(
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *", "INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *",
payload.email,
password_hash
) )
.fetch_one(&state.db) .bind(&payload.email)
.bind(&password_hash)
.fetch_one(&mut *tx)
.await?; .await?;
tx.commit().await?;
user
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Sqlite, User>(
"INSERT INTO users (email, password_hash) VALUES (?, ?) RETURNING *",
)
.bind(&payload.email)
.bind(&password_hash)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
user
}
};
let claims = Claims::new(user.id); let claims = Claims::new(user.id);
let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string()); let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string());
@ -256,9 +405,26 @@ pub async fn login(
state: web::Data<AppState>, state: web::Data<AppState>,
payload: web::Json<LoginRequest>, payload: web::Json<LoginRequest>,
) -> Result<impl Responder, AppError> { ) -> Result<impl Responder, AppError> {
let user = sqlx::query_as!(User, "SELECT * FROM users WHERE email = $1", payload.email) let user = match &state.db {
.fetch_optional(&state.db) DatabasePool::Postgres(pool) => {
.await? let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Postgres, User>("SELECT * FROM users WHERE email = $1")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
user
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let user = sqlx::query_as::<Sqlite, User>("SELECT * FROM users WHERE email = ?")
.bind(&payload.email)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
user
}
}
.ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?; .ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?;
let argon2 = Argon2::default(); let argon2 = Argon2::default();
@ -297,15 +463,16 @@ pub async fn delete_link(
) -> Result<impl Responder, AppError> { ) -> Result<impl Responder, AppError> {
let link_id = path.into_inner(); let link_id = path.into_inner();
// Start transaction match &state.db {
let mut tx = state.db.begin().await?; DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
// Verify the link belongs to the user // Verify the link belongs to the user
let link = sqlx::query!( let link = sqlx::query_as::<Postgres, (i32,)>(
"SELECT id FROM links WHERE id = $1 AND user_id = $2", "SELECT id FROM links WHERE id = $1 AND user_id = $2",
link_id,
user.user_id
) )
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx) .fetch_optional(&mut *tx)
.await?; .await?;
@ -314,16 +481,50 @@ pub async fn delete_link(
} }
// Delete associated clicks first due to foreign key constraint // Delete associated clicks first due to foreign key constraint
sqlx::query!("DELETE FROM clicks WHERE link_id = $1", link_id) sqlx::query("DELETE FROM clicks WHERE link_id = $1")
.bind(link_id)
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
// Delete the link // Delete the link
sqlx::query!("DELETE FROM links WHERE id = $1", link_id) sqlx::query("DELETE FROM links WHERE id = $1")
.bind(link_id)
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
tx.commit().await?; tx.commit().await?;
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
// Verify the link belongs to the user
let link = sqlx::query_as::<Sqlite, (i32,)>(
"SELECT id FROM links WHERE id = ? AND user_id = ?",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
if link.is_none() {
return Err(AppError::NotFound);
}
// Delete associated clicks first due to foreign key constraint
sqlx::query("DELETE FROM clicks WHERE link_id = ?")
.bind(link_id)
.execute(&mut *tx)
.await?;
// Delete the link
sqlx::query("DELETE FROM links WHERE id = ?")
.bind(link_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
}
Ok(HttpResponse::NoContent().finish()) Ok(HttpResponse::NoContent().finish())
} }
@ -336,20 +537,40 @@ pub async fn get_link_clicks(
let link_id = path.into_inner(); let link_id = path.into_inner();
// Verify the link belongs to the user // Verify the link belongs to the user
let link = sqlx::query!( let link = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Postgres, (i32,)>(
"SELECT id FROM links WHERE id = $1 AND user_id = $2", "SELECT id FROM links WHERE id = $1 AND user_id = $2",
link_id,
user.user_id
) )
.fetch_optional(&state.db) .bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?; .await?;
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Sqlite, (i32,)>(
"SELECT id FROM links WHERE id = ? AND user_id = ?",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
};
if link.is_none() { if link.is_none() {
return Err(AppError::NotFound); return Err(AppError::NotFound);
} }
let clicks = sqlx::query_as!( let clicks = match &state.db {
ClickStats, DatabasePool::Postgres(pool) => {
sqlx::query_as::<Postgres, ClickStats>(
r#" r#"
SELECT SELECT
DATE(created_at)::date as "date!", DATE(created_at)::date as "date!",
@ -357,13 +578,32 @@ pub async fn get_link_clicks(
FROM clicks FROM clicks
WHERE link_id = $1 WHERE link_id = $1
GROUP BY DATE(created_at) GROUP BY DATE(created_at)
ORDER BY DATE(created_at) ASC -- Changed from DESC to ASC ORDER BY DATE(created_at) ASC
LIMIT 30 LIMIT 30
"#, "#,
link_id
) )
.fetch_all(&state.db) .bind(link_id)
.await?; .fetch_all(pool)
.await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<Sqlite, ClickStats>(
r#"
SELECT
DATE(created_at) as "date!",
COUNT(*) as "clicks!"
FROM clicks
WHERE link_id = ?
GROUP BY DATE(created_at)
ORDER BY DATE(created_at) ASC
LIMIT 30
"#,
)
.bind(link_id)
.fetch_all(pool)
.await?
}
};
Ok(HttpResponse::Ok().json(clicks)) Ok(HttpResponse::Ok().json(clicks))
} }
@ -376,20 +616,40 @@ pub async fn get_link_sources(
let link_id = path.into_inner(); let link_id = path.into_inner();
// Verify the link belongs to the user // Verify the link belongs to the user
let link = sqlx::query!( let link = match &state.db {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Postgres, (i32,)>(
"SELECT id FROM links WHERE id = $1 AND user_id = $2", "SELECT id FROM links WHERE id = $1 AND user_id = $2",
link_id,
user.user_id
) )
.fetch_optional(&state.db) .bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?; .await?;
tx.commit().await?;
link
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let link = sqlx::query_as::<Sqlite, (i32,)>(
"SELECT id FROM links WHERE id = ? AND user_id = ?",
)
.bind(link_id)
.bind(user.user_id)
.fetch_optional(&mut *tx)
.await?;
tx.commit().await?;
link
}
};
if link.is_none() { if link.is_none() {
return Err(AppError::NotFound); return Err(AppError::NotFound);
} }
let sources = sqlx::query_as!( let sources = match &state.db {
SourceStats, DatabasePool::Postgres(pool) => {
sqlx::query_as::<Postgres, SourceStats>(
r#" r#"
SELECT SELECT
query_source as "source!", query_source as "source!",
@ -402,10 +662,31 @@ pub async fn get_link_sources(
ORDER BY COUNT(*) DESC ORDER BY COUNT(*) DESC
LIMIT 10 LIMIT 10
"#, "#,
link_id
) )
.fetch_all(&state.db) .bind(link_id)
.await?; .fetch_all(pool)
.await?
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<Sqlite, SourceStats>(
r#"
SELECT
query_source as "source!",
COUNT(*) as "count!"
FROM clicks
WHERE link_id = ?
AND query_source IS NOT NULL
AND query_source != ''
GROUP BY query_source
ORDER BY COUNT(*) DESC
LIMIT 10
"#,
)
.bind(link_id)
.fetch_all(pool)
.await?
}
};
Ok(HttpResponse::Ok().json(sources)) Ok(HttpResponse::Ok().json(sources))
} }

View file

@ -1,9 +1,14 @@
use anyhow::Result;
use rand::Rng; use rand::Rng;
use sqlx::PgPool; use sqlx::migrate::MigrateDatabase;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Postgres, Sqlite};
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use tracing::info; use tracing::info;
use models::DatabasePool;
pub mod auth; pub mod auth;
pub mod error; pub mod error;
pub mod handlers; pub mod handlers;
@ -11,17 +16,90 @@ pub mod models;
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub db: PgPool, pub db: DatabasePool,
pub admin_token: Option<String>, pub admin_token: Option<String>,
} }
pub async fn check_and_generate_admin_token(pool: &sqlx::PgPool) -> anyhow::Result<Option<String>> { pub async fn create_db_pool() -> Result<DatabasePool> {
let database_url = std::env::var("DATABASE_URL").ok();
match database_url {
Some(url) if url.starts_with("postgres://") => {
info!("Using PostgreSQL database");
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&url)
.await?;
Ok(DatabasePool::Postgres(pool))
}
_ => {
info!("No PostgreSQL connection string found, using SQLite");
// Create a data directory if it doesn't exist
let data_dir = std::path::Path::new("data");
if !data_dir.exists() {
std::fs::create_dir_all(data_dir)?;
}
let db_path = data_dir.join("simplelink.db");
let sqlite_url = format!("sqlite://{}", db_path.display());
// Check if database exists and create it if it doesn't
if !Sqlite::database_exists(&sqlite_url).await.unwrap_or(false) {
info!("Creating new SQLite database at {}", db_path.display());
Sqlite::create_database(&sqlite_url).await?;
info!("Database created successfully");
} else {
info!("Database already exists");
}
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(5)
.connect(&sqlite_url)
.await?;
Ok(DatabasePool::Sqlite(pool))
}
}
}
pub async fn run_migrations(pool: &DatabasePool) -> Result<()> {
match pool {
DatabasePool::Postgres(pool) => {
// Use the root migrations directory for postgres
sqlx::migrate!().run(pool).await?;
}
DatabasePool::Sqlite(pool) => {
sqlx::migrate!("./migrations/sqlite").run(pool).await?;
}
}
Ok(())
}
pub async fn check_and_generate_admin_token(db: &DatabasePool) -> anyhow::Result<Option<String>> {
// Check if any users exist // Check if any users exist
let user_count = sqlx::query!("SELECT COUNT(*) as count FROM users") let user_count = match db {
.fetch_one(pool) DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Postgres, (i64,)>("SELECT COUNT(*)::bigint FROM users")
.fetch_one(&mut *tx)
.await? .await?
.count .0;
.unwrap_or(0); tx.commit().await?;
count
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let count = sqlx::query_as::<Sqlite, (i64,)>("SELECT COUNT(*) FROM users")
.fetch_one(&mut *tx)
.await?
.0;
tx.commit().await?;
count
}
};
if user_count == 0 { if user_count == 0 {
// Generate a random token using simple characters // Generate a random token using simple characters

View file

@ -3,8 +3,8 @@ use actix_web::{web, App, HttpResponse, HttpServer};
use anyhow::Result; use anyhow::Result;
use rust_embed::RustEmbed; use rust_embed::RustEmbed;
use simplelink::check_and_generate_admin_token; use simplelink::check_and_generate_admin_token;
use simplelink::{create_db_pool, run_migrations};
use simplelink::{handlers, AppState}; use simplelink::{handlers, AppState};
use sqlx::postgres::PgPoolOptions;
use tracing::info; use tracing::info;
#[derive(RustEmbed)] #[derive(RustEmbed)]
@ -31,18 +31,9 @@ async fn main() -> Result<()> {
// Initialize logging // Initialize logging
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
// Database connection string from environment
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
// Create database connection pool // Create database connection pool
let pool = PgPoolOptions::new() let pool = create_db_pool().await?;
.max_connections(5) run_migrations(&pool).await?;
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&database_url)
.await?;
// Run database migrations
sqlx::migrate!("./migrations").run(&pool).await?;
let admin_token = check_and_generate_admin_token(&pool).await?; let admin_token = check_and_generate_admin_token(&pool).await?;

View file

@ -1,8 +1,80 @@
use anyhow::Result;
use chrono::NaiveDate;
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use sqlx::postgres::PgRow;
use sqlx::sqlite::SqliteRow;
use sqlx::FromRow;
use sqlx::Pool;
use sqlx::Postgres;
use sqlx::Sqlite;
use sqlx::Transaction;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use chrono::NaiveDate; #[derive(Clone)]
use serde::{Deserialize, Serialize}; pub enum DatabasePool {
use sqlx::FromRow; Postgres(Pool<Postgres>),
Sqlite(Pool<Sqlite>),
}
impl DatabasePool {
pub async fn begin(&self) -> Result<Box<dyn std::any::Any + Send>> {
match self {
DatabasePool::Postgres(pool) => Ok(Box::new(pool.begin().await?)),
DatabasePool::Sqlite(pool) => Ok(Box::new(pool.begin().await?)),
}
}
pub async fn fetch_optional<T>(&self, pg_query: &str, sqlite_query: &str) -> Result<Option<T>>
where
T: for<'r> FromRow<'r, PgRow> + for<'r> FromRow<'r, SqliteRow> + Send + Sync + Unpin,
{
match self {
DatabasePool::Postgres(pool) => {
Ok(sqlx::query_as(pg_query).fetch_optional(pool).await?)
}
DatabasePool::Sqlite(pool) => {
Ok(sqlx::query_as(sqlite_query).fetch_optional(pool).await?)
}
}
}
pub async fn execute(&self, pg_query: &str, sqlite_query: &str) -> Result<()> {
match self {
DatabasePool::Postgres(pool) => {
sqlx::query(pg_query).execute(pool).await?;
Ok(())
}
DatabasePool::Sqlite(pool) => {
sqlx::query(sqlite_query).execute(pool).await?;
Ok(())
}
}
}
pub async fn transaction<'a, F, R>(&'a self, f: F) -> Result<R>
where
F: for<'c> Fn(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<R>>
+ for<'c> Fn(&'c mut Transaction<'_, Sqlite>) -> BoxFuture<'c, Result<R>>
+ Copy,
R: Send + 'static,
{
match self {
DatabasePool::Postgres(pool) => {
let mut tx = pool.begin().await?;
let result = f(&mut tx).await?;
tx.commit().await?;
Ok(result)
}
DatabasePool::Sqlite(pool) => {
let mut tx = pool.begin().await?;
let result = f(&mut tx).await?;
tx.commit().await?;
Ok(result)
}
}
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Claims { pub struct Claims {