Skip to content

Commit 11e745e

Browse files
authored
Merge pull request #87 from microsoft/longnet/longvit
Release LongNet and LongViT
2 parents 3ff2f1f + ef15951 commit 11e745e

31 files changed

+4210
-76
lines changed

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Fundamental research to develop new architectures for foundation models and A(G)
2020

2121
## News
2222

23+
- December, 2023: [LongNet](torchscale/model/LongNet.py) and [LongViT](examples/longvit/README.md) released
2324
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
2425
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
2526

@@ -37,6 +38,18 @@ cd torchscale
3738
pip install -e .
3839
```
3940

41+
For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs:
42+
```
43+
pip install flash-attn
44+
```
45+
or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs:
46+
```
47+
# cuda 11.8 version
48+
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
49+
# cuda 12.1 version
50+
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
51+
```
52+
4053
## Getting Started
4154

4255
It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
@@ -85,6 +98,21 @@ It takes only several lines of code to create a RetNet model:
8598
>>> print(retnet)
8699
```
87100

101+
For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required):
102+
```python
103+
>>> import torch
104+
>>> from torchscale.architecture.config import EncoderConfig, DecoderConfig
105+
>>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder
106+
107+
# Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
108+
>>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
109+
>>> longnet = LongNetEncoder(config)
110+
111+
# Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
112+
>>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
113+
>>> longnet = LongNetDecoder(config)
114+
```
115+
88116
## Key Features
89117

90118
- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
@@ -142,6 +170,8 @@ We have examples of how to use TorchScale in the following scenarios/tasks:
142170

143171
- Vision
144172

173+
* [LongViT](examples/longvit/README.md)
174+
145175
* ViT/BEiT [In progress]
146176

147177
- Speech
@@ -228,6 +258,26 @@ If you find this repository useful, please consider citing our work:
228258
}
229259
```
230260

261+
```
262+
@article{longnet,
263+
author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei},
264+
title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens},
265+
journal = {ArXiv},
266+
volume = {abs/2307.02486},
267+
year = {2023}
268+
}
269+
```
270+
271+
```
272+
@article{longvit,
273+
title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
274+
author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei},
275+
journal = {ArXiv},
276+
volume = {abs/2312.03558},
277+
year = {2023}
278+
}
279+
```
280+
231281
## Contributing
232282

233283
This project welcomes contributions and suggestions. Most contributions require you to agree to a

examples/fairseq/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,45 @@ python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
251251
--use-xmoe
252252
```
253253

254+
### LongNet Model
255+
256+
```bash
257+
cd examples/fairseq/
258+
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
259+
${PATH_TO_DATA} \
260+
--num-workers 2 \
261+
--activation-fn gelu \
262+
--share-decoder-input-output-embed \
263+
--validate-interval-updates 1000 \
264+
--save-interval-updates 1000 \
265+
--no-epoch-checkpoints \
266+
--memory-efficient-fp16 \
267+
--fp16-init-scale 4 \
268+
--arch lm_base \
269+
--task language_modeling \
270+
--sample-break-mode none \
271+
--tokens-per-sample 4096 \
272+
--optimizer adam --adam-betas "(0.9, 0.98)" \
273+
--adam-eps 1e-08 \
274+
--clip-norm 0.0 \
275+
--lr 5e-4 \
276+
--lr-scheduler polynomial_decay \
277+
--warmup-updates 750 \
278+
--dropout 0.1 \
279+
--attention-dropout 0.1 \
280+
--weight-decay 0.01 \
281+
--batch-size 4 \
282+
--update-freq 1 \
283+
--required-batch-size-multiple 1 \
284+
--total-num-update 50000 \
285+
--max-update 50000 \
286+
--seed 1 \
287+
--ddp-backend=c10d \
288+
--flash-attention \
289+
--segment-length [2048,4096] \
290+
--dilated-ratio [1,2]
291+
```
292+
254293
## Example: Machine Translation
255294

256295
### Data Format

examples/fairseq/models/language_modeling.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from torchscale.architecture.config import DecoderConfig
2727
from torchscale.architecture.decoder import Decoder
28+
from torchscale.model.LongNet import LongNetDecoder
2829

2930
DEFAULT_MAX_TARGET_POSITIONS = 1024
3031
logger = logging.getLogger(__name__)
@@ -196,6 +197,19 @@ class LanguageConfig(FairseqDataclass):
196197
xpos_scale_base: Optional[int] = field(
197198
default=512,
198199
)
200+
flash_attention: Optional[bool] = field(
201+
default=False,
202+
)
203+
seq_parallel: Optional[bool] = field(
204+
default=False,
205+
)
206+
segment_length: Optional[str] = field(
207+
default='',
208+
)
209+
dilated_ratio: Optional[str] = field(
210+
default='',
211+
)
212+
199213

200214

201215
@register_model("lm", dataclass=LanguageConfig)
@@ -256,7 +270,13 @@ def build_model(cls, args, task):
256270
config = DecoderConfig()
257271
config.override(args)
258272

259-
decoder = LMDecoder(
273+
if args.segment_length != '':
274+
assert args.dilated_ratio != ''
275+
DECODER_CLASS = LongNetLMDecoder
276+
else:
277+
DECODER_CLASS = LMDecoder
278+
279+
decoder = DECODER_CLASS(
260280
config,
261281
embed_tokens,
262282
embed_positions,
@@ -291,6 +311,25 @@ def reorder_incremental_state_scripting(
291311
incremental_state[module][key] = result
292312

293313

314+
class LongNetLMDecoder(LongNetDecoder, FairseqIncrementalDecoder):
315+
def forward(self, src_tokens, **kwargs):
316+
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
317+
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
318+
319+
def max_positions(self):
320+
return self.embed_positions.max_positions
321+
322+
def reorder_incremental_state_scripting(
323+
self,
324+
incremental_state,
325+
new_order,
326+
):
327+
for module in incremental_state:
328+
for key in incremental_state[module]:
329+
result = incremental_state[module][key].index_select(0, new_order)
330+
incremental_state[module][key] = result
331+
332+
294333
@register_model_architecture("lm", "lm_base")
295334
def base_lm_architecture(args):
296335
# backward compatibility for older model checkpoints

examples/longvit/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# [(LongViT) When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology](https://arxiv.org/abs/2312.03558)
2+
3+
**LongViT** is a vision Transformer that can process gigapixel images (e.g., 32,768x32,768 images) in an end-to-end manner. We split the image into millions of patches and employ [LongNet](https://arxiv.org/abs/2307.02486) to directly model the extremely long sequence. We apply LongViT in the field of computational pathology and achieve remarkable performance on cancer subtyping and survival prediction tasks.
4+
5+
6+
## Setup
7+
```
8+
pip install -r requirements.txt
9+
pip install git+https://github.com/shumingma/fairseq.git@moe
10+
pip install -v -U git+https://github.com/facebookresearch/[email protected]#egg=xformers
11+
```
12+
13+
14+
## Pretraining
15+
16+
We perform self-supervised pretraining on TCGA diagnostic slides using [DINO](https://arxiv.org/abs/2104.14294) objective. The detailed instructions can be found at [`get_started_for_tcga_pretraining.md`](get_started/get_started_for_tcga_pretraining.md).
17+
18+
The link to the pretrained LongViT model on TCGA diagnostic slides:
19+
- [`LongViT`](https://conversationhub.blob.core.windows.net/beit-share-public/longvit/longvit_small_patch32_1024.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=384; FFN factor=4x; #head=16; patch=32x32
20+
21+
22+
## Fine-tuning on Subtyping Classification
23+
24+
We perform finetuning on cancer subtyping on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_subtyping.md`](get_started/get_started_for_tcga_subtyping.md).
25+
26+
27+
## Fine-tuning on Survival Prediction
28+
29+
We perform finetuning on survival prediction on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_survival_prediction.md`](get_started/get_started_for_tcga_survival_prediction.md).
30+
31+
32+
## Citation
33+
34+
If you find this repository useful, please consider citing our work:
35+
```
36+
@article{longvit,
37+
title={When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
38+
author={Wang, Wenhui and Ma, Shuming and Xu, Hanwen and Usuyama, Naoto and Ding, Jiayu and Poon, Hoifung and Wei, Furu},
39+
journal={arXiv preprint arXiv:2312.03558},
40+
year={2023}
41+
}
42+
43+
@article{longnet,
44+
title={LongNet: Scaling transformers to 1,000,000,000 tokens},
45+
author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Zheng, Nanning and Wei, Furu},
46+
journal={arXiv preprint arXiv:2307.02486},
47+
year={2023}
48+
}
49+
50+
@article{torchscale,
51+
title={TorchScale: Transformers at scale},
52+
author={Ma, Shuming and Wang, Hongyu and Huang, Shaohan and Wang, Wenhui and Chi, Zewen and Dong, Li and Benhaim, Alon and Patra, Barun and Chaudhary, Vishrav and Song, Xia and others},
53+
journal={arXiv preprint arXiv:2211.13184},
54+
year={2022}
55+
}
56+
```
57+
58+
59+
## Acknowledgement
60+
61+
This repository is built using the [BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3), the [MCAT](https://github.com/mahmoodlab/MCAT), the [DINO](https://github.com/facebookresearch/dino), the [HIPT](https://github.com/mahmoodlab/HIPT) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library.
62+
63+
64+
## License
65+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
66+
67+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
68+
69+
### Contact Information
70+
71+
For help or issues using LongViT models, please submit a GitHub issue.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
import sys
3+
import torch
4+
import random
5+
import argparse
6+
from PIL import Image, ImageFilter, ImageOps
7+
from multiprocessing import Pool, cpu_count
8+
from timm.data.transforms import RandomResizedCropAndInterpolation
9+
import torchvision.transforms as transforms
10+
11+
Image.MAX_IMAGE_PIXELS = 6400000000
12+
13+
14+
def build_transform(input_size):
15+
train_interpolation = "bicubic"
16+
t = [
17+
RandomResizedCropAndInterpolation(input_size, scale=(0.5, 1.0), interpolation=train_interpolation),
18+
transforms.RandomHorizontalFlip(),
19+
]
20+
t = transforms.Compose(t)
21+
22+
return t
23+
24+
25+
def pil_loader(path):
26+
with open(path, "rb") as f:
27+
img = Image.open(f)
28+
return img.convert("RGB")
29+
30+
31+
def save_image(transformed_img, output_image_path):
32+
if isinstance(transformed_img, torch.Tensor):
33+
transformed_img = transforms.ToPILImage()(transformed_img)
34+
transformed_img.save(output_image_path)
35+
36+
37+
def get_image_files(input_dir):
38+
for root, _, files in os.walk(input_dir):
39+
for file in files:
40+
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
41+
yield os.path.join(root, file)
42+
43+
44+
def transform_and_save_crops(args):
45+
input_path, input_dir, output_dir, transform = args
46+
print(input_path)
47+
file_basename = os.path.basename(input_path)
48+
49+
img = pil_loader(input_path)
50+
transformed_img = transform(img)
51+
output_image_path = os.path.join(output_dir, file_basename)
52+
save_image(transformed_img, output_image_path)
53+
54+
55+
if __name__ == '__main__':
56+
parser = argparse.ArgumentParser(description='Save transformed images in a directory.')
57+
parser.add_argument('input_dir', help='Path to the input directory.')
58+
parser.add_argument('output_dir', help='Path to the output directory.')
59+
parser.add_argument('-p', '--processes', type=int, default=cpu_count(), help='Number of processes to use. Default: number of CPU cores')
60+
parser.add_argument('--input_size', type=int, default=16384, help='input image size')
61+
args = parser.parse_args()
62+
63+
input_dir = args.input_dir
64+
output_dir = args.output_dir
65+
num_processes = args.processes
66+
input_size = args.input_size
67+
print("num_processes: {}".format(num_processes))
68+
print("input_size: {}".format(input_size))
69+
70+
transform = build_transform(input_size=input_size)
71+
72+
image_files = list(get_image_files(input_dir))
73+
task_args = [(file, input_dir, output_dir, transform) for file in image_files]
74+
75+
os.makedirs(output_dir, exist_ok=True)
76+
77+
with Pool(processes=num_processes) as pool:
78+
pool.map(transform_and_save_crops, task_args)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import glob
3+
import argparse
4+
import openslide
5+
6+
from PIL import Image
7+
from concurrent.futures import ProcessPoolExecutor
8+
9+
10+
def convert_wsi_to_images(slide_path, image_path, target_size, level=0):
11+
slide = openslide.open_slide(slide_path)
12+
level_dims = slide.level_dimensions
13+
region = slide.read_region((0,0), level, level_dims[level])
14+
region = region.convert("RGB")
15+
print("convert: {}({}) -> {}".format(slide_path, region.size, image_path))
16+
resized_img = region.resize((target_size, target_size), Image.BICUBIC)
17+
resized_img.save(image_path)
18+
19+
20+
def process_slides(input_folder, output_folder, target_size, level=0):
21+
if not os.path.exists(output_folder):
22+
os.makedirs(output_folder)
23+
24+
slide_paths = glob.glob(os.path.join(input_folder, "*.svs"))
25+
26+
with ProcessPoolExecutor(max_workers=1) as executor:
27+
for slide_path in slide_paths:
28+
image_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0] + ".jpg")
29+
executor.submit(convert_wsi_to_images, slide_path, image_path, target_size, level=level)
30+
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser(description="Convert slides into images")
34+
parser.add_argument("input_folder", type=str, help="")
35+
parser.add_argument("output_folder", type=str, help="")
36+
parser.add_argument("target_size", type=int, help="")
37+
parser.add_argument("level", type=int, help="")
38+
39+
args = parser.parse_args()
40+
input_folder = args.input_folder
41+
output_folder = args.output_folder
42+
target_size = args.target_size
43+
level = args.level
44+
45+
process_slides(input_folder, output_folder, target_size, level=level)

0 commit comments

Comments
 (0)