Source code for deploydb

"""Top-level package for deploydb."""
__author__ = 'Mert Guvencli'
__email__ = 'guvenclimert@gmail.com'
__version__ = '0.2.2'

import os
import sys
import traceback
from datetime import datetime
import time
import json
from typing import Any

import pyodbc
from git import Repo, Git
from .model import Config
from .db import Database
from .utils import _save_csv, _set_commit_log, _last_commit_hash
from .script import DATABASES, OBJECTS, TABLES, CREATE_TABLE, GET_OBJECT


[docs]class Base: """ Check and validate configuration. """ def __init__(self, config) -> None: """Returns validated configuration model. Args: config : json file or a dict for global variables. """ self.config = config self._config: Config = None self._handle_config() def _is_file_path(self): try: return os.path.exists(self.config) except: # noqa return False def _handle_config(self) -> str: if self._is_file_path(): with open(self.config) as json_file: self._config = Config(**json.load(json_file)) elif isinstance(self.config, dict): self._config = Config(**self.config) else: raise ValueError( 'Invalid Config argument: "{0}". Config argument must be a file path, ' 'or a dict containing the parsed file contents.'.format(self.config) ) if self._config: _db = Database(self._config.db_creds) with _db.connect() as db: print("Database connection successfully created!") db.execute("SELECT 1").fetchone()
[docs]class RepoGenerator(Base): """It will create your database object's script that you need. Args: config (Any): config file or a `dict`. export_path (str): where the exported files locate includes (list, optional): default takes all databases from the given credential. excludes (list, optional): exclude databases from the given credential. err_file_path (str, optional): where the errors locate. Defaults to "errors.csv". Example: from deploydb import RepoGenerator scripter = RepoGenerator( config="config.json", export_path="path-to-export", includes=[], excludes=[] ) scripter.run() """ def __init__( self, *, config, export_path, includes=[], excludes=[], err_file_path="errors.csv" ) -> None: super().__init__(config) self.path = export_path self.includes = includes self.excludes = excludes self.err_file_path = err_file_path self._failure = [] self.sub_folders = ( {'Tables': '**Tables**'}, {'Views': '**Views**'}, {'Functions': '**Functions**'}, {'Stored-Procedures': '**Stored-Procedures**'}, {'Triggers': '**Triggers**'}, {'Types': '**User Defined Data Types**'}, {'DMLs': '**DMLs - Data Manipulations**'}, {'DDLs': '**DDLs - Data Definitions**'} ) def _create_folder(self, db_name): # base folder if not os.path.exists(self.path): os.mkdir(self.path) # wrapper folder project_path = os.path.join(self.path, 'Databases') if not os.path.exists(project_path): os.mkdir(project_path) # db folder project_path = os.path.join(project_path, db_name) if not os.path.exists(project_path): os.mkdir(project_path) # objects folder for folder in self.sub_folders: # add README.md file _folder = list(folder.keys())[0] os.mkdir(os.path.join(self.path, project_path, _folder)) with open(os.path.join(self.path, project_path, _folder, f'{_folder}_README.md'), 'w') as f: f.write(folder.get(_folder)) return project_path def _safe_file_name(self, schema_name, object_name) -> str: allowed_chars = 'abcdefghijklmnopqrstuvwxyz_0123456789' found = False for char in str(object_name).lower(): if char not in allowed_chars: found = True break if found or schema_name != "dbo": object_name = f"[{object_name}]" if schema_name == "dbo": schema_name = "" else: schema_name = f"[{schema_name}]." return f"{schema_name}{object_name}.sql" def _write_script(self, parent, sub, schema_name, object_name, script): safe_name = self._safe_file_name(schema_name, object_name) path = os.path.join(parent, sub, safe_name) with open(path, mode='w', encoding='utf-8') as f: f.write(script) def _init_project(self, db_name): project_path = self._create_folder(db_name) _db = Database(creds=self._config.db_creds) with _db.connect(db_name) as db: tables = db.execute(TABLES).fetchall() print(f'{len(tables)} tables found on {db_name}. Generating table script...') # noqa for index, table in enumerate(tables, start=1): try: script = db.execute( CREATE_TABLE, table.SCHEMA_NAME, table.TABLE_NAME ).fetchone().SQL if len(str(script)) > 0: self._write_script(project_path, "Tables", table.SCHEMA_NAME, table.TABLE_NAME, script) print(f'--->{index}/{len(tables)} {table.TABLE_NAME} on {db_name}') except: # noqa error = str(traceback.format_exception(*sys.exc_info())) print(f'Failed--->{index}/{len(tables)} {table.TABLE_NAME} on {db_name}') self._failure.append([db_name, "Tables", table.TABLE_NAME, error]) objects = db.execute(OBJECTS).fetchall() print(f'{len(objects)} objects found on {db_name}. Generating object script...') # noqa for index, item in enumerate(objects, start=1): try: self._write_script(project_path, item.SUB_FOLDER, item.SCHEMA_NAME, item.OBJECT_NAME, item.SQL) print(f'--->{index}/{len(objects)} {item.OBJECT_NAME} on {db_name}') except: # noqa error = str(traceback.format_exception(*sys.exc_info())) print(f'Failed--->{index}/{len(objects)} {item.OBJECT_NAME} on {db_name}') # noqa self._failure.append(db_name, item.SUB_FOLDER, item.OBJECT_NAME, error) def _generate(self): _db = Database(creds=self._config.db_creds) with _db.connect("master") as db: db_list = None if self.includes: db_list = self.includes else: db_list = [x.DB_NAME for x in db.execute(DATABASES).fetchall()] for ix, db_name in enumerate(db_list, start=1): if db_name not in self.excludes: print(f'Database: {db_name} {ix}/{len(db_list)}') # noqa self._init_project(db_name) print("*" * 100, end="\n")
[docs] def run(self): self._generate() if self._failure: _save_csv( path=os.path.join(self.path, self.err_file_path), columns=['DB_NAME', 'SUB_FOLDER', 'OBJECT_NAME', 'ERROR'], rows=self._failure )
[docs]class Listener(Base): def __init__( self, config, *, ssh_path="~/.ssh/id_rsa", changelog_path="changelog.csv", err_path="errors.csv" ) -> None: super().__init__(config) self.ssh_path = ssh_path self.changelog_path = changelog_path self.err_path = err_path self._init_table() def _init_table(self): with self._db().connect(self._config.db_creds.default_db) as db: db.execute(script.CREATE_LOG_TABLE) def _db(self): return Database(creds=self._config.db_creds) def _prep_cmd(self, file) -> Any: path = os.path.join(self._config.local_path, file) with open(path, mode='r', encoding='utf-8') as f: return f.read() def _run_cmd(self, db_name, target_hash, file) -> None: cmd = self._prep_cmd(file) with self._db().connect(db_name) as db: try: db.execute(cmd) except pyodbc.ProgrammingError as ex: err, msg = ex.args with self._db().connect(self._config.db_creds.default_db) as db: db.execute(script.LOG_INSERT, target_hash, file, msg) def _pull(self): if self._config.ssh_url: print("SSH connection starting...") cmd = f'ssh -i {os.path.expanduser(self.ssh_path)}' with Git().custom_environment(GIT_SSH_COMMAND=cmd): Repo.clone_from( self._config.ssh_url, self._config.local_path, branch=self._config.target_branch ) elif self._config.https_url: print("HTTPS connection starting...") Repo.clone_from( self._config.https_url, self._config.local_path, branch=self._config.target_branch ) else: raise Exception('No found repository!') def _extract_creds(self, changed_file): x = changed_file.split('/') db_name = str(x[1]) object_type = str(x[2]) object_name = str(x[3]).split('.sql')[0] return db_name, object_type, object_name def _is_object_exists(self, db_name, object_type, object_name): print("Is object exists", db_name, object_type, object_name) with self._db().connect(db_name) as db: res = any(db.execute(GET_OBJECT, object_type, object_name).fetchall()) print(f"{object_name} is exists : {res}") return res
[docs] def policy(self, file): """ Determine if the script be able to execute ? """ db_name, object_type, object_name = self._extract_creds(file) # If table already created, script wont execute. if object_type == "Tables": if self._is_object_exists(db_name, object_type, object_name): print("Item rejected!") return False return True
[docs] def sync(self, loop=False, sleep=15, max_retry=3): """Handles changes and deploys to your server automatically. Args: loop (bool, optional): creates infinite loop to handle changes. Defaults to False. sleep (int, optional): determines how many seconds will run. Defaults to 15. max_retry (int, optional): if any error occurs how many times will retry. Defaults to 3. """ if not os.path.exists(self._config.local_path): print(f"Initial pulling branch: {self._config.target_branch}") os.mkdir(self._config.local_path) self._pull() def changes(): print("Checking changes...", datetime.now()) repo = Repo(self._config.local_path) origin = repo.remotes.origin origin.pull() source_hash = _last_commit_hash(path=self.changelog_path) target_hash = repo.head.commit.hexsha failure = [] if source_hash != target_hash: print("Changes detected...") source_commit = repo.commit(source_hash) target_commit = repo.commit(target_hash) git_diff = source_commit.diff(target_commit) changed_files = [f.a_path for f in git_diff] for item in changed_files: if str(item).endswith('.sql'): print("Changed file:", item) db_name, object_type, object_name = self._extract_creds(item) # Refers customized applied policies. # Pre-defined rules are listed. You may customize that. # Say for instance: # Prevent DDL commands side affects over existing table. if self.policy(file=item): try: self._run_cmd(db_name, target_hash, item) except: # noqa error = str(traceback.format_exception(*sys.exc_info())) failure.append([item, error]) _set_commit_log(hexsha=target_hash, path=self.changelog_path) if failure: columns = ['commit_hexsha', 'time', 'object', 'error'] rows = [[target_hash, datetime.now(), x[0], x[1]] for x in failure] _save_csv(self.err_path, columns, rows) changes() _fails = 0 prev = time.time() while loop: now = time.time() try: changes() time.sleep(sleep) # if not consecutive, setting default value if now - prev < (sleep + abs(sleep - 1)): _fails = 0 except: # noqa error = str(traceback.format_exception(*sys.exc_info())) _fails += 1 print(f"An error occured, retrying ... Error: {error}") prev = time.time() if _fails == max_retry: print('Terminating process...Maximum retry value is over.') break