Skip to content

Commit 3190400

Browse files
committed
update docs for new release
1 parent 1562b14 commit 3190400

16 files changed

+118
-114
lines changed

autokeras/tasks/structured_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Union
2121

2222
import pandas as pd
23+
import tensorflow as tf
2324
from tensorflow.python.util import nest
2425

2526
from autokeras import auto_model
@@ -122,7 +123,7 @@ def fit(
122123
self._target_col_name = y
123124
x, y = self._read_from_csv(x, y)
124125

125-
if validation_data:
126+
if validation_data and not isinstance(validation_data, tf.data.Dataset):
126127
x_val, y_val = validation_data
127128
if isinstance(x_val, str):
128129
validation_data = self._read_from_csv(x_val, y_val)

autokeras/utils/io_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,8 @@ def image_dataset_from_directory(
229229
)
230230

231231
images = tf.data.Dataset.from_tensor_slices(image_paths)
232-
images = images.map(tf.io.read_file)
233232
images = images.map(
234-
lambda img: tf.io.decode_image(
235-
img, channels=num_channels, expand_animations=False
236-
)
237-
)
238-
images = images.map(
239-
lambda img: tf.image.resize(img, image_size, method=interpolation)
233+
lambda img: path_to_image(img, num_channels, image_size, interpolation)
240234
)
241235

242236
labels = np.array(class_names)[np.array(labels)]
@@ -245,3 +239,11 @@ def image_dataset_from_directory(
245239
dataset = tf.data.Dataset.zip((images, labels))
246240
dataset = dataset.batch(batch_size)
247241
return dataset
242+
243+
244+
def path_to_image(image, num_channels, image_size, interpolation):
245+
image = tf.io.read_file(image)
246+
image = tf.io.decode_image(image, channels=num_channels, expand_animations=False)
247+
image = tf.image.resize(image, image_size, method=interpolation)
248+
image.set_shape((image_size[0], image_size[1], num_channels))
249+
return image

docs/autogen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@
8888
'autokeras.ClassificationHead',
8989
'autokeras.RegressionHead',
9090
],
91+
'utils.md': [
92+
'autokeras.image_dataset_from_directory',
93+
'autokeras.text_dataset_from_directory',
94+
]
9195
}
9296

9397

docs/mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ nav:
4242
- Multi-Modal and Multi-Task: tutorial/multi.md
4343
- Customized Model: tutorial/customized.md
4444
- Export Model: tutorial/export.md
45+
- Load Data from Disk: tutorial/load.md
4546
- FAQ: tutorial/faq.md
4647
- Extensions:
4748
- TensorFlow Cloud: extensions/tf_cloud.md
@@ -59,4 +60,5 @@ nav:
5960
- Base Class: base.md
6061
- Node: node.md
6162
- Block: block.md
63+
- Utils: utils.md
6264
- About: about.md

docs/py/customized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""shell
22
pip install autokeras
3-
pip install git+https://github.com/keras-team/[email protected].2rc2
3+
pip install git+https://github.com/keras-team/[email protected].2rc3
44
"""
55

66
"""

docs/py/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
"""shell
1111
pip install autokeras
12-
pip install git+https://github.com/keras-team/[email protected].2rc2
12+
pip install git+https://github.com/keras-team/[email protected].2rc3
1313
"""
1414

1515
import tensorflow as tf

docs/py/image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""shell
22
pip install autokeras
3-
pip install git+https://github.com/keras-team/[email protected].2rc2
3+
pip install git+https://github.com/keras-team/[email protected].2rc3
44
"""
55

66
"""

docs/py/image_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""shell
22
pip install autokeras
3-
pip install git+https://github.com/keras-team/[email protected].2rc2
3+
pip install git+https://github.com/keras-team/[email protected].2rc3
44
"""
55

66
"""

docs/py/load.py

Lines changed: 78 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""shell
22
pip install autokeras
3-
pip install git+https://github.com/keras-team/[email protected].2rc2
3+
pip install git+https://github.com/keras-team/[email protected].2rc3
44
"""
55

66
"""
@@ -10,18 +10,19 @@
1010
First, we download the data and extract the files.
1111
"""
1212

13+
import autokeras as ak
1314
import tensorflow as tf
1415
import os
1516

16-
# dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
17-
# local_file_path = tf.keras.utils.get_file(origin=dataset_url,
18-
# fname='image_data',
19-
# extract=True)
20-
# # The file is extracted in the same directory as the downloaded file.
21-
# local_dir_path = os.path.dirname(local_file_path)
22-
# # After check mannually, we know the extracted data is in 'flower_photos'.
23-
# data_dir = os.path.join(local_dir_path, 'flower_photos')
24-
# print(data_dir)
17+
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
18+
local_file_path = tf.keras.utils.get_file(origin=dataset_url,
19+
fname='image_data',
20+
extract=True)
21+
# The file is extracted in the same directory as the downloaded file.
22+
local_dir_path = os.path.dirname(local_file_path)
23+
# After check mannually, we know the extracted data is in 'flower_photos'.
24+
data_dir = os.path.join(local_dir_path, 'flower_photos')
25+
print(data_dir)
2526

2627
"""
2728
The directory should look like this. Each folder contains the images in the same class.
@@ -42,33 +43,31 @@
4243
img_height = 180
4344
img_width = 180
4445

45-
# train_data = tf.keras.preprocessing.image_dataset_from_directory(
46-
# data_dir,
47-
# # Use 20% data as testing data.
48-
# validation_split=0.2,
49-
# subset="training",
50-
# # Set seed to ensure the same split when loading testing data.
51-
# seed=123,
52-
# image_size=(img_height, img_width),
53-
# batch_size=batch_size)
54-
55-
# test_data = tf.keras.preprocessing.image_dataset_from_directory(
56-
# data_dir,
57-
# validation_split=0.2,
58-
# subset="validation",
59-
# seed=123,
60-
# image_size=(img_height, img_width),
61-
# batch_size=batch_size)
46+
train_data = ak.image_dataset_from_directory(
47+
data_dir,
48+
# Use 20% data as testing data.
49+
validation_split=0.2,
50+
subset="training",
51+
# Set seed to ensure the same split when loading testing data.
52+
seed=123,
53+
image_size=(img_height, img_width),
54+
batch_size=batch_size)
55+
56+
test_data = ak.image_dataset_from_directory(
57+
data_dir,
58+
validation_split=0.2,
59+
subset="validation",
60+
seed=123,
61+
image_size=(img_height, img_width),
62+
batch_size=batch_size)
6263

6364
"""
6465
Then we just do one quick demo of AutoKeras to make sure the dataset works.
6566
"""
6667

67-
import autokeras as ak
68-
69-
# clf = ak.ImageClassifier(overwrite=True, max_trials=1)
70-
# clf.fit(train_data, epochs=1)
71-
# print(clf.evaluate(test_data))
68+
clf = ak.ImageClassifier(overwrite=True, max_trials=1)
69+
clf.fit(train_data, epochs=1)
70+
print(clf.evaluate(test_data))
7271

7372
"""
7473
You can also load text datasets in the same way.
@@ -94,76 +93,59 @@
9493
For this dataset, the data is already split into train and test.
9594
We just load them separately.
9695
"""
97-
print(data_dir)
98-
train_data = tf.keras.preprocessing.text_dataset_from_directory(
99-
os.path.join(data_dir, 'train'),
100-
class_names=['pos', 'neg'],
101-
validation_split=0.2,
102-
subset="training",
103-
# shuffle=False,
104-
seed=123,
105-
batch_size=batch_size)
10696

107-
val_data = tf.keras.preprocessing.text_dataset_from_directory(
97+
print(data_dir)
98+
train_data = ak.text_dataset_from_directory(
10899
os.path.join(data_dir, 'train'),
109-
class_names=['pos', 'neg'],
110-
validation_split=0.2,
111-
subset="validation",
112-
# shuffle=False,
113-
seed=123,
114100
batch_size=batch_size)
115101

116-
test_data = tf.keras.preprocessing.text_dataset_from_directory(
102+
test_data = ak.text_dataset_from_directory(
117103
os.path.join(data_dir, 'test'),
118-
class_names=['pos', 'neg'],
119104
shuffle=False,
120105
batch_size=batch_size)
121106

122-
for x, y in train_data:
123-
print(x.numpy()[0])
124-
print(y.numpy()[0])
125-
# record_x = x.numpy()
126-
# record_y = y.numpy()
127-
break
128-
129-
for x, y in train_data:
130-
print(x.numpy()[0])
131-
print(y.numpy()[0])
132-
break
133-
134-
# train_data = tf.keras.preprocessing.text_dataset_from_directory(
135-
# os.path.join(data_dir, 'train'),
136-
# class_names=['pos', 'neg'],
137-
# shuffle=True,
138-
# seed=123,
139-
# batch_size=batch_size)
140-
141-
# for x, y in train_data:
142-
# for i, a in enumerate(x.numpy()):
143-
# for j, b in enumerate(record_x):
144-
# if a == b:
145-
# print('*')
146-
# assert record_y[j] == y.numpy()[i]
147-
148-
# import numpy as np
149-
# x_train = []
150-
# y_train = []
151-
# for x, y in train_data:
152-
# for a in x.numpy():
153-
# x_train.append(a)
154-
# for a in y.numpy():
155-
# y_train.append(a)
156-
157-
# x_train = np.array(x_train)
158-
# y_train = np.array(y_train)
159-
160-
# train_data = train_data.shuffle(1000, seed=123, reshuffle_each_iteration=False)
161-
162-
163-
clf = ak.TextClassifier(overwrite=True, max_trials=2)
164-
# clf.fit(train_data, validation_data=test_data)
165-
# clf.fit(train_data, validation_data=train_data)
166-
clf.fit(train_data, validation_data=val_data)
167-
# clf.fit(x_train, y_train)
168-
# clf.fit(train_data)
107+
clf = ak.TextClassifier(overwrite=True, max_trials=1)
108+
clf.fit(train_data, epochs=2)
169109
print(clf.evaluate(test_data))
110+
111+
112+
"""
113+
If you want to use generators, you can refer to the following code.
114+
"""
115+
116+
import math
117+
118+
import numpy as np
119+
120+
N_BATCHES = 30
121+
BATCH_SIZE = 100
122+
N_FEATURES = 10
123+
124+
125+
def get_data_generator(n_batches, batch_size, n_features):
126+
"""Get a generator returning n_batches random data of batch_size with n_features."""
127+
128+
def data_generator():
129+
for _ in range(n_batches * batch_size):
130+
x = np.random.randn(n_features)
131+
y = x.sum(axis=0) / n_features > 0.5
132+
yield x, y
133+
134+
return data_generator
135+
136+
137+
dataset = tf.data.Dataset.from_generator(
138+
get_data_generator(N_BATCHES, BATCH_SIZE, N_FEATURES),
139+
output_types=(tf.float32, tf.float32),
140+
output_shapes=((N_FEATURES,), tuple()),
141+
).batch(BATCH_SIZE)
142+
143+
clf = ak.StructuredDataClassifier(overwrite=True, max_trials=1, seed=5)
144+
clf.fit(x=dataset, validation_data=dataset, batch_size=BATCH_SIZE)
145+
print(clf.evaluate(dataset))
146+
147+
"""
148+
## Reference
149+
[image_dataset_from_directory](utils/#image_dataset_from_directory-function)
150+
[text_dataset_from_directory](utils/#text_dataset_from_directory-function)
151+
"""

docs/py/multi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""shell
22
pip install autokeras
3-
pip install git+https://github.com/keras-team/[email protected].2rc2
3+
pip install git+https://github.com/keras-team/[email protected].2rc3
44
"""
55

66
"""

0 commit comments

Comments
 (0)