Skip to content

Inference#12

Open
glicerico wants to merge 19 commits into
macabdul9:mainfrom
glicerico:inference
Open

Inference#12
glicerico wants to merge 19 commits into
macabdul9:mainfrom
glicerico:inference

Conversation

@glicerico

Copy link
Copy Markdown
Contributor

@macabdul9 sorry for the large PR, but I had to accumulate improvements to achieve proper inference.
Summary of changes:

  • I further cleaned the Switchboard dataset, to remove all utterances tagged with +, as those are continuations of interrupted utterances. Unless these are somehow joined back to their initial utterance, I believe they are useless... See discussion here.
  • Fix issue incompatible dataset labels #10 : Obtain a dictionary for classes from training data, to ensure consistency even if test data has less classes (as is the case with this swithboard split).
  • Implement code to perform inference, and add steps to use it in README

@PolKul

PolKul commented Feb 15, 2021

Copy link
Copy Markdown

@glicerico, I've just created a new fork from yours. I made some changes to the code related to running on gpu. Also added annotations to the classes. And added evaluation of the model on a test dataset. The problem is, when using your trained checkpoint, my eval method gives me quite poor results. May I ask you to try it on your side and see if there is any problem with your trained checkpoint or with my code? My fork is here https://github.com/PolKul/CASA-Dialogue-Act-Classifier.git

here is the result of running Eval on your checkpoint "epoch=29-val_accuracy=0.751411.ckpt"

              precision    recall  f1-score   support

           0       0.59      0.81      0.68       360
           1       0.00      0.00      0.00       328
           2       0.00      0.00      0.00        19
           3       0.00      0.00      0.00         0
           4       0.00      0.00      0.00         7
           5       0.00      0.00      0.00        17
           6       0.00      0.00      0.00       208
           7       0.00      0.00      0.00         7
           8       0.00      0.00      0.00        27
           9       0.00      0.00      0.00         3
          10       0.00      0.00      0.00         3
          11       0.00      0.00      0.00       765
          12       0.00      0.00      0.00        21
          13       0.00      0.00      0.00        76
          14       0.00      0.00      0.00         1
          15       0.00      0.00      0.00        23
          16       0.00      0.00      0.00        21
          17       0.00      0.00      0.00        28
          18       0.00      0.00      0.00         9
          19       0.00      0.00      0.00         2
          20       0.37      0.09      0.14        81
          21       0.00      0.00      0.00        16
          22       0.00      0.00      0.00         5
          23       0.00      0.00      0.00         7
          24       0.00      0.00      0.00        23
          25       0.00      0.00      0.00        10
          26       0.00      0.00      0.00         6
          27       0.00      0.00      0.00        26
          28       0.00      0.00      0.00         6
          29       0.00      0.00      0.00        73
          30       0.00      0.00      0.00         0
          31       0.00      0.00      0.00        12
          32       0.00      0.00      0.00        16
          33       0.00      0.00      0.00         2
          34       0.00      0.00      0.00        55
          35       0.00      0.00      0.00         1
          36       0.00      0.00      0.00        84
          37       0.01      0.33      0.01        36
          38       0.20      0.11      0.14      1317
          39       0.00      0.00      0.00       718
          40       0.00      0.00      0.00         1
          41       0.00      0.00      0.00         0
          42       0.00      0.00      0.00        94

    accuracy                           0.10      4514
   macro avg       0.03      0.03      0.02      4514
weighted avg       0.11      0.10      0.10      4514

It shows accuracy of only 10%...

@glicerico

glicerico commented Feb 16, 2021

Copy link
Copy Markdown
Contributor Author

Hey @PolKul , as commented in one of the issues, that checkpoint was trained before the classes problem was solved, so it probably is using the wrong labels.
Do you have a newer checkpoint, like epoch=5-val_accuracy=0.779101.ckpt that I uploaded before (and removed bc my dropbox was full)?

@glicerico

Copy link
Copy Markdown
Contributor Author

@PolKul I uploaded it here again... please try your evaluation with this checkpoint and let me know.
I'll probably remove the file in a day or 2, unless someone wants to host it somewhere else :)
https://www.dropbox.com/s/e88ymjfej80zabs/epoch%3D28-val_accuracy%3D0.746056.ckpt?dl=0

@glicerico

Copy link
Copy Markdown
Contributor Author

