From c6000b93e7b6a72090bbf3166967024325aa8747 Mon Sep 17 00:00:00 2001 From: Bart Jeukendrup Date: Tue, 19 Jul 2022 15:56:44 +0200 Subject: [PATCH 1/8] Add XLSX support --- README.md | 6 +++--- app/engine.py | 27 ++++++++++++++++++--------- app/train.py | 8 ++++---- notebook/requirements.txt | 1 - requirements-train.txt | 1 + 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 572128e..95a645b 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ pip install -r requirements.txt # input data -csv input file with at least the following columns: +input file (CSV or XLSX) with at least the following columns: | column | description | | ------------- | ------------- | | Main | Main category | @@ -30,7 +30,7 @@ See python train.py for all options. To train Middle and Sub categoeries use: ``` -python train.py --csv file.csv --columns Middle,Sub +python train.py --input-file file.csv --columns Middle,Sub ``` This step will generate a categories `json` file. Use this file to load the categories in the backend. ``` @@ -39,7 +39,7 @@ python manage.py load_categories To train Middle category use: ``` -python train.py --csv file.csv --columns Middle +python train.py --input-file file.csv --columns Middle ``` Rename resulting files to "main_model.pkl, sub_model.pkl, main_slugs.pkl, sub_slugs.pkl" and copy the pkl files into the classification endpoint. diff --git a/app/engine.py b/app/engine.py index 56dc2ba..1727f7b 100644 --- a/app/engine.py +++ b/app/engine.py @@ -5,6 +5,7 @@ from sklearn.linear_model import LogisticRegression from nltk.stem.snowball import DutchStemmer import joblib +import os import warnings import nltk import re @@ -41,17 +42,27 @@ def export_model(self, file): joblib.dump(self.model, file) def preprocessor(self, text): + text = str(text) text=text.lower() - text=re.sub("\\W"," ",text) # remove special chars - + # stem words words=re.split("\\s+",text) stemmed_words=[self.stemmer.stem(word=word) for word in words] return ' '.join(stemmed_words) - def load_data(self, csv_file, frac=1): - df = pd.read_csv(csv_file, sep=None, engine='python') + def load_data(self, input_file, frac=1): + _, extension = os.path.splitext(input_file) + + if extension == '.csv': + df = pd.read_csv(input_file, sep=None, engine='python') + elif extension == '.xlsx': + df = pd.read_excel(input_file) + else: + raise Exception('Could not read input file. Extension should be .csv or .xlsx') + + print(df) + df = df.dropna( axis=0, how='any', thresh=None, @@ -60,16 +71,14 @@ def load_data(self, csv_file, frac=1): ) # cleanup dataset - df = df.drop_duplicates(subset=[self._text], keep='first') + #df = df.drop_duplicates(subset=[self._text], keep='first') # for dev use only a subset (for speed purpose) - df = df.sample(frac=frac).reset_index(drop=True) + #df = df.sample(frac=frac).reset_index(drop=True) # construct unique label df[self._lbl] = df[self._main] + "|" + df[self._middle] + "|" + df[self._sub] number_of_examples = df[self._lbl].value_counts().to_frame() - df['is_bigger_than_50'] = df[self._lbl].isin(number_of_examples[number_of_examples[self._lbl]>50].index) - df['is_bigger_than_50'].value_counts() - df = df[df['is_bigger_than_50'] == True] + # The example dataset is not large enough to train a good classification model # print(len(self.df),'rows valid') return df diff --git a/app/train.py b/app/train.py index 580b901..72046fb 100644 --- a/app/train.py +++ b/app/train.py @@ -7,7 +7,7 @@ def parse_args(): parser = argparse.ArgumentParser() optional = parser._action_groups.pop() required = parser.add_argument_group('required arguments') - required.add_argument('--csv', required=True) + required.add_argument('--input-file', required=True) optional.add_argument('--columns', default='') optional.add_argument('--fract', default=1.0, type=float) optional.add_argument('--output-fixtures', const=True, nargs="?", default=True, type=bool) @@ -70,7 +70,7 @@ def generate_fixtures(categories): print("Warning invalid slug: {slug}, length: {length}".format(slug=slug, length=len(slug))) return cats.values() - + def train(df, columns, output_validation=False, output_fixtures=True): texts, labels, train_texts, train_labels, test_texts, test_labels = classifier.make_data_sets(df, columns=columns) colnames = "_".join(columns) @@ -108,9 +108,9 @@ def train(df, columns, output_validation=False, output_fixtures=True): print("Using args: {}".format(args)) classifier = TextClassifier() - df = classifier.load_data(csv_file=args.csv, frac=args.fract) + df = classifier.load_data(input_file=args.input_file, frac=args.fract) if len(df) == 0: - print("Failed to load {}".format(args.csv)) + print("Failed to load {}".format(args.input_file)) exit(-1) else: print("{} rows loaded".format(len(df))) diff --git a/notebook/requirements.txt b/notebook/requirements.txt index 3375307..1f00bc0 100644 --- a/notebook/requirements.txt +++ b/notebook/requirements.txt @@ -71,5 +71,4 @@ wcwidth==0.1.9 webencodings==0.5.1 wrapt==1.12.1 xlrd==1.2.0 -xlsx2csv==0.7.6 zipp==3.1.0 diff --git a/requirements-train.txt b/requirements-train.txt index b3e1048..0371243 100644 --- a/requirements-train.txt +++ b/requirements-train.txt @@ -20,6 +20,7 @@ mccabe==0.6.1 more-itertools==8.4.0 nltk==3.5 numpy +openpyxl==3.0.10 packaging==20.4 pandas pluggy==0.13.1 From cd251adeb621543719ff66e6b576cc71ef3cdd7e Mon Sep 17 00:00:00 2001 From: Bart Date: Fri, 29 Jul 2022 10:27:30 +0200 Subject: [PATCH 2/8] Update requirements-train.txt --- requirements-train.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-train.txt b/requirements-train.txt index 0371243..215bc81 100644 --- a/requirements-train.txt +++ b/requirements-train.txt @@ -30,7 +30,7 @@ pyparsing==2.4.7 pytest==5.4.3 python-dateutil==2.8.1 pytz==2020.1 -regex==2020.6.8 +regex scikit-learn scipy six==1.15.0 From 83f8afc2504510488eefe8f74e8ce2a67f182ac3 Mon Sep 17 00:00:00 2001 From: Bart Date: Fri, 29 Jul 2022 10:30:46 +0200 Subject: [PATCH 3/8] Update requirements-train.txt --- requirements-train.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/requirements-train.txt b/requirements-train.txt index 215bc81..0089dec 100644 --- a/requirements-train.txt +++ b/requirements-train.txt @@ -12,7 +12,6 @@ isort==4.3.21 itsdangerous==1.1.0 Jinja2==2.11.2 joblib==0.15.1 -kiwisolver==1.2.0 lazy-object-proxy==1.4.3 MarkupSafe==1.1.1 matplotlib @@ -42,4 +41,3 @@ tqdm==4.46.1 wcwidth==0.2.4 Werkzeug==1.0.1 wrapt==1.12.1 -psutil==5.7.0 From e0111e5916ef4f7eb8738b637aafa6de162a809c Mon Sep 17 00:00:00 2001 From: Bart Date: Fri, 29 Jul 2022 10:32:01 +0200 Subject: [PATCH 4/8] Update requirements-train.txt --- requirements-train.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-train.txt b/requirements-train.txt index 0089dec..d2b607b 100644 --- a/requirements-train.txt +++ b/requirements-train.txt @@ -41,3 +41,4 @@ tqdm==4.46.1 wcwidth==0.2.4 Werkzeug==1.0.1 wrapt==1.12.1 +psutil From f06b53dd484cb85e99472347f44d993995edca02 Mon Sep 17 00:00:00 2001 From: Bart Jeukendrup Date: Fri, 5 Aug 2022 11:19:52 +0200 Subject: [PATCH 5/8] Set n_jobs to 1 --- app/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/engine.py b/app/engine.py index 1727f7b..2cbd34e 100644 --- a/app/engine.py +++ b/app/engine.py @@ -128,7 +128,7 @@ def fit(self, train_texts, train_labels): 'vect__ngram_range': ((1, 1),) # (1,2) } - grid_search = GridSearchCV(pipeline, parameters_slow,verbose=True,n_jobs=psutil.cpu_count(logical=False),cv=5) + grid_search = GridSearchCV(pipeline, parameters_slow,verbose=True,n_jobs=1,cv=5) grid_search.fit(train_texts, train_labels) #print('Best parameters: ') #print(grid_search.best_params_) From 4d98de85c0010f982908c904be67c859388afdb1 Mon Sep 17 00:00:00 2001 From: Bart Jeukendrup Date: Wed, 8 Mar 2023 11:11:41 +0100 Subject: [PATCH 6/8] Make pandas version explicit, remove thres parameter --- .gitignore | 7 +++++++ app/engine.py | 6 ++---- requirements-train.txt | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 3df34cb..7e283bd 100755 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,10 @@ dwh_media/ *.lock *.env dex/dex-data/dex.db + +# training output files +*.pkl +*.xlsx +*.csv +*.pdf +*.json \ No newline at end of file diff --git a/app/engine.py b/app/engine.py index 2cbd34e..5ac086f 100644 --- a/app/engine.py +++ b/app/engine.py @@ -6,11 +6,9 @@ from nltk.stem.snowball import DutchStemmer import joblib import os -import warnings import nltk import re import csv -import psutil class TextClassifier: _text = 'Text' @@ -65,7 +63,6 @@ def load_data(self, input_file, frac=1): df = df.dropna( axis=0, how='any', - thresh=None, subset=[self._text, self._main, self._middle, self._sub], inplace=False ) @@ -88,13 +85,14 @@ def make_data_sets(self, df, split=0.9, columns=['Middle', 'Sub']): texts = df[self._text] labels = df[columns].apply('|'.join, axis=1) + print(labels.value_counts()) + train_texts, test_texts, train_labels, test_labels = train_test_split( texts, labels, test_size=1-split, stratify=labels) return texts, labels, train_texts, train_labels, test_texts, test_labels def fit(self, train_texts, train_labels): - pipeline = Pipeline([ ('vect', CountVectorizer(preprocessor=self.preprocessor, stop_words=self.stop_words)), ('tfidf', TfidfTransformer()), diff --git a/requirements-train.txt b/requirements-train.txt index d2b607b..d3f1991 100644 --- a/requirements-train.txt +++ b/requirements-train.txt @@ -21,7 +21,7 @@ nltk==3.5 numpy openpyxl==3.0.10 packaging==20.4 -pandas +pandas==1.5.3 pluggy==0.13.1 py==1.8.2 pylint==2.5.3 From fe2c34deaba7f4829bf167e1141399df8a655554 Mon Sep 17 00:00:00 2001 From: Bart Jeukendrup Date: Mon, 2 Sep 2024 21:24:40 +0200 Subject: [PATCH 7/8] slugify labels --- app/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/engine.py b/app/engine.py index 5ac086f..d6b46ca 100644 --- a/app/engine.py +++ b/app/engine.py @@ -83,7 +83,7 @@ def load_data(self, input_file, frac=1): def make_data_sets(self, df, split=0.9, columns=['Middle', 'Sub']): texts = df[self._text] - labels = df[columns].apply('|'.join, axis=1) + labels = df[columns].applymap(lambda x: x.lower().capitalize()).apply('|'.join, axis=1) print(labels.value_counts()) From 12287bfccfacaa1d86e84e479a23f647ed1fdabb Mon Sep 17 00:00:00 2001 From: Bart Jeukendrup Date: Mon, 2 Sep 2024 21:24:58 +0200 Subject: [PATCH 8/8] chore: update deps --- requirements-train.txt | 44 ------------------------------------------ requirements.txt | 28 ++++++++++++++++++++------- 2 files changed, 21 insertions(+), 51 deletions(-) delete mode 100644 requirements-train.txt diff --git a/requirements-train.txt b/requirements-train.txt deleted file mode 100644 index d3f1991..0000000 --- a/requirements-train.txt +++ /dev/null @@ -1,44 +0,0 @@ -asgiref==3.2.10 -astroid==2.4.2 -attrs==19.3.0 -click==7.1.2 -cycler==0.10.0 -dill==0.3.2 -Django==3.0.7 -Flask==1.1.2 -Flask-Cors==3.0.8 -gunicorn==20.0.4 -isort==4.3.21 -itsdangerous==1.1.0 -Jinja2==2.11.2 -joblib==0.15.1 -lazy-object-proxy==1.4.3 -MarkupSafe==1.1.1 -matplotlib -mccabe==0.6.1 -more-itertools==8.4.0 -nltk==3.5 -numpy -openpyxl==3.0.10 -packaging==20.4 -pandas==1.5.3 -pluggy==0.13.1 -py==1.8.2 -pylint==2.5.3 -pyparsing==2.4.7 -pytest==5.4.3 -python-dateutil==2.8.1 -pytz==2020.1 -regex -scikit-learn -scipy -six==1.15.0 -sklearn==0.0 -sqlparse==0.3.1 -threadpoolctl==2.1.0 -toml==0.10.1 -tqdm==4.46.1 -wcwidth==0.2.4 -Werkzeug==1.0.1 -wrapt==1.12.1 -psutil diff --git a/requirements.txt b/requirements.txt index b8ed1b3..76df6b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,39 +1,53 @@ +asgiref==3.2.10 astroid==2.4.2 attrs==19.3.0 click==7.1.2 +contourpy==1.0.7 +cycler==0.10.0 dill==0.3.2 +Django==3.0.7 +et-xmlfile==1.1.0 Flask==1.1.2 Flask-Cors==3.0.8 +fonttools==4.39.4 gunicorn==20.0.4 +importlib-resources==5.12.0 isort==4.3.21 itsdangerous==1.1.0 Jinja2==2.11.2 joblib==0.15.1 +kiwisolver==1.4.4 lazy-object-proxy==1.4.3 MarkupSafe==1.1.1 +matplotlib==3.7.1 mccabe==0.6.1 more-itertools==8.4.0 nltk==3.5 -numpy==1.18.5 +numpy==1.24.3 +openpyxl==3.0.10 packaging==20.4 -pandas==1.0.4 +pandas==1.5.3 +Pillow==9.5.0 pluggy==0.13.1 +psutil==5.9.5 py==1.8.2 pylint==2.5.3 pyparsing==2.4.7 pytest==5.4.3 python-dateutil==2.8.1 pytz==2020.1 -regex==2020.6.8 -scikit-learn==0.23.1 -scipy==1.4.1 +regex==2023.6.3 +scikit-learn==1.0.2 +scipy==1.10.1 six==1.15.0 sklearn==0.0 +sqlparse==0.3.1 threadpoolctl==2.1.0 toml==0.10.1 tqdm==4.46.1 -uWSGI==2.0.19 +uWSGI==2.0.21 wcwidth==0.2.4 Werkzeug==1.0.1 wrapt==1.12.1 -psutil==5.7.0 +xlrd==2.0.1 +zipp==3.15.0 \ No newline at end of file