only show register if no users, otherwise show only login

This commit is contained in:
Wavering Ana 2025-01-29 20:47:04 -05:00
parent daa1323b88
commit 3585ca70e8
6 changed files with 116 additions and 185 deletions

120
Cargo.lock generated
View file

@ -606,16 +606,6 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.7" version = "0.8.7"
@ -843,21 +833,6 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "form_urlencoded" name = "form_urlencoded"
version = "1.2.1" version = "1.2.1"
@ -1463,23 +1438,6 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "native-tls"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.46.0" version = "0.46.0"
@ -1568,50 +1526,6 @@ version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "openssl"
version = "0.10.68"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5"
dependencies = [
"bitflags",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.104"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@ -1953,44 +1867,12 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "schannel"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.2.0" version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "1.0.25" version = "1.0.25"
@ -2220,7 +2102,6 @@ dependencies = [
"indexmap", "indexmap",
"log", "log",
"memchr", "memchr",
"native-tls",
"once_cell", "once_cell",
"percent-encoding", "percent-encoding",
"serde", "serde",
@ -2741,7 +2622,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b"
dependencies = [ dependencies = [
"getrandom", "getrandom",
"serde",
] ]
[[package]] [[package]]

View file

@ -13,15 +13,15 @@ jsonwebtoken = "9"
actix-web = "4.4" actix-web = "4.4"
actix-files = "0.6" actix-files = "0.6"
actix-cors = "0.6" actix-cors = "0.6"
tokio = { version = "1.36", features = ["full"] } tokio = { version = "1.36", features = ["rt-multi-thread", "macros"] }
sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "sqlite", "chrono"] } sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "sqlite", "chrono"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
anyhow = "1.0" anyhow = "1.0"
thiserror = "1.0" thiserror = "1.0"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
uuid = { version = "1.7", features = ["v4", "serde"] } uuid = { version = "1.7", features = ["v4"] } # Remove serde if not using UUID serialization
base62 = "2.0" base62 = "2.0"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
dotenv = "0.15" dotenv = "0.15"

View file

@ -72,4 +72,9 @@ export const getLinkSourceStats = async (id: number) => {
return response.data; return response.data;
}; };
export const checkFirstUser = async () => {
const response = await api.get<{ isFirstUser: boolean }>('/auth/check-first-user');
return response.data.isFirstUser;
};
export { api }; export { api };

View file

@ -1,4 +1,4 @@
import { useState } from 'react' import { useState, useEffect } from 'react'
import { useForm } from 'react-hook-form' import { useForm } from 'react-hook-form'
import { z } from 'zod' import { z } from 'zod'
import { zodResolver } from '@hookform/resolvers/zod' import { zodResolver } from '@hookform/resolvers/zod'
@ -6,7 +6,6 @@ import { useAuth } from '../context/AuthContext'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input' import { Input } from '@/components/ui/input'
import { Card } from '@/components/ui/card' import { Card } from '@/components/ui/card'
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { import {
Form, Form,
FormControl, FormControl,
@ -16,17 +15,18 @@ import {
FormMessage, FormMessage,
} from '@/components/ui/form' } from '@/components/ui/form'
import { useToast } from '@/hooks/use-toast' import { useToast } from '@/hooks/use-toast'
import { checkFirstUser } from '../api/client'
const formSchema = z.object({ const formSchema = z.object({
email: z.string().email('Invalid email address'), email: z.string().email('Invalid email address'),
password: z.string().min(6, 'Password must be at least 6 characters long'), password: z.string().min(6, 'Password must be at least 6 characters long'),
adminToken: z.string(), adminToken: z.string().optional(),
}) })
type FormValues = z.infer<typeof formSchema> type FormValues = z.infer<typeof formSchema>
export function AuthForms() { export function AuthForms() {
const [activeTab, setActiveTab] = useState<'login' | 'register'>('login') const [isFirstUser, setIsFirstUser] = useState<boolean | null>(null)
const { login, register } = useAuth() const { login, register } = useAuth()
const { toast } = useToast() const { toast } = useToast()
@ -39,12 +39,26 @@ export function AuthForms() {
}, },
}) })
useEffect(() => {
const init = async () => {
try {
const isFirst = await checkFirstUser()
setIsFirstUser(isFirst)
} catch (err) {
console.error('Error checking first user:', err)
setIsFirstUser(false)
}
}
init()
}, [])
const onSubmit = async (values: FormValues) => { const onSubmit = async (values: FormValues) => {
try { try {
if (activeTab === 'login') { if (isFirstUser) {
await login(values.email, values.password) await register(values.email, values.password, values.adminToken || '')
} else { } else {
await register(values.email, values.password, values.adminToken) await login(values.email, values.password)
} }
form.reset() form.reset()
} catch (err: any) { } catch (err: any) {
@ -56,68 +70,74 @@ export function AuthForms() {
} }
} }
if (isFirstUser === null) {
return <div>Loading...</div>
}
return ( return (
<Card className="w-full max-w-md mx-auto p-6"> <Card className="w-full max-w-md mx-auto p-6">
<Tabs value={activeTab} onValueChange={(value: string) => setActiveTab(value as 'login' | 'register')}> <div className="mb-6 text-center">
<TabsList className="grid w-full grid-cols-2"> <h2 className="text-2xl font-bold">
<TabsTrigger value="login">Login</TabsTrigger> {isFirstUser ? 'Create Admin Account' : 'Login'}
<TabsTrigger value="register">Register</TabsTrigger> </h2>
</TabsList> <p className="text-sm text-muted-foreground mt-1">
{isFirstUser
? 'Set up your admin account to get started'
: 'Welcome back! Please login to your account'}
</p>
</div>
<TabsContent value={activeTab}> <Form {...form}>
<Form {...form}> <form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4"> <FormField
<FormField control={form.control}
control={form.control} name="email"
name="email" render={({ field }) => (
render={({ field }) => ( <FormItem>
<FormItem> <FormLabel>Email</FormLabel>
<FormLabel>Email</FormLabel> <FormControl>
<FormControl> <Input type="email" {...field} />
<Input type="email" {...field} /> </FormControl>
</FormControl> <FormMessage />
<FormMessage /> </FormItem>
</FormItem> )}
)} />
/>
<FormField <FormField
control={form.control} control={form.control}
name="password" name="password"
render={({ field }) => ( render={({ field }) => (
<FormItem> <FormItem>
<FormLabel>Password</FormLabel> <FormLabel>Password</FormLabel>
<FormControl> <FormControl>
<Input type="password" {...field} /> <Input type="password" {...field} />
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
</FormItem> </FormItem>
)} )}
/> />
{activeTab === 'register' && ( {isFirstUser && (
<FormField <FormField
control={form.control} control={form.control}
name="adminToken" name="adminToken"
render={({ field }) => ( render={({ field }) => (
<FormItem> <FormItem>
<FormLabel>Admin Setup Token</FormLabel> <FormLabel>Admin Setup Token</FormLabel>
<FormControl> <FormControl>
<Input type="text" {...field} /> <Input type="text" {...field} />
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
</FormItem> </FormItem>
)}
/>
)} )}
/>
)}
<Button type="submit" className="w-full"> <Button type="submit" className="w-full">
{activeTab === 'login' ? 'Sign in' : 'Create account'} {isFirstUser ? 'Create Account' : 'Sign in'}
</Button> </Button>
</form> </form>
</Form> </Form>
</TabsContent>
</Tabs>
</Card> </Card>
) )
} }

View file

@ -16,6 +16,7 @@ use argon2::{Argon2, PasswordHash, PasswordHasher};
use jsonwebtoken::{encode, EncodingKey, Header}; use jsonwebtoken::{encode, EncodingKey, Header};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use serde_json::json;
use sqlx::{Postgres, Sqlite}; use sqlx::{Postgres, Sqlite};
lazy_static! { lazy_static! {
@ -690,3 +691,24 @@ pub async fn get_link_sources(
Ok(HttpResponse::Ok().json(sources)) Ok(HttpResponse::Ok().json(sources))
} }
pub async fn check_first_user(state: web::Data<AppState>) -> Result<impl Responder, AppError> {
let user_count = match &state.db {
DatabasePool::Postgres(pool) => {
sqlx::query_as::<Postgres, (i64,)>("SELECT COUNT(*)::bigint FROM users")
.fetch_one(pool)
.await?
.0
}
DatabasePool::Sqlite(pool) => {
sqlx::query_as::<Sqlite, (i64,)>("SELECT COUNT(*) FROM users")
.fetch_one(pool)
.await?
.0
}
};
Ok(HttpResponse::Ok().json(json!({
"isFirstUser": user_count == 0
})))
}

View file

@ -72,6 +72,10 @@ async fn main() -> Result<()> {
) )
.route("/auth/register", web::post().to(handlers::register)) .route("/auth/register", web::post().to(handlers::register))
.route("/auth/login", web::post().to(handlers::login)) .route("/auth/login", web::post().to(handlers::login))
.route(
"/auth/check-first-user",
web::get().to(handlers::check_first_user),
)
.route("/health", web::get().to(handlers::health_check)), .route("/health", web::get().to(handlers::health_check)),
) )
.service(web::resource("/{short_code}").route(web::get().to(handlers::redirect_to_url))) .service(web::resource("/{short_code}").route(web::get().to(handlers::redirect_to_url)))