Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,42 @@
# limitations under the License.
#

import glob
import json
import importlib.resources
import os
import zipfile
from pyspark.find_spark_home import _find_spark_home

# Note: Though we call them "error classes" here, the proper name is "error conditions",
# hence why the name of the JSON file is different.
# For more information, please see: https://issues.apache.org/jira/browse/SPARK-46810
# This discrepancy will be resolved as part of: https://issues.apache.org/jira/browse/SPARK-47429
ERROR_CLASSES_JSON = (
importlib.resources.files("pyspark.errors").joinpath("error-conditions.json").read_text()
)
ERROR_CLASSES_MAP = json.loads(ERROR_CLASSES_JSON)


def get_error_classes() -> dict[str, dict]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this impact python Spark Connect client-only installations? They would not have the jars installed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. So for now, it won't generate any error. The connect client would miss the java side error classes. As long as connect client only uses python error classes it should be fine.

python_error_classes_json = (
importlib.resources.files("pyspark.errors").joinpath("error-conditions.json").read_text()
)
python_error_classes_map = json.loads(python_error_classes_json)

# We load the Java error classes from the jars so Python recognizes them too
java_error_classes_map = {}
spark_home = _find_spark_home()

# Released spark packages have the jars in SPARK_HOME/jars, and development builds have them
# in assembly/target
for bin_dir in ("jars", "assembly/target/scala-*/jars"):
bin_path = os.path.join(spark_home, bin_dir)
jars = glob.glob(os.path.join(bin_path, "spark-common-utils_*.jar"))
if jars:
with zipfile.ZipFile(jars[0]) as zf:
with zf.open("error/error-conditions.json") as f:
java_error_classes_json = f.read().decode("utf-8")
java_error_classes_map = json.loads(java_error_classes_json)
break

return java_error_classes_map | python_error_classes_map


ERROR_CLASSES_MAP = get_error_classes()
21 changes: 18 additions & 3 deletions python/pyspark/errors/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@
# limitations under the License.
#

import importlib.resources
import json
import unittest

from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.errors.error_classes import ERROR_CLASSES_JSON
from pyspark.errors.utils import ErrorClassesReader


class ErrorsTest(unittest.TestCase):
def test_error_classes_sorted(self):
# Test error classes is sorted alphabetically
error_reader = ErrorClassesReader()
error_class_names = list(error_reader.error_info_map.keys())
ERROR_CLASSES_JSON = (
importlib.resources.files("pyspark.errors")
.joinpath("error-conditions.json")
.read_text()
)
error_class_names = list(json.loads(ERROR_CLASSES_JSON).keys())
for i in range(len(error_class_names) - 1):
self.assertTrue(
error_class_names[i] < error_class_names[i + 1],
Expand All @@ -48,6 +52,12 @@ def detect_duplication(pairs):
error_classes_json[name] = message
return error_classes_json

ERROR_CLASSES_JSON = (
importlib.resources.files("pyspark.errors")
.joinpath("error-conditions.json")
.read_text()
)

json.loads(ERROR_CLASSES_JSON, object_pairs_hook=detect_duplication)

def test_invalid_error_class(self):
Expand Down Expand Up @@ -108,6 +118,11 @@ def test_breaking_change_info(self):
subclass_map = error_reader.error_info_map["TEST_ERROR_WITH_SUB_CLASS"]["sub_class"]
self.assertEqual(breaking_change_info2, subclass_map["SUBCLASS"]["breaking_change_info"])

def test_java_error_classes(self):
error_reader = ErrorClassesReader()
msg = error_reader.get_error_message("AGGREGATE_OUT_OF_MEMORY", {})
self.assertEqual(msg, "No enough memory for aggregation")

def test_sqlstate(self):
error = PySparkRuntimeError(errorClass="APPLICATION_NAME_NOT_SET", messageParameters={})
self.assertIsNone(error.getSqlState())
Expand Down
12 changes: 9 additions & 3 deletions python/pyspark/errors_doc_gen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib.resources
import json
import re

from pyspark.errors.error_classes import ERROR_CLASSES_MAP


def generate_errors_doc(output_rst_file_path: str) -> None:
"""
Expand Down Expand Up @@ -47,7 +47,13 @@ def generate_errors_doc(output_rst_file_path: str) -> None:
"""
with open(output_rst_file_path, "w") as f:
f.write(header + "\n\n")
for error_key, error_details in ERROR_CLASSES_MAP.items():
python_error_classes_json = (
importlib.resources.files("pyspark.errors")
.joinpath("error-conditions.json")
.read_text()
)
python_error_classes_map = json.loads(python_error_classes_json)
for error_key, error_details in python_error_classes_map.items():
f.write(error_key + "\n")
# The length of the error class name and underline must be the same
# to satisfy the RST format.
Expand Down