From d53aacdf59c802783bc6d5b849cd8ec8728d71d6 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Fri, 8 Mar 2024 16:38:59 +0100 Subject: [PATCH] Refactor cleanup_after_install as a decorator (#885) --- .../cli/plugins/snowpark/package/commands.py | 5 ++--- .../cli/plugins/snowpark/package/manager.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/snowflake/cli/plugins/snowpark/package/commands.py b/src/snowflake/cli/plugins/snowpark/package/commands.py index fc0b938d8..37ac02970 100644 --- a/src/snowflake/cli/plugins/snowpark/package/commands.py +++ b/src/snowflake/cli/plugins/snowpark/package/commands.py @@ -40,6 +40,7 @@ @app.command("lookup", requires_connection=True) +@cleanup_after_install def package_lookup( name: str = typer.Argument(..., help="Name of the package."), install_packages: bool = install_option, @@ -55,7 +56,6 @@ def package_lookup( install_packages = deprecated_install_option lookup_result = lookup(name=name, install_packages=install_packages) - cleanup_after_install() return MessageResult(lookup_result.message) @@ -89,6 +89,7 @@ def package_upload( @app.command("create", requires_connection=True) +@cleanup_after_install def package_create( name: str = typer.Argument( ..., @@ -117,6 +118,4 @@ def package_create( message += "\n" + lookup_result.message else: message = lookup_result.message - - cleanup_after_install() return MessageResult(message) diff --git a/src/snowflake/cli/plugins/snowpark/package/manager.py b/src/snowflake/cli/plugins/snowpark/package/manager.py index 041615f57..e9aa74a74 100644 --- a/src/snowflake/cli/plugins/snowpark/package/manager.py +++ b/src/snowflake/cli/plugins/snowpark/package/manager.py @@ -2,6 +2,7 @@ import logging import os.path +from functools import wraps from pathlib import Path from requirements.requirement import Requirement @@ -71,6 +72,13 @@ def create(zip_name: str): return CreatedSuccessfully(zip_name, Path(file_name)) -def cleanup_after_install(): - if PACKAGES_DIR.exists(): - SecurePath(PACKAGES_DIR).rmdir(recursive=True) +def cleanup_after_install(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + finally: + if PACKAGES_DIR.exists(): + SecurePath(PACKAGES_DIR).rmdir(recursive=True) + + return wrapper