diff --git a/app/Fixtures/constants.py b/app/Fixtures/constants.py index 6ff5650..6f0a61e 100644 --- a/app/Fixtures/constants.py +++ b/app/Fixtures/constants.py @@ -1,3 +1,14 @@ +import os + + +""" +App root directory. This is dependent on constants.py being one directory +down from the app directory. +""" +APP_ROOT_DIR = os.path.split( + os.path.dirname(os.path.abspath(__file__)) +)[0] + RANDOM_SEED = 1067641072 WINSOR_THRESHOLDS = { diff --git a/app/Fixtures/gams.py b/app/Fixtures/gams.py index d703404..847dff3 100644 --- a/app/Fixtures/gams.py +++ b/app/Fixtures/gams.py @@ -1,8 +1,12 @@ -import pickle +import os, pickle +from constants import APP_ROOT_DIR -study_export = pickle.load(open("app/Fixtures/production_assets.pkl", "rb")) +study_export = pickle.load(open( + os.path.join(APP_ROOT_DIR, 'Fixtures', 'production_assets.pkl'), + "rb" +)) -MORTALTIY_GAM = study_export["mortality"]["model"] +MORTALITY_GAM = study_export["mortality"]["model"] LACTATE_GAM = study_export["lactate"]["model"] ALBUMIN_GAM = study_export["albumin"]["model"] diff --git a/app/prediction/predict.py b/app/prediction/predict.py index 5e0d79b..b71f369 100644 --- a/app/prediction/predict.py +++ b/app/prediction/predict.py @@ -4,7 +4,7 @@ from sklearn.preprocessing import QuantileTransformer from pygam import GAM, LinearGAM from pygam.distributions import NormalDist -from app.Fixtures.gams import MORTALTIY_GAM +from app.Fixtures.gams import MORTALITY_GAM def quick_sample( @@ -142,7 +142,7 @@ def predict_mortality( (features.shape[0] * n_samples_per_row,) """ return quick_sample( - gam=MORTALTIY_GAM, + gam=MORTALITY_GAM, sample_at_X=features, quantity="mu", n_draws=n_samples_per_row, diff --git a/tests/test_predict.py b/tests/test_predict.py index 6a04385..1ffbcc5 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -6,7 +6,7 @@ import app.prediction.predict as predict from app.Fixtures.gams import LACTATE_GAM, LACTATE_TRANSFORMER -from app.Fixtures.gams import MORTALTIY_GAM +from app.Fixtures.gams import MORTALITY_GAM def lineargam_data(n_rows: int) -> Tuple[np.ndarray, np.ndarray]: diff --git a/tests/test_predict_api.py b/tests/test_predict_api.py index 270009d..f929915 100644 --- a/tests/test_predict_api.py +++ b/tests/test_predict_api.py @@ -1,5 +1,9 @@ +import numpy as np from fastapi.testclient import TestClient from app.main import api +from app.Fixtures.gams import study_export +from app.prediction.predict import predict_mortality +from app.Fixtures.constants import RANDOM_SEED client = TestClient(api) @@ -9,7 +13,7 @@ def test_index(): assert response.status_code == 200 -pred = { +patient1 = { "Age": 40, "ASA": 3, "HR": 87, @@ -31,9 +35,57 @@ def test_index(): } +# An example patient with observed lactate & albumin, Winsorisation not required +patient2 = { + "Age": 81, + "ASA": 2, + "HR": 82, + "SBP": 104, + "WCC": 9.1, + "Na": 135, + "K": 4.4, + "Urea": 8.7, + "Creat": 78, + "GCS": 15, + "Resp": 0, + "Cardio": 1, + "Sinus": 0, + "CT_performed": 1, + "Indication": 0, + "Malignancy": 0, + "Soiling": 1, + "Lactate": 3.2, + "Albumin": 25 +} + + +# Keys = API variable names, values = corresponding NELA variable names +api_nela_var_map = { + "Age": "S01AgeOnArrival", + "ASA": "S03ASAScore", + "HR": "S03Pulse", + "SBP": "S03SystolicBloodPressure", + "WCC": "S03WhiteCellCount", + "Na": "S03Sodium", + "K": "S03Potassium", + "Urea": "S03Urea", + "Creat": "S03SerumCreatinine", + "GCS": 'S03GlasgowComaScore', + "Resp": 'S03RespiratorySigns', + "Cardio": 'S03CardiacSigns', + "Sinus": "S03ECG", + "CT_performed": "S02PreOpCTPerformed", + "Indication": "Indication", + "Malignancy": 'S03DiagnosedMalignancy', + "Soiling": 'S03Pred_Peritsoil', + "Lactate": 'S03PreOpArterialBloodLactate', + "Albumin": 'S03PreOpLowestAlbumin' +} + + def test_predict_api_both_impute(): response = client.post( - "/predict", headers={"Content-Type": "application/json"}, json=pred + "/predict", headers={"Content-Type": "application/json"}, json=patient1 ) assert response.status_code == 200 @@ -42,10 +94,10 @@ def test_predict_api_both_impute(): def test_predict_api_alb_impute(): - pred["Albumin"] = 40 + patient1["Albumin"] = 40 response = client.post( - "/predict", headers={"Content-Type": "application/json"}, json=pred + "/predict", headers={"Content-Type": "application/json"}, json=patient1 ) assert response.status_code == 200 @@ -54,10 +106,10 @@ def test_predict_api_alb_impute(): def test_predict_api_basic(): - pred["Lactate"] = 1 + patient1["Lactate"] = 1 response = client.post( - "/predict", headers={"Content-Type": "application/json"}, json=pred + "/predict", headers={"Content-Type": "application/json"}, json=patient1 ) assert response.status_code == 200 @@ -70,19 +122,60 @@ def test_predict_api_basic(): def test_predict_api_invalid_cat(): - pred["Soiling"] = 7 + patient1["Soiling"] = 7 response = client.post( - "/predict", headers={"Content-Type": "application/json"}, json=pred + "/predict", headers={"Content-Type": "application/json"}, json=patient1 ) assert response.status_code == 422 def test_predict_api_invalid_type(): - pred["soiling"] = 1 - pred["SBP"] = 103.4 + patient1["soiling"] = 1 + patient1["SBP"] = 103.4 response = client.post( - "/predict", headers={"Content-Type": "application/json"}, json=pred + "/predict", headers={"Content-Type": "application/json"}, json=patient1 ) assert response.status_code == 422 + + +def test_predict_api_vs_direct_prediction(): + """ + Compares mortality risk predictions from the predict API, to those + generated by direct use of predict_mortality() in a patient that doesn't + require Winsorisation or lactate / albumin imputation. These should be the + same. + """ + # Get API mortality risk prediction + response = client.post( + "/predict", headers={"Content-Type": "application/json"}, json=patient2 + ) + assert response.status_code == 200 + api_pred = np.array(response.json()["Result"]) + + # Get a 1-row DataFrame with same columns as input to predict_mortality() + features = study_export['mortality']['input_data']['describe'].iloc[ + 5:6 + ].reset_index(drop=True) + + # Replace values with those from example patient 2, and add missingness vars + direct_patient = {} + for api_name, value in patient2.items(): + direct_patient[api_nela_var_map[api_name]] = value + direct_patient["S03PreOpLowestAlbumin_missing"] = 0. + direct_patient['S03PreOpArterialBloodLactate_missing'] = 0. + + # Get direct mortality risk prediction + direct_pred = predict_mortality( + features=features, + n_samples_per_row=api_pred.size, + random_seed=RANDOM_SEED + ) + + # Compare predictions + # TODO: We need to round for the test to pass - why so much numerical error? + decimal_places = 4 + assert ( + direct_pred.round(decimal_places) == api_pred.round(decimal_places) + ).all()