diff --git a/server/server.go b/server/server.go index 65ea7e8..2a4a47b 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "archive/zip" "bytes" "context" + "embed" "encoding/json" "fmt" "io" @@ -12,8 +13,6 @@ import ( "os" "time" - _ "embed" - "cernobor.cz/oko-server/errs" "cernobor.cz/oko-server/models" "crawshaw.io/sqlite" @@ -23,11 +22,8 @@ import ( "github.com/sirupsen/logrus" ) -//go:embed sql_schema/V1_init.sql -var sql_v1 string - -//go:embed sql_schema/V2_proposals.sql -var sql_v2 string +//go:embed sql_schema/V*.sql +var sqlSchema embed.FS type Server struct { config ServerConfig @@ -240,7 +236,7 @@ func (s *Server) initDB(reinit bool) error { defer s.requestCheckpoint() if reinit { - s.log.Warn("REinitializing DB.") + s.log.Warn("Reinitializing DB.") tables := []string{} err := sqlitex.Exec(conn, "select name from sqlite_master where type = 'table'", func(stmt *sqlite.Stmt) error { tables = append(tables, stmt.ColumnText(0)) @@ -262,6 +258,16 @@ func (s *Server) initDB(reinit bool) error { return fmt.Errorf("failed to reset user version: %w", err) } } + + err := s.migrateDb(conn) + if err != nil { + return fmt.Errorf("failed to migrate db: %w", err) + } + + return nil +} + +func (s *Server) migrateDb(conn *sqlite.Conn) error { var version int err := sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error { version = stmt.ColumnInt(0) @@ -270,21 +276,96 @@ func (s *Server) initDB(reinit bool) error { if err != nil { return fmt.Errorf("failed to get user version: %w", err) } - s.log.Debugf("Current db version: %d", version) - if version <= 0 { - s.log.Debugf("Running db migration V1") - err = sqlitex.ExecScript(conn, sql_v1) + + entries, err := sqlSchema.ReadDir("sql_schema") + if err != nil { + return fmt.Errorf("failed to read sql_schema migrations: %w", err) + } + + type migration struct { + file string + version int + name string + } + + pattern := regexp.MustCompile("^V([0-9]+)_(.*)[.][sS][qQ][lL]$") + migrations := []migration{} + for _, entry := range entries { + name := entry.Name() + if entry.IsDir() { + return fmt.Errorf("embedded sql migration '%s' is a directory", name) + } + matches := pattern.FindStringSubmatch(name) + if matches == nil { + return fmt.Errorf("embedded sql migration '%s' does not match the filename pattern", name) + } + if len(matches) != 3 { + return fmt.Errorf("embedded sql migration '%s' does not have the correct number of submatches", name) + } + version, err := strconv.Atoi(matches[1]) if err != nil { - return fmt.Errorf("failed to run V1 init script") + return fmt.Errorf("failed to parse version number of migration '%s': %w", name, err) + } + migName := matches[2] + file := path.Join("sql_schema", name) + migrations = append(migrations, migration{ + file: file, + version: version, + name: migName, + }) + } + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].version < migrations[j].version + }) + + for _, migration := range migrations { + if version >= migration.version { + s.log.Debugf("Skipping migration version %d because current version %d is not smaller.", migration.version, version) + continue + } + + migContent, err := sqlSchema.ReadFile(migration.file) + if err != nil { + return fmt.Errorf("failed to read embedded migration '%s': %w", migration.file, err) + } + + err = func() (err error) { + rollback := sqlitex.Save(conn) + defer func() { + if err != nil { + s.log.Info("Rolling back last migration attempt.") + } + rollback(&err) + }() + + s.log.Infof("Executing migration V%d - %s", migration.version, migration.name) + err = sqlitex.ExecScript(conn, string(migContent)) + if err != nil { + return fmt.Errorf("failed to execute migration '%s': %w", migration.name, err) + } + + err = sqlitex.Exec(conn, fmt.Sprintf("PRAGMA user_version = %d", migration.version), nil) + if err != nil { + return fmt.Errorf("failed to set user_version in db: %w", err) + } + + err = sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error { + version = stmt.ColumnInt(0) + return nil + }) + if err != nil { + return fmt.Errorf("failed to get user_version: %w", err) + } + + s.log.Infof("Migrated db to version: %d", version) + return nil + }() + if err != nil { + return err } } - if version <= 1 { - s.log.Debugf("Running db migration V2") - err = sqlitex.ExecScript(conn, sql_v2) - if err != nil { - return fmt.Errorf("failed to run V2 init script") - } + return nil } return nil @@ -300,7 +381,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error) var id *int64 if hc.Exists { err = sqlitex.Exec(conn, "select id from users where name = ?", func(stmt *sqlite.Stmt) error { - id = ptrInt64(stmt.ColumnInt64(0)) + id = ptr(stmt.ColumnInt64(0)) return nil }, hc.Name) if sqlite.ErrCode(err) != sqlite.SQLITE_OK { @@ -314,7 +395,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error) } } else { err = sqlitex.Exec(conn, "insert into users(name) values(?)", func(stmt *sqlite.Stmt) error { - id = ptrInt64(stmt.ColumnInt64(0)) + id = ptr(stmt.ColumnInt64(0)) return nil }, hc.Name) if sqlite.ErrCode(err) == sqlite.SQLITE_CONSTRAINT_UNIQUE { @@ -323,7 +404,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error) if sqlite.ErrCode(err) != sqlite.SQLITE_OK { return 0, err } - id = ptrInt64(conn.LastInsertRowID()) + id = ptr(conn.LastInsertRowID()) s.requestCheckpoint() } return *id, nil @@ -925,7 +1006,7 @@ func (s *Server) getPhoto(featureID models.FeatureID, photoID models.FeaturePhot if found { return fmt.Errorf("multiple photos returned for feature id %d, photo id %d", featureID, photoID) } - contentType = ptrString(stmt.ColumnText(0)) + contentType = ptr(stmt.ColumnText(0)) data = make([]byte, stmt.ColumnLen(1)) stmt.ColumnBytes(1, data) found = true diff --git a/server/sql_schema/V1_init.sql b/server/sql_schema/V1_init.sql index bb77e0b..84dca1c 100644 --- a/server/sql_schema/V1_init.sql +++ b/server/sql_schema/V1_init.sql @@ -21,5 +21,3 @@ CREATE TABLE feature_photos ( thumbnail_contents blob NOT NULL, contents blob NOT NULL ); - -PRAGMA user_version = 1; \ No newline at end of file diff --git a/server/sql_schema/V2_proposals.sql b/server/sql_schema/V2_proposals.sql index 45e8495..150d9c1 100644 --- a/server/sql_schema/V2_proposals.sql +++ b/server/sql_schema/V2_proposals.sql @@ -3,5 +3,3 @@ CREATE TABLE proposals ( description text NOT NULL, how text NOT NULL ); - -PRAGMA user_version = 2; \ No newline at end of file diff --git a/server/utils.go b/server/utils.go index b08b295..d3e30a1 100644 --- a/server/utils.go +++ b/server/utils.go @@ -13,15 +13,7 @@ import ( "cernobor.cz/oko-server/models" ) -func ptrInt(x int) *int { - return &x -} - -func ptrInt64(x int64) *int64 { - return &x -} - -func ptrString(x string) *string { +func ptr[T any](x T) *T { return &x }