mirror of
https://github.com/mudita/MuditaOS.git
synced 2026-01-01 18:39:03 -05:00
* Removed doubled copyright header from 'up/down/devel.sql' when committing databases. * During installing, 'devel.sql' was created even if specific revisions didn't contain one. * Removed unnecessary product-specific .lua scripts.
472 lines
17 KiB
Python
472 lines
17 KiB
Python
#!/usr/bin/python3
|
|
# Copyright (c) 2017-2023, Mudita Sp. z.o.o. All rights reserved.
|
|
# For licensing, see https://github.com/mudita/MuditaOS/LICENSE.md
|
|
import os
|
|
import uuid
|
|
import sqlite3
|
|
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
import sys
|
|
import datetime
|
|
import json
|
|
import shutil
|
|
import traceback
|
|
import itertools
|
|
|
|
# Constants
|
|
up_script = "up.sql"
|
|
down_script = "down.sql"
|
|
devel_script = "devel.sql"
|
|
meta_file = ".meta"
|
|
databases_set = "databases.json"
|
|
env_file = "dbm_env.ini"
|
|
|
|
license_header = f"-- Copyright (c) 2017-{datetime.date.today().year}, Mudita Sp. z.o.o. All rights reserved.\n" \
|
|
"-- For licensing, see https://github.com/mudita/MuditaOS/LICENSE.md\n\n"
|
|
|
|
cli = ArgumentParser()
|
|
subparsers = cli.add_subparsers(dest="subcommand")
|
|
|
|
|
|
def subcommand(args=[], parent=subparsers):
|
|
def decorator(func):
|
|
parser = parent.add_parser(func.__name__, description=func.__doc__)
|
|
for arg in args:
|
|
parser.add_argument(*arg[0], **arg[1])
|
|
parser.set_defaults(func=func)
|
|
|
|
return decorator
|
|
|
|
|
|
def argument(*name_or_flags, **kwargs):
|
|
return [*name_or_flags], kwargs
|
|
|
|
|
|
class RevisionMetadata:
|
|
_key_id = "id"
|
|
_key_date = "date"
|
|
_key_message = "message"
|
|
_key_parent = "parent"
|
|
file_name = ".meta"
|
|
|
|
def __init__(self, id, date, message, parent):
|
|
self.set = {RevisionMetadata._key_id: str(id), RevisionMetadata._key_date: date,
|
|
RevisionMetadata._key_message: message, RevisionMetadata._key_parent: parent}
|
|
|
|
def id(self):
|
|
return self.set[RevisionMetadata._key_id]
|
|
|
|
def parent(self):
|
|
return self.set[RevisionMetadata._key_parent]
|
|
|
|
def message(self):
|
|
return self.set[RevisionMetadata._key_message]
|
|
|
|
@classmethod
|
|
def from_file(cls, path: Path):
|
|
with open(path, "r") as f:
|
|
raw = json.load(f)
|
|
return cls(raw[cls._key_id], raw[cls._key_date], raw[cls._key_message], raw[cls._key_parent])
|
|
|
|
def dump_to_file(self, path: Path):
|
|
with path.open('a') as file:
|
|
file.write(json.dumps(self.set, indent=1))
|
|
|
|
|
|
class ConstRevisionEntry:
|
|
def __init__(self, dir: Path):
|
|
self.dir = dir
|
|
self.metadata = RevisionMetadata.from_file(dir / RevisionMetadata.file_name)
|
|
|
|
def read_sql(self):
|
|
lines_to_skip = license_header.count('\n')
|
|
|
|
with open(self.dir / up_script) as f:
|
|
up = ''.join(f.readlines()[lines_to_skip:])
|
|
with open(self.dir / down_script) as f:
|
|
down = ''.join(f.readlines()[lines_to_skip:])
|
|
try:
|
|
with open(self.dir / devel_script) as f:
|
|
devel = ''.join(f.readlines()[lines_to_skip:])
|
|
except OSError:
|
|
devel = None
|
|
return up, down, devel
|
|
|
|
|
|
class RevisionEntry:
|
|
def __init__(self, base_dir: Path, message: str):
|
|
self.id = uuid.uuid4()
|
|
self.base_dir = base_dir
|
|
self.dir = base_dir / "{id}_{message}".format(id=str(self.id)[:8], message=message.replace(' ', '_'))
|
|
self.date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
self.message = message
|
|
latest_revision = get_latest_revision(self.base_dir)
|
|
self.metadata = RevisionMetadata(self.id, self.date, self.message,
|
|
latest_revision.metadata.id() if latest_revision else 0)
|
|
|
|
def spawn(self):
|
|
Path.mkdir(self.dir, exist_ok=True, parents=True)
|
|
self.metadata.dump_to_file(self.dir / RevisionMetadata.file_name)
|
|
self._build_sql_template()
|
|
|
|
def _build_sql_template(self):
|
|
with (self.dir / up_script).open('w') as file:
|
|
file.write(license_header + self._sql_header())
|
|
|
|
with (self.dir / down_script).open('w') as file:
|
|
file.write(license_header + self._sql_header())
|
|
|
|
with (self.dir / devel_script).open('w') as file:
|
|
file.write(license_header + self._sql_header())
|
|
|
|
def _sql_header(self):
|
|
return f'-- Message: {self.message}\n' \
|
|
f'-- Revision: {self.id}\n' \
|
|
f'-- Create Date: {self.date}\n\n' \
|
|
f'-- Insert SQL here\n'
|
|
|
|
|
|
class DatabaseSet:
|
|
def __init__(self, path: Path):
|
|
self.key_db_version = "version"
|
|
self.key_db_name = "name"
|
|
self.key_db_array = "databases"
|
|
|
|
with open(path, "r") as f:
|
|
self.set = json.load(f)
|
|
|
|
self.path = path
|
|
self.product = list(self.set.keys())[0]
|
|
|
|
def get_database_version(self, db_name: str):
|
|
v = next(
|
|
d[self.key_db_version] for d in self.set[self.product][self.key_db_array] if
|
|
d[self.key_db_name] == db_name)
|
|
return int(v)
|
|
|
|
def db_array(self):
|
|
return self.set[self.product][self.key_db_array]
|
|
|
|
def list_databases_by_name(self):
|
|
return set([database["name"] for database in self.db_array()])
|
|
|
|
def modify_database_version(self, db_name: str, version: int):
|
|
entry = next(d for d in self.set[self.product][self.key_db_array] if d[self.key_db_name] == db_name)
|
|
entry[self.key_db_version] = str(version)
|
|
with open(self.path, 'w') as file:
|
|
file.write(json.dumps(self.set, indent=1))
|
|
|
|
|
|
class Migration:
|
|
env_var = "DB_MIGRATION_ENV"
|
|
_rev_base_dir = "current"
|
|
|
|
def _get_env(self, path: Path):
|
|
"""Tries to fetch environment settings from the given file"""
|
|
with open(path / env_file) as f:
|
|
data = json.load(f)
|
|
|
|
data["output_dir"] = Path(data["output_dir"])
|
|
data["dirs"][:] = [Path(e) for e in data["dirs"]]
|
|
data["db_set_dir"] = Path(data["db_set_dir"])
|
|
return data
|
|
|
|
def _get_db_set(self):
|
|
return DatabaseSet(self._env["db_set_dir"])
|
|
|
|
def _get_database_path(self, db_name):
|
|
return next(d / db_name for d in self._env["dirs"] if (Path(d) / db_name).exists())
|
|
|
|
def _invoke_sql(self, db_name, script_name):
|
|
base_dir = self._get_database_path(db_name) / Migration._rev_base_dir
|
|
if not base_dir.exists():
|
|
print("Nothing to invoke")
|
|
return
|
|
|
|
rev = get_latest_revision(base_dir)
|
|
execute_db_script(self._env["output_dir"] / f"{db_name}.db", rev.dir / script_name)
|
|
|
|
def __init__(self, env_path: Path):
|
|
self._env = self._get_env(env_path) if env_path else self._get_env(Path(os.environ.get(Migration.env_var)))
|
|
|
|
self.db_names = [os.listdir(d) for d in self._env["dirs"]]
|
|
self.db_names = list(itertools.chain(*self.db_names))
|
|
|
|
def upgrade(self, db_name, rev, devel):
|
|
print(f"Upgrading '{db_name}', devel features: {devel}")
|
|
|
|
db_path = self._get_database_path(db_name)
|
|
|
|
Path.mkdir(self._env["output_dir"], exist_ok=True, parents=True)
|
|
# Remove old database, if exists
|
|
Path.unlink(self._env["output_dir"] / f"{db_name}.db", missing_ok=True)
|
|
|
|
# First, migrate using already committed db version from database set file
|
|
version = self._get_db_set().get_database_version(db_name)
|
|
print(f"-> Upgrading to committed version: {version}")
|
|
migrate_database_up(db_name, db_path, self._env["output_dir"], version, devel)
|
|
|
|
# Check if 'current' directory exists and apply current revision list
|
|
current_path = db_path / Migration._rev_base_dir
|
|
if not current_path.exists():
|
|
return
|
|
|
|
revisions = build_revision_entries(current_path)
|
|
|
|
if rev is None:
|
|
print(f"-> Upgrading to the newest available revision: {revisions[-1].metadata.id()}")
|
|
revisions_range = revisions[:]
|
|
else:
|
|
# Upgrade up to the specified revision
|
|
revisions_range = build_revision_entries_up_to(revisions, rev)
|
|
if not revisions_range:
|
|
print(f"-> revision: {rev} does not exist")
|
|
return
|
|
|
|
print(f"-> Upgrading to the revision: {rev}")
|
|
|
|
for revision in revisions_range:
|
|
meta = revision.metadata
|
|
print(f" -> Running upgrade from {meta.parent()} to {meta.id()}")
|
|
execute_db_script(self._env["output_dir"] / f"{db_name}.db", revision.dir / up_script)
|
|
|
|
if devel and os.path.exists(revision.dir / devel_script):
|
|
execute_db_script(self._env["output_dir"] / f"{db_name}.db", revision.dir / devel_script)
|
|
|
|
def install(self, devel):
|
|
shutil.rmtree(self._env["output_dir"], ignore_errors=True)
|
|
Path.mkdir(self._env["output_dir"], exist_ok=True, parents=True)
|
|
|
|
databases_to_migrate = self._get_db_set().list_databases_by_name().intersection(self.db_names)
|
|
|
|
print(f"Database set to be upgraded and installed: {databases_to_migrate}")
|
|
for db_name in databases_to_migrate:
|
|
self.upgrade(db_name, None, devel)
|
|
|
|
# Populate output dir with migration scripts, skip 'devel.sql' scripts
|
|
for d in self._env["dirs"]:
|
|
shutil.copytree(d, self._env["output_dir"] / "migration", dirs_exist_ok=True,
|
|
ignore=shutil.ignore_patterns(devel_script))
|
|
|
|
def commit(self, db_name):
|
|
db_path = self._get_database_path(db_name)
|
|
current_path = db_path / Migration._rev_base_dir
|
|
upgrade_version = self._get_db_set().get_database_version(db_name) + 1
|
|
|
|
print(f"Committing database '{db_name}':")
|
|
|
|
if not current_path.exists():
|
|
print("->Nothing to commit")
|
|
return
|
|
|
|
# Prepare new version directory structure
|
|
version_path = db_path / str(upgrade_version)
|
|
Path.mkdir(db_path / version_path, exist_ok=True, parents=True)
|
|
|
|
merge_sql_from_dir(current_path, db_path / version_path)
|
|
|
|
self._get_db_set().modify_database_version(db_name, upgrade_version)
|
|
|
|
shutil.rmtree(current_path)
|
|
|
|
print(f"->New version generated from commit: {upgrade_version}")
|
|
|
|
def commit_all(self):
|
|
for db_name in self._get_db_set().list_databases_by_name():
|
|
self.commit(db_name)
|
|
|
|
def revision(self, db_name, message):
|
|
base_dir = self._get_database_path(db_name) / Migration._rev_base_dir
|
|
|
|
Path.mkdir(base_dir, exist_ok=True, parents=True)
|
|
entry = RevisionEntry(base_dir, message)
|
|
entry.spawn()
|
|
print(f"Added new revision: {entry.metadata.id()}")
|
|
|
|
def revert(self, db_name):
|
|
self._invoke_sql(db_name, down_script)
|
|
|
|
def redo(self, db_name):
|
|
self._invoke_sql(db_name, down_script)
|
|
self._invoke_sql(db_name, up_script)
|
|
|
|
|
|
def build_revision_entries(base: Path):
|
|
""" Builds the list of ConstRevisionEntry entries where each child is placed after its parent.
|
|
Revision_1(id=1,parent=0) -> Revision_2(id=2,parent=1) -> Revision_n(id=n,parent=2)
|
|
"""
|
|
|
|
metas = []
|
|
|
|
for entry in base.iterdir():
|
|
metas.append(ConstRevisionEntry(entry))
|
|
|
|
chain = []
|
|
parent_index = 0
|
|
for _ in metas:
|
|
try:
|
|
entry = next(d for d in metas if d.metadata.parent() == parent_index)
|
|
parent_index = entry.metadata.id()
|
|
chain.append(entry)
|
|
except StopIteration:
|
|
break
|
|
return chain
|
|
|
|
|
|
def build_revision_entries_up_to(revisions, rev):
|
|
""" Try to build the list of ConstRevisionEntry entries from the already existing list of revisions up to the
|
|
specified revision. For instance, Revision_1(id=1,parent=0) -> Revision_2(id=2,parent=1) -> Revision_n(id=rev,
|
|
parent=2)
|
|
"""
|
|
if next((r for r in revisions if r.metadata.id() == rev), [None]):
|
|
revisions_range = []
|
|
for r in revisions:
|
|
revisions_range.append(r)
|
|
if r.metadata.id() == rev:
|
|
return revisions_range
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_latest_revision(base: Path):
|
|
"""Obtains the latest ConstRevisionEntry """
|
|
chain = build_revision_entries(base)
|
|
return None if len(chain) == 0 else chain[-1]
|
|
|
|
|
|
def merge_sql_from_dir(directory: Path, out: Path):
|
|
revisions = build_revision_entries(directory)
|
|
|
|
# Merge up/down.sql
|
|
with open(out / up_script, 'w') as up_file, open(out / down_script, 'w') as down_file:
|
|
up_file.write(license_header)
|
|
down_file.write(license_header)
|
|
|
|
for rev in revisions:
|
|
print(f"->Merging revision: {rev.metadata.id()}")
|
|
sql_up, _, sql_devel = rev.read_sql()
|
|
up_file.write(sql_up + '\n')
|
|
|
|
if sql_devel:
|
|
if not (out / devel_script).exists():
|
|
with open(out / devel_script, 'w') as devel_file:
|
|
devel_file.write(license_header)
|
|
|
|
with open(out / devel_script, 'a+') as devel_file:
|
|
devel_file.write(sql_devel + '\n')
|
|
|
|
# Down scripts need to be merged in reversed order
|
|
for rev in reversed(revisions):
|
|
_, sql_down, _ = rev.read_sql()
|
|
down_file.write(sql_down + '\n')
|
|
|
|
|
|
def execute_db_script(db_path: Path, script: Path, version: int = None):
|
|
connection = sqlite3.connect(db_path)
|
|
with open(script) as ms:
|
|
connection.executescript(ms.read())
|
|
connection.commit()
|
|
if version:
|
|
connection.execute(f"PRAGMA user_version = {version};")
|
|
connection.commit()
|
|
connection.close()
|
|
|
|
|
|
def migrate_database_up(database: str, migration_path: os.path, dst_directory: os.path, dst_version: int, devel: bool):
|
|
db_name_full = f"{database}.db"
|
|
dst_db_path = dst_directory / db_name_full
|
|
Path(dst_db_path).unlink(missing_ok=True)
|
|
|
|
for i in range(dst_version + 1):
|
|
migration_script = migration_path / str(i) / up_script
|
|
devel_script_path = migration_path / str(i) / devel_script
|
|
execute_db_script(dst_db_path, migration_script, i)
|
|
if devel and os.path.exists(devel_script_path):
|
|
execute_db_script(dst_db_path, devel_script_path, i)
|
|
|
|
|
|
@subcommand([argument("-e", "--env", help="where to store environment configuration", required=True, type=Path),
|
|
argument("--dbset", help="location of the file describing database set", required=True, type=Path),
|
|
argument("-o", "--out", help="where to store generated databases", required=True, type=Path),
|
|
argument("--dirs",
|
|
help="list of migration base directories. It's important to pass product-specific directory as "
|
|
"a first element on the list",
|
|
action='append',
|
|
nargs='*',
|
|
required=True,
|
|
type=Path)])
|
|
def init(args):
|
|
"""Initializes migration environment"""
|
|
env = {"db_set_dir": args.dbset.as_posix(), "output_dir": args.out.as_posix(),
|
|
"dirs": [a[0].as_posix() for a in args.dirs]}
|
|
with open(args.env / env_file, 'w') as f:
|
|
f.write(json.dumps(env, indent=1))
|
|
|
|
|
|
@subcommand([argument("-e", "--env", help="environment location", type=Path),
|
|
argument("--db", help="database name", required=True, type=str),
|
|
argument("-m", "--message", help="revision message", required=True, type=str)])
|
|
def revision(args):
|
|
"""Creates a new database migration revision"""
|
|
Migration(args.env).revision(args.db, args.message)
|
|
|
|
|
|
@subcommand([argument("-e", "--env", help="environment location", type=Path),
|
|
argument("--db", help="database name", type=str)])
|
|
def commit(args):
|
|
"""Commits current set of SQL statements and updates database version number"""
|
|
if args.db:
|
|
Migration(args.env).commit(args.db)
|
|
else:
|
|
Migration(args.env).commit_all()
|
|
|
|
|
|
@subcommand(
|
|
[argument("-e", "--env", help="environment location", type=Path),
|
|
argument("-d", "--devel", help="with development schema", default=False)])
|
|
def install(args):
|
|
""" Generates database set and then installs it in the specific output directory. It also populates output
|
|
directory with corresponding migration scripts"""
|
|
Migration(args.env).install(args.devel)
|
|
|
|
|
|
@subcommand(
|
|
[argument("-e", "--env", help="environment location", type=Path),
|
|
argument("--db", help="database name", type=str, required=True),
|
|
argument("-r", "--revision", help="target revision", type=str),
|
|
argument("-d", "--devel", help="with development schema", default=False)])
|
|
def upgrade(args):
|
|
""" Upgrades database to the specific revision(or the newest one if --revision parameter omitted)"""
|
|
Migration(args.env).upgrade(args.db, args.revision, args.devel)
|
|
|
|
|
|
@subcommand([argument("-e", "--env", help="environment location", type=Path),
|
|
argument("--db", help="database name", type=str, required=True)])
|
|
def revert(args):
|
|
""" Runs the (down.sql) for the specified database for the most recent migration"""
|
|
Migration(args.env).revert(args.db)
|
|
|
|
|
|
@subcommand([argument("-e", "--env", help="environment location", type=Path),
|
|
argument("--db", help="database name", type=str, required=True)])
|
|
def redo(args):
|
|
""" Runs the (down.sql) and then the (up.sql) for the most recent migration"""
|
|
Migration(args.env).redo(args.db)
|
|
|
|
|
|
def main() -> int:
|
|
args = cli.parse_args()
|
|
if args.subcommand is None:
|
|
cli.print_help()
|
|
return 1
|
|
else:
|
|
try:
|
|
args.func(args)
|
|
except:
|
|
print(traceback.format_exc())
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|