wargh/main.go
2024-11-24 12:54:59 +02:00

319 lines
6.2 KiB
Go

package main
import (
"crypto/rand"
"database/sql"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"strconv"
"text/template"
"time"
"wargh/db"
"github.com/gorilla/websocket"
"golang.org/x/crypto/bcrypt"
)
var DB *sql.DB
var templates *template.Template
const COOKIE_NAME = "WARGH_SESSION"
const COOKIE_EXPIRATION = time.Hour * 24 * 7 // 7 days
type UserSession struct {
Id int
}
var sessions map[string]UserSession = make(map[string]UserSession)
const (
ERROR_UNKNOWN = 0
ERROR_INVALID_INPUT = 1
ERROR_INVALID_LOGIN_PASSWORD = 2
ERROR_UNAUTHORIZED = 3
)
var ERROR_TEXT = []string{
"Unknown error",
"Invalid input",
"Invalid login or password",
"Unathorized",
}
var wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type Job struct {
Command *exec.Cmd
StdoutPipe io.ReadCloser
StderrPipe io.ReadCloser
Socket []*websocket.Conn
}
var jobs map[string]Job = make(map[string]Job)
func errorHandler(w http.ResponseWriter, r *http.Request) {
errorParam := r.URL.Query().Get("error")
var errorCode int
if errorParam != "" {
var err error
errorCode, err = strconv.Atoi(errorParam)
if err != nil {
errorCode = 0
}
} else {
errorCode = 0
}
if errorCode < 0 || errorCode >= len(ERROR_TEXT) {
errorCode = 0
}
templates.Lookup("error.html").Execute(w, ERROR_TEXT[errorCode])
}
func redirectError(w http.ResponseWriter, r *http.Request, errorCode int) {
http.Redirect(w, r, fmt.Sprintf("/error?error=%d", errorCode), http.StatusSeeOther)
}
func redirectIndex(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func isFirstUser() (bool, error) {
row := DB.QueryRow("SELECT COUNT(id) FROM `user`;")
if row.Err() != nil {
return false, row.Err()
}
var count int
err := row.Scan(&count)
if err != nil {
return false, err
}
return count == 0, nil
}
func createUser(login string, password string) (int, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
if err != nil {
return 0, err
}
_, err = DB.Exec("INSERT INTO `user` (login, password) VALUES (?,?);", login, bytes)
if err != nil {
return 0, err
}
row := DB.QueryRow("SELECT id FROM `user` WHERE login = ?;", login)
if row.Err() != nil {
return 0, row.Err()
}
var id int
row.Scan(&id)
return id, nil
}
func randString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
return string(bytes)
}
func createSession(w http.ResponseWriter, userId int) {
sessionStr := randString(64)
sessions[sessionStr] = UserSession{Id: userId}
http.SetCookie(w, &http.Cookie{
Name: COOKIE_NAME,
Value: sessionStr,
Expires: time.Now().Add(COOKIE_EXPIRATION),
HttpOnly: true,
})
}
func checkSession(w http.ResponseWriter, r *http.Request) bool {
var sessionCookie *http.Cookie = nil
for _, cookie := range r.Cookies() {
if cookie.Name == COOKIE_NAME {
sessionCookie = cookie
}
}
if sessionCookie.Valid() != nil {
redirectError(w, r, ERROR_UNAUTHORIZED)
return false
}
if _, ok := sessions[sessionCookie.Value]; ok {
return true
} else {
redirectError(w, r, ERROR_UNAUTHORIZED)
return false
}
}
func indexHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
templates.Lookup("index.html").Execute(w, nil)
}
}
func loginHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
templates.Lookup("login.html").Execute(w, nil)
} else if r.Method == "POST" {
err := r.ParseForm()
if err != nil {
log.Print(err)
redirectError(w, r, ERROR_UNKNOWN)
return
}
loginParam := r.Form.Get("login")
passwordParam := r.Form.Get("password")
if loginParam == "" || passwordParam == "" {
redirectError(w, r, ERROR_INVALID_INPUT)
return
}
isFirst, err := isFirstUser()
if err != nil {
log.Print(err)
redirectError(w, r, ERROR_UNKNOWN)
}
if isFirst {
id, err := createUser(loginParam, passwordParam)
if err != nil {
log.Print(err)
} else {
log.Printf("created user %d\n", id)
}
createSession(w, id)
redirectIndex(w, r)
} else {
// try authorize
row := DB.QueryRow("SELECT id,password FROM `user` WHERE login = ?;", loginParam)
if row.Err() != nil {
if row.Err().Error() == sql.ErrNoRows.Error() {
redirectError(w, r, ERROR_INVALID_LOGIN_PASSWORD)
return
} else {
log.Print(err)
redirectError(w, r, ERROR_UNKNOWN)
return
}
}
var id int
password := make([]byte, 8)
err = row.Scan(&id, &password)
if err != nil {
// can't scan blob? fatal error
log.Fatal(err)
}
if bcrypt.CompareHashAndPassword(password, []byte(passwordParam)) != nil {
redirectError(w, r, ERROR_INVALID_LOGIN_PASSWORD)
return
}
createSession(w, id)
redirectIndex(w, r)
}
}
}
func createJob() (string, error) {
cmd := exec.Command("/usr/bin/sh", "-c", "while true; do echo $(date) \"test\"")
stdout, err := cmd.StdoutPipe()
if err != nil {
return "", err
}
stderr, err := cmd.StderrPipe()
if err != nil {
return "", err
}
err = cmd.Start()
jobId := randString(64)
jobs[jobId] = Job{
Command: cmd,
StdoutPipe: stdout,
StderrPipe: stderr,
Socket: make([]*websocket.Conn, 0),
}
return jobId, err
}
func wsHandler(w http.ResponseWriter, r *http.Request) {
if !checkSession(w, r) {
return
}
// get job id and attach
ws, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
log.Print(err)
redirectError(w, r, ERROR_UNKNOWN)
return
}
ws.WriteMessage(websocket.TextMessage, ([]byte)("test"))
ws.Close()
}
func main() {
db.Init(&db.DBConfig{
DBPath: "wargh.db",
MigrationsPath: "migrations",
})
var err error
DB, err = db.Open()
if err != nil {
log.Fatal(err)
os.Exit(1)
}
defer DB.Close()
templates, err = template.New("templates").ParseFiles(
"templates/index.html",
"templates/error.html",
"templates/login.html",
)
if err != nil {
log.Fatal(err)
}
http.HandleFunc("/", indexHandler)
http.HandleFunc("/error", errorHandler)
http.HandleFunc("/login", loginHandler)
http.HandleFunc("/ws", wsHandler)
log.Fatal(http.ListenAndServe(":8080", nil))
}