generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 514
⚡ Add support for Chronos-Bolt models #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 22 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
4eccf75
Add support for ChronosBolt models
e3215ae
Fix test
e2b39b3
Fix tests
7554de9
Fix mypy issues
7c28bdc
Update eval script
8b18e08
Add bolt tests
09556d6
Update docstrings
fcb78fa
Add tests for predict_quantiles
711a369
Use dummy bolt model for tests
386caeb
Add versions hints
abb3e97
Readme update
d4c48bc
Update readme
696a5a0
Update AutoGluon tip
daede64
fix readme
e512f89
fix
3af9465
Add authors
0ab0032
Fix docstrings
67358b2
Add metrics csvs
3d6f600
Add autogluon src
d5d5b67
Update zero-shot plot
67a0bd9
Update version ranges
77ce929
fix
a18bc20
Apply suggestions from code review
abdulfatir 0efed52
Merge branch 'main' into add-chronos-bolt
abdulfatir 005e10b
compilation fixes
8b9e67e
Rename autogluon -> amazon
46a1bbc
Update README
b48c62e
Rename models
fab8fb9
Update README.md
abdulfatir ed7a4a1
Use SVG
d3986a8
wider
f57a3d1
Remove PNG
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,19 +1,19 @@ | ||
| [project] | ||
| name = "chronos" | ||
| version = "1.2.1" | ||
| version = "1.3.0" | ||
| requires-python = ">=3.8" | ||
| license = { file = "LICENSE" } | ||
| dependencies = [ | ||
| "torch~=2.0", # package was tested on 2.2 | ||
| "transformers~=4.30", | ||
| "accelerate", | ||
| "torch>=2.0,<2.6", # package was tested on 2.2 | ||
| "transformers>=4.30,<4.48", | ||
| "accelerate>=0.32,<1", | ||
| ] | ||
|
|
||
| [project.optional-dependencies] | ||
| test = ["pytest~=8.0", "numpy~=1.21"] | ||
| typecheck = ["mypy~=1.9"] | ||
| training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config", "joblib", "scikit-learn"] | ||
| evaluation = ["gluonts[pro]", "datasets", "numpy", "typer"] | ||
| training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"] | ||
| evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"] | ||
|
|
||
| [tool.mypy] | ||
| ignore_missing_imports = true | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import pandas as pd | ||
| import typer | ||
| from scipy.stats import gmean | ||
| from pathlib import Path | ||
|
|
||
| app = typer.Typer(pretty_exceptions_enable=False) | ||
| DEFAULT_RESULTS_DIR = Path(__file__).parent / "results" | ||
|
|
||
|
|
||
| def agg_relative_score(model_csv: Path, baseline_csv: Path): | ||
| model_df = pd.read_csv(model_csv).set_index("dataset") | ||
| baseline_df = pd.read_csv(baseline_csv).set_index("dataset") | ||
| relative_score = model_df.drop("model", axis="columns") / baseline_df.drop( | ||
| "model", axis="columns" | ||
| ) | ||
| return relative_score.agg(gmean) | ||
|
|
||
|
|
||
| @app.command() | ||
| def main( | ||
| model_name: str, | ||
| baseline_name: str = "seasonal-naive", | ||
| results_dir: Path = DEFAULT_RESULTS_DIR, | ||
| ): | ||
| """ | ||
| Compute the aggregated relative score as reported in the Chronos paper. | ||
| Results will be saved to {results_dir}/{model_name}-agg-rel-scores.csv | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model_name : str | ||
| Name of the model used in the CSV files. The in-domain and zero-shot CSVs | ||
| are expected to be named {model_name}-in-domain.csv and {model_name}-zero-shot.csv. | ||
| results_dir : Path, optional, default = results/ | ||
| Directory where results CSVs generated by evaluate.py are stored | ||
| """ | ||
|
|
||
| in_domain_agg_score_df = agg_relative_score( | ||
| results_dir / f"{model_name}-in-domain.csv", | ||
| results_dir / f"{baseline_name}-in-domain.csv", | ||
| ) | ||
| in_domain_agg_score_df.name = "value" | ||
| in_domain_agg_score_df.index.name = "metric" | ||
|
|
||
| zero_shot_agg_score_df = agg_relative_score( | ||
| results_dir / f"{model_name}-zero-shot.csv", | ||
| results_dir / f"{baseline_name}-zero-shot.csv", | ||
| ) | ||
| zero_shot_agg_score_df.name = "value" | ||
| zero_shot_agg_score_df.index.name = "metric" | ||
|
|
||
| agg_score_df = pd.concat( | ||
| {"in-domain": in_domain_agg_score_df, "zero-shot": zero_shot_agg_score_df}, | ||
| names=["benchmark"], | ||
| ) | ||
| agg_score_df.to_csv(f"{results_dir}/{model_name}-agg-rel-scores.csv") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| app() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| benchmark,metric,value | ||
| in-domain,MASE,0.6800133628315155 | ||
| in-domain,WQL,0.5339263811489279 | ||
| zero-shot,MASE,0.7914551113353537 | ||
| zero-shot,WQL,0.6241424984163773 |
16 changes: 16 additions & 0 deletions
16
scripts/evaluation/results/chronos-bolt-base-in-domain.csv
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| dataset,model,MASE,WQL | ||
| electricity_15min,autogluon/chronos-bolt-base,0.41069374835605243,0.0703533790998506 | ||
| m4_daily,autogluon/chronos-bolt-base,3.205192517121196,0.02110308498174413 | ||
| m4_hourly,autogluon/chronos-bolt-base,0.8350129849014075,0.025353803894164 | ||
| m4_monthly,autogluon/chronos-bolt-base,0.9491758928362231,0.09382496106659234 | ||
| m4_weekly,autogluon/chronos-bolt-base,2.0847827409162742,0.03816605075768161 | ||
| monash_electricity_hourly,autogluon/chronos-bolt-base,1.254966217685461,0.09442192616975713 | ||
| monash_electricity_weekly,autogluon/chronos-bolt-base,1.8391546050108039,0.06410971963960499 | ||
| monash_kdd_cup_2018,autogluon/chronos-bolt-base,0.6405985809360102,0.2509172188706336 | ||
| monash_london_smart_meters,autogluon/chronos-bolt-base,0.701398572604996,0.3218915088923906 | ||
| monash_pedestrian_counts,autogluon/chronos-bolt-base,0.2646412642278343,0.18789459806066328 | ||
| monash_rideshare,autogluon/chronos-bolt-base,0.7695376426829713,0.11637119433040358 | ||
| monash_temperature_rain,autogluon/chronos-bolt-base,0.8983612698773724,0.6050555216496304 | ||
| taxi_30min,autogluon/chronos-bolt-base,0.7688908266765317,0.2363178601205094 | ||
| uber_tlc_daily,autogluon/chronos-bolt-base,0.8231767493519677,0.0926036406916842 | ||
| uber_tlc_hourly,autogluon/chronos-bolt-base,0.6632193728217927,0.14987786887626975 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.