|
12 | 12 | select_least_cloudy_images_pipeline,
|
13 | 13 | )
|
14 | 14 | from rslp.forest_loss_driver.predict_pipeline import MODEL_CFG_FNAME
|
15 |
| -from rslp.forest_loss_driver.train import CATEGORIES |
16 | 15 | from rslp.log_utils import get_logger
|
17 | 16 | from rslp.utils.rslearn import run_model_predict
|
18 | 17 |
|
@@ -58,30 +57,9 @@ def test_forest_loss_driver_model_predict(
|
58 | 57 |
|
59 | 58 | with output_path.open("r") as f:
|
60 | 59 | output_json = json.load(f)
|
61 |
| - # TODO: Ideally we would have a pydantic model for this output perhaps that we could subclass from rslearn? |
62 |
| - # Check properties except probs |
63 |
| - assert output_json["type"] == expected_output_json["type"] |
64 |
| - assert output_json["properties"] == expected_output_json["properties"] |
65 |
| - assert len(output_json["features"]) == len(expected_output_json["features"]) |
66 |
| - assert ( |
67 |
| - output_json["features"][0]["type"] |
68 |
| - == expected_output_json["features"][0]["type"] |
69 |
| - ) |
70 |
| - assert ( |
71 |
| - output_json["features"][0]["geometry"] |
72 |
| - == expected_output_json["features"][0]["geometry"] |
73 |
| - ) |
| 60 | + |
| 61 | + # Check that the predicted label is correct (this example should be river). |
74 | 62 | assert (
|
75 | 63 | output_json["features"][0]["properties"]["new_label"]
|
76 | 64 | == expected_output_json["features"][0]["properties"]["new_label"]
|
77 | 65 | )
|
78 |
| - |
79 |
| - # Ensure river class is at least 0.9 probability and others at most 0.05. |
80 |
| - actual_probs = output_json["features"][0]["properties"]["probs"] |
81 |
| - expected_category = expected_output_json["features"][0]["properties"]["new_label"] |
82 |
| - assert len(actual_probs) == len(CATEGORIES) |
83 |
| - for prob, category_name in zip(actual_probs, CATEGORIES): |
84 |
| - if category_name == expected_category: |
85 |
| - assert prob >= 0.9, f"Probability for category {category_name} < 0.9" |
86 |
| - else: |
87 |
| - assert prob <= 0.05, f"Probability for category {category_name} > 0.05" |
0 commit comments