This commit is contained in:
waveringana 2025-01-25 03:16:35 -05:00
parent c048377bcc
commit 9a43049978
11 changed files with 471 additions and 66 deletions

41
src/auth.rs Normal file
View file

@ -0,0 +1,41 @@
use actix_web::{dev::Payload, FromRequest, HttpRequest};
use jsonwebtoken::{decode, DecodingKey, Validation};
use std::future::{ready, Ready};
use crate::{error::AppError, models::Claims};
pub struct AuthenticatedUser {
pub user_id: i32,
}
impl FromRequest for AuthenticatedUser {
type Error = AppError;
type Future = Ready<Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let auth_header = req.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok());
if let Some(auth_header) = auth_header {
if auth_header.starts_with("Bearer ") {
let token = &auth_header[7..];
let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string());
match decode::<Claims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::default()
) {
Ok(token_data) => {
return ready(Ok(AuthenticatedUser {
user_id: token_data.claims.sub,
}));
}
Err(_) => return ready(Err(AppError::Unauthorized)),
}
}
}
ready(Err(AppError::Unauthorized))
}
}

View file

@ -11,14 +11,22 @@ pub enum AppError {
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Unauthorized")]
Unauthorized,
}
impl ResponseError for AppError {
fn error_response(&self) -> HttpResponse {
match self {
AppError::NotFound => HttpResponse::NotFound().json("Not found"),
AppError::Database(_) => HttpResponse::InternalServerError().json("Internal server error"),
AppError::Database(err) => HttpResponse::InternalServerError().json(format!("Database error: {}", err)), // Show actual error
AppError::InvalidInput(msg) => HttpResponse::BadRequest().json(msg),
AppError::Auth(msg) => HttpResponse::BadRequest().json(msg),
AppError::Unauthorized => HttpResponse::Unauthorized().json("Unauthorized"),
}
}
}
}

View file

