diff --git a/cmd/lucifer-server/main.go b/cmd/lucifer-server/main.go index 8f8e244..b1eb65e 100644 --- a/cmd/lucifer-server/main.go +++ b/cmd/lucifer-server/main.go @@ -9,6 +9,7 @@ import ( "git.aiterp.net/lucifer/lucifer/controllers" "git.aiterp.net/lucifer/lucifer/database/sqlite" "git.aiterp.net/lucifer/lucifer/internal/config" + "git.aiterp.net/lucifer/lucifer/middlewares" ) func main() { @@ -22,9 +23,10 @@ func main() { log.Fatalln("Failed to set up database:", err) } - userController := controllers.NewUserController(sqlite.UserRepository) + userController := controllers.NewUserController(sqlite.UserRepository, sqlite.SessionRepository) router := mux.NewRouter() + router.Use(middlewares.Session(sqlite.SessionRepository)) userController.Mount(router, "/api/user/") diff --git a/controllers/user-controller.go b/controllers/user-controller.go index 18edd15..38e1209 100644 --- a/controllers/user-controller.go +++ b/controllers/user-controller.go @@ -2,7 +2,9 @@ package controllers import ( "encoding/json" + "log" "net/http" + "strconv" "time" "git.aiterp.net/lucifer/lucifer/internal/respond" @@ -12,16 +14,19 @@ import ( // The UserController is a controller for all user inports. type UserController struct { - users models.UserRepository + users models.UserRepository + sessions models.SessionRepository } // getUsers (`GET /`): List users func (c *UserController) getUsers(w http.ResponseWriter, r *http.Request) { if session := models.SessionFromContext(r.Context()); session == nil { - respond.Error(w, 403, "permission_denied", "You must log in") + respond.Error(w, http.StatusForbidden, "permission_denied", "You must log in") return } + // TODO: Only admin can do this? + users, err := c.users.List(r.Context()) if err != nil { respond.Error(w, 500, "db_error", err.Error()) @@ -31,6 +36,36 @@ func (c *UserController) getUsers(w http.ResponseWriter, r *http.Request) { respond.JSON(w, 200, users) } +// getUser (`GET /:id`): Get user by id +func (c *UserController) getUser(w http.ResponseWriter, r *http.Request) { + session := models.SessionFromContext(r.Context()) + if session == nil { + respond.Error(w, http.StatusForbidden, "permission_denied", "You must log in") + return + } + + // TODO: Only admin can do this? + + var user models.User + + id, err := strconv.Atoi(mux.Vars(r)["id"]) + if err == nil { + user, err = c.users.FindByID(r.Context(), id) + if err != nil { + respond.Error(w, 404, "not_found", err.Error()) + return + } + } else { + user, err = c.users.FindByName(r.Context(), mux.Vars(r)["id"]) + if err != nil { + respond.Error(w, 404, "not_found", err.Error()) + return + } + } + + respond.JSON(w, 200, user) +} + // login (`POST /login`): Log in as user func (c *UserController) login(w http.ResponseWriter, r *http.Request) { loginData := struct { @@ -61,20 +96,114 @@ func (c *UserController) login(w http.ResponseWriter, r *http.Request) { } session.GenerateID() + if err := c.sessions.Insert(r.Context(), session); err != nil { + log.Printf("Session create for user %s (%d) failed: %s", user.Name, user.ID, err) + respond.Error(w, http.StatusInternalServerError, "session_failure", "Failed to open session.") + return + } + http.SetCookie(w, session.Cookie()) + log.Printf("User %s logged in", user.Name) + respond.JSON(w, 200, user) } +func (c *UserController) register(w http.ResponseWriter, r *http.Request) { + registerData := struct { + Username string `json:"username"` + Password string `json:"password"` + }{} + + if err := json.NewDecoder(r.Body).Decode(®isterData); err != nil { + respond.Error(w, http.StatusBadRequest, "invalid_json", "Input is not valid JSON.") + return + } + + if len(registerData.Username) < 1 { + respond.Error(w, http.StatusBadRequest, "invalid_username", "The username cannot be empty.") + return + } + if _, err := strconv.Atoi(registerData.Username); err == nil { + respond.Error(w, http.StatusBadRequest, "invalid_username", "The username cannot start with a number.") + return + } + + user := models.User{Name: registerData.Username} + + if err := user.SetPassword(registerData.Password); err != nil { + respond.Error(w, http.StatusBadRequest, "invalid_password", "The password is not valid: "+err.Error()) + return + } + + user, err := c.users.Insert(r.Context(), user) + if err != nil { + respond.Error(w, http.StatusBadRequest, "invalid_password", "The password is not valid: "+err.Error()) + return + } + + log.Printf("User %s registered", user.Name) + + respond.JSON(w, 200, user) +} + +// login (`POST /logout`): Log in as user +func (c *UserController) logout(w http.ResponseWriter, r *http.Request) { + logoutData := struct { + ClearAll bool `json:"clearAll"` + }{} + + session := models.SessionFromContext(r.Context()) + if session == nil { + respond.Error(w, http.StatusUnauthorized, "permission_denied", "You are not logged in (that's what you wanted anyway, wasn't it?)") + return + } + + json.NewDecoder(r.Body).Decode(&logoutData) + + user, err := c.users.FindByID(r.Context(), session.UserID) + if err != nil { + respond.Error(w, http.StatusNotFound, "not_found", "You user was not found") + return + } + + if logoutData.ClearAll { + err := c.sessions.Clear(r.Context(), user) + if err != nil { + log.Printf("Session clear for user %s (%d) failed: %s", user.Name, user.ID, err) + respond.Error(w, http.StatusInternalServerError, "clear_failed", "Sesison clear failed") + return + } + } else { + err := c.sessions.Remove(r.Context(), *session) + if err != nil { + log.Printf("Session remove for user %s (%d) with id %s failed: %s", user.Name, user.ID, session.ID, err) + respond.Error(w, http.StatusInternalServerError, "remove_failed", "Session remove failed") + return + } + } + + cookie := session.Cookie() + cookie.Expires = time.Unix(0, 0) + http.SetCookie(w, cookie) + + log.Printf("User %s logged out (clearAll: %t)", user.Name, logoutData.ClearAll) + + respond.JSON(w, 200, logoutData) +} + // Mount mounts the controller func (c *UserController) Mount(router *mux.Router, prefix string) { sub := router.PathPrefix(prefix).Subrouter() sub.HandleFunc("/", c.getUsers).Methods("GET") + sub.HandleFunc("/{id}", c.getUser).Methods("GET") sub.HandleFunc("/login", c.login).Methods("POST") + sub.HandleFunc("/logout", c.logout).Methods("POST") + sub.HandleFunc("/register", c.register).Methods("POST") } // NewUserController creates a new UserController. -func NewUserController(users models.UserRepository) *UserController { - return &UserController{users: users} +func NewUserController(users models.UserRepository, sessions models.SessionRepository) *UserController { + return &UserController{users: users, sessions: sessions} } diff --git a/database/sqlite/session-repository.go b/database/sqlite/session-repository.go new file mode 100644 index 0000000..ceee41e --- /dev/null +++ b/database/sqlite/session-repository.go @@ -0,0 +1,62 @@ +package sqlite + +import ( + "context" + + "git.aiterp.net/lucifer/lucifer/models" +) + +// SessionRepository is a sqlite database. +var SessionRepository = &sessionRepository{} + +type sessionRepository struct{} + +func (r *sessionRepository) FindByID(ctx context.Context, id string) (models.Session, error) { + row := db.QueryRowxContext(ctx, "SELECT * FROM session WHERE id=?", id) + if err := row.Err(); err != nil { + return models.Session{}, err + } + + session := models.Session{} + if err := row.StructScan(&session); err != nil { + return models.Session{}, err + } + + return session, nil +} + +func (r *sessionRepository) Insert(ctx context.Context, session models.Session) error { + _, err := db.NamedExecContext(ctx, "INSERT INTO session (id, user_id, expire_date) VALUES(:id, :user_id, :expire_date)", session) + if err != nil { + return err + } + + return nil +} + +func (r *sessionRepository) Update(ctx context.Context, session models.Session) error { + _, err := db.NamedExecContext(ctx, "UPDATE session SET user_id=:user_id AND expire_date=:expire_date", session) + if err != nil { + return err + } + + return nil +} + +func (r *sessionRepository) Remove(ctx context.Context, session models.Session) error { + _, err := db.NamedExecContext(ctx, "DELETE FROM session WHERE id=:id", session) + if err != nil { + return err + } + + return nil +} + +func (r *sessionRepository) Clear(ctx context.Context, user models.User) error { + _, err := db.NamedExecContext(ctx, "DELETE FROM session WHERE user_id=:id", user) + if err != nil { + return err + } + + return nil +} diff --git a/database/sqlite/user-repository.go b/database/sqlite/user-repository.go index e1b9364..211b044 100644 --- a/database/sqlite/user-repository.go +++ b/database/sqlite/user-repository.go @@ -85,10 +85,10 @@ func (repo *userRepository) Update(ctx context.Context, user models.User) error } func (repo *userRepository) Remove(ctx context.Context, user models.User) error { - _, err := db.ExecContext(ctx, "REMOVE FROM user WHERE id=?", user.ID) + _, err := db.ExecContext(ctx, "DELETE FROM user WHERE id=?", user.ID) if err != nil { return err } - return err + return nil } diff --git a/middlewares/session.go b/middlewares/session.go index ff727a6..6586805 100644 --- a/middlewares/session.go +++ b/middlewares/session.go @@ -24,20 +24,13 @@ func Session(repo models.SessionRepository) mux.MiddlewareFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Find cookie cookie, err := r.Cookie("lucifer_session") - if err == nil && cookie != nil { - next.ServeHTTP(w, r) - return - } - - // Check cookie expiration - if cookie.Expires.IsZero() || time.Now().After(cookie.Expires) { - http.SetCookie(w, clearCookie) + if err != nil || cookie == nil { next.ServeHTTP(w, r) return } // Check session existence - session, err := repo.FindSessionByID(cookie.Value) + session, err := repo.FindByID(r.Context(), cookie.Value) if err != nil { http.SetCookie(w, clearCookie) next.ServeHTTP(w, r) diff --git a/models/session.go b/models/session.go index 2cc7afa..ba69fe1 100644 --- a/models/session.go +++ b/models/session.go @@ -15,9 +15,9 @@ const sessionCtxKey = sessionCtxKeyType("luciter_session") // The Session model represents a user being logged in. type Session struct { - ID string - UserID int - Expires time.Time + ID string `db:"id"` + UserID int `db:"user_id"` + Expires time.Time `db:"expire_date"` } // GenerateID generates a unique ID for the session that's 64 base36 digits long. @@ -73,16 +73,20 @@ func (session *Session) InContext(parent context.Context) context.Context { // SessionFromContext gets the session from context, or `nil` if no session is available. func SessionFromContext(ctx context.Context) *Session { - return ctx.Value(sessionCtxKey).(*Session) + v := ctx.Value(sessionCtxKey) + if v == nil { + return nil + } + + return v.(*Session) } // SessionRepository is an interface for all database operations // the user model makes. type SessionRepository interface { - // FindSessionByID finds a non-expired session by ID. - FindSessionByID(id string) (Session, error) - InsertSession(session Session) error - RemoveSession(session Session) error - ClearUserSessions(user User) error - RemoveExpiredSessions() error + FindByID(ctx context.Context, id string) (Session, error) + Insert(ctx context.Context, session Session) error + Update(ctx context.Context, session Session) error + Remove(ctx context.Context, session Session) error + Clear(ctx context.Context, user User) error } diff --git a/models/user.go b/models/user.go index b58a918..6bbd8f2 100644 --- a/models/user.go +++ b/models/user.go @@ -2,6 +2,8 @@ package models import ( "context" + "errors" + "unicode/utf8" "golang.org/x/crypto/bcrypt" ) @@ -15,7 +17,13 @@ type User struct { // SetPassword sets the user's password func (user *User) SetPassword(password string) error { - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + passwordBytes := []byte(password) + + if utf8.RuneCount(passwordBytes) < 6 { + return errors.New("Password is too short (<6 unicode runes)") + } + + hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) if err != nil { return err }