Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions autokeras/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

import inspect
import os
import ssl
import urllib.request

import keras
import numpy as np
import pandas as pd

import autokeras as ak

SEED = 5

# Train/test split ratio
TRAIN_SPLIT_RATIO = 0.8

COLUMN_NAMES = [
"sex",
Expand All @@ -45,15 +50,74 @@
"embark_town": "categorical",
"alone": "categorical",
}
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
TEST_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"

TRAIN_CSV_PATH = keras.utils.get_file(
fname=os.path.basename(TRAIN_DATA_URL), origin=TRAIN_DATA_URL
)
TEST_CSV_PATH = keras.utils.get_file(
fname=os.path.basename(TEST_DATA_URL), origin=TEST_DATA_URL
)

# Download Titanic dataset from OpenML and split into train/test
TITANIC_DATA_URL = "https://www.openml.org/data/get_csv/16826755/phpMYEkMl"

_cache_dir = os.path.expanduser(os.path.join("~", ".keras", "datasets"))
os.makedirs(_cache_dir, exist_ok=True)
_titanic_data_path = os.path.join(_cache_dir, "titanic.csv")

# Define paths for train/test splits
TRAIN_CSV_PATH = os.path.join(_cache_dir, "titanic_train.csv")
TEST_CSV_PATH = os.path.join(_cache_dir, "titanic_test.csv")

# Only process dataset if train/test files don't exist
if not (os.path.exists(TRAIN_CSV_PATH) and os.path.exists(TEST_CSV_PATH)):
# Download raw dataset if it doesn't exist
if not os.path.exists(_titanic_data_path):
# WARNING: Using unverified SSL context is a security risk.
# This is necessary only because some test environments have outdated
# or missing SSL certificates. In production code, SSL verification
# should always be enabled to prevent man-in-the-middle attacks.
# TODO: Remove this workaround once test environments have proper SSL
# certificate configuration.
ssl_context = ssl._create_unverified_context()
with urllib.request.urlopen(
TITANIC_DATA_URL, context=ssl_context
) as response:
with open(_titanic_data_path, "wb") as out_file:
out_file.write(response.read())

# Load and preprocess the dataset to match expected format
_df = pd.read_csv(_titanic_data_path)

# Rename columns to match expected format
_df = _df.rename(
columns={
"pclass": "class",
"sibsp": "n_siblings_spouses",
"cabin": "deck",
"embarked": "embark_town",
}
)

# Create 'alone' column
_df["alone"] = (_df["n_siblings_spouses"] + _df["parch"] == 0).astype(str)

# Select only the columns we need in the expected order
_columns_to_keep = [
"sex",
"age",
"n_siblings_spouses",
"parch",
"fare",
"class",
"deck",
"embark_town",
"alone",
"survived",
]
_df = _df[_columns_to_keep]

# Split into train and test
_train_size = int(len(_df) * TRAIN_SPLIT_RATIO)
_train_df = _df.iloc[:_train_size]
_test_df = _df.iloc[_train_size:]

# Save train and test splits
_train_df.to_csv(TRAIN_CSV_PATH, index=False)
_test_df.to_csv(TEST_CSV_PATH, index=False)


def generate_data(num_instances=100, shape=(32, 32, 3)):
Expand Down
Loading