|
| 1 | +import configparser |
| 2 | +import os |
| 3 | +from typing import Optional, Union |
| 4 | + |
| 5 | +from absl import logging |
| 6 | +from cryptography.hazmat import backends |
| 7 | +from cryptography.hazmat.primitives import serialization |
| 8 | + |
| 9 | +_DEFAULT_CONNECTION_FILE = "~/.snowsql/config" |
| 10 | + |
| 11 | + |
| 12 | +def _read_token(token_file: str = "") -> str: |
| 13 | + """ |
| 14 | + Reads token from environment or file provided. |
| 15 | +
|
| 16 | + First tries to read the token from environment variable |
| 17 | + (`SNOWFLAKE_TOKEN`) followed by the token file. |
| 18 | + Both the options are tried out in SnowServices. |
| 19 | +
|
| 20 | + Args: |
| 21 | + token_file: File from which token needs to be read. Optional. |
| 22 | +
|
| 23 | + Returns: |
| 24 | + the token. |
| 25 | + """ |
| 26 | + token = os.getenv("SNOWFLAKE_TOKEN", "") |
| 27 | + if token: |
| 28 | + return token |
| 29 | + if token_file and os.path.exists(token_file): |
| 30 | + with open(token_file) as f: |
| 31 | + token = f.read() |
| 32 | + return token |
| 33 | + |
| 34 | + |
| 35 | +_ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----" |
| 36 | +_UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----" |
| 37 | + |
| 38 | + |
| 39 | +def _load_pem_to_der(private_key_path: str) -> bytes: |
| 40 | + """Given a private key file path (in PEM format), decode key data into DER format.""" |
| 41 | + with open(private_key_path, "rb") as f: |
| 42 | + private_key_pem = f.read() |
| 43 | + private_key_passphrase: Optional[str] = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", None) |
| 44 | + |
| 45 | + # Only PKCS#8 format key will be accepted. However, openssl |
| 46 | + # transparently handle PKCS#8 and PKCS#1 format (by some fallback |
| 47 | + # logic) and their is no function to distinguish between them. By |
| 48 | + # reading openssl source code, apparently they also relies on header |
| 49 | + # to determine if give bytes is PKCS#8 format or not |
| 50 | + if not private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and not private_key_pem.startswith( |
| 51 | + _UNENCRYPTED_PKCS8_PK_HEADER |
| 52 | + ): |
| 53 | + raise Exception("Private key provided is not in PKCS#8 format. Please use correct format.") |
| 54 | + |
| 55 | + if private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and private_key_passphrase is None: |
| 56 | + raise Exception( |
| 57 | + "Private key is encrypted but passphrase could not be found. " |
| 58 | + "Please set SNOWFLAKE_PRIVATE_KEY_PASSPHRASE env variable." |
| 59 | + ) |
| 60 | + |
| 61 | + if private_key_pem.startswith(_UNENCRYPTED_PKCS8_PK_HEADER): |
| 62 | + private_key_passphrase = None |
| 63 | + |
| 64 | + private_key = serialization.load_pem_private_key( |
| 65 | + private_key_pem, |
| 66 | + str.encode(private_key_passphrase) if private_key_passphrase is not None else private_key_passphrase, |
| 67 | + backends.default_backend(), |
| 68 | + ) |
| 69 | + |
| 70 | + return private_key.private_bytes( |
| 71 | + encoding=serialization.Encoding.DER, |
| 72 | + format=serialization.PrivateFormat.PKCS8, |
| 73 | + encryption_algorithm=serialization.NoEncryption(), |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +def _connection_properties_from_env() -> dict[str, str]: |
| 78 | + """Returns a dict with all possible login related env variables.""" |
| 79 | + sf_conn_prop = { |
| 80 | + # Mandatory fields |
| 81 | + "account": os.environ["SNOWFLAKE_ACCOUNT"], |
| 82 | + "database": os.environ["SNOWFLAKE_DATABASE"], |
| 83 | + # With a default value |
| 84 | + "token_file": os.getenv("SNOWFLAKE_TOKEN_FILE", "/snowflake/session/token"), |
| 85 | + "ssl": os.getenv("SNOWFLAKE_SSL", "on"), |
| 86 | + "protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"), |
| 87 | + } |
| 88 | + # With empty default value |
| 89 | + for key, env_var in { |
| 90 | + "user": "SNOWFLAKE_USER", |
| 91 | + "authenticator": "SNOWFLAKE_AUTHENTICATOR", |
| 92 | + "password": "SNOWFLAKE_PASSWORD", |
| 93 | + "host": "SNOWFLAKE_HOST", |
| 94 | + "port": "SNOWFLAKE_PORT", |
| 95 | + "schema": "SNOWFLAKE_SCHEMA", |
| 96 | + "warehouse": "SNOWFLAKE_WAREHOUSE", |
| 97 | + "private_key_path": "SNOWFLAKE_PRIVATE_KEY_PATH", |
| 98 | + }.items(): |
| 99 | + value = os.getenv(env_var, "") |
| 100 | + if value: |
| 101 | + sf_conn_prop[key] = value |
| 102 | + return sf_conn_prop |
| 103 | + |
| 104 | + |
| 105 | +def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]: |
| 106 | + """Loads the dictionary from snowsql config file.""" |
| 107 | + snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE) |
| 108 | + if not os.path.exists(snowsql_config_file): |
| 109 | + logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}") |
| 110 | + raise Exception("Snowflake SnowSQL config not found.") |
| 111 | + |
| 112 | + config = configparser.ConfigParser(inline_comment_prefixes="#") |
| 113 | + |
| 114 | + snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME") |
| 115 | + if snowflake_connection_name is not None: |
| 116 | + connection_name = snowflake_connection_name |
| 117 | + |
| 118 | + if connection_name: |
| 119 | + if not connection_name.startswith("connections."): |
| 120 | + connection_name = "connections." + connection_name |
| 121 | + else: |
| 122 | + # See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings |
| 123 | + connection_name = "connections" |
| 124 | + |
| 125 | + logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}") |
| 126 | + config.read(snowsql_config_file) |
| 127 | + conn_params = dict(config[connection_name]) |
| 128 | + # Remap names to appropriate args in Python Connector API |
| 129 | + # Note: "dbname" should become "database" |
| 130 | + conn_params = {k.replace("name", ""): v.strip('"') for k, v in conn_params.items()} |
| 131 | + if "db" in conn_params: |
| 132 | + conn_params["database"] = conn_params["db"] |
| 133 | + del conn_params["db"] |
| 134 | + return conn_params |
| 135 | + |
| 136 | + |
| 137 | +def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]: |
| 138 | + """Returns a dict that can be used directly into snowflake python connector or Snowpark session config. |
| 139 | +
|
| 140 | + NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order: |
| 141 | + 1. If SNOWFLAKE_TOKEN is defined in the environment, it will be used. |
| 142 | + 2. If SNOWFLAKE_TOKEN_FILE is defined in the environment and file matching the value found, content of the file |
| 143 | + will be used. |
| 144 | +
|
| 145 | + If token is found, username, password will be reset and 'authenticator' will be set to 'oauth'. |
| 146 | +
|
| 147 | + Python Connector: |
| 148 | + >> ctx = snowflake.connector.connect(**(SnowflakeLoginOptions())) |
| 149 | +
|
| 150 | + Snowpark Session: |
| 151 | + >> session = Session.builder.configs(SnowflakeLoginOptions()).create() |
| 152 | +
|
| 153 | + Usage Note: |
| 154 | + Ideally one should have a snowsql config file. Read more here: |
| 155 | + https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings |
| 156 | +
|
| 157 | + If snowsql config file does not exist, it tries auth from env variables. |
| 158 | +
|
| 159 | + Args: |
| 160 | + connection_name: Name of the connection to look for inside the config file. If environment variable |
| 161 | + SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name. |
| 162 | + login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE). |
| 163 | +
|
| 164 | + Returns: |
| 165 | + A dict with connection parameters. |
| 166 | +
|
| 167 | + Raises: |
| 168 | + Exception: if none of config file and environment variable are present. |
| 169 | + """ |
| 170 | + conn_prop: dict[str, Union[str, bytes]] = {} |
| 171 | + login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE) |
| 172 | + # If login file exists, use this exclusively. |
| 173 | + if os.path.exists(login_file): |
| 174 | + conn_prop = {**(_load_from_snowsql_config_file(connection_name, login_file))} |
| 175 | + else: |
| 176 | + # If environment exists for SNOWFLAKE_ACCOUNT, assume everything |
| 177 | + # comes from environment. Mixing it not allowed. |
| 178 | + account = os.getenv("SNOWFLAKE_ACCOUNT", "") |
| 179 | + if account: |
| 180 | + conn_prop = {**_connection_properties_from_env()} |
| 181 | + else: |
| 182 | + raise Exception("Snowflake credential is neither set in env nor a login file was provided.") |
| 183 | + |
| 184 | + # Token, if specified, is always side-loaded in all cases. |
| 185 | + token = _read_token(str(conn_prop["token_file"]) if "token_file" in conn_prop else "") |
| 186 | + if token: |
| 187 | + conn_prop["token"] = token |
| 188 | + if "authenticator" not in conn_prop or conn_prop["authenticator"]: |
| 189 | + conn_prop["authenticator"] = "oauth" |
| 190 | + elif "private_key_path" in conn_prop and "private_key" not in conn_prop: |
| 191 | + conn_prop["private_key"] = _load_pem_to_der(str(conn_prop["private_key_path"])) |
| 192 | + |
| 193 | + if "ssl" in conn_prop and conn_prop["ssl"].lower() == "off": |
| 194 | + conn_prop["protocol"] = "http" |
| 195 | + |
| 196 | + return conn_prop |
0 commit comments