Skip to content

Commit eb89ab6

Browse files
committed
Update Readme and minor fix
1 parent 64be2f0 commit eb89ab6

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ cd torchscale
3838
pip install -e .
3939
```
4040

41+
For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs:
42+
```
43+
pip install flash-attn
44+
```
45+
or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs:
46+
```
47+
# cuda 11.8 version
48+
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
49+
# cuda 12.1 version
50+
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
51+
```
52+
4153
## Getting Started
4254

4355
It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
@@ -86,6 +98,21 @@ It takes only several lines of code to create a RetNet model:
8698
>>> print(retnet)
8799
```
88100

101+
For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required):
102+
```python
103+
>>> import torch
104+
>>> from torchscale.architecture.config import EncoderConfig, DecoderConfig
105+
>>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder
106+
107+
# Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
108+
>>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
109+
>>> longnet = LongNetEncoder(config)
110+
111+
# Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
112+
>>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
113+
>>> longnet = LongNetDecoder(config)
114+
```
115+
89116
## Key Features
90117

91118
- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
@@ -231,6 +258,24 @@ If you find this repository useful, please consider citing our work:
231258
}
232259
```
233260

261+
```
262+
@article{longnet,
263+
author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei},
264+
title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens},
265+
journal = {ArXiv},
266+
volume = {abs/2307.02486},
267+
year = {2023}
268+
}
269+
```
270+
271+
@article{longvit,
272+
title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
273+
author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei},
274+
journal = {ArXiv},
275+
volume = {abs/2312.03558},
276+
year = {2023}
277+
}
278+
234279
## Contributing
235280

236281
This project welcomes contributions and suggestions. Most contributions require you to agree to a

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
license="MIT",
1818
url="https://github.com/microsoft/torchscale",
1919
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
20-
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13"],
20+
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13", "einops"],
2121
python_requires=">=3.8.0",
2222
classifiers=[
2323
"Programming Language :: Python :: 3",

torchscale/component/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def backward(cls, ctx, grad, dlse):
113113
grads = _memory_efficient_attention_backward(
114114
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
115115
)
116-
return grads.dq, grads.dk, grads.dv, grads.db, None, None, None
116+
return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
117117

118118
flash_attn_func = FlashAttnFunc.apply
119119
except ModuleNotFoundError:

0 commit comments

Comments
 (0)