@ -1,7 +1,10 @@
use actix_web::{web, HttpResponse, Responder, HttpRequest};
use crate::{AppState, error::AppError, models::{CreateLink, Link}};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation, errors::Error as JwtError};use crate::{error::AppError, models::{AuthResponse, Claims, CreateLink, Link, LoginRequest, RegisterRequest, User, UserResponse}, AppState};
use regex::Regex;
use argon2::{password_hash::{rand_core::OsRng, SaltString}, PasswordVerifier};
use lazy_static::lazy_static;
use argon2::{Argon2, PasswordHash, PasswordHasher};
use crate::auth::{AuthenticatedUser};
lazy_static! {
static ref VALID_CODE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9_-]{1,32}$").unwrap();
@ -9,14 +12,17 @@ lazy_static! {
pub async fn create_short_url(
state: web::Data<AppState>,
user: AuthenticatedUser,
payload: web::Json<CreateLink>,
req: HttpRequest,
) -> Result<impl Responder, AppError> {
tracing::debug!("Creating short URL with user_id: {}", user.user_id);
validate_url(&payload.url)?;
let short_code = if let Some(ref custom_code) = payload.custom_code {
validate_custom_code(custom_code)?;
tracing::debug!("Checking if custom code {} exists", custom_code);
// Check if code is already taken
if let Some(_) = sqlx::query_as::<_, Link>(
"SELECT * FROM links WHERE short_code = $1"
@ -36,16 +42,19 @@ pub async fn create_short_url(
// Start transaction
let mut tx = state.db.begin().await?;
tracing::debug!("Inserting new link with short_code: {}", short_code);
let link = sqlx::query_as::<_, Link>(
"INSERT INTO links (original_url, short_code) VALUES ($1, $2) RETURNING *"
"INSERT INTO links (original_url, short_code, user_id) VALUES ($1, $2, $3) RETURNING *"
)
.bind(&payload.url)
.bind(&short_code)
.bind(user.user_id)
.fetch_one(&mut *tx)
.await?;
if let Some(ref source) = payload.source {
tracing::debug!("Adding click source: {}", source);
sqlx::query(
"INSERT INTO clicks (link_id, source) VALUES ($1, $2)"
)
@ -54,7 +63,7 @@ pub async fn create_short_url(
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(HttpResponse::Created().json(link))
}
@ -94,6 +103,12 @@ pub async fn redirect_to_url(
) -> Result<impl Responder, AppError> {
let short_code = path.into_inner();
// Extract query source if present
let query_source = req.uri()
.query()
.and_then(|q| web::Query::<std::collections::HashMap<String, String>>::from_query(q).ok())
.and_then(|params| params.get("source").cloned());
let mut tx = state.db.begin().await?;
let link = sqlx::query_as::<_, Link>(
@ -105,7 +120,7 @@ pub async fn redirect_to_url(
match link {
Some(link) => {
// Record click with user agent as source
// Record click with both user agent and query source
let user_agent = req.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
@ -113,10 +128,11 @@ pub async fn redirect_to_url(
.to_string();
sqlx::query(
"INSERT INTO clicks (link_id, source) VALUES ($1, $2)"
"INSERT INTO clicks (link_id, source, query_source) VALUES ($1, $2, $3)"
)
.bind(link.id)
.bind(user_agent)
.bind(query_source)
.execute(&mut *tx)
.await?;
@ -132,10 +148,12 @@ pub async fn redirect_to_url(
pub async fn get_all_links(
state: web::Data<AppState>,
user: AuthenticatedUser,
) -> Result<impl Responder, AppError> {
let links = sqlx::query_as::<_, Link>(
"SELECT * FROM links ORDER BY created_at DESC"
"SELECT * FROM links WHERE user_id = $1 ORDER BY created_at DESC"
)
.bind(user.user_id)
.fetch_all(&state.db)
.await?;
@ -158,3 +176,88 @@ fn generate_short_code() -> String {
let uuid = Uuid::new_v4();
encode(uuid.as_u128() as u64).chars().take(8).collect()
}
pub async fn register(
state: web::Data<AppState>,
payload: web::Json<RegisterRequest>,
) -> Result<impl Responder, AppError> {
let exists = sqlx::query!(
"SELECT id FROM users WHERE email = $1",
payload.email
)
.fetch_optional(&state.db)
.await?;
if exists.is_some() {
return Err(AppError::Auth("Email already registered".to_string()));
}
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2.hash_password(payload.password.as_bytes(), &salt)
.map_err(|e| AppError::Auth(e.to_string()))?
.to_string();
let user = sqlx::query_as!(
User,
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING *",
payload.email,
password_hash
)
.fetch_one(&state.db)
.await?;
let claims = Claims::new(user.id);
let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string());
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes())
).map_err(|e| AppError::Auth(e.to_string()))?;
Ok(HttpResponse::Ok().json(AuthResponse {
token,
user: UserResponse {
id: user.id,
email: user.email,
},
}))
}
pub async fn login(
state: web::Data<AppState>,
payload: web::Json<LoginRequest>,
) -> Result<impl Responder, AppError> {
let user = sqlx::query_as!(
User,
"SELECT * FROM users WHERE email = $1",
payload.email
)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| AppError::Auth("Invalid credentials".to_string()))?;
let argon2 = Argon2::default();
let parsed_hash = PasswordHash::new(&user.password_hash)
.map_err(|e| AppError::Auth(e.to_string()))?;
if argon2.verify_password(payload.password.as_bytes(), &parsed_hash).is_err() {
return Err(AppError::Auth("Invalid credentials".to_string()));
}
let claims = Claims::new(user.id);
let secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string());
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes())
).map_err(|e| AppError::Auth(e.to_string()))?;
Ok(HttpResponse::Ok().json(AuthResponse {
token,
user: UserResponse {
id: user.id,
email: user.email,
},
}))
}

View file

@ -7,6 +7,7 @@ use tracing::info;
mod error;
mod handlers;
mod models;
mod auth;
#[derive(Clone)]
pub struct AppState {
@ -35,7 +36,7 @@ async fn main() -> Result<()> {
.await?;
// Run database migrations
sqlx::migrate!("./migrations").run(&pool).await?;
//sqlx::migrate!("./migrations").run(&pool).await?;
let state = AppState { db: pool };
@ -55,7 +56,9 @@ async fn main() -> Result<()> {
.service(
web::scope("/api")
.route("/shorten", web::post().to(handlers::create_short_url))
.route("/links", web::get().to(handlers::get_all_links)),
.route("/links", web::get().to(handlers::get_all_links))
.route("/auth/register", web::post().to(handlers::register))
.route("/auth/login", web::post().to(handlers::login)),
)
.service(

View file

@ -0,0 +1,30 @@
-- Create users table
CREATE TABLE users (
id SERIAL PRIMARY KEY,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash TEXT NOT NULL
);
-- Create links table with user_id from the start
CREATE TABLE links (
id SERIAL PRIMARY KEY,
original_url TEXT NOT NULL,
short_code VARCHAR(8) NOT NULL UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
clicks BIGINT NOT NULL DEFAULT 0,
user_id INTEGER REFERENCES users(id)
);
-- Create clicks table for tracking
CREATE TABLE clicks (
id SERIAL PRIMARY KEY,
link_id INTEGER REFERENCES links(id),
source TEXT,
query_source TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Create indexes
CREATE INDEX idx_short_code ON links(short_code);
CREATE INDEX idx_user_id ON links(user_id);
CREATE INDEX idx_link_id ON clicks(link_id);

View file

@ -1,6 +1,28 @@
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: i32, // user id
pub exp: usize,
}
impl Claims {
pub fn new(user_id: i32) -> Self {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as usize + 24 * 60 * 60; // 24 hours from now
Self {
sub: user_id,
exp,
}
}
}
#[derive(Deserialize)]
pub struct CreateLink {
pub url: String,
@ -11,7 +33,7 @@ pub struct CreateLink {
#[derive(Serialize, FromRow)]
pub struct Link {
pub id: i32,
pub user_id: i32,
pub user_id: Option<i32>,
pub original_url: String,
pub short_code: String,
pub created_at: chrono::DateTime<chrono::Utc>,