Skip to content

Commit 950a646

Browse files
sfc-gh-anavalosSnowflake Authors
andauthored
Project import generated by Copybara. (#161)
GitOrigin-RevId: 1aa057e43c96b7dc03a018f482ab9a9dacccca46 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 66197a8 commit 950a646

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+3099
-497
lines changed

CHANGELOG.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Release History
22

3+
## 1.8.6
4+
5+
### Bug Fixes
6+
7+
### New Features
8+
9+
- Registry: Add service container info to logs.
10+
311
## 1.8.5
412

513
### Bug Fixes
@@ -19,6 +27,11 @@
1927
- Registry: No longer checks if the snowflake-ml-python version is available in the Snowflake Conda channel when logging
2028
an SPCS-only model.
2129
- ML Job: Add `min_instances` argument to the job decorator to allow waiting for workers to be ready.
30+
- ML Job: Adjust polling behavior to reduce number of SQL calls.
31+
32+
### Deprecations
33+
34+
- `SnowflakeLoginOptions` is deprecated and will be removed in a future release.
2235

2336
## 1.8.4 (2025-05-12)
2437

@@ -47,10 +60,6 @@
4760

4861
## 1.8.3
4962

50-
### Bug Fixes
51-
52-
### Behavior Change
53-
5463
### New Features
5564

5665
- Registry: Default to the runtime cuda version if available when logging a GPU model in Container Runtime.

bazel/get_affected_targets.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ filter_query_rules_file="${working_dir}/filter_query_rules"
9393
# -- Begin of Query Rules Heredoc --
9494
cat >"${filter_query_rules_file}" <<EndOfMessage
9595
let raw_targets = set($(<"${impacted_targets_path}")) in
96-
\$raw_targets - kind('source file', \$raw_targets) - filter('//external[:/].*', \$raw_targets)
96+
\$raw_targets - kind('source file', \$raw_targets) - filter('//external[:/].*', \$raw_targets) - filter('^//snowflake/ml/_internal:telemetry$', \$raw_targets)
9797
EndOfMessage
9898
# -- End of Query Rules Heredoc --
9999

ci/conda_recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ build:
1717
noarch: python
1818
package:
1919
name: snowflake-ml-python
20-
version: 1.8.5
20+
version: 1.8.6
2121
requirements:
2222
build:
2323
- python

ci/targets/quarantine/prod3.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
//tests/integ/snowflake/ml/extra_tests:xgboost_external_memory_training_test
22
//tests/integ/snowflake/ml/extra_tests:pipeline_with_ohe_and_xgbr_test
3+
//tests/integ/snowflake/ml/lineage:lineage_integ_test
4+
//tests/integ/snowflake/ml/modeling/manifold:spectral_embedding_test

codegen/codegen_rules.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def autogen_estimators(module, estimator_info_list):
8888
"//snowflake/ml/_internal/utils:temp_file_utils",
8989
"//snowflake/ml/_internal/utils:query_result_checker",
9090
"//snowflake/ml/_internal/utils:identifier",
91+
"//snowflake/ml/_internal/utils:connection_params",
9192
"//snowflake/ml/model:model_signature",
9293
"//snowflake/ml/model/_signatures:utils",
9394
"//snowflake/ml/modeling/_internal:estimator_utils",

snowflake/ml/_internal/telemetry.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,16 +411,13 @@ def send_custom_usage(
411411
**kwargs: Any,
412412
) -> None:
413413
conn = _get_snowflake_connection()
414-
if conn is None:
415-
raise ValueError(
416-
"""Snowflake connection is required to send custom telemetry. This means there
417-
must be at least one active session, or that telemetry is being sent from within an SPCS service."""
418-
)
419414

420-
client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
421-
common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
422-
data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
423-
client._send(msg=data)
415+
# Send telemetry if Snowflake connection is available.
416+
if conn is not None:
417+
client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
418+
common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
419+
data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
420+
client._send(msg=data)
424421

425422

426423
def send_api_usage_telemetry(

snowflake/ml/_internal/utils/BUILD.bazel

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ py_test(
174174
deps = [":temp_file_utils"],
175175
)
176176

177+
py_library(
178+
name = "connection_params",
179+
srcs = ["connection_params.py"],
180+
deps = [],
181+
)
182+
183+
py_test(
184+
name = "connection_params_test",
185+
srcs = ["connection_params_test.py"],
186+
deps = [":connection_params"],
187+
)
188+
177189
py_library(
178190
name = "parallelize",
179191
srcs = ["parallelize.py"],
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import configparser
2+
import os
3+
from typing import Optional, Union
4+
5+
from absl import logging
6+
from cryptography.hazmat import backends
7+
from cryptography.hazmat.primitives import serialization
8+
9+
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
10+
11+
12+
def _read_token(token_file: str = "") -> str:
13+
"""
14+
Reads token from environment or file provided.
15+
16+
First tries to read the token from environment variable
17+
(`SNOWFLAKE_TOKEN`) followed by the token file.
18+
Both the options are tried out in SnowServices.
19+
20+
Args:
21+
token_file: File from which token needs to be read. Optional.
22+
23+
Returns:
24+
the token.
25+
"""
26+
token = os.getenv("SNOWFLAKE_TOKEN", "")
27+
if token:
28+
return token
29+
if token_file and os.path.exists(token_file):
30+
with open(token_file) as f:
31+
token = f.read()
32+
return token
33+
34+
35+
_ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----"
36+
_UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----"
37+
38+
39+
def _load_pem_to_der(private_key_path: str) -> bytes:
40+
"""Given a private key file path (in PEM format), decode key data into DER format."""
41+
with open(private_key_path, "rb") as f:
42+
private_key_pem = f.read()
43+
private_key_passphrase: Optional[str] = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", None)
44+
45+
# Only PKCS#8 format key will be accepted. However, openssl
46+
# transparently handle PKCS#8 and PKCS#1 format (by some fallback
47+
# logic) and their is no function to distinguish between them. By
48+
# reading openssl source code, apparently they also relies on header
49+
# to determine if give bytes is PKCS#8 format or not
50+
if not private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and not private_key_pem.startswith(
51+
_UNENCRYPTED_PKCS8_PK_HEADER
52+
):
53+
raise Exception("Private key provided is not in PKCS#8 format. Please use correct format.")
54+
55+
if private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and private_key_passphrase is None:
56+
raise Exception(
57+
"Private key is encrypted but passphrase could not be found. "
58+
"Please set SNOWFLAKE_PRIVATE_KEY_PASSPHRASE env variable."
59+
)
60+
61+
if private_key_pem.startswith(_UNENCRYPTED_PKCS8_PK_HEADER):
62+
private_key_passphrase = None
63+
64+
private_key = serialization.load_pem_private_key(
65+
private_key_pem,
66+
str.encode(private_key_passphrase) if private_key_passphrase is not None else private_key_passphrase,
67+
backends.default_backend(),
68+
)
69+
70+
return private_key.private_bytes(
71+
encoding=serialization.Encoding.DER,
72+
format=serialization.PrivateFormat.PKCS8,
73+
encryption_algorithm=serialization.NoEncryption(),
74+
)
75+
76+
77+
def _connection_properties_from_env() -> dict[str, str]:
78+
"""Returns a dict with all possible login related env variables."""
79+
sf_conn_prop = {
80+
# Mandatory fields
81+
"account": os.environ["SNOWFLAKE_ACCOUNT"],
82+
"database": os.environ["SNOWFLAKE_DATABASE"],
83+
# With a default value
84+
"token_file": os.getenv("SNOWFLAKE_TOKEN_FILE", "/snowflake/session/token"),
85+
"ssl": os.getenv("SNOWFLAKE_SSL", "on"),
86+
"protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"),
87+
}
88+
# With empty default value
89+
for key, env_var in {
90+
"user": "SNOWFLAKE_USER",
91+
"authenticator": "SNOWFLAKE_AUTHENTICATOR",
92+
"password": "SNOWFLAKE_PASSWORD",
93+
"host": "SNOWFLAKE_HOST",
94+
"port": "SNOWFLAKE_PORT",
95+
"schema": "SNOWFLAKE_SCHEMA",
96+
"warehouse": "SNOWFLAKE_WAREHOUSE",
97+
"private_key_path": "SNOWFLAKE_PRIVATE_KEY_PATH",
98+
}.items():
99+
value = os.getenv(env_var, "")
100+
if value:
101+
sf_conn_prop[key] = value
102+
return sf_conn_prop
103+
104+
105+
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
106+
"""Loads the dictionary from snowsql config file."""
107+
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
108+
if not os.path.exists(snowsql_config_file):
109+
logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
110+
raise Exception("Snowflake SnowSQL config not found.")
111+
112+
config = configparser.ConfigParser(inline_comment_prefixes="#")
113+
114+
snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
115+
if snowflake_connection_name is not None:
116+
connection_name = snowflake_connection_name
117+
118+
if connection_name:
119+
if not connection_name.startswith("connections."):
120+
connection_name = "connections." + connection_name
121+
else:
122+
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
123+
connection_name = "connections"
124+
125+
logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
126+
config.read(snowsql_config_file)
127+
conn_params = dict(config[connection_name])
128+
# Remap names to appropriate args in Python Connector API
129+
# Note: "dbname" should become "database"
130+
conn_params = {k.replace("name", ""): v.strip('"') for k, v in conn_params.items()}
131+
if "db" in conn_params:
132+
conn_params["database"] = conn_params["db"]
133+
del conn_params["db"]
134+
return conn_params
135+
136+
137+
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
138+
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
139+
140+
NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
141+
1. If SNOWFLAKE_TOKEN is defined in the environment, it will be used.
142+
2. If SNOWFLAKE_TOKEN_FILE is defined in the environment and file matching the value found, content of the file
143+
will be used.
144+
145+
If token is found, username, password will be reset and 'authenticator' will be set to 'oauth'.
146+
147+
Python Connector:
148+
>> ctx = snowflake.connector.connect(**(SnowflakeLoginOptions()))
149+
150+
Snowpark Session:
151+
>> session = Session.builder.configs(SnowflakeLoginOptions()).create()
152+
153+
Usage Note:
154+
Ideally one should have a snowsql config file. Read more here:
155+
https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
156+
157+
If snowsql config file does not exist, it tries auth from env variables.
158+
159+
Args:
160+
connection_name: Name of the connection to look for inside the config file. If environment variable
161+
SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
162+
login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
163+
164+
Returns:
165+
A dict with connection parameters.
166+
167+
Raises:
168+
Exception: if none of config file and environment variable are present.
169+
"""
170+
conn_prop: dict[str, Union[str, bytes]] = {}
171+
login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
172+
# If login file exists, use this exclusively.
173+
if os.path.exists(login_file):
174+
conn_prop = {**(_load_from_snowsql_config_file(connection_name, login_file))}
175+
else:
176+
# If environment exists for SNOWFLAKE_ACCOUNT, assume everything
177+
# comes from environment. Mixing it not allowed.
178+
account = os.getenv("SNOWFLAKE_ACCOUNT", "")
179+
if account:
180+
conn_prop = {**_connection_properties_from_env()}
181+
else:
182+
raise Exception("Snowflake credential is neither set in env nor a login file was provided.")
183+
184+
# Token, if specified, is always side-loaded in all cases.
185+
token = _read_token(str(conn_prop["token_file"]) if "token_file" in conn_prop else "")
186+
if token:
187+
conn_prop["token"] = token
188+
if "authenticator" not in conn_prop or conn_prop["authenticator"]:
189+
conn_prop["authenticator"] = "oauth"
190+
elif "private_key_path" in conn_prop and "private_key" not in conn_prop:
191+
conn_prop["private_key"] = _load_pem_to_der(str(conn_prop["private_key_path"]))
192+
193+
if "ssl" in conn_prop and conn_prop["ssl"].lower() == "off":
194+
conn_prop["protocol"] = "http"
195+
196+
return conn_prop

snowflake/ml/utils/connection_params_test.py renamed to snowflake/ml/_internal/utils/connection_params_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from cryptography.hazmat.primitives import serialization
88
from cryptography.hazmat.primitives.asymmetric import rsa
99

10-
from snowflake.ml.utils import connection_params
10+
from snowflake.ml._internal.utils import connection_params
1111

1212

1313
class SnowflakeLoginOptionsTest(absltest.TestCase):

snowflake/ml/jobs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
list_jobs,
88
submit_directory,
99
submit_file,
10+
submit_from_stage,
1011
)
1112

1213
__all__ = [
@@ -18,4 +19,5 @@
1819
"delete_job",
1920
"MLJob",
2021
"JOB_STATUS",
22+
"submit_from_stage",
2123
]

snowflake/ml/jobs/_utils/BUILD.bazel

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ package(default_visibility = ["//visibility:public"])
55
py_library(
66
name = "types",
77
srcs = ["types.py"],
8+
deps = [
9+
":stage_utils",
10+
],
811
)
912

1013
py_library(
@@ -49,12 +52,22 @@ py_library(
4952
data = glob(["scripts/**"]),
5053
)
5154

55+
py_library(
56+
name = "stage_utils",
57+
srcs = ["stage_utils.py"],
58+
deps = [
59+
"//snowflake/ml/_internal/utils:identifier",
60+
],
61+
)
62+
5263
py_library(
5364
name = "payload_utils",
5465
srcs = ["payload_utils.py"],
5566
deps = [
5667
":constants",
68+
":function_payload_utils",
5769
":payload_scripts",
70+
":stage_utils",
5871
":types",
5972
],
6073
)
@@ -67,15 +80,29 @@ py_test(
6780
],
6881
deps = [
6982
":payload_utils",
83+
":stage_utils",
7084
":test_file_helper",
7185
],
7286
)
7387

88+
py_test(
89+
name = "stage_utils_test",
90+
srcs = ["stage_utils_test.py"],
91+
deps = [
92+
":stage_utils",
93+
],
94+
)
95+
7496
py_library(
7597
name = "interop_utils",
7698
srcs = ["interop_utils.py"],
7799
)
78100

101+
py_library(
102+
name = "function_payload_utils",
103+
srcs = ["function_payload_utils.py"],
104+
)
105+
79106
py_test(
80107
name = "interop_utils_test",
81108
srcs = ["interop_utils_test.py"],
@@ -103,5 +130,6 @@ py_library(
103130
":interop_utils",
104131
":payload_utils",
105132
":spec_utils",
133+
":stage_utils",
106134
],
107135
)

0 commit comments

Comments
 (0)