Skip to content
This repository was archived by the owner on Nov 14, 2024. It is now read-only.

Add Excel (XLSX) support #33

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,10 @@ dwh_media/
*.lock
*.env
dex/dex-data/dex.db

# training output files
*.pkl
*.xlsx
*.csv
*.pdf
*.json
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pip install -r requirements.txt

# input data

csv input file with at least the following columns:
input file (CSV or XLSX) with at least the following columns:
| column | description |
| ------------- | ------------- |
| Main | Main category |
Expand All @@ -30,7 +30,7 @@ See python train.py for all options.

To train Middle and Sub categoeries use:
```
python train.py --csv file.csv --columns Middle,Sub
python train.py --input-file file.csv --columns Middle,Sub
```
This step will generate a categories `json` file. Use this file to load the categories in the backend.
```
Expand All @@ -39,7 +39,7 @@ python manage.py load_categories <file.json>

To train Middle category use:
```
python train.py --csv file.csv --columns Middle
python train.py --input-file file.csv --columns Middle
```

Rename resulting files to "main_model.pkl, sub_model.pkl, main_slugs.pkl, sub_slugs.pkl" and copy the pkl files into the classification endpoint.
Expand Down
37 changes: 22 additions & 15 deletions app/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from sklearn.linear_model import LogisticRegression
from nltk.stem.snowball import DutchStemmer
import joblib
import warnings
import os
import nltk
import re
import csv
import psutil

class TextClassifier:
_text = 'Text'
Expand Down Expand Up @@ -41,51 +40,59 @@ def export_model(self, file):
joblib.dump(self.model, file)

def preprocessor(self, text):
text = str(text)
text=text.lower()
text=re.sub("\\W"," ",text) # remove special chars


# stem words
words=re.split("\\s+",text)
stemmed_words=[self.stemmer.stem(word=word) for word in words]
return ' '.join(stemmed_words)


def load_data(self, csv_file, frac=1):
df = pd.read_csv(csv_file, sep=None, engine='python')
def load_data(self, input_file, frac=1):
_, extension = os.path.splitext(input_file)

if extension == '.csv':
df = pd.read_csv(input_file, sep=None, engine='python')
elif extension == '.xlsx':
df = pd.read_excel(input_file)
else:
raise Exception('Could not read input file. Extension should be .csv or .xlsx')

print(df)

df = df.dropna(
axis=0, how='any',
thresh=None,
subset=[self._text, self._main, self._middle, self._sub],
inplace=False
)

# cleanup dataset
df = df.drop_duplicates(subset=[self._text], keep='first')
#df = df.drop_duplicates(subset=[self._text], keep='first')
# for dev use only a subset (for speed purpose)
df = df.sample(frac=frac).reset_index(drop=True)
#df = df.sample(frac=frac).reset_index(drop=True)
# construct unique label
df[self._lbl] = df[self._main] + "|" + df[self._middle] + "|" + df[self._sub]

number_of_examples = df[self._lbl].value_counts().to_frame()
df['is_bigger_than_50'] = df[self._lbl].isin(number_of_examples[number_of_examples[self._lbl]>50].index)
df['is_bigger_than_50'].value_counts()
df = df[df['is_bigger_than_50'] == True]

# The example dataset is not large enough to train a good classification model
# print(len(self.df),'rows valid')
return df

def make_data_sets(self, df, split=0.9, columns=['Middle', 'Sub']):

texts = df[self._text]
labels = df[columns].apply('|'.join, axis=1)
labels = df[columns].applymap(lambda x: x.lower().capitalize()).apply('|'.join, axis=1)

print(labels.value_counts())

train_texts, test_texts, train_labels, test_labels = train_test_split(
texts, labels, test_size=1-split, stratify=labels)

return texts, labels, train_texts, train_labels, test_texts, test_labels

def fit(self, train_texts, train_labels):

pipeline = Pipeline([
('vect', CountVectorizer(preprocessor=self.preprocessor, stop_words=self.stop_words)),
('tfidf', TfidfTransformer()),
Expand Down Expand Up @@ -119,7 +126,7 @@ def fit(self, train_texts, train_labels):
'vect__ngram_range': ((1, 1),) # (1,2)
}

grid_search = GridSearchCV(pipeline, parameters_slow,verbose=True,n_jobs=psutil.cpu_count(logical=False),cv=5)
grid_search = GridSearchCV(pipeline, parameters_slow,verbose=True,n_jobs=1,cv=5)
grid_search.fit(train_texts, train_labels)
#print('Best parameters: ')
#print(grid_search.best_params_)
Expand Down
8 changes: 4 additions & 4 deletions app/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def parse_args():
parser = argparse.ArgumentParser()
optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')
required.add_argument('--csv', required=True)
required.add_argument('--input-file', required=True)
optional.add_argument('--columns', default='')
optional.add_argument('--fract', default=1.0, type=float)
optional.add_argument('--output-fixtures', const=True, nargs="?", default=True, type=bool)
Expand Down Expand Up @@ -70,7 +70,7 @@ def generate_fixtures(categories):
print("Warning invalid slug: {slug}, length: {length}".format(slug=slug, length=len(slug)))

return cats.values()

def train(df, columns, output_validation=False, output_fixtures=True):
texts, labels, train_texts, train_labels, test_texts, test_labels = classifier.make_data_sets(df, columns=columns)
colnames = "_".join(columns)
Expand Down Expand Up @@ -108,9 +108,9 @@ def train(df, columns, output_validation=False, output_fixtures=True):
print("Using args: {}".format(args))

classifier = TextClassifier()
df = classifier.load_data(csv_file=args.csv, frac=args.fract)
df = classifier.load_data(input_file=args.input_file, frac=args.fract)
if len(df) == 0:
print("Failed to load {}".format(args.csv))
print("Failed to load {}".format(args.input_file))
exit(-1)
else:
print("{} rows loaded".format(len(df)))
Expand Down
1 change: 0 additions & 1 deletion notebook/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,4 @@ wcwidth==0.1.9
webencodings==0.5.1
wrapt==1.12.1
xlrd==1.2.0
xlsx2csv==0.7.6
zipp==3.1.0
44 changes: 0 additions & 44 deletions requirements-train.txt

This file was deleted.

28 changes: 21 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,39 +1,53 @@
asgiref==3.2.10
astroid==2.4.2
attrs==19.3.0
click==7.1.2
contourpy==1.0.7
cycler==0.10.0
dill==0.3.2
Django==3.0.7
et-xmlfile==1.1.0
Flask==1.1.2
Flask-Cors==3.0.8
fonttools==4.39.4
gunicorn==20.0.4
importlib-resources==5.12.0
isort==4.3.21
itsdangerous==1.1.0
Jinja2==2.11.2
joblib==0.15.1
kiwisolver==1.4.4
lazy-object-proxy==1.4.3
MarkupSafe==1.1.1
matplotlib==3.7.1
mccabe==0.6.1
more-itertools==8.4.0
nltk==3.5
numpy==1.18.5
numpy==1.24.3
openpyxl==3.0.10
packaging==20.4
pandas==1.0.4
pandas==1.5.3
Pillow==9.5.0
pluggy==0.13.1
psutil==5.9.5
py==1.8.2
pylint==2.5.3
pyparsing==2.4.7
pytest==5.4.3
python-dateutil==2.8.1
pytz==2020.1
regex==2020.6.8
scikit-learn==0.23.1
scipy==1.4.1
regex==2023.6.3
scikit-learn==1.0.2
scipy==1.10.1
six==1.15.0
sklearn==0.0
sqlparse==0.3.1
threadpoolctl==2.1.0
toml==0.10.1
tqdm==4.46.1
uWSGI==2.0.19
uWSGI==2.0.21
wcwidth==0.2.4
Werkzeug==1.0.1
wrapt==1.12.1
psutil==5.7.0
xlrd==2.0.1
zipp==3.15.0