Oh, @PolKul , I am just noticing that you are using your own class label numbering... so it's expected that the predictions won't match.
In the following line:
https://github.com/PolKul/CASA-Dialogue-Act-Classifier/blob/32214d64d556505424b1efe54905371e7f417dcb/predict.py#L130
you give an arbitrary number to each tag, based on enumerate and the order in which you defined the class labels in dataset.py. The checkpoint was trained using the sorted list of tags from the training set, as it was suggested by @Christopher-Thornton:
https://github.com/glicerico/CASA-Dialogue-Act-Classifier/blob/92400edff9e0ab724d545d4495346e5eae4cd77e/dataset/dataset.py#L18

So, you probably should leave the classes as it was proposed in my pull request, or train a model with the label order that you prefer :)

@PolKul

PolKul commented Feb 16, 2021

Copy link
Copy Markdown

@glicerico, thank you for your review. However my question was more about the eval() method of the DialogClassifier class. As you can see it doesn't use my annotated classes (act_label_names list) in any way and still produces really bad results (0.1 F1 score). To avoid confusion, you can add the same eval method to your branch and try running it. Let me know if you can see any better statistics from it?

@glicerico

Copy link
Copy Markdown
Contributor Author

You're right, I see that you only use act_label_names to print (to print incorrectly, as the classes in act_label_names are numbered differently from the predictions).
The other point I made above still remains: that model was trained when there were some problems with class numbering.
Sorry, but I would prefer not to have to checkout your branch, figure it out, and run it, until you explore all possible reasons we see :)

@PolKul

PolKul commented Feb 17, 2021

Copy link
Copy Markdown

Sorry, but I don't see where you see the problem with the act_label_names.

to print incorrectly, as the classes in act_label_names are numbered differently from the predictions

that is a dictionary, with the following structure: ["name","act_tag","example"]. The code below is finding a "name" by "act_tag":

for utterance, prediction in zip(utterances, predicted_acts):
    for index, act_tag in enumerate(act_label_names['act_tag']):
        if act_tag == prediction:
            print(f"{prediction}({utterance})-> {act_label_names['name'][index]}")

Or you mean that "prediction" is incorrectly labeled?

@glicerico

glicerico commented Feb 17, 2021

Copy link
Copy Markdown
Contributor Author

After your past comment, I don't see a problem with act_label_names.
I am talking about the 2 posts prior to that:

Hey @PolKul , as commented in one of the issues, that checkpoint was trained before the classes problem was solved, so it probably is using the wrong labels.
Do you have a newer checkpoint, like epoch=5-val_accuracy=0.779101.ckpt that I uploaded before (and removed bc my dropbox was full)?

@PolKul I uploaded it here again... please try your evaluation with this checkpoint and let me know.
I'll probably remove the file in a day or 2, unless someone wants to host it somewhere else :)
https://www.dropbox.com/s/e88ymjfej80zabs/epoch%3D28-val_accuracy%3D0.746056.ckpt?dl=0

@glicerico

Copy link
Copy Markdown
Contributor Author

Or you mean that "prediction" is incorrectly labeled?

I mean that prediction is labeled differently

@PolKul

PolKul commented Feb 17, 2021

Copy link
Copy Markdown

I confirm that both "epoch=28-val_accuracy=0.746056.ckpt" and "epoch=29-val_accuracy=0.751411.ckpt" give the same (bad) results with F1 score of 0.1

It would be interesting to see the results of your eval()...

@glicerico

Copy link
Copy Markdown
Contributor Author

Hi @PolKul , these are the results I got using the best checkpoint I trained, with unfrozen Roberta weights.
I invite you to use that checkpoint, I uploaded it here: https://www.dropbox.com/s/1zj4vq59z9h6re3/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0
Please let me know when you get it, so I don't have my dropbox account completely full.

