DB migration, small technicalities.

* DB migration rewritten:
  * Migration scripts are embedded as FS.
  * Migration versions are handled automatically.
* Use generics in utils.
This commit is contained in:
zegkljan 2022-09-22 22:05:41 +02:00
parent c1cdd4f904
commit c9377b04fc
4 changed files with 105 additions and 36 deletions

View File

@ -4,6 +4,7 @@ import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"context" "context"
"embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -12,8 +13,6 @@ import (
"os" "os"
"time" "time"
_ "embed"
"cernobor.cz/oko-server/errs" "cernobor.cz/oko-server/errs"
"cernobor.cz/oko-server/models" "cernobor.cz/oko-server/models"
"crawshaw.io/sqlite" "crawshaw.io/sqlite"
@ -23,11 +22,8 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
//go:embed sql_schema/V1_init.sql //go:embed sql_schema/V*.sql
var sql_v1 string var sqlSchema embed.FS
//go:embed sql_schema/V2_proposals.sql
var sql_v2 string
type Server struct { type Server struct {
config ServerConfig config ServerConfig
@ -240,7 +236,7 @@ func (s *Server) initDB(reinit bool) error {
defer s.requestCheckpoint() defer s.requestCheckpoint()
if reinit { if reinit {
s.log.Warn("REinitializing DB.") s.log.Warn("Reinitializing DB.")
tables := []string{} tables := []string{}
err := sqlitex.Exec(conn, "select name from sqlite_master where type = 'table'", func(stmt *sqlite.Stmt) error { err := sqlitex.Exec(conn, "select name from sqlite_master where type = 'table'", func(stmt *sqlite.Stmt) error {
tables = append(tables, stmt.ColumnText(0)) tables = append(tables, stmt.ColumnText(0))
@ -262,6 +258,16 @@ func (s *Server) initDB(reinit bool) error {
return fmt.Errorf("failed to reset user version: %w", err) return fmt.Errorf("failed to reset user version: %w", err)
} }
} }
err := s.migrateDb(conn)
if err != nil {
return fmt.Errorf("failed to migrate db: %w", err)
}
return nil
}
func (s *Server) migrateDb(conn *sqlite.Conn) error {
var version int var version int
err := sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error { err := sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error {
version = stmt.ColumnInt(0) version = stmt.ColumnInt(0)
@ -270,21 +276,96 @@ func (s *Server) initDB(reinit bool) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get user version: %w", err) return fmt.Errorf("failed to get user version: %w", err)
} }
s.log.Debugf("Current db version: %d", version) s.log.Debugf("Current db version: %d", version)
if version <= 0 {
s.log.Debugf("Running db migration V1") entries, err := sqlSchema.ReadDir("sql_schema")
err = sqlitex.ExecScript(conn, sql_v1) if err != nil {
return fmt.Errorf("failed to read sql_schema migrations: %w", err)
}
type migration struct {
file string
version int
name string
}
pattern := regexp.MustCompile("^V([0-9]+)_(.*)[.][sS][qQ][lL]$")
migrations := []migration{}
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() {
return fmt.Errorf("embedded sql migration '%s' is a directory", name)
}
matches := pattern.FindStringSubmatch(name)
if matches == nil {
return fmt.Errorf("embedded sql migration '%s' does not match the filename pattern", name)
}
if len(matches) != 3 {
return fmt.Errorf("embedded sql migration '%s' does not have the correct number of submatches", name)
}
version, err := strconv.Atoi(matches[1])
if err != nil { if err != nil {
return fmt.Errorf("failed to run V1 init script") return fmt.Errorf("failed to parse version number of migration '%s': %w", name, err)
}
migName := matches[2]
file := path.Join("sql_schema", name)
migrations = append(migrations, migration{
file: file,
version: version,
name: migName,
})
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
})
for _, migration := range migrations {
if version >= migration.version {
s.log.Debugf("Skipping migration version %d because current version %d is not smaller.", migration.version, version)
continue
}
migContent, err := sqlSchema.ReadFile(migration.file)
if err != nil {
return fmt.Errorf("failed to read embedded migration '%s': %w", migration.file, err)
}
err = func() (err error) {
rollback := sqlitex.Save(conn)
defer func() {
if err != nil {
s.log.Info("Rolling back last migration attempt.")
}
rollback(&err)
}()
s.log.Infof("Executing migration V%d - %s", migration.version, migration.name)
err = sqlitex.ExecScript(conn, string(migContent))
if err != nil {
return fmt.Errorf("failed to execute migration '%s': %w", migration.name, err)
}
err = sqlitex.Exec(conn, fmt.Sprintf("PRAGMA user_version = %d", migration.version), nil)
if err != nil {
return fmt.Errorf("failed to set user_version in db: %w", err)
}
err = sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error {
version = stmt.ColumnInt(0)
return nil
})
if err != nil {
return fmt.Errorf("failed to get user_version: %w", err)
}
s.log.Infof("Migrated db to version: %d", version)
return nil
}()
if err != nil {
return err
} }
} }
if version <= 1 { return nil
s.log.Debugf("Running db migration V2")
err = sqlitex.ExecScript(conn, sql_v2)
if err != nil {
return fmt.Errorf("failed to run V2 init script")
}
} }
return nil return nil
@ -300,7 +381,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error)
var id *int64 var id *int64
if hc.Exists { if hc.Exists {
err = sqlitex.Exec(conn, "select id from users where name = ?", func(stmt *sqlite.Stmt) error { err = sqlitex.Exec(conn, "select id from users where name = ?", func(stmt *sqlite.Stmt) error {
id = ptrInt64(stmt.ColumnInt64(0)) id = ptr(stmt.ColumnInt64(0))
return nil return nil
}, hc.Name) }, hc.Name)
if sqlite.ErrCode(err) != sqlite.SQLITE_OK { if sqlite.ErrCode(err) != sqlite.SQLITE_OK {
@ -314,7 +395,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error)
} }
} else { } else {
err = sqlitex.Exec(conn, "insert into users(name) values(?)", func(stmt *sqlite.Stmt) error { err = sqlitex.Exec(conn, "insert into users(name) values(?)", func(stmt *sqlite.Stmt) error {
id = ptrInt64(stmt.ColumnInt64(0)) id = ptr(stmt.ColumnInt64(0))
return nil return nil
}, hc.Name) }, hc.Name)
if sqlite.ErrCode(err) == sqlite.SQLITE_CONSTRAINT_UNIQUE { if sqlite.ErrCode(err) == sqlite.SQLITE_CONSTRAINT_UNIQUE {
@ -323,7 +404,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error)
if sqlite.ErrCode(err) != sqlite.SQLITE_OK { if sqlite.ErrCode(err) != sqlite.SQLITE_OK {
return 0, err return 0, err
} }
id = ptrInt64(conn.LastInsertRowID()) id = ptr(conn.LastInsertRowID())
s.requestCheckpoint() s.requestCheckpoint()
} }
return *id, nil return *id, nil
@ -925,7 +1006,7 @@ func (s *Server) getPhoto(featureID models.FeatureID, photoID models.FeaturePhot
if found { if found {
return fmt.Errorf("multiple photos returned for feature id %d, photo id %d", featureID, photoID) return fmt.Errorf("multiple photos returned for feature id %d, photo id %d", featureID, photoID)
} }
contentType = ptrString(stmt.ColumnText(0)) contentType = ptr(stmt.ColumnText(0))
data = make([]byte, stmt.ColumnLen(1)) data = make([]byte, stmt.ColumnLen(1))
stmt.ColumnBytes(1, data) stmt.ColumnBytes(1, data)
found = true found = true

View File

@ -21,5 +21,3 @@ CREATE TABLE feature_photos (
thumbnail_contents blob NOT NULL, thumbnail_contents blob NOT NULL,
contents blob NOT NULL contents blob NOT NULL
); );
PRAGMA user_version = 1;

View File

@ -3,5 +3,3 @@ CREATE TABLE proposals (
description text NOT NULL, description text NOT NULL,
how text NOT NULL how text NOT NULL
); );
PRAGMA user_version = 2;

View File

@ -13,15 +13,7 @@ import (
"cernobor.cz/oko-server/models" "cernobor.cz/oko-server/models"
) )
func ptrInt(x int) *int { func ptr[T any](x T) *T {
return &x
}
func ptrInt64(x int64) *int64 {
return &x
}
func ptrString(x string) *string {
return &x return &x
} }