From 99b96c4ac577d23ac75e66a840dcce5ff74a7488 Mon Sep 17 00:00:00 2001 From: Jan Sikorski <132985823+sfc-gh-jsikorski@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:58:07 +0100 Subject: [PATCH] Switch from StrictYAML to Pydantic (#870) * Revert "Bump tomlkit from 0.12.3 to 0.12.4 (#848)" (#863) This reverts commit 80467b3d2b50c7996cfd068ed510476e64a714e0. mraba/app-factory: create app with factory (#860) * mraba/app-factory: create app with factory SNOW-1043081: Adding support for qualified names for image repositories. (#823) * SNOW-1043081: Adding support for qualified image repository names * SNOW-1043081: Fixing test imports * SNOW-1043081: Adding tests for getting image repository url without db or schema SNOW-1011771: Added generic REPLACE, IF EXISTS, IF NOT EXISTS flags (#826) * SNOW-1011771: Adding generic OR REPLACE, IF EXISTS, IF NOT EXISTS flags to flags.py * SNOW-1011771: Using generic ReplaceOption in snowpark deploy and streamlit deploy * SNOW-1011771: Using generic IfNotExistsOption in compute pool create and updating unit tests. * SNOW-1011771: Using generic IfNotExistsOption in service create and updating unit tests * SNOW-1011771: Using generic ReplaceOption and IfNotExistsOption in image-repository create. * SNOW-1011771: Fixup * SNOW-1011771: Update release notes * SNOW-1011771: Update test_help_messages * SNOW-1011771: precommit * SNOW-1011771: Adding validation that only one create mode option can be set at once * fixup * SNOW-1011771: Updating tests for REPLACE AND IF NOT EXISTS case on image-repository create to throw error * SNOW-1011771: Adding snapshots * SNOW-1011771: Adding a new mutually_exclusive field to OverrideableOption * formatting * SNOW-1011771: Adding tests for OverrideableOption * SNOW-1011771: Fixing test failures due to improperly quoted string Add snow --help to test_help_messages (#821) * Add snow --help to test_help_messages * update snapshot Avoid plain print, make sure silent is eager flag (#871) [NADE] Update CODEOWNERS to use NADE team id. (#873) update to using nade team in codeowners New workflow to stop running workflows after new commit (#862) * new workflow * new workflow * new workflow * new workflow * typo fix * typo fix * import fix * import fix * import fix * import fix * import fix * import fix * import fix * new approach * new approach * new approach * new approach * new approach * New approach * added to test * Added to more workflows * Dummy commit Schemas adjusting native apps to streamlit fixing streamlit fixies after unit tests fixies after unit tests fixing for snowflake fixing for snowflake Fixes after review Fixes after review Fixes after review * Fixes after review * Implemented error class * Fixes * Fixes * Fixes * Fixes * typo fix * Added unit test * Added unit test * Fixes after review * Fixes after review * Fixes * Fixes * Fixes --------- Co-authored-by: Adam Stus --- RELEASE-NOTES.md | 1 + pyproject.toml | 1 + src/snowflake/cli/api/commands/flags.py | 2 +- src/snowflake/cli/api/project/definition.py | 44 +- .../cli/api/project/definition_manager.py | 3 +- src/snowflake/cli/api/project/errors.py | 24 + .../cli/api/project/schemas/native_app.py | 43 - .../project/schemas/native_app/__init__.py | 0 .../project/schemas/native_app/application.py | 31 + .../project/schemas/native_app/native_app.py | 39 + .../api/project/schemas/native_app/package.py | 40 + .../schemas/native_app/path_mapping.py | 10 + .../api/project/schemas/project_definition.py | 42 +- .../cli/api/project/schemas/relaxed_map.py | 44 - .../cli/api/project/schemas/snowpark.py | 47 -- .../api/project/schemas/snowpark/__init__.py | 0 .../api/project/schemas/snowpark/argument.py | 12 + .../api/project/schemas/snowpark/callable.py | 66 ++ .../api/project/schemas/snowpark/snowpark.py | 22 + .../cli/api/project/schemas/streamlit.py | 20 - .../api/project/schemas/streamlit/__init__.py | 0 .../project/schemas/streamlit/streamlit.py | 28 + .../api/project/schemas/updatable_model.py | 27 + src/snowflake/cli/api/project/util.py | 1 + .../cli/plugins/nativeapp/artifacts.py | 5 +- .../cli/plugins/nativeapp/manager.py | 74 +- .../cli/plugins/nativeapp/run_processor.py | 5 +- .../nativeapp/version/version_processor.py | 3 +- .../cli/plugins/snowpark/commands.py | 63 +- src/snowflake/cli/plugins/snowpark/common.py | 23 +- .../cli/plugins/streamlit/commands.py | 17 +- tests/nativeapp/test_artifacts.py | 6 +- tests/nativeapp/test_manager.py | 2 +- tests/nativeapp/test_package_scripts.py | 2 +- tests/nativeapp/test_run_processor.py | 9 +- tests/nativeapp/test_teardown_processor.py | 2 +- .../test_version_create_processor.py | 2 +- .../nativeapp/test_version_drop_processor.py | 2 +- tests/project/__snapshots__/test_config.ambr | 764 ++++++++++++++++++ tests/project/test_config.py | 86 +- tests/streamlit/test_config.py | 2 +- tests/testing_utils/fixtures.py | 9 +- 42 files changed, 1299 insertions(+), 324 deletions(-) create mode 100644 src/snowflake/cli/api/project/errors.py delete mode 100644 src/snowflake/cli/api/project/schemas/native_app.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/application.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/native_app.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/package.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/path_mapping.py delete mode 100644 src/snowflake/cli/api/project/schemas/relaxed_map.py delete mode 100644 src/snowflake/cli/api/project/schemas/snowpark.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/argument.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/callable.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/snowpark.py delete mode 100644 src/snowflake/cli/api/project/schemas/streamlit.py create mode 100644 src/snowflake/cli/api/project/schemas/streamlit/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/streamlit/streamlit.py create mode 100644 src/snowflake/cli/api/project/schemas/updatable_model.py create mode 100644 tests/project/__snapshots__/test_config.ambr diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 33589aa8e..dc9d351f3 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -10,6 +10,7 @@ ## Fixes and improvements * Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands. * Fixed errors during `spcs image-registry login` not being formatted correctly. +* Project definition no longer accept extra fields. Any extra field will cause an error. # v2.1.0 diff --git a/pyproject.toml b/pyproject.toml index 3fbc25522..00af57691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "typer==0.9.0", "urllib3>=1.21.1,<2.3", "GitPython==3.1.42", + "pydantic==2.6.3" ] classifiers = [ "Development Status :: 3 - Alpha", diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index 6aa90f0d3..693c613bb 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -374,7 +374,7 @@ def project_definition_option(project_name: str): def _callback(project_path: Optional[str]): dm = DefinitionManager(project_path) - project_definition = dm.project_definition.get(project_name) + project_definition = getattr(dm.project_definition, project_name, None) project_root = dm.project_root if not project_definition: diff --git a/src/snowflake/cli/api/project/definition.py b/src/snowflake/cli/api/project/definition.py index ae5477b6f..f6b5373d3 100644 --- a/src/snowflake/cli/api/project/definition.py +++ b/src/snowflake/cli/api/project/definition.py @@ -1,12 +1,10 @@ from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List +import yaml.loader from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB -from snowflake.cli.api.project.schemas.project_definition import ( - project_override_schema, - project_schema, -) +from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition from snowflake.cli.api.project.util import ( append_to_identifier, clean_identifier, @@ -14,32 +12,25 @@ to_identifier, ) from snowflake.cli.api.secure_path import SecurePath -from strictyaml import ( - YAML, - as_document, - load, -) +from yaml import load DEFAULT_USERNAME = "unknown_user" -def merge_left(target: Union[Dict, YAML], source: Union[Dict, YAML]) -> None: +def merge_left(target: Dict, source: Dict) -> None: """ Recursively merges key/value pairs from source into target. Modifies the original dict-like "target". """ for k, v in source.items(): - if k in target and ( - isinstance(v, dict) or (isinstance(v, YAML) and v.is_mapping()) - ): + if k in target and isinstance(target[k], dict): # assumption: all inputs have been validated. - assert isinstance(target[k], dict) or isinstance(target[k], YAML) merge_left(target[k], v) else: target[k] = v -def load_project_definition(paths: List[Path]) -> dict: +def load_project_definition(paths: List[Path]) -> ProjectDefinition: """ Loads project definition, optionally overriding values. Definition values are merged in left-to-right order (increasing precedence). @@ -49,22 +40,23 @@ def load_project_definition(paths: List[Path]) -> dict: raise ValueError("Need at least one definition file.") with spaths[0].open("r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as base_yml: - definition = load(base_yml.read(), project_schema) + definition = load(base_yml.read(), Loader=yaml.loader.BaseLoader) for override_path in spaths[1:]: with override_path.open( "r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB ) as override_yml: - overrides = load(override_yml.read(), project_override_schema) + overrides = load(override_yml.read(), Loader=yaml.loader.BaseLoader) merge_left(definition, overrides) # TODO: how to show good error messages here? - definition.revalidate(project_schema) - return definition.data + return ProjectDefinition(**definition) -def generate_local_override_yml(project: Union[Dict, YAML]) -> YAML: +def generate_local_override_yml( + project: ProjectDefinition, +) -> ProjectDefinition: """ Generates defaults for optional keys in the same YAML structure as the project schema. The returned YAML object can be saved directly to a file, if desired. @@ -76,8 +68,8 @@ def generate_local_override_yml(project: Union[Dict, YAML]) -> YAML: warehouse = conn.warehouse local: dict = {} - if "native_app" in project: - name = clean_identifier(project["native_app"]["name"]) + if project.native_app: + name = clean_identifier(project.native_app.name) app_identifier = to_identifier(name) user_app_identifier = append_to_identifier(app_identifier, f"_{user}") package_identifier = append_to_identifier(app_identifier, f"_pkg_{user}") @@ -90,8 +82,12 @@ def generate_local_override_yml(project: Union[Dict, YAML]) -> YAML: }, "package": {"name": package_identifier, "role": role}, } + # TODO: this is an ugly workaround, because pydantics BaseModel.model_copy(update=) doesn't work properly + # After fixing UpdatableModel.update_from_dict it should be used here + target_definition = project.model_dump() + merge_left(target_definition, local) - return as_document(local, project_override_schema) + return ProjectDefinition(**target_definition) def default_app_package(project_name: str): diff --git a/src/snowflake/cli/api/project/definition_manager.py b/src/snowflake/cli/api/project/definition_manager.py index 663e58156..dd9060616 100644 --- a/src/snowflake/cli/api/project/definition_manager.py +++ b/src/snowflake/cli/api/project/definition_manager.py @@ -7,6 +7,7 @@ from snowflake.cli.api.exceptions import MissingConfiguration from snowflake.cli.api.project.definition import load_project_definition +from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition def _compat_is_mount(path: Path): @@ -100,5 +101,5 @@ def _user_definition_file_if_available(project_path: Path) -> Optional[Path]: ) @functools.cached_property - def project_definition(self) -> dict: + def project_definition(self) -> ProjectDefinition: return load_project_definition(self._project_config_paths) diff --git a/src/snowflake/cli/api/project/errors.py b/src/snowflake/cli/api/project/errors.py new file mode 100644 index 000000000..ba6844b51 --- /dev/null +++ b/src/snowflake/cli/api/project/errors.py @@ -0,0 +1,24 @@ +from textwrap import dedent + +from pydantic import ValidationError + + +class SchemaValidationError(Exception): + generic_message = "For field {loc} you provided '{loc}'. This caused: {msg}" + message_templates = { + "string_type": "{msg} for field '{loc}', you provided '{input}'", + "extra_forbidden": "{msg}. You provided field '{loc}' with value '{input}' that is not present in the schema", + "missing": "Your project definition is missing following fields: {loc}", + } + + def __init__(self, error: ValidationError): + errors = error.errors() + message = f"During evaluation of {error.title} schema following errors were encoutered:\n" + message += "\n".join( + [ + self.message_templates.get(e["type"], self.generic_message).format(**e) + for e in errors + ] + ) + + super().__init__(dedent(message)) diff --git a/src/snowflake/cli/api/project/schemas/native_app.py b/src/snowflake/cli/api/project/schemas/native_app.py deleted file mode 100644 index 0b6f1048c..000000000 --- a/src/snowflake/cli/api/project/schemas/native_app.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from snowflake.cli.api.project.schemas.relaxed_map import FilePath, Glob, RelaxedMap -from snowflake.cli.api.project.util import ( - IDENTIFIER, - SCHEMA_AND_NAME, -) -from strictyaml import Bool, Enum, Optional, Regex, Seq, Str, UniqueSeq - -PathMapping = RelaxedMap( - { - "src": Glob() | Seq(Glob()), - Optional("dest"): FilePath(), - } -) - -native_app_schema = RelaxedMap( - { - "name": Str(), - "artifacts": Seq(FilePath() | PathMapping), - Optional("deploy_root", default="output/deploy/"): FilePath(), - Optional("source_stage", default="app_src.stage"): Regex(SCHEMA_AND_NAME), - Optional("package"): RelaxedMap( - { - Optional("scripts", default=None): UniqueSeq(FilePath()), - Optional("role"): Regex(IDENTIFIER), - Optional("name"): Regex(IDENTIFIER), - Optional("warehouse"): Regex(IDENTIFIER), - Optional("distribution", default="internal"): Enum( - ["internal", "external", "INTERNAL", "EXTERNAL"] - ), - } - ), - Optional("application"): RelaxedMap( - { - Optional("role"): Regex(IDENTIFIER), - Optional("name"): Regex(IDENTIFIER), - Optional("warehouse"): Regex(IDENTIFIER), - Optional("debug", default=True): Bool(), - } - ), - } -) diff --git a/src/snowflake/cli/api/project/schemas/native_app/__init__.py b/src/snowflake/cli/api/project/schemas/native_app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/snowflake/cli/api/project/schemas/native_app/application.py b/src/snowflake/cli/api/project/schemas/native_app/application.py new file mode 100644 index 000000000..e7383bcc3 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/application.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class Application(UpdatableModel): + role: Optional[str] = Field( + title="Role to use when creating the application object and consumer-side objects", + default=None, + ) + name: Optional[str] = Field( + title="Name of the application object created when you run the snow app run command", + default=None, + ) + warehouse: Optional[str] = IdentifierField( + title="Name of the application object created when you run the snow app run command", + default=None, + ) + debug: Optional[bool] = Field( + title="Whether to enable debug mode when using a named stage to create an application object", + default=True, + ) + + +DistributionOptions = Literal["internal", "external", "INTERNAL", "EXTERNAL"] diff --git a/src/snowflake/cli/api/project/schemas/native_app/native_app.py b/src/snowflake/cli/api/project/schemas/native_app/native_app.py new file mode 100644 index 000000000..97714cc18 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/native_app.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import re +from typing import List, Optional, Union + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.native_app.application import Application +from snowflake.cli.api.project.schemas.native_app.package import Package +from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel +from snowflake.cli.api.project.util import ( + SCHEMA_AND_NAME, +) + + +class NativeApp(UpdatableModel): + name: str = Field( + title="Project identifier", + ) + artifacts: List[Union[PathMapping, str]] = Field( + title="List of file source and destination pairs to add to the deploy root", + ) + deploy_root: Optional[str] = Field( + title="Folder at the root of your project where the build step copies the artifacts.", + default="output/deploy/", + ) + source_stage: Optional[str] = Field( + title="Identifier of the stage that stores the application artifacts.", + default="app_src.stage", + ) + package: Optional[Package] = Field(title="PackageSchema", default=None) + application: Optional[Application] = Field(title="Application info", default=None) + + @field_validator("source_stage") + @classmethod + def validate_source_stage(cls, input_value: str): + if not re.match(SCHEMA_AND_NAME, input_value): + raise ValueError("Incorrect value for source_stage value of native_app") + return input_value diff --git a/src/snowflake/cli/api/project/schemas/native_app/package.py b/src/snowflake/cli/api/project/schemas/native_app/package.py new file mode 100644 index 000000000..3562ffe0e --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/package.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.native_app.application import DistributionOptions +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class Package(UpdatableModel): + scripts: Optional[List[str]] = Field( + title="List of SQL file paths relative to the project root", default=None + ) + role: Optional[str] = IdentifierField( + title="Role to use when creating the application package and provider-side objects", + default=None, + ) + name: Optional[str] = IdentifierField( + title="Name of the application package created when you run the snow app run command", + default=None, + ) + warehouse: Optional[str] = IdentifierField( + title="Warehouse used to run the scripts", default=None + ) + distribution: Optional[DistributionOptions] = Field( + title="Distribution of the application package created by the Snowflake CLI", + default="internal", + ) + + @field_validator("scripts") + @classmethod + def validate_scripts(cls, input_list): + if len(input_list) != len(set(input_list)): + raise ValueError( + "package.scripts field should contain unique values. Check the list for duplicates and try again" + ) + return input_list diff --git a/src/snowflake/cli/api/project/schemas/native_app/path_mapping.py b/src/snowflake/cli/api/project/schemas/native_app/path_mapping.py new file mode 100644 index 000000000..61d2520af --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/path_mapping.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import Optional + +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class PathMapping(UpdatableModel): + src: str + dest: Optional[str] = None diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index beb93af26..1e3bcecd6 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -1,23 +1,27 @@ from __future__ import annotations -from snowflake.cli.api.project.schemas import ( - native_app, - snowpark, - streamlit, -) -from snowflake.cli.api.project.schemas.relaxed_map import RelaxedMap -from strictyaml import ( - Int, - Optional, -) +from typing import Optional -project_schema = RelaxedMap( - { - "definition_version": Int(), - Optional("native_app"): native_app.native_app_schema, - Optional("snowpark"): snowpark.snowpark_schema, - Optional("streamlit"): streamlit.streamlit_schema, - } -) +from pydantic import Field +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp +from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark +from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel -project_override_schema = project_schema.as_fully_optional() + +class ProjectDefinition(UpdatableModel): + definition_version: int = Field( + title="Version of the project definition schema, which is currently 1", + ge=1, + le=1, + ) + native_app: Optional[NativeApp] = Field( + title="Native app definitions for the project", default=None + ) + snowpark: Optional[Snowpark] = Field( + title="Snowpark functions and procedures definitions for the project", + default=None, + ) + streamlit: Optional[Streamlit] = Field( + title="Streamlit definitions for the project", default=None + ) diff --git a/src/snowflake/cli/api/project/schemas/relaxed_map.py b/src/snowflake/cli/api/project/schemas/relaxed_map.py deleted file mode 100644 index bbfe99dfd..000000000 --- a/src/snowflake/cli/api/project/schemas/relaxed_map.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from strictyaml import ( - Any, - Bool, - Decimal, - Int, - MapCombined, - Optional, - Str, -) - -# TODO: use the util regexes to validate paths + globs -FilePath = Str -Glob = Str - - -class RelaxedMap(MapCombined): - """ - A version of a Map that allows any number of unknown key/value pairs. - """ - - def __init__(self, map_validator): - super().__init__( - map_validator, - Str(), - # moves through value validators left-to-right until one matches - Bool() | Decimal() | Int() | Any(), - ) - - def as_fully_optional(self) -> RelaxedMap: - """ - Returns a copy of this schema with all its keys optional, recursing into other - RelaxedMaps we find inside the schema. For existing optional keys, we strip out - the default value and ensure we don't create any new keys. - """ - validator = {} - for key, value in self._validator_dict.items(): - validator[Optional(key)] = ( - value - if not isinstance(value, RelaxedMap) - else value.as_fully_optional() - ) - return RelaxedMap(validator) diff --git a/src/snowflake/cli/api/project/schemas/snowpark.py b/src/snowflake/cli/api/project/schemas/snowpark.py deleted file mode 100644 index ed8d75639..000000000 --- a/src/snowflake/cli/api/project/schemas/snowpark.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -from snowflake.cli.api.project.schemas.relaxed_map import RelaxedMap -from snowflake.cli.api.project.util import IDENTIFIER -from strictyaml import ( - Bool, - EmptyList, - MapPattern, - Optional, - Regex, - Seq, - Str, -) - -Argument = RelaxedMap({"name": Str(), "type": Str(), Optional("default"): Str()}) - -_callable_mapping = { - "name": Str(), - Optional("database", default=None): Regex(IDENTIFIER), - Optional("schema", default=None): Regex(IDENTIFIER), - "handler": Str(), - "returns": Str(), - "signature": Seq(Argument) | EmptyList(), - Optional("runtime"): Str(), - Optional("external_access_integration"): Seq(Str()), - Optional("secrets"): MapPattern(Str(), Str()), - Optional("imports"): Seq(Str()), -} - -function_schema = RelaxedMap(_callable_mapping) - -procedure_schema = RelaxedMap( - { - **_callable_mapping, - Optional("execute_as_caller"): Bool(), - } -) - -snowpark_schema = RelaxedMap( - { - "project_name": Str(), - "stage_name": Str(), - "src": Str(), - Optional("functions"): Seq(function_schema), - Optional("procedures"): Seq(procedure_schema), - } -) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/__init__.py b/src/snowflake/cli/api/project/schemas/snowpark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/snowflake/cli/api/project/schemas/snowpark/argument.py b/src/snowflake/cli/api/project/schemas/snowpark/argument.py new file mode 100644 index 000000000..521925950 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/snowpark/argument.py @@ -0,0 +1,12 @@ +from typing import Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class Argument(UpdatableModel): + name: str = Field(title="Name of the argument") + arg_type: str = Field( + title="Type of the argument", alias="type" + ) # TODO: consider introducing literal/enum here + default: Optional[str] = Field(title="Default value for an argument", default=None) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/callable.py b/src/snowflake/cli/api/project/schemas/snowpark/callable.py new file mode 100644 index 000000000..29b3ee4dc --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/snowpark/callable.py @@ -0,0 +1,66 @@ +from typing import Dict, List, Optional, Union + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.snowpark.argument import Argument +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class Callable(UpdatableModel): + name: str = Field( + title="Object identifier" + ) # TODO: implement validator. If a name is filly qualified, database and schema cannot be specified + database: Optional[str] = IdentifierField( + title="Name of the database for the function or procedure", default=None + ) + + schema_name: Optional[str] = IdentifierField( + title="Name of the schema for the function or procedure", + default=None, + alias="schema", + ) + handler: str = Field( + title="Function’s or procedure’s implementation of the object inside source module", + examples=["functions.hello_function"], + ) + returns: str = Field( + title="Type of the result" + ) # TODO: again, consider Literal/Enum + signature: Union[str, List[Argument]] = Field( + title="The signature parameter describes consecutive arguments passed to the object" + ) + runtime: Optional[Union[str, float]] = Field( + title="Python version to use when executing ", default=None + ) + external_access_integrations: Optional[List[str]] = Field( + title="Names of external access integrations needed for this procedure’s handler code to access external networks", + default=[], + ) + secrets: Optional[Dict[str, str]] = Field( + title="Assigns the names of secrets to variables so that you can use the variables to reference the secrets", + default=[], + ) + imports: Optional[List[str]] = Field( + title="Stage and path to previously uploaded files you want to import", + default=[], + ) + + @field_validator("runtime") + @classmethod + def convert_runtime(cls, runtime_input: Union[str, float]) -> str: + if isinstance(runtime_input, float): + return str(runtime_input) + return runtime_input + + +class FunctionSchema(Callable): + pass + + +class ProcedureSchema(Callable): + execute_as_caller: Optional[bool] = Field( + title="Determine whether the procedure is executed with the privileges of the owner (you) or with the privileges of the caller", + default=False, + ) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/snowpark.py b/src/snowflake/cli/api/project/schemas/snowpark/snowpark.py new file mode 100644 index 000000000..0a3f66845 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/snowpark/snowpark.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.snowpark.callable import ( + FunctionSchema, + ProcedureSchema, +) +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class Snowpark(UpdatableModel): + project_name: str = Field(title="Project identifier") + stage_name: str = Field(title="Stage in which project’s artifacts will be stored") + src: str = Field(title="Folder where your code should be located") + functions: Optional[List[FunctionSchema]] = Field( + title="List of functions defined in the project", default=[] + ) + procedures: Optional[List[ProcedureSchema]] = Field( + title="List of procedures defined in the project", default=[] + ) diff --git a/src/snowflake/cli/api/project/schemas/streamlit.py b/src/snowflake/cli/api/project/schemas/streamlit.py deleted file mode 100644 index 8283ead83..000000000 --- a/src/snowflake/cli/api/project/schemas/streamlit.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from snowflake.cli.api.project.schemas.relaxed_map import FilePath, RelaxedMap -from strictyaml import ( - Optional, - Seq, - Str, -) - -streamlit_schema = RelaxedMap( - { - "name": Str(), - Optional("stage", default="streamlit"): Str(), - "query_warehouse": Str(), - Optional("main_file", default="streamlit_app.py"): FilePath(), - Optional("env_file"): FilePath(), - Optional("pages_dir"): FilePath(), - Optional("additional_source_files"): Seq(FilePath()), - } -) diff --git a/src/snowflake/cli/api/project/schemas/streamlit/__init__.py b/src/snowflake/cli/api/project/schemas/streamlit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/snowflake/cli/api/project/schemas/streamlit/streamlit.py b/src/snowflake/cli/api/project/schemas/streamlit/streamlit.py new file mode 100644 index 000000000..ce6b5e08c --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/streamlit/streamlit.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class Streamlit(UpdatableModel): + name: str = Field(title="App identifier") + stage: Optional[str] = Field( + title="Stage in which the app’s artifacts will be stored", default="streamlit" + ) + query_warehouse: str = Field( + title="Snowflake warehouse to host the app", default="streamlit" + ) + main_file: Optional[str] = Field( + title="Entrypoint file of the streamlit app", default="streamlit_app.py" + ) + env_file: Optional[str] = Field( + title="File defining additional configurations for the app, such as external dependencies", + default=None, + ) + pages_dir: Optional[str] = Field(title="Streamlit pages", default=None) + additional_source_files: Optional[List[str]] = Field( + title="List of additional files which should be included into deployment artifacts", + default=None, + ) diff --git a/src/snowflake/cli/api/project/schemas/updatable_model.py b/src/snowflake/cli/api/project/schemas/updatable_model.py new file mode 100644 index 000000000..837e625d5 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/updatable_model.py @@ -0,0 +1,27 @@ +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from snowflake.cli.api.project.errors import SchemaValidationError +from snowflake.cli.api.project.util import IDENTIFIER_NO_LENGTH + + +class UpdatableModel(BaseModel): + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + def __init__(self, *args, **kwargs): + try: + super().__init__(**kwargs) + except ValidationError as e: + raise SchemaValidationError(e) + + def update_from_dict( + self, update_values: Dict[str, Any] + ): # this method works wrong for optional fields set to None + for field, value in update_values.items(): # do we even need this? + if getattr(self, field, None): + setattr(self, field, value) + return self + + +def IdentifierField(*args, **kwargs): # noqa + return Field(max_length=254, pattern=IDENTIFIER_NO_LENGTH, *args, **kwargs) diff --git a/src/snowflake/cli/api/project/util.py b/src/snowflake/cli/api/project/util.py index 61d5109ea..d10646eda 100644 --- a/src/snowflake/cli/api/project/util.py +++ b/src/snowflake/cli/api/project/util.py @@ -4,6 +4,7 @@ from typing import Optional IDENTIFIER = r'((?:"[^"]*(?:""[^"]*)*")|(?:[A-Za-z_][\w$]{0,254}))' +IDENTIFIER_NO_LENGTH = r'((?:"[^"]*(?:""[^"]*)*")|(?:[A-Za-z_][\w$]*))' DB_SCHEMA_AND_NAME = f"{IDENTIFIER}[.]{IDENTIFIER}[.]{IDENTIFIER}" SCHEMA_AND_NAME = f"{IDENTIFIER}[.]{IDENTIFIER}" GLOB_REGEX = r"^[a-zA-Z0-9_\-./*?**\p{L}\p{N}]+$" diff --git a/src/snowflake/cli/plugins/nativeapp/artifacts.py b/src/snowflake/cli/plugins/nativeapp/artifacts.py index 282d11b10..e623502e7 100644 --- a/src/snowflake/cli/plugins/nativeapp/artifacts.py +++ b/src/snowflake/cli/plugins/nativeapp/artifacts.py @@ -6,6 +6,7 @@ import strictyaml from click import ClickException from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB +from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.secure_path import SecurePath @@ -153,8 +154,8 @@ def translate_artifact(item: Union[dict, str]) -> ArtifactMapping: Validation is done later when we actually resolve files / folders. """ - if isinstance(item, dict): - return ArtifactMapping(item["src"], item.get("dest", item["src"])) + if isinstance(item, PathMapping): + return ArtifactMapping(item.src, item.dest if item.dest else item.src) elif isinstance(item, str): return ArtifactMapping(item, item) diff --git a/src/snowflake/cli/plugins/nativeapp/manager.py b/src/snowflake/cli/plugins/nativeapp/manager.py index b6675ad14..ff4591868 100644 --- a/src/snowflake/cli/plugins/nativeapp/manager.py +++ b/src/snowflake/cli/plugins/nativeapp/manager.py @@ -4,7 +4,7 @@ from functools import cached_property from pathlib import Path from textwrap import dedent -from typing import Dict, List, Optional +from typing import List, Optional from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError @@ -13,6 +13,7 @@ default_application, default_role, ) +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.util import ( extract_schema, to_identifier, @@ -99,7 +100,7 @@ class NativeAppManager(SqlExecutionMixin): Base class with frequently used functionality already implemented and ready to be used by related subclasses. """ - def __init__(self, project_definition: Dict, project_root: Path): + def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__() self._project_root = project_root self._project_definition = project_definition @@ -109,27 +110,30 @@ def project_root(self) -> Path: return self._project_root @property - def definition(self) -> Dict: + def definition(self) -> NativeApp: return self._project_definition @cached_property def artifacts(self) -> List[ArtifactMapping]: - return [translate_artifact(item) for item in self.definition["artifacts"]] + return [translate_artifact(item) for item in self.definition.artifacts] @cached_property def deploy_root(self) -> Path: - return Path(self.project_root, self.definition["deploy_root"]) + return Path(self.project_root, self.definition.deploy_root) @cached_property def package_scripts(self) -> List[str]: """ Relative paths to package scripts from the project root. """ - return self.definition.get("package", {}).get("scripts", []) + if self.definition.package and self.definition.package.scripts: + return self.definition.package.scripts + else: + return [] @cached_property def stage_fqn(self) -> str: - return f'{self.package_name}.{self.definition["source_stage"]}' + return f"{self.package_name}.{self.definition.source_stage}" @cached_property def stage_schema(self) -> Optional[str]: @@ -137,55 +141,65 @@ def stage_schema(self) -> Optional[str]: @cached_property def package_warehouse(self) -> Optional[str]: - return self.definition.get("package", {}).get("warehouse", self._conn.warehouse) + if self.definition.package and self.definition.package.warehouse: + return self.definition.package.warehouse + else: + return self._conn.warehouse @cached_property def application_warehouse(self) -> Optional[str]: - return self.definition.get("application", {}).get( - "warehouse", self._conn.warehouse - ) + if self.definition.application and self.definition.application.warehouse: + return self.definition.application.warehouse + else: + return self._conn.warehouse @cached_property def project_identifier(self) -> str: # name is expected to be a valid Snowflake identifier, but PyYAML # will sometimes strip out double quotes so we try to get them back here. - return to_identifier(self.definition["name"]) + return to_identifier(self.definition.name) @cached_property def package_name(self) -> str: - return to_identifier( - self.definition.get("package", {}).get( - "name", default_app_package(self.project_identifier) - ) - ) + if self.definition.package and self.definition.package.name: + return to_identifier(self.definition.package.name) + else: + return to_identifier(default_app_package(self.project_identifier)) @cached_property def package_role(self) -> str: - return self.definition.get("package", {}).get("role", None) or default_role() + if self.definition.package and self.definition.package.role: + return self.definition.package.role + else: + return default_role() @cached_property def package_distribution(self) -> str: - return ( - self.definition.get("package", {}).get("distribution", "internal").lower() - ) + if self.definition.package and self.definition.package.distribution: + return self.definition.package.distribution.lower() + else: + return "internal" @cached_property def app_name(self) -> str: - return to_identifier( - self.definition.get("application", {}).get( - "name", default_application(self.project_identifier) - ) - ) + if self.definition.application and self.definition.application.name: + return to_identifier(self.definition.application.name) + else: + return to_identifier(default_application(self.project_identifier)) @cached_property def app_role(self) -> str: - return ( - self.definition.get("application", {}).get("role", None) or default_role() - ) + if self.definition.application and self.definition.application.role: + return self.definition.application.role + else: + return default_role() @cached_property def debug_mode(self) -> bool: - return self.definition.get("application", {}).get("debug", True) + if self.definition.application: + return self.definition.application.debug + else: + return True @cached_property def get_app_pkg_distribution_in_snowflake(self) -> str: diff --git a/src/snowflake/cli/plugins/nativeapp/run_processor.py b/src/snowflake/cli/plugins/nativeapp/run_processor.py index b514dceb3..8c13310ee 100644 --- a/src/snowflake/cli/plugins/nativeapp/run_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/run_processor.py @@ -1,12 +1,13 @@ from pathlib import Path from textwrap import dedent -from typing import Dict, Optional +from typing import Optional import jinja2 import typer from click import UsageError from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.plugins.nativeapp.constants import ( ALLOWED_SPECIAL_COMMENTS, COMMENT_COL, @@ -38,7 +39,7 @@ class NativeAppRunProcessor(NativeAppManager, NativeAppCommandProcessor): - def __init__(self, project_definition: Dict, project_root: Path): + def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__(project_definition, project_root) def create_app_package(self) -> None: diff --git a/src/snowflake/cli/plugins/nativeapp/version/version_processor.py b/src/snowflake/cli/plugins/nativeapp/version/version_processor.py index 399c337b7..574fa8e7f 100644 --- a/src/snowflake/cli/plugins/nativeapp/version/version_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/version/version_processor.py @@ -6,6 +6,7 @@ from click import BadOptionUsage, ClickException from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.util import unquote_identifier from snowflake.cli.api.utils.cursor import ( find_all_rows, @@ -240,7 +241,7 @@ def process( class NativeAppVersionDropProcessor(NativeAppManager, NativeAppCommandProcessor): - def __init__(self, project_definition: Dict, project_root: Path): + def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__(project_definition, project_root) def process( diff --git a/src/snowflake/cli/plugins/snowpark/commands.py b/src/snowflake/cli/plugins/snowpark/commands.py index f938d9f5a..09634c6b8 100644 --- a/src/snowflake/cli/plugins/snowpark/commands.py +++ b/src/snowflake/cli/plugins/snowpark/commands.py @@ -27,6 +27,12 @@ MessageResult, SingleQueryResult, ) +from snowflake.cli.api.project.schemas.snowpark.callable import ( + Callable, + FunctionSchema, + ProcedureSchema, +) +from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark from snowflake.cli.plugins.object.manager import ObjectManager from snowflake.cli.plugins.object.stage.manager import StageManager from snowflake.cli.plugins.snowpark.common import ( @@ -74,8 +80,8 @@ def deploy( """ snowpark = cli_context.project_definition - procedures = snowpark.get("procedures", []) - functions = snowpark.get("functions", []) + procedures = snowpark.procedures + functions = snowpark.functions if not procedures and not functions: raise ClickException( @@ -109,7 +115,7 @@ def deploy( raise ClickException(msg) # Create stage - stage_name = snowpark.get("stage_name", DEPLOYMENT_STAGE) + stage_name = snowpark.stage_name stage_manager = StageManager() stage_name = stage_manager.to_fully_qualified_name(stage_name) stage_manager.create( @@ -118,7 +124,7 @@ def deploy( packages = get_snowflake_packages() - artifact_stage_directory = get_app_stage_path(stage_name, snowpark["project_name"]) + artifact_stage_directory = get_app_stage_path(stage_name, snowpark.project_name) artifact_stage_target = f"{artifact_stage_directory}/{build_artifact_path.name}" stage_manager.put( @@ -155,11 +161,13 @@ def deploy( return CollectionResult(deploy_status) -def _assert_object_definitions_are_correct(object_type, object_definitions): +def _assert_object_definitions_are_correct( + object_type, object_definitions: List[Callable] +): for definition in object_definitions: - database = definition.get("database") - schema = definition.get("schema") - name = definition["name"] + database = definition.database + schema = definition.schema_name + name = definition.name fqn_parts = len(name.split(".")) if fqn_parts == 3 and database: raise ClickException( @@ -193,7 +201,9 @@ def _find_existing_objects( def _check_if_all_defined_integrations_exists( - om: ObjectManager, functions: List[Dict], procedures: List[Dict] + om: ObjectManager, + functions: List[FunctionSchema], + procedures: List[ProcedureSchema], ): existing_integrations = { i["name"].lower() @@ -203,14 +213,12 @@ def _check_if_all_defined_integrations_exists( declared_integration: Set[str] = set() for object_definition in [*functions, *procedures]: external_access_integrations = { - s.lower() for s in object_definition.get("external_access_integrations", []) + s.lower() for s in object_definition.external_access_integrations } - secrets = [s.lower() for s in object_definition.get("secrets", [])] + secrets = [s.lower() for s in object_definition.secrets] if not external_access_integrations and secrets: - raise SecretsWithoutExternalAccessIntegrationError( - object_definition["name"] - ) + raise SecretsWithoutExternalAccessIntegrationError(object_definition.name) declared_integration = declared_integration | external_access_integrations @@ -229,7 +237,7 @@ def get_app_stage_path(stage_name: Optional[str], project_name: str) -> str: def _deploy_single_object( manager: FunctionManager | ProcedureManager, object_type: ObjectType, - object_definition: Dict, + object_definition: Callable, existing_objects: Dict[str, Dict], packages: List[str], stage_artifact_path: str, @@ -245,8 +253,8 @@ def _deploy_single_object( ) log.info("Deploying %s: %s", object_type, identifier_with_default_values) - handler = object_definition["handler"] - returns = object_definition["returns"] + handler = object_definition.handler + returns = object_definition.returns replace_object = False object_exists = identifier in existing_objects @@ -271,18 +279,15 @@ def _deploy_single_object( "return_type": returns, "artifact_file": stage_artifact_path, "packages": packages, - "runtime": object_definition.get("runtime"), - "external_access_integrations": object_definition.get( - "external_access_integrations" - ), - "secrets": object_definition.get("secrets"), - "imports": object_definition.get("imports", []), + "runtime": object_definition.runtime, + "external_access_integrations": object_definition.external_access_integrations, + "secrets": object_definition.secrets, + "imports": object_definition.imports, } if object_type == ObjectType.PROCEDURE: - create_or_replace_kwargs["execute_as_caller"] = object_definition.get( + create_or_replace_kwargs[ "execute_as_caller" - ) - + ] = object_definition.execute_as_caller manager.create_or_replace(**create_or_replace_kwargs) status = "created" if not object_exists else "definition updated" @@ -293,8 +298,8 @@ def _deploy_single_object( } -def _get_snowpark_artifact_path(snowpark_definition: Dict): - source = Path(snowpark_definition["src"]) +def _get_snowpark_artifact_path(snowpark_definition: Snowpark): + source = Path(snowpark_definition.src) artifact_file = Path.cwd() / (source.name + ".zip") return artifact_file @@ -312,7 +317,7 @@ def build( The archive is built using only the `src` directory specified in the project file. """ snowpark = cli_context.project_definition - source = Path(snowpark.get("src")) + source = Path(snowpark.src) artifact_file = _get_snowpark_artifact_path(snowpark) log.info("Building package using sources from: %s", source.resolve()) diff --git a/src/snowflake/cli/plugins/snowpark/common.py b/src/snowflake/cli/plugins/snowpark/common.py index a6d70af12..11bbbe503 100644 --- a/src/snowflake/cli/plugins/snowpark/common.py +++ b/src/snowflake/cli/plugins/snowpark/common.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB, ObjectType +from snowflake.cli.api.project.schemas.snowpark.argument import Argument from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.plugins.snowpark.package_utils import generate_deploy_stage_name @@ -173,26 +174,26 @@ def _is_signature_type_a_string(sig_type: str) -> bool: def build_udf_sproc_identifier( - udf_sproc_dict, + udf_sproc, slq_exec_mixin, include_parameter_names, include_default_values=False, ): - def format_arg(arg): - result = f"{arg['type']}" + def format_arg(arg: Argument): + result = f"{arg.arg_type}" if include_parameter_names: - result = f"{arg['name']} {result}" - if include_default_values and "default" in arg: - val = f"{arg['default']}" - if _is_signature_type_a_string(arg["type"]): + result = f"{arg.name} {result}" + if include_default_values and arg.default: + val = f"{arg.default}" + if _is_signature_type_a_string(arg.arg_type): val = f"'{val}'" result += f" default {val}" return result - arguments = ", ".join(format_arg(arg) for arg in udf_sproc_dict["signature"]) + arguments = ", ".join(format_arg(arg) for arg in udf_sproc.signature) name = slq_exec_mixin.to_fully_qualified_name( - udf_sproc_dict["name"], - database=udf_sproc_dict.get("database"), - schema=udf_sproc_dict.get("schema"), + udf_sproc.name, + database=udf_sproc.database, + schema=udf_sproc.schema_name, ) return f"{name}({arguments})" diff --git a/src/snowflake/cli/plugins/streamlit/commands.py b/src/snowflake/cli/plugins/streamlit/commands.py index c2b68df6d..3f018ed8d 100644 --- a/src/snowflake/cli/plugins/streamlit/commands.py +++ b/src/snowflake/cli/plugins/streamlit/commands.py @@ -17,6 +17,7 @@ MessageResult, SingleQueryResult, ) +from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit from snowflake.cli.plugins.streamlit.manager import StreamlitManager app = SnowTyper( @@ -81,31 +82,31 @@ def streamlit_deploy( upload environment.yml and pages/ folder if present. If stage name is not specified then 'streamlit' stage will be used. If stage does not exist it will be created by this command. """ - streamlit = cli_context.project_definition + streamlit: Streamlit = cli_context.project_definition if not streamlit: return MessageResult("No streamlit were specified in project definition.") - environment_file = streamlit.get("env_file", None) + environment_file = streamlit.env_file if environment_file and not Path(environment_file).exists(): raise ClickException(f"Provided file {environment_file} does not exist") elif environment_file is None: environment_file = "environment.yml" - pages_dir = streamlit.get("pages_dir", None) + pages_dir = streamlit.pages_dir if pages_dir and not Path(pages_dir).exists(): raise ClickException(f"Provided file {pages_dir} does not exist") elif pages_dir is None: pages_dir = "pages" url = StreamlitManager().deploy( - streamlit_name=streamlit["name"], + streamlit_name=streamlit.name, environment_file=Path(environment_file), pages_dir=Path(pages_dir), - stage_name=streamlit["stage"], - main_file=Path(streamlit["main_file"]), + stage_name=streamlit.stage, + main_file=Path(streamlit.main_file), replace=replace, - query_warehouse=streamlit["query_warehouse"], - additional_source_files=streamlit.get("additional_source_files"), + query_warehouse=streamlit.query_warehouse, + additional_source_files=streamlit.additional_source_files, **options, ) diff --git a/tests/nativeapp/test_artifacts.py b/tests/nativeapp/test_artifacts.py index 2f4bae98c..51b311d8d 100644 --- a/tests/nativeapp/test_artifacts.py +++ b/tests/nativeapp/test_artifacts.py @@ -39,10 +39,10 @@ def dir_structure(path: Path, prefix="") -> List[str]: @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_napp_project_1_artifacts(project_definition_files): project_root = project_definition_files[0].parent - native_app = load_project_definition(project_definition_files)["native_app"] + native_app = load_project_definition(project_definition_files).native_app - deploy_root = Path(project_root, native_app["deploy_root"]) - artifacts = [translate_artifact(item) for item in native_app["artifacts"]] + deploy_root = Path(project_root, native_app.deploy_root) + artifacts = [translate_artifact(item) for item in native_app.artifacts] build_bundle(project_root, deploy_root, artifacts) assert dir_structure(deploy_root) == [ diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index 2477f7a71..58ee2f8e3 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -48,7 +48,7 @@ def _get_na_manager(): dm = DefinitionManager() return NativeAppManager( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_package_scripts.py b/tests/nativeapp/test_package_scripts.py index 24579b361..77af1038a 100644 --- a/tests/nativeapp/test_package_scripts.py +++ b/tests/nativeapp/test_package_scripts.py @@ -22,7 +22,7 @@ def _get_na_manager(working_dir): dm = DefinitionManager(working_dir) return NativeAppRunProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index 12f5b9c6b..3591f4cfd 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -67,7 +67,7 @@ def _get_na_run_processor(): dm = DefinitionManager() return NativeAppRunProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) @@ -784,15 +784,16 @@ def test_create_dev_app_create_new_quoted( definition_version: 1 native_app: name: '"My Native Application"' - + source_stage: app_src.stage - + artifacts: - setup.sql - app/README.md - src: app/streamlit/*.py - dest: ui/ + dest: ui/ + application: name: >- diff --git a/tests/nativeapp/test_teardown_processor.py b/tests/nativeapp/test_teardown_processor.py index c1b431c73..8d97e4244 100644 --- a/tests/nativeapp/test_teardown_processor.py +++ b/tests/nativeapp/test_teardown_processor.py @@ -38,7 +38,7 @@ def _get_na_teardown_processor(): dm = DefinitionManager() return NativeAppTeardownProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_version_create_processor.py b/tests/nativeapp/test_version_create_processor.py index d5c7f7930..e98080362 100644 --- a/tests/nativeapp/test_version_create_processor.py +++ b/tests/nativeapp/test_version_create_processor.py @@ -40,7 +40,7 @@ def _get_version_create_processor(): dm = DefinitionManager() return NativeAppVersionCreateProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_version_drop_processor.py b/tests/nativeapp/test_version_drop_processor.py index 846523a7e..9d5519190 100644 --- a/tests/nativeapp/test_version_drop_processor.py +++ b/tests/nativeapp/test_version_drop_processor.py @@ -40,7 +40,7 @@ def _get_version_drop_processor(): dm = DefinitionManager() return NativeAppVersionDropProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/project/__snapshots__/test_config.ambr b/tests/project/__snapshots__/test_config.ambr new file mode 100644 index 000000000..714bfe003 --- /dev/null +++ b/tests/project/__snapshots__/test_config.ambr @@ -0,0 +1,764 @@ +# serializer version: 1 +# name: test_fields_are_parsed_correctly[integration] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': None, + 'artifacts': list([ + dict({ + 'dest': './', + 'src': 'app/*', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'integration', + 'package': dict({ + 'distribution': 'internal', + 'name': None, + 'role': None, + 'scripts': list([ + 'package/001-shared.sql', + 'package/002-shared.sql', + ]), + 'warehouse': None, + }), + 'source_stage': 'app_src.stage', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[integration_external] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': None, + 'artifacts': list([ + dict({ + 'dest': './', + 'src': 'app/*', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'integration_external', + 'package': dict({ + 'distribution': 'external', + 'name': None, + 'role': None, + 'scripts': list([ + 'package/001-shared.sql', + 'package/002-shared.sql', + ]), + 'warehouse': None, + }), + 'source_stage': 'app_src.stage', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[minimal] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': None, + 'artifacts': list([ + 'setup.sql', + 'README.md', + ]), + 'deploy_root': 'output/deploy/', + 'name': 'minimal', + 'package': None, + 'source_stage': 'app_src.stage', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[napp_project_1] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': dict({ + 'debug': True, + 'name': 'myapp_polly', + 'role': 'myapp_consumer', + 'warehouse': None, + }), + 'artifacts': list([ + 'setup.sql', + 'app/README.md', + dict({ + 'dest': 'ui/', + 'src': 'app/streamlit/*.py', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'myapp', + 'package': dict({ + 'distribution': 'internal', + 'name': 'myapp_pkg_polly', + 'role': 'accountadmin', + 'scripts': list([ + '001-shared.sql', + '002-shared.sql', + ]), + 'warehouse': None, + }), + 'source_stage': '"MySourceSchema"."SRC_Stage"', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[napp_project_with_pkg_warehouse] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': dict({ + 'debug': True, + 'name': 'myapp_polly', + 'role': 'myapp_consumer', + 'warehouse': None, + }), + 'artifacts': list([ + 'setup.sql', + 'app/README.md', + dict({ + 'dest': 'ui/', + 'src': 'app/streamlit/*.py', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'myapp', + 'package': dict({ + 'distribution': 'internal', + 'name': 'myapp_pkg_polly', + 'role': 'accountadmin', + 'scripts': list([ + '001-shared.sql', + '002-shared.sql', + ]), + 'warehouse': 'myapp_pkg_warehouse', + }), + 'source_stage': '"MySourceSchema"."SRC_Stage"', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_function_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + 'external_1', + 'external_2', + ]), + 'handler': 'app.func1_handler', + 'imports': list([ + ]), + 'name': 'func1', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'a', + }), + dict({ + 'arg_type': 'variant', + 'default': None, + 'name': 'b', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_function_fully_qualified_name] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_db.custom_schema.fqn_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_schema.fqn_function_only_schema', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'schema_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'database_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_schema.database_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_database', + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_database.custom_schema.fqn_function_error', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_function_secrets_without_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.func1_handler', + 'imports': list([ + ]), + 'name': 'func1', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'a', + }), + dict({ + 'arg_type': 'variant', + 'default': None, + 'name': 'b', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_functions] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.func1_handler', + 'imports': list([ + ]), + 'name': 'func1', + 'returns': 'string', + 'runtime': '3.10', + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': 'default value', + 'name': 'a', + }), + dict({ + 'arg_type': 'variant', + 'default': None, + 'name': 'b', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedure_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + 'external_1', + 'external_2', + ]), + 'handler': 'app.hello', + 'imports': list([ + ]), + 'name': 'procedureName', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedure_fully_qualified_name] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_db.custom_schema.fqn_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_schema.fqn_procedure_only_schema', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'schema_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'database_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_schema.database_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_database', + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_database.custom_schema.fqn_procedure_error', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedure_secrets_without_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello', + 'imports': list([ + ]), + 'name': 'procedureName', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedures] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'hello', + 'imports': list([ + ]), + 'name': 'procedureName', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'test', + 'imports': list([ + ]), + 'name': 'test', + 'returns': 'string', + 'runtime': '3.10', + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': '', + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedures_coverage] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'foo.func', + 'imports': list([ + ]), + 'name': 'foo', + 'returns': 'variant', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[streamlit_full_definition] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': None, + 'streamlit': dict({ + 'additional_source_files': list([ + 'utils/utils.py', + 'extra_file.py', + ]), + 'env_file': 'environment.yml', + 'main_file': 'streamlit_app.py', + 'name': 'test_streamlit', + 'pages_dir': 'pages', + 'query_warehouse': 'test_warehouse', + 'stage': 'streamlit', + }), + }) +# --- diff --git a/tests/project/test_config.py b/tests/project/test_config.py index 43684a784..4a76cd79c 100644 --- a/tests/project/test_config.py +++ b/tests/project/test_config.py @@ -8,25 +8,25 @@ generate_local_override_yml, load_project_definition, ) -from strictyaml import YAMLValidationError +from snowflake.cli.api.project.errors import SchemaValidationError @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_napp_project_1(project_definition_files): project = load_project_definition(project_definition_files) - assert project["native_app"]["name"] == "myapp" - assert project["native_app"]["deploy_root"] == "output/deploy/" - assert project["native_app"]["package"]["role"] == "accountadmin" - assert project["native_app"]["application"]["name"] == "myapp_polly" - assert project["native_app"]["application"]["role"] == "myapp_consumer" - assert project["native_app"]["application"]["debug"] == True + assert project.native_app.name == "myapp" + assert project.native_app.deploy_root == "output/deploy/" + assert project.native_app.package.role == "accountadmin" + assert project.native_app.application.name == "myapp_polly" + assert project.native_app.application.role == "myapp_consumer" + assert project.native_app.application.debug == True @pytest.mark.parametrize("project_definition_files", ["minimal"], indirect=True) def test_na_minimal_project(project_definition_files: List[Path]): project = load_project_definition(project_definition_files) - assert project["native_app"]["name"] == "minimal" - assert project["native_app"]["artifacts"] == ["setup.sql", "README.md"] + assert project.native_app.name == "minimal" + assert project.native_app.artifacts == ["setup.sql", "README.md"] from os import getenv as original_getenv @@ -46,36 +46,72 @@ def mock_getenv(key: str, default: Optional[str] = None) -> Optional[str]: # a definition structure for these values but directly return defaults # in "getter" functions (higher-level data structures). local = generate_local_override_yml(project) - assert local["native_app"]["application"]["name"] == "minimal_jsmith" - assert local["native_app"]["application"]["role"] == "resolved_role" - assert ( - local["native_app"]["application"]["warehouse"] == "resolved_warehouse" - ) - assert local["native_app"]["application"]["debug"] == True - assert local["native_app"]["package"]["name"] == "minimal_pkg_jsmith" - assert local["native_app"]["package"]["role"] == "resolved_role" + assert local.native_app.application.name == "minimal_jsmith" + assert local.native_app.application.role == "resolved_role" + assert local.native_app.application.warehouse == "resolved_warehouse" + assert local.native_app.application.debug == True + assert local.native_app.package.name == "minimal_pkg_jsmith" + assert local.native_app.package.role == "resolved_role" @pytest.mark.parametrize("project_definition_files", ["underspecified"], indirect=True) def test_underspecified_project(project_definition_files): - with pytest.raises(YAMLValidationError) as exc_info: + with pytest.raises(SchemaValidationError) as exc_info: load_project_definition(project_definition_files) - assert "required key(s) 'artifacts' not found" in str(exc_info.value) + assert "NativeApp schema" in str(exc_info) + assert "Your project definition is missing following fields: ('artifacts',)" in str( + exc_info.value + ) @pytest.mark.parametrize( "project_definition_files", ["no_definition_version"], indirect=True ) def test_fails_without_definition_version(project_definition_files): - with pytest.raises(YAMLValidationError) as exc_info: + with pytest.raises(SchemaValidationError) as exc_info: load_project_definition(project_definition_files) - assert "required key(s) 'definition_version' not found" in str(exc_info.value) + assert "ProjectDefinition" in str(exc_info) + assert ( + "Your project definition is missing following fields: ('definition_version',)" + in str(exc_info.value) + ) @pytest.mark.parametrize("project_definition_files", ["unknown_fields"], indirect=True) -def test_accepts_unknown_fields(project_definition_files): - project = load_project_definition(project_definition_files) - assert project["native_app"]["name"] == "unknown_fields" - assert project["native_app"]["unknown_fields_accepted"] == True +def test_does_not_accept_unknown_fields(project_definition_files): + with pytest.raises(SchemaValidationError) as exc_info: + project = load_project_definition(project_definition_files) + + assert "NativeApp schema" in str(exc_info) + assert ( + "You provided field '('unknown_fields_accepted',)' with value 'true' that is not present in the schema" + in str(exc_info) + ) + + +@pytest.mark.parametrize( + "project_definition_files", + [ + "integration", + "integration_external", + "minimal", + "napp_project_1", + "napp_project_with_pkg_warehouse", + "snowpark_function_external_access", + "snowpark_function_fully_qualified_name", + "snowpark_function_secrets_without_external_access", + "snowpark_functions", + "snowpark_procedure_external_access", + "snowpark_procedure_fully_qualified_name", + "snowpark_procedure_secrets_without_external_access", + "snowpark_procedures", + "snowpark_procedures_coverage", + "streamlit_full_definition", + ], + indirect=True, +) +def test_fields_are_parsed_correctly(project_definition_files, snapshot): + result = load_project_definition(project_definition_files).model_dump() + assert result == snapshot diff --git a/tests/streamlit/test_config.py b/tests/streamlit/test_config.py index a1253f173..86d51c1b7 100644 --- a/tests/streamlit/test_config.py +++ b/tests/streamlit/test_config.py @@ -30,4 +30,4 @@ def test_load_project_definition(test_files, expected): result = load_project_definition(test_files) - assert expected in result["streamlit"]["additional_source_files"] + assert expected in result.streamlit.additional_source_files diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index 5c30a0bb2..cd52a0e4f 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -11,12 +11,11 @@ from unittest import mock import pytest -import strictyaml +import yaml from snowflake.cli.api.project.definition import merge_left from snowflake.cli.app.cli_app import app_factory from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.errors import ProgrammingError -from strictyaml import as_document from typer import Typer from typer.testing import CliRunner @@ -250,10 +249,12 @@ def _temporary_project_directory( test_data_file = test_root_path / "test_data" / "projects" / project_name shutil.copytree(test_data_file, temp_dir, dirs_exist_ok=True) if merge_project_definition: - project_definition = strictyaml.load(Path("snowflake.yml").read_text()).data + project_definition = yaml.load( + Path("snowflake.yml").read_text(), Loader=yaml.BaseLoader + ) merge_left(project_definition, merge_project_definition) with open(Path(temp_dir) / "snowflake.yml", "w") as file: - file.write(as_document(project_definition).as_yaml()) + file.write(yaml.dump(project_definition)) yield Path(temp_dir)