Eval on Test dataset
-------------------------------------
100%|██████████| 64/64 [24:00<00:00, 22.50s/it]
/home/andres/src/miniconda3/envs/CASA/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

           0       0.78      0.84      0.81       360
           1       0.08      0.05      0.06        19
           3       0.33      0.14      0.20         7
           4       1.00      0.06      0.11        17
           5       0.74      0.46      0.57       208
           6       0.00      0.00      0.00         7
           7       0.41      0.26      0.32        27
           8       0.00      0.00      0.00         3
           9       0.00      0.00      0.00         3
          10       0.79      0.92      0.85       765
          11       0.50      0.05      0.09        21
          12       0.68      0.84      0.75        76
          13       0.00      0.00      0.00         1
          14       1.00      0.04      0.08        23
          15       0.67      0.86      0.75        21
          16       0.47      0.32      0.38        28
          17       0.67      0.44      0.53         9
          18       1.00      1.00      1.00         2
          19       0.84      0.59      0.70        81
          20       0.33      0.19      0.24        16
          21       0.75      0.60      0.67         5
          22       0.00      0.00      0.00         7
          23       0.82      0.61      0.70        23
          24       0.33      0.10      0.15        10
          25       0.50      0.33      0.40         6
          26       0.81      0.81      0.81        26
          27       0.33      0.17      0.22         6
          28       0.79      0.62      0.69        73
          30       0.11      0.08      0.10        12
          31       0.65      0.81      0.72        16
          32       1.00      1.00      1.00         2
          33       0.73      0.75      0.74        55
          34       0.00      0.00      0.00         1
          35       0.69      0.80      0.74        84
          36       0.00      0.00      0.00        36
          37       0.81      0.85      0.83      1317
          38       0.69      0.72      0.71       718
          39       0.00      0.00      0.00         1

    accuracy                           0.76      4092
   macro avg       0.51      0.40      0.42      4092
weighted avg       0.75      0.76      0.74      4092

@PolKul

PolKul commented Feb 25, 2021

Copy link
Copy Markdown

Hi @glicerico, thanks for the checkpoint and eval. I've just updated the repo from your latest inference branch and it worked! Not sure what was wrong with my previous code though.. any way, thank you for your assistance.

@minarainbow

Copy link
Copy Markdown

Hi @PolKul , these are the results I got using the best checkpoint I trained, with unfrozen Roberta weights.
I invite you to use that checkpoint, I uploaded it here: https://www.dropbox.com/s/1zj4vq59z9h6re3/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0
Please let me know when you get it, so I don't have my dropbox account completely full.

Eval on Test dataset
-------------------------------------
100%|██████████| 64/64 [24:00<00:00, 22.50s/it]
/home/andres/src/miniconda3/envs/CASA/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

           0       0.78      0.84      0.81       360
           1       0.08      0.05      0.06        19
           3       0.33      0.14      0.20         7
           4       1.00      0.06      0.11        17
           5       0.74      0.46      0.57       208
           6       0.00      0.00      0.00         7
           7       0.41      0.26      0.32        27
           8       0.00      0.00      0.00         3
           9       0.00      0.00      0.00         3
          10       0.79      0.92      0.85       765
          11       0.50      0.05      0.09        21
          12       0.68      0.84      0.75        76
          13       0.00      0.00      0.00         1
          14       1.00      0.04      0.08        23
          15       0.67      0.86      0.75        21
          16       0.47      0.32      0.38        28
          17       0.67      0.44      0.53         9
          18       1.00      1.00      1.00         2
          19       0.84      0.59      0.70        81
          20       0.33      0.19      0.24        16
          21       0.75      0.60      0.67         5
          22       0.00      0.00      0.00         7
          23       0.82      0.61      0.70        23
          24       0.33      0.10      0.15        10
          25       0.50      0.33      0.40         6
          26       0.81      0.81      0.81        26
          27       0.33      0.17      0.22         6
          28       0.79      0.62      0.69        73
          30       0.11      0.08      0.10        12
          31       0.65      0.81      0.72        16
          32       1.00      1.00      1.00         2
          33       0.73      0.75      0.74        55
          34       0.00      0.00      0.00         1
          35       0.69      0.80      0.74        84
          36       0.00      0.00      0.00        36
          37       0.81      0.85      0.83      1317
          38       0.69      0.72      0.71       718
          39       0.00      0.00      0.00         1

    accuracy                           0.76      4092
   macro avg       0.51      0.40      0.42      4092
weighted avg       0.75      0.76      0.74      4092

Dear @glicerico, may I ask if you can re-upload the checkpoint? Somehow I don't get the results, (and my inference speed is so slow when using yours, do you know why?)

@glicerico

Copy link
Copy Markdown
Contributor Author

@minarainbow , you can find the checkpoint at https://www.dropbox.com/s/egiv70dwl1ikrbq/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0, I'll remove it from there in a couple days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants