Skip to content

Commit

Permalink
fix source error
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pczajka committed Mar 5, 2024
1 parent c3de383 commit cdf311d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
16 changes: 8 additions & 8 deletions src/snowflake/cli/plugins/object/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class StageManager(SqlExecutionMixin):
@staticmethod
def get_standard_stage_name(name: str) -> str:
def get_standard_stage_prefix(name: str) -> str:
# Handle embedded stages
if name.startswith("snow://") or name.startswith("@"):
return name
Expand All @@ -30,7 +30,7 @@ def get_standard_stage_name(name: str) -> str:
def get_standard_stage_directory_path(path):
if not path.endswith("/"):
path += "/"
return StageManager.get_standard_stage_name(path)
return StageManager.get_standard_stage_prefix(path)

@staticmethod
def get_stage_name_from_path(path: str):
Expand All @@ -45,7 +45,7 @@ def quote_stage_name(name: str) -> str:
if name.startswith("'") and name.endswith("'"):
return name # already quoted

standard_name = StageManager.get_standard_stage_name(name)
standard_name = StageManager.get_standard_stage_prefix(name)
if standard_name.startswith("@") and not re.fullmatch(
r"@([\w./$])+", standard_name
):
Expand All @@ -60,13 +60,13 @@ def _to_uri(self, local_path: str):
return to_string_literal(uri)

def list_files(self, stage_name: str) -> SnowflakeCursor:
stage_name = self.get_standard_stage_name(stage_name)
stage_name = self.get_standard_stage_prefix(stage_name)
return self._execute_query(f"ls {self.quote_stage_name(stage_name)}")

def get(
self, stage_path: str, dest_path: Path, parallel: int = 4
) -> SnowflakeCursor:
stage_path = self.get_standard_stage_directory_path(stage_path)
stage_path = self.get_standard_stage_prefix(stage_path)
dest_directory = f"{dest_path}/"
return self._execute_query(
f"get {self.quote_stage_name(stage_path)} {self._to_uri(dest_directory)} parallel={parallel}"
Expand All @@ -87,7 +87,7 @@ def put(
and switch back to the original role for the next commands to run.
"""
with self.use_role(role) if role else nullcontext():
stage_path = self.get_standard_stage_name(stage_path)
stage_path = self.get_standard_stage_prefix(stage_path)
local_resolved_path = path_resolver(str(local_path))
log.info("Uploading %s to @%s", local_resolved_path, stage_path)
cursor = self._execute_query(
Expand All @@ -97,7 +97,7 @@ def put(
return cursor

def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor:
source = self.get_standard_stage_directory_path(source_path)
source = self.get_standard_stage_prefix(source_path)
destination = self.get_standard_stage_directory_path(destination_path)
log.info("Copying files from %s to %s", source, destination)
query = f"copy files into {destination} from {source}"
Expand All @@ -113,7 +113,7 @@ def remove(
and switch back to the original role for the next commands to run.
"""
with self.use_role(role) if role else nullcontext():
stage_name = self.get_standard_stage_name(stage_name)
stage_name = self.get_standard_stage_prefix(stage_name)
path = path if path.startswith("/") else "/" + path
quoted_stage_name = self.quote_stage_name(f"{stage_name}{path}")
return self._execute_query(f"remove {quoted_stage_name}")
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/cli/plugins/streamlit/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def deploy(

stage_manager.create(stage_name=stage_name)

root_location = stage_manager.get_standard_stage_name(
root_location = stage_manager.get_standard_stage_prefix(
f"{stage_name}/{streamlit_name_for_root_location}"
)

Expand Down
6 changes: 2 additions & 4 deletions tests/git/test_git_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ def test_copy_to_local_file_system(mock_connector, runner, mock_ctx, temp_dir):

assert result.exit_code == 0, result.output
assert local_path.exists()
# paths in generated SQL should end with '/'
assert (
ctx.get_query()
== f"get @repo_name/branches/main/ file://{local_path.resolve()}/ parallel=4"
== f"get @repo_name/branches/main file://{local_path.resolve()}/ parallel=4"
)


Expand All @@ -87,11 +86,10 @@ def test_copy_to_remote_dir(mock_connector, runner, mock_ctx):
["git", "copy", "@repo_name/branches/main", "@stage_path/dir_in_stage"]
)

# paths in generated SQL should end with '/'
assert result.exit_code == 0, result.output
assert (
ctx.get_query()
== "copy files into @stage_path/dir_in_stage/ from @repo_name/branches/main/"
== "copy files into @stage_path/dir_in_stage/ from @repo_name/branches/main"
)


Expand Down
6 changes: 3 additions & 3 deletions tests/object/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_stage_copy_remote_to_local(mock_execute, runner, mock_cursor):
)
assert result.exit_code == 0, result.output
mock_execute.assert_called_once_with(
f"get @stageName/ file://{Path(tmp_dir).resolve()}/ parallel=4"
f"get @stageName file://{Path(tmp_dir).resolve()}/ parallel=4"
)


Expand All @@ -47,7 +47,7 @@ def test_stage_copy_remote_to_local_quoted_stage(mock_execute, runner, mock_curs
)
assert result.exit_code == 0, result.output
mock_execute.assert_called_once_with(
f"get '@\"stage name\"/' file://{Path(tmp_dir).resolve()}/ parallel=4"
f"get '@\"stage name\"' file://{Path(tmp_dir).resolve()}/ parallel=4"
)


Expand Down Expand Up @@ -80,7 +80,7 @@ def test_stage_copy_remote_to_local_quoted_uri(
["object", "stage", "copy", "-c", "empty", "@stageName", local_path]
)
assert result.exit_code == 0, result.output
mock_execute.assert_called_once_with(f"get @stageName/ {file_uri} parallel=4")
mock_execute.assert_called_once_with(f"get @stageName {file_uri} parallel=4")


@mock.patch(f"{STAGE_MANAGER}._execute_query")
Expand Down

0 comments on commit cdf311d

Please sign in to comment.