417 lines
9.9 KiB
TypeScript
417 lines
9.9 KiB
TypeScript
'use client'
|
||
|
||
import React, { createContext, useContext, useEffect, useState, ReactNode } from 'react'
|
||
import { AuthState, User, LoginCredentials, RegisterData, JWTPayload } from '@/types/user'
|
||
import { useMutation, gql, useQuery, ApolloClient } from '@apollo/client'
|
||
|
||
interface UserContextType extends AuthState {
|
||
login: (credentials: LoginCredentials) => Promise<void>
|
||
register: (data: RegisterData) => Promise<void>
|
||
logout: () => void
|
||
refreshToken: () => Promise<void>
|
||
resetApolloCache: () => void
|
||
refetchUser: () => Promise<any>
|
||
}
|
||
|
||
const UserContext = createContext<UserContextType | undefined>(undefined)
|
||
|
||
interface UserProviderProps {
|
||
children: ReactNode
|
||
}
|
||
|
||
const TOKEN_KEY = 'auth_token'
|
||
const GET_USER_QUERY = gql`
|
||
query GetUser {
|
||
currentUser {
|
||
id
|
||
username
|
||
email
|
||
role
|
||
}
|
||
}
|
||
`
|
||
|
||
|
||
const REGISTER_MUTATION = gql`
|
||
mutation Register($username: String!, $password: String!) {
|
||
register(username: $username, password: $password) {
|
||
token
|
||
}
|
||
}
|
||
`
|
||
|
||
// 全局 Apollo Client 实例引用
|
||
let globalApolloClient: ApolloClient<any> | null = null;
|
||
|
||
export function setGlobalApolloClient(client: ApolloClient<any>) {
|
||
globalApolloClient = client;
|
||
}
|
||
|
||
export function UserProvider({ children }: UserProviderProps) {
|
||
const [authState, setAuthState] = useState<AuthState>({
|
||
user: null,
|
||
token: null,
|
||
isAuthenticated: false,
|
||
isLoading: true
|
||
})
|
||
|
||
const [registerMutation, { loading: registerLoading, error: registerError }] = useMutation(REGISTER_MUTATION)
|
||
|
||
// 定期查询用户信息的 hook
|
||
const { data: userData, loading: userLoading, error: userError, refetch: refetchUser } = useQuery(GET_USER_QUERY, {
|
||
skip: !authState.isAuthenticated,
|
||
pollInterval: 30000, // 每30秒查询一次
|
||
errorPolicy: 'all',
|
||
notifyOnNetworkStatusChange: true
|
||
})
|
||
|
||
const parseJWT = (token: string): JWTPayload | null => {
|
||
try {
|
||
const base64Url = token.split('.')[1]
|
||
const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/')
|
||
const jsonPayload = decodeURIComponent(
|
||
atob(base64)
|
||
.split('')
|
||
.map(c => '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2))
|
||
.join('')
|
||
)
|
||
return JSON.parse(jsonPayload)
|
||
} catch (error) {
|
||
console.error('Failed to parse JWT:', error)
|
||
return null
|
||
}
|
||
}
|
||
|
||
const isTokenValid = (token: string): boolean => {
|
||
const payload = parseJWT(token)
|
||
if (!payload) return false
|
||
|
||
const currentTime = Date.now() / 1000
|
||
return payload.exp > currentTime
|
||
}
|
||
|
||
const resetApolloCache = () => {
|
||
if (globalApolloClient) {
|
||
globalApolloClient.resetStore();
|
||
}
|
||
}
|
||
|
||
const initializeAuth = async () => {
|
||
try {
|
||
const token = localStorage.getItem(TOKEN_KEY)
|
||
|
||
if (token && isTokenValid(token)) {
|
||
const payload = parseJWT(token)
|
||
if (payload) {
|
||
const user: User = {
|
||
id: payload.sub,
|
||
email: payload.email,
|
||
name: payload.name,
|
||
role: payload.role
|
||
}
|
||
|
||
setAuthState({
|
||
user,
|
||
token,
|
||
isAuthenticated: true,
|
||
isLoading: false
|
||
})
|
||
|
||
const res = await fetch('/api/session/sync', {
|
||
method: 'POST',
|
||
body: JSON.stringify({ jwt: token })
|
||
})
|
||
|
||
if (!res.ok) {
|
||
throw new Error('Failed to sync session')
|
||
}
|
||
|
||
return
|
||
}
|
||
}
|
||
|
||
localStorage.removeItem(TOKEN_KEY)
|
||
setAuthState({
|
||
user: null,
|
||
token: null,
|
||
isAuthenticated: false,
|
||
isLoading: false
|
||
})
|
||
} catch (error) {
|
||
console.error('Auth initialization error:', error)
|
||
setAuthState({
|
||
user: null,
|
||
token: null,
|
||
isAuthenticated: false,
|
||
isLoading: false
|
||
})
|
||
}
|
||
}
|
||
|
||
const login = async (credentials: LoginCredentials) => {
|
||
try {
|
||
setAuthState(prev => ({ ...prev, isLoading: true }))
|
||
|
||
const response = await fetch('/api/login', {
|
||
method: 'POST',
|
||
headers: {
|
||
'Content-Type': 'application/json'
|
||
},
|
||
body: JSON.stringify(credentials)
|
||
})
|
||
|
||
const { token, ok } = await response.json()
|
||
|
||
if (!ok) {
|
||
throw new Error('Login failed')
|
||
}
|
||
|
||
if (!isTokenValid(token)) {
|
||
throw new Error('Invalid token received')
|
||
}
|
||
|
||
const payload = parseJWT(token)
|
||
if (!payload) {
|
||
throw new Error('Failed to parse token')
|
||
}
|
||
|
||
const user: User = {
|
||
id: payload.sub,
|
||
email: payload.email,
|
||
name: payload.name,
|
||
role: payload.role
|
||
}
|
||
|
||
localStorage.setItem(TOKEN_KEY, token)
|
||
setAuthState({
|
||
user,
|
||
token,
|
||
isAuthenticated: true,
|
||
isLoading: false
|
||
})
|
||
} catch (error) {
|
||
console.error('Login error:', error)
|
||
setAuthState({
|
||
user: null,
|
||
token: null,
|
||
isAuthenticated: false,
|
||
isLoading: false
|
||
})
|
||
throw error
|
||
}
|
||
}
|
||
|
||
const register = async (data: RegisterData) => {
|
||
try {
|
||
setAuthState(prev => ({ ...prev, isLoading: true }))
|
||
|
||
const response = await registerMutation({ variables: data })
|
||
|
||
if (registerError) {
|
||
throw new Error('Registration failed')
|
||
}
|
||
|
||
const { token } = response.data.register
|
||
|
||
if (!isTokenValid(token)) {
|
||
throw new Error('Invalid token received')
|
||
}
|
||
|
||
const payload = parseJWT(token)
|
||
if (!payload) {
|
||
throw new Error('Failed to parse token')
|
||
}
|
||
|
||
const user: User = {
|
||
id: payload.sub,
|
||
email: payload.email,
|
||
name: payload.name,
|
||
role: payload.role
|
||
}
|
||
|
||
localStorage.setItem(TOKEN_KEY, token)
|
||
setAuthState({
|
||
user,
|
||
token,
|
||
isAuthenticated: true,
|
||
isLoading: false
|
||
})
|
||
|
||
// 注册成功后重置 Apollo 缓存
|
||
resetApolloCache()
|
||
} catch (error) {
|
||
console.error('Registration error:', error)
|
||
setAuthState({
|
||
user: null,
|
||
token: null,
|
||
isAuthenticated: false,
|
||
isLoading: false
|
||
})
|
||
throw error
|
||
}
|
||
}
|
||
|
||
const logout = () => {
|
||
localStorage.removeItem(TOKEN_KEY)
|
||
setAuthState({
|
||
user: null,
|
||
token: null,
|
||
isAuthenticated: false,
|
||
isLoading: false
|
||
})
|
||
|
||
|
||
// 登出后重置 Apollo 缓存
|
||
resetApolloCache()
|
||
}
|
||
|
||
const refreshToken = async () => {
|
||
try {
|
||
const currentToken = authState.token
|
||
if (!currentToken) {
|
||
throw new Error('No token to refresh')
|
||
}
|
||
|
||
const response = await fetch('/api/auth/refresh', {
|
||
method: 'POST',
|
||
headers: {
|
||
'Authorization': `Bearer ${currentToken}`,
|
||
'Content-Type': 'application/json'
|
||
}
|
||
})
|
||
|
||
if (!response.ok) {
|
||
throw new Error('Token refresh failed')
|
||
}
|
||
|
||
const { token } = await response.json()
|
||
|
||
if (!isTokenValid(token)) {
|
||
throw new Error('Invalid refreshed token')
|
||
}
|
||
|
||
const payload = parseJWT(token)
|
||
if (!payload) {
|
||
throw new Error('Failed to parse refreshed token')
|
||
}
|
||
|
||
const user: User = {
|
||
id: payload.sub,
|
||
email: payload.email,
|
||
name: payload.name,
|
||
role: payload.role
|
||
}
|
||
|
||
localStorage.setItem(TOKEN_KEY, token)
|
||
setAuthState({
|
||
user,
|
||
token,
|
||
isAuthenticated: true,
|
||
isLoading: false
|
||
})
|
||
|
||
// Token 刷新后重置 Apollo 缓存
|
||
resetApolloCache()
|
||
} catch (error) {
|
||
console.error('Token refresh error:', error)
|
||
logout()
|
||
throw error
|
||
}
|
||
}
|
||
|
||
// 更新用户信息的函数
|
||
const updateUserInfo = (userData: any) => {
|
||
if (userData?.currentUser) {
|
||
const updatedUser: User = {
|
||
id: userData.currentUser.id,
|
||
email: userData.currentUser.email,
|
||
name: userData.currentUser.username,
|
||
avatar: userData.currentUser.avatar,
|
||
role: userData.currentUser.role
|
||
}
|
||
|
||
setAuthState(prev => ({
|
||
...prev,
|
||
user: updatedUser
|
||
}))
|
||
}
|
||
}
|
||
|
||
useEffect(() => {
|
||
initializeAuth()
|
||
}, [])
|
||
|
||
// 监听登录和注册的 loading 状态
|
||
useEffect(() => {
|
||
setAuthState(prev => ({ ...prev, isLoading: registerLoading || userLoading }))
|
||
}, [registerLoading, userLoading])
|
||
|
||
// 监听用户数据变化,定期更新用户信息
|
||
useEffect(() => {
|
||
if (userData && authState.isAuthenticated) {
|
||
console.log('userData', userData)
|
||
updateUserInfo(userData)
|
||
}
|
||
}, [userData, authState.isAuthenticated])
|
||
|
||
// 处理用户查询错误
|
||
useEffect(() => {
|
||
if (userError && authState.isAuthenticated) {
|
||
console.error('User data fetch error:', userError)
|
||
// 如果用户查询失败,可能是token过期,尝试刷新token
|
||
if (userError.message.includes('Unauthorized') || userError.message.includes('401')) {
|
||
refreshToken().catch(() => {
|
||
logout()
|
||
})
|
||
}
|
||
}
|
||
}, [userError, authState.isAuthenticated])
|
||
|
||
useEffect(() => {
|
||
if (authState.token && authState.isAuthenticated) {
|
||
const payload = parseJWT(authState.token)
|
||
if (payload) {
|
||
const timeUntilExpiry = (payload.exp * 1000) - Date.now()
|
||
const refreshTime = Math.max(timeUntilExpiry - 5 * 60 * 1000, 30 * 1000)
|
||
|
||
const refreshTimer = setTimeout(() => {
|
||
refreshToken().catch(() => {
|
||
logout()
|
||
})
|
||
}, refreshTime)
|
||
|
||
return () => clearTimeout(refreshTimer)
|
||
}
|
||
}
|
||
}, [authState.token, authState.isAuthenticated])
|
||
|
||
const value: UserContextType = {
|
||
...authState,
|
||
login,
|
||
register,
|
||
logout,
|
||
refreshToken,
|
||
resetApolloCache,
|
||
refetchUser
|
||
}
|
||
|
||
return (
|
||
<UserContext.Provider value={value}>
|
||
{children}
|
||
</UserContext.Provider>
|
||
)
|
||
}
|
||
|
||
export function useUser() {
|
||
const context = useContext(UserContext)
|
||
|
||
if (context === undefined) {
|
||
throw new Error('useUser must be used within a UserProvider')
|
||
}
|
||
|
||
return context
|
||
}
|
||
|
||
export function useAuth() {
|
||
return useUser()
|
||
} |