Skip to content

Commit 4b288a0

Browse files
authored
fix: session invalid issue (#9301)
* feat(auth): Enhanced device login session management - Upon login, obtain and verify `Client-Id` to ensure unique device sessions. - If there are too many device sessions, clean up old ones according to the configured policy or return an error. - If a device session is invalid, deregister the old token and return a 401 error. - Added `EnsureActiveOnLogin` function to handle the creation and refresh of device sessions during login. * feat(session): Modified session deletion logic to mark sessions as inactive. - Changed session deletion logic to mark sessions as inactive using the `MarkInactive` method. - Adjusted error handling to ensure an error is returned if marking fails. * feat(session): Added device limits and eviction policies - Added a device limit, controlling the maximum number of devices using the `MaxDevices` configuration option. - If the number of devices exceeds the limit, the configured eviction policy is used. - If the policy is `evict_oldest`, the oldest device is evicted. - Otherwise, an error message indicating too many devices is returned. * refactor(session): Filter for the user's oldest active session - Renamed `GetOldestSession` to `GetOldestActiveSession` to more accurately reflect its functionality - Updated the SQL query to add the `status = SessionActive` condition to retrieve only active sessions - Replaced all callpoints and unified the new function name to ensure logical consistency
1 parent 63391a2 commit 4b288a0

File tree

4 files changed

+100
-19
lines changed

4 files changed

+100
-19
lines changed

internal/db/session.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ func DeleteSessionsBefore(ts int64) error {
3838
return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error)
3939
}
4040

41-
func GetOldestSession(userID uint) (*model.Session, error) {
41+
// GetOldestActiveSession returns the oldest active session for the specified user.
42+
func GetOldestActiveSession(userID uint) (*model.Session, error) {
4243
var s model.Session
43-
if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil {
44-
return nil, errors.Wrap(err, "failed get oldest session")
44+
if err := db.Where("user_id = ? AND status = ?", userID, model.SessionActive).
45+
Order("last_active ASC").First(&s).Error; err != nil {
46+
return nil, errors.Wrap(err, "failed get oldest active session")
4547
}
4648
return &s, nil
4749
}

internal/device/session.go

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,68 @@ func Handle(userID uint, deviceKey, ua, ip string) error {
2323
ip = utils.MaskIP(ip)
2424

2525
now := time.Now().Unix()
26+
sess, err := db.GetSession(userID, deviceKey)
27+
if err == nil {
28+
if sess.Status == model.SessionInactive {
29+
return errors.WithStack(errs.SessionInactive)
30+
}
31+
sess.Status = model.SessionActive
32+
sess.LastActive = now
33+
sess.UserAgent = ua
34+
sess.IP = ip
35+
return db.UpsertSession(sess)
36+
}
37+
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
38+
return err
39+
}
40+
41+
max := setting.GetInt(conf.MaxDevices, 0)
42+
if max > 0 {
43+
count, err := db.CountActiveSessionsByUser(userID)
44+
if err != nil {
45+
return err
46+
}
47+
if count >= int64(max) {
48+
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
49+
if policy == "evict_oldest" {
50+
if oldest, err := db.GetOldestActiveSession(userID); err == nil {
51+
if err := db.MarkInactive(oldest.DeviceKey); err != nil {
52+
return err
53+
}
54+
}
55+
} else {
56+
return errors.WithStack(errs.TooManyDevices)
57+
}
58+
}
59+
}
60+
61+
s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
62+
return db.CreateSession(s)
63+
}
64+
65+
// EnsureActiveOnLogin is used only in login flow:
66+
// - If session exists (even Inactive): reactivate and refresh fields.
67+
// - If not exists: apply max-devices policy, then create Active session.
68+
func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error {
69+
ip = utils.MaskIP(ip)
70+
now := time.Now().Unix()
71+
2672
sess, err := db.GetSession(userID, deviceKey)
2773
if err == nil {
2874
if sess.Status == model.SessionInactive {
2975
max := setting.GetInt(conf.MaxDevices, 0)
3076
if max > 0 {
31-
count, cerr := db.CountActiveSessionsByUser(userID)
32-
if cerr != nil {
33-
return cerr
77+
count, err := db.CountActiveSessionsByUser(userID)
78+
if err != nil {
79+
return err
3480
}
3581
if count >= int64(max) {
3682
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
3783
if policy == "evict_oldest" {
38-
if oldest, gerr := db.GetOldestSession(userID); gerr == nil {
39-
_ = db.DeleteSession(userID, oldest.DeviceKey)
84+
if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil {
85+
if err := db.MarkInactive(oldest.DeviceKey); err != nil {
86+
return err
87+
}
4088
}
4189
} else {
4290
return errors.WithStack(errs.TooManyDevices)
@@ -63,18 +111,25 @@ func Handle(userID uint, deviceKey, ua, ip string) error {
63111
if count >= int64(max) {
64112
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
65113
if policy == "evict_oldest" {
66-
oldest, err := db.GetOldestSession(userID)
67-
if err == nil {
68-
_ = db.DeleteSession(userID, oldest.DeviceKey)
114+
if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil {
115+
if err := db.MarkInactive(oldest.DeviceKey); err != nil {
116+
return err
117+
}
69118
}
70119
} else {
71120
return errors.WithStack(errs.TooManyDevices)
72121
}
73122
}
74123
}
75124

76-
s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
77-
return db.CreateSession(s)
125+
return db.CreateSession(&model.Session{
126+
UserID: userID,
127+
DeviceKey: deviceKey,
128+
UserAgent: ua,
129+
IP: ip,
130+
LastActive: now,
131+
Status: model.SessionActive,
132+
})
78133
}
79134

80135
// Refresh updates last_active for the session.

server/handles/auth.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@ package handles
33
import (
44
"bytes"
55
"encoding/base64"
6+
"errors"
7+
"fmt"
68
"image/png"
79
"path"
810
"strings"
911
"time"
1012

1113
"github.com/Xhofe/go-cache"
1214
"github.com/alist-org/alist/v3/internal/conf"
15+
"github.com/alist-org/alist/v3/internal/device"
16+
"github.com/alist-org/alist/v3/internal/errs"
1317
"github.com/alist-org/alist/v3/internal/model"
1418
"github.com/alist-org/alist/v3/internal/op"
1519
"github.com/alist-org/alist/v3/internal/session"
1620
"github.com/alist-org/alist/v3/internal/setting"
21+
"github.com/alist-org/alist/v3/pkg/utils"
1722
"github.com/alist-org/alist/v3/server/common"
18-
"github.com/alist-org/alist/v3/server/middlewares"
1923
"github.com/gin-gonic/gin"
2024
"github.com/pquerna/otp/totp"
2125
)
@@ -83,17 +87,29 @@ func loginHash(c *gin.Context, req *LoginReq) {
8387
return
8488
}
8589
}
86-
// generate device session
87-
if !middlewares.HandleSession(c, user) {
90+
91+
clientID := c.GetHeader("Client-Id")
92+
if clientID == "" {
93+
clientID = c.Query("client_id")
94+
}
95+
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s",
96+
user.ID, clientID))
97+
98+
if err := device.EnsureActiveOnLogin(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
99+
if errors.Is(err, errs.TooManyDevices) {
100+
common.ErrorResp(c, err, 403)
101+
} else {
102+
common.ErrorResp(c, err, 400, true)
103+
}
88104
return
89105
}
106+
90107
// generate token
91108
token, err := common.GenerateToken(user)
92109
if err != nil {
93110
common.ErrorResp(c, err, 400, true)
94111
return
95112
}
96-
key := c.GetString("device_key")
97113
common.SuccessResp(c, gin.H{"token": token, "device_key": key})
98114
loginCache.Del(ip)
99115
}

server/middlewares/auth.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package middlewares
22

33
import (
44
"crypto/subtle"
5+
"errors"
56
"fmt"
67

78
"github.com/alist-org/alist/v3/internal/conf"
89
"github.com/alist-org/alist/v3/internal/device"
10+
"github.com/alist-org/alist/v3/internal/errs"
911
"github.com/alist-org/alist/v3/internal/model"
1012
"github.com/alist-org/alist/v3/internal/op"
1113
"github.com/alist-org/alist/v3/internal/setting"
@@ -106,9 +108,15 @@ func HandleSession(c *gin.Context, user *model.User) bool {
106108
if clientID == "" {
107109
clientID = c.Query("client_id")
108110
}
109-
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID))
111+
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID))
110112
if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
111-
common.ErrorResp(c, err, 403)
113+
token := c.GetHeader("Authorization")
114+
if errors.Is(err, errs.SessionInactive) {
115+
_ = common.InvalidateToken(token)
116+
common.ErrorResp(c, err, 401)
117+
} else {
118+
common.ErrorResp(c, err, 403)
119+
}
112120
c.Abort()
113121
return false
114122
}

0 commit comments

Comments
 (0)