# Web app code actions

import json, logging, os, sys, time
from time import sleep, time
from queue import Queue
from threading import Thread
import argparse, secrets

from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from cryptography.fernet import Fernet

__author__      = "Vishal Anand"
__email__       = "vishal.anand@columbia.edu"
__copyright__   = "Copyright 2024"


KEY_VAULT_KEYS = [
    "keys", 
    ...
]
STAGING_COMMAND = """az webapp config container set \
    --docker-custom-image-name {}/{}:{} \
    --docker-registry-server-url {} \
    --docker-registry-server-user {} \
    --docker-registry-server-password {} \
    --name {} \
    --slot staging \
    --resource-group {}"""

REGULAR_COMMAND = """az webapp config container set \
    --docker-custom-image-name {}/{}:{} \
    --docker-registry-server-url {} \
    --docker-registry-server-user {} \
    --docker-registry-server-password {} \
    --name {} \
    {} \
    --resource-group {}"""

SWAPPING_COMMAND = """az webapp deployment slot swap \
    --name {} \
    --resource-group {} \
    --slot staging \
    --target-slot production
    """

REBOOT_COMMAND = """az webapp restart \
    --name {} \
    --resource-group {}
    """
REBOOT_SLOT_COMMAND = """az webapp restart \
    --name {} \
    --slot {} \
    --resource-group {}
    """

class CLIWorker(Thread):

    def __init__(self, queue):
        Thread.__init__(self)
        self.queue = queue

    def run(self):
        while True:
            command = self.queue.get()
            try:
                print(command)
                os.system(command)
            finally:
                self.queue.task_done()

def call_queue(commands, logger):
    ts = time()
    queue = Queue()
    for _ in range(len(commands)):
        worker = CLIWorker(queue)
        worker.daemon = True
        worker.start()
    for (server, command) in commands:
        logger.info('Queueing {}'.format(server))
        queue.put(command)
    queue.join()
    logging.info('Took {:.3f}s'.format(time() - ts))

def login():
    try:
        os.system("az login --identity --output none")
    except:
        os.system("az login --tenant <tenant> --output none")
        os.system("az account set --subscription '<>' --output none")

def swap(arguments):
    commands = []
    
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)

    requested_zone = arguments.zone
    deployment_json = arguments.deployment_json

    with open(deployment_json) as f:
        data = json.load(f)
        for zone in data["Deployment"]:
            if not (
                (requested_zone == zone["Zone"]) or 
                ( (requested_zone == "Zone1") and (zone["Zone2"] not in ["Zone3", "Zone4"]) )
            ):
                continue

            for server_name in [zone["Server"], zone["Server2"]]:
                if not server_name:
                    continue
                os_command = SWAPPING_COMMAND.format(
                    server_name,
                    zone["ResourceGroup"]
                )
                #print("{}".format(os_command))
                commands.append((server_name, os_command))
    
    call_queue(commands, logger)

def deployment(deployment_json, image, tag, requested_zone, registry_url, registry_username, registry_password, slot):
    commands = []
    with open(deployment_json) as f:
        data = json.load(f)
        for zone in data["Deployment"]:
            if not (
                (requested_zone == zone["Zone"]) or 
                ( (requested_zone == "Zone1") and (zone["Zone2"] not in ["Zone3", "Zone4"]) )
            ):
                continue

            for server_name in [zone["Server"], zone["Server2"]]:
                if not server_name:
                    continue

                if slot != "":
                    slot = "--slot " + slot

                os_command = REGULAR_COMMAND.format(
                    registry_url, image, tag,
                    registry_url,
                    registry_username, registry_password,
                    server_name,
                    slot,
                    zone["ResourceGroup"]
                )
                #print("{}".format(os_command))
                commands.append((server_name, os_command))

    return commands

def reboot_command(deployment_json, requested_zone, slot=""):
    commands = []
    with open(deployment_json) as f:
        data = json.load(f)
        for zone in data["Deployment"]:
            if not (
                (requested_zone == zone["Zone"]) or 
                ( (requested_zone == "Zone1") and (zone["Zone2"] not in ["Zone3", "Zone4"]) )
            ):
                continue

            for server_name in [zone["Server"], zone["Server2"]]:
                if not server_name:
                    continue

                if slot == "":
                  os_command = REBOOT_COMMAND.format(
                    server_name,
                    zone["ResourceGroup"]
                  )
                else:
                  os_command = REBOOT_SLOT_COMMAND.format(
                    server_name,
                    slot,
                    zone["ResourceGroup"]
                  )
                #print("{}".format(os_command))
                commands.append((server_name, os_command))

    return commands

def reboot(arguments):
    commands = []
    KEY_VAULT_URL = os.environ.get("KEY_VAULT_URL") or "https://<keyvault>.vault.azure.net"
    credential = DefaultAzureCredential()
    secret_client = SecretClient(vault_url=KEY_VAULT_URL, credential=credential)
    secrets = {}
    
    for key in KEY_VAULT_KEYS:
        secrets[key] = secret_client.get_secret(key).value
        #print("{}: {}".format(key, secrets[key]))

    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)

    commands = reboot_command(arguments.deployment_json, arguments.zone, arguments.slot)
    call_queue(commands, logger)

def staging(arguments):
    commands = []
    KEY_VAULT_URL = os.environ.get("KEY_VAULT_URL") or "https://<keyvault>.vault.azure.net"
    credential = DefaultAzureCredential()
    secret_client = SecretClient(vault_url=KEY_VAULT_URL, credential=credential)
    secrets = {}
    
    for key in KEY_VAULT_KEYS:
        secrets[key] = secret_client.get_secret(key).value
        print("{}: {}".format(key, secrets[key]))

    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)

    image = arguments.image
    tag = arguments.tag
    requested_zone = arguments.zone
    deployment_json = arguments.deployment_json
    registry_url = secrets["<keys>"]
    registry_username = secrets["<keys>"]
    registry_password = secrets["<keys>"]
    slot = arguments.slot
    
    if image == "":
        if arguments.zone == "zone-name":
            image = secrets["<keys>"]

    commands = deployment(deployment_json, image, tag, requested_zone, registry_url, registry_username, registry_password, slot)
    call_queue(commands, logger)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--image', type=str, help='repository-image-name')
    parser.add_argument('-t', '--tag', type=str, help='repository-image-tag-version')
    parser.add_argument('-d', '--deployment-json', type=str, help='repository-image-tag-version', default="<>/deployment_details.json")
    parser.add_argument('-slot', '--slot', type=str, help='slot', default="")
    parser.add_argument(
        '-z', '--zone', type=str, 
        help='staging-zones', required=False,
        choices=["Zone1", "Zone2", "Zone3", ..., "Zone-n"]
    )
    parser.add_argument('-s', '--swap', action='store_true', help='swap staging into main slot', default=False)
    parser.add_argument('-r', '--reboot', action='store_true', help='swap staging into main slot', default=False)
    parser.add_argument('-l', '--login', action='store_true', help='Azure CLI login', default=False)
    arguments = parser.parse_args()

    if arguments.login:
        login()

    if arguments.reboot:
        reboot(arguments)
    elif arguments.swap:
        swap(arguments)
    else:
        staging(arguments)


if __name__ == "__main__":
    main()