Skip to content

Commit 58fa97a

Browse files
Updated README.md
1 parent 042c1b9 commit 58fa97a

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

.github/images/classification.jpg

25.4 KB
Loading

.github/images/tokenizer.jpg

30.4 KB
Loading

.github/images/train.jpg

87.1 KB
Loading

README.md

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,62 @@
11
# daTrin 👋
22

3-
What is daTrin?
3+
is a learning material for AI beginners that focuses on training and inference. It is built on top of JAX and Flax, and is designed to be easily extensible and customizable to facilitate a better understanding of AI concepts:
4+
5+
collect data -> clean -> dataset -> vectorize -> train -> some magic -> save tensors -> inference
6+
7+
in this project, we will focus on the following:
8+
- Loading dataset and creating training and testing samples
9+
- Converting samples to vectors using a tokenizer
10+
- Training on vectors, measuring loss & accuracy
11+
- Saving/loading the model and tokenizer data
12+
- Inference on the trained model
13+
14+
**!!! Warning: This project is not production-ready and is not intended to be used in production. It is a learning material for AI beginners.!!!**
15+
16+
## Directory structure
17+
18+
- `datrin` - The root directory of the project
19+
- `data` - Contains the datasets **ag_news.csv** and **botchain.txt**
20+
- `dataset` - Contains a class for loading the data and splitting it into training and testing samples
21+
- `inference` - Contains the inference for the classification model and tokenizer
22+
- `model` - Contains the classification model class, save and load config and tensors
23+
- `out` - Contains the saved config, tensors and tokenizer
24+
- `tokenizer` - Contains the tokenizer class
25+
- `train` - Contains the training for the classification model and tokenizer
26+
27+
## Let's go...
28+
29+
```bash
30+
git clone [email protected]:leliuga/datrin.git && cd datrin
31+
python3.10 -m venv venv && . venv/bin/activate
32+
pip install .
33+
34+
for the GPU version of JAX, run the following command:
35+
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
36+
```
37+
38+
## Let's train classification model based on the AG News dataset
39+
40+
```bash
41+
python -m train.classification --classes=World,Sports,Business,Sci/Tech data/ag_news.csv
42+
```
43+
![train output](.github/images/train.jpg)
44+
45+
## Let's inference on the trained tokenizer
46+
47+
```bash
48+
python -m inference.tokenizer --prefix=out/ag_news "After earning a PH.D. in Sociology, Evaldas Leliuga started to work as the general manager"
49+
```
50+
51+
![tokenizer output](.github/images/tokenizer.jpg)
52+
53+
## Let's inference on the trained model
54+
55+
```bash
56+
python -m inference.classification --prefix=out/ag_news "After earning a PH.D. in Sociology, Evaldas Leliuga started to work as the general manager"
57+
```
58+
59+
![classification output](.github/images/classification.jpg)
460

561
## License
662

@@ -14,4 +70,4 @@ This project could not have been built without the following libraries or projec
1470
- [Flax](https://flax.readthedocs.io/en/latest)
1571
- [Optax](https://optax.readthedocs.io/en/latest)
1672
- [SentencePiece](https://github.com/google/sentencepiece)
17-
- [Safetensors](https://huggingface.co/docs/safetensors/index)
73+
- [Safetensors](https://huggingface.co/docs/safetensors/index)

0 commit comments

Comments
 (0)