diff --git a/yeonsu/.gitignore b/yeonsu/.gitignore new file mode 100644 index 0000000..9d9cf0b --- /dev/null +++ b/yeonsu/.gitignore @@ -0,0 +1,229 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + + +# experiments +experiments/ + +# checkpoints +checkpoints/ + +# data +data/ + +# logs +logs/ + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml \ No newline at end of file diff --git a/yeonsu/.gitkeep b/yeonsu/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/yeonsu/README.md b/yeonsu/README.md new file mode 100644 index 0000000..ba2dac2 --- /dev/null +++ b/yeonsu/README.md @@ -0,0 +1,232 @@ +# CIFAR-10 ResNet-18 학습 프로젝트 + +CIFAR-10 데이터셋을 사용하여 ResNet-18 모델을 학습하고, 특히 Cat과 Dog 클래스 간의 혼동 문제를 해결하기 위한 다양한 기법을 적용한 프로젝트입니다. + +## 📋 목차 + +- [프로젝트 개요](#프로젝트-개요) +- [주요 기능](#주요-기능) +- [설치 방법](#설치-방법) +- [데이터 준비](#데이터-준비) +- [사용 방법](#사용-방법) +- [실험 설정](#실험-설정) +- [프로젝트 구조](#프로젝트-구조) +- [실험 결과](#실험-결과) + +## 🎯 프로젝트 개요 + +이 프로젝트는 CIFAR-10 데이터셋에서 ResNet-18 모델을 학습하며, 특히 유사한 특징을 가진 Cat과 Dog 클래스 간의 혼동 문제를 해결하기 위해 다음과 같은 기법들을 적용합니다: + +- **다양한 손실 함수**: Focal Loss, Weighted Cross-Entropy Loss, Label Smoothing +- **고급 데이터 증강**: Cutout, Mixup, ColorJitter +- **체계적인 실험 관리**: 각 실험의 설정과 결과를 자동으로 저장 + +## ✨ 주요 기능 + +### 1. 다양한 손실 함수 지원 + +- **Cross-Entropy Loss**: 기본 손실 함수 +- **Focal Loss**: 어려운 샘플에 집중하여 학습 (Cat/Dog 혼동 문제 해결에 효과적) +- **Weighted Cross-Entropy Loss**: 특정 클래스(Cat/Dog)에 높은 가중치 부여 +- **Label Smoothing**: 과신뢰 방지 및 일반화 성능 향상 + +### 2. 데이터 증강 기법 + +- **Advanced Augmentation**: RandomHorizontalFlip, RandomCrop, ColorJitter +- **Cutout**: 이미지의 일부 영역을 제거하여 세부 특징에 의존하지 않도록 학습 +- **Mixup**: 두 이미지를 혼합하여 결정 경계를 부드럽게 만듦 + +### 3. 실험 관리 시스템 + +- 각 실험마다 고유한 폴더 자동 생성 +- 하이퍼파라미터 설정 자동 저장 (`config.json`) +- 학습 곡선 그래프 자동 생성 (`training_curves.png`) +- 학습 결과 CSV 저장 (`results.csv`) +- 최종 결과 요약 저장 (`summary.json`) + +## 📦 설치 방법 + +### 패키지 설치 + +```bash +pip install -r requirements.txt +``` + +필요한 패키지: +- torch >= 2.0.0 +- torchvision >= 0.15.0 +- tqdm >= 4.65.0 +- numpy >= 1.24.0 +- scikit-learn >= 1.3.0 +- matplotlib >= 3.7.0 +- seaborn >= 0.12.0 + +## 📁 데이터 준비 + +CIFAR-10 데이터셋을 자동으로 다운로드하고 준비합니다: + +```bash +python download_cifar10.py +``` + +데이터는 `./data/cifar-10-batches-py/` 디렉토리에 저장됩니다. + +## 🚀 사용 방법 + +### 기본 학습 + +가장 간단한 방법으로 학습을 시작합니다: + +```bash +python train.py +``` + +### 고급 설정을 사용한 학습 + +Focal Loss와 고급 데이터 증강을 사용하여 학습: + +```bash +python train.py \ + --epochs 150 \ + --lr 0.1 \ + --batch_size 128 \ + --loss_type focal \ + --focal_gamma 2.0 \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 +``` + +### 제공된 스크립트 사용 + +프로젝트에는 여러 실험 설정이 포함된 스크립트가 제공됩니다: + +```bash +# 기본 실험 (Cross-Entropy Loss) +bash scripts/baseline_train.sh + +# 실험 1: Focal Loss + 고급 증강 + Cutout +bash scripts/exp1.sh + +# 실험 2: Focal Loss + 고급 증강 + Cutout + Mixup +bash scripts/exp2.sh + +# 실험 3: Weighted Loss + 고급 증강 + Cutout +bash scripts/exp3.sh +``` + +## ⚙️ 실험 설정 + +### 주요 하이퍼파라미터 + +#### 학습 관련 +- `--epochs`: 학습 에폭 수 (기본값: 150) +- `--lr`: 초기 학습률 (기본값: 0.1) +- `--batch_size`: 배치 크기 (기본값: 128) +- `--momentum`: SGD momentum (기본값: 0.9) +- `--weight_decay`: Weight decay (기본값: 5e-4) + +#### 학습률 스케줄러 +- `--scheduler`: 스케줄러 타입 (`step`, `cosine`, `none`) +- `--step_size`: StepLR step_size (기본값: 30) +- `--gamma`: StepLR gamma (기본값: 0.1) + +#### 손실 함수 +- `--loss_type`: 손실 함수 타입 (`ce`, `focal`, `weighted`, `label_smooth`) +- `--focal_alpha`: Focal Loss alpha (기본값: 1.0) +- `--focal_gamma`: Focal Loss gamma (기본값: 2.0) +- `--class_weight`: Weighted Loss 클래스 가중치 (기본값: 2.0) +- `--label_smoothing`: Label Smoothing 값 (기본값: 0.1) + +#### 데이터 증강 +- `--use_advanced_aug`: 고급 데이터 증강 사용 +- `--use_cutout`: Cutout 사용 +- `--cutout_length`: Cutout 길이 (기본값: 16) +- `--use_mixup`: Mixup 사용 +- `--mixup_alpha`: Mixup alpha (기본값: 1.0) + +### 추천 설정 조합 + +#### 조합 1: 기본 개선 (추천) +```bash +python train.py \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 \ + --loss_type focal \ + --focal_gamma 2.0 +``` + +#### 조합 2: 강력한 개선 +```bash +python train.py \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 \ + --loss_type focal \ + --focal_gamma 2.5 \ + --use_mixup \ + --mixup_alpha 0.8 +``` + +#### 조합 3: 극대화 +```bash +python train.py \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 \ + --loss_type weighted \ + --class_weight 2.0 \ + --use_mixup \ + --mixup_alpha 1.0 +``` + +## 📂 프로젝트 구조 + +``` +yeonsu/ +├── README.md # 이 파일 +├── requirements.txt # 의존성 패키지 목록 +├── train.py # 메인 학습 스크립트 +├── dataset.py # 데이터 로더 +├── losses.py # 손실 함수 구현 +├── augmentation.py # 데이터 증강 구현 +├── download_cifar10.py # CIFAR-10 다운로드 스크립트 +├── models/ +│ └── resnet18.py # ResNet-18 모델 정의 +├── scripts/ +│ ├── baseline_train.sh # 기본 학습 스크립트 +│ ├── exp1.sh # 실험 1 스크립트 +│ ├── exp2.sh # 실험 2 스크립트 +│ └── exp3.sh # 실험 3 스크립트 +├── data/ # 데이터 디렉토리 +│ └── cifar-10-batches-py/ +├── experiments/ # 실험 결과 저장 디렉토리 +│ ├── ce_20260106_133608/ +│ │ ├── config.json # 실험 설정 +│ │ ├── summary.json # 결과 요약 +│ │ ├── results.csv # 학습 결과 +│ │ ├── training_curves.png # 학습 곡선 +│ │ └── checkpoints/ # 모델 체크포인트 +│ └── ... +└── augmentation_examples/ # 증강 기법 예제 이미지 + ├── cutout_example.png + └── mixup_example.png +``` + +## 📊 실험 결과 + +각 실험의 결과는 `experiments/` 디렉토리에 저장됩니다. 각 실험 폴더에는 다음 파일들이 포함됩니다: + +- `config.json`: 사용된 하이퍼파라미터 설정 +- `summary.json`: 최고 정확도, 최종 정확도, 학습 시간 등 요약 정보 +- `results.csv`: 에폭별 학습/테스트 손실 및 정확도 +- `training_curves.png`: 학습 곡선 시각화 +- `checkpoints/`: 모델 체크포인트 파일들 + +### 예시 결과 + +실험 예시: +- **기본 (CE)**: 최고 정확도 ~94% +- **Focal Loss + 증강**: 최고 정확도 ~95%+ +- **Focal Loss + 증강 + Mixup**: 최고 정확도 ~94+ diff --git a/yeonsu/augmentation.py b/yeonsu/augmentation.py new file mode 100644 index 0000000..48d1964 --- /dev/null +++ b/yeonsu/augmentation.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from torchvision import transforms + + +class Cutout(object): + def __init__(self, length=16): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1:y2, x1:x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + + +def get_advanced_augmentation(use_cutout=True, cutout_length=16): + mean = (0.4914, 0.4822, 0.4465) + std = (0.2023, 0.1994, 0.2010) + + train_transform = transforms.Compose([ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomCrop(32, padding=4), + transforms.ColorJitter( + brightness=0.2, + contrast=0.2, + saturation=0.2, + hue=0.1 + ), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + + if use_cutout: + train_transform.transforms.append(Cutout(cutout_length)) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]) + + return train_transform, test_transform + + +class Mixup(object): + def __init__(self, alpha=1.0): + self.alpha = alpha + + def __call__(self, batch): + images, labels = batch + if self.alpha > 0: + lam = np.random.beta(self.alpha, self.alpha) + else: + lam = 1 + + batch_size = images.size(0) + index = torch.randperm(batch_size).to(images.device) + + mixed_images = lam * images + (1 - lam) * images[index] + labels_a, labels_b = labels, labels[index] + + return mixed_images, labels_a, labels_b, lam + + +def mixup_criterion(criterion, pred, y_a, y_b, lam): + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + diff --git a/yeonsu/augmentation_examples/cutout_example.png b/yeonsu/augmentation_examples/cutout_example.png new file mode 100644 index 0000000..ad82f13 Binary files /dev/null and b/yeonsu/augmentation_examples/cutout_example.png differ diff --git a/yeonsu/augmentation_examples/mixup_example.png b/yeonsu/augmentation_examples/mixup_example.png new file mode 100644 index 0000000..3ddf653 Binary files /dev/null and b/yeonsu/augmentation_examples/mixup_example.png differ diff --git a/yeonsu/dataset.py b/yeonsu/dataset.py new file mode 100644 index 0000000..5e8413b --- /dev/null +++ b/yeonsu/dataset.py @@ -0,0 +1,63 @@ +import torch +import torchvision +import torchvision.transforms as transforms +from augmentation import get_advanced_augmentation + + +def get_cifar10_dataloaders(data_dir='./data', batch_size=128, num_workers=4, + use_advanced_aug=False, use_cutout=True, cutout_length=16): + if use_advanced_aug: + train_transform, test_transform = get_advanced_augmentation( + use_cutout=use_cutout, + cutout_length=cutout_length + ) + else: + mean = (0.4914, 0.4822, 0.4465) + std = (0.2023, 0.1994, 0.2010) + + train_transform = transforms.Compose([ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomCrop(32, padding=4), + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]) + + train_dataset = torchvision.datasets.CIFAR10( + root=data_dir, + train=True, + download=False, + transform=train_transform + ) + + test_dataset = torchvision.datasets.CIFAR10( + root=data_dir, + train=False, + download=False, + transform=test_transform + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True if torch.cuda.is_available() else False + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True if torch.cuda.is_available() else False + ) + + classes = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') + + return train_loader, test_loader, classes diff --git a/yeonsu/download_cifar10.py b/yeonsu/download_cifar10.py new file mode 100644 index 0000000..838df30 --- /dev/null +++ b/yeonsu/download_cifar10.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +CIFAR-10 데이터셋 다운로드 스크립트 +torchvision을 사용하여 CIFAR-10 데이터를 다운로드합니다. +""" + +import torch +import torchvision +import torchvision.transforms as transforms +import os + +def download_cifar10(data_dir='./data'): + """ + CIFAR-10 데이터셋을 다운로드합니다. + + Args: + data_dir: 데이터를 저장할 디렉토리 경로 + """ + # 데이터 디렉토리 생성 + os.makedirs(data_dir, exist_ok=True) + + print(f"CIFAR-10 데이터를 {data_dir} 디렉토리에 다운로드합니다...") + + # 데이터 변환 (다운로드를 위한 기본 변환) + transform = transforms.ToTensor() + + # CIFAR-10 학습 데이터셋 다운로드 + print("학습 데이터셋 다운로드 중...") + trainset = torchvision.datasets.CIFAR10( + root=data_dir, + train=True, + download=True, + transform=transform + ) + print(f"학습 데이터셋 다운로드 완료: {len(trainset)}개 이미지") + + # CIFAR-10 테스트 데이터셋 다운로드 + print("테스트 데이터셋 다운로드 중...") + testset = torchvision.datasets.CIFAR10( + root=data_dir, + train=False, + download=True, + transform=transform + ) + print(f"테스트 데이터셋 다운로드 완료: {len(testset)}개 이미지") + + # 클래스 이름 출력 + classes = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') + print("\nCIFAR-10 클래스:") + for i, class_name in enumerate(classes): + print(f" {i}: {class_name}") + + print(f"\n다운로드 완료! 데이터는 {data_dir}/cifar-10-batches-py/ 디렉토리에 저장되었습니다.") + return trainset, testset + +if __name__ == "__main__": + # 현재 스크립트가 있는 디렉토리에 data 폴더 생성 + script_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(script_dir, 'data') + + try: + trainset, testset = download_cifar10(data_dir) + print("\n성공적으로 CIFAR-10 데이터를 다운로드했습니다!") + except Exception as e: + print(f"\n오류 발생: {e}") + import traceback + traceback.print_exc() + diff --git a/yeonsu/losses.py b/yeonsu/losses.py new file mode 100644 index 0000000..75dd1d1 --- /dev/null +++ b/yeonsu/losses.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FocalLoss(nn.Module): + def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, inputs, targets): + ce_loss = F.cross_entropy(inputs, targets, reduction='none') + pt = torch.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss + + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss + + +class LabelSmoothingCrossEntropy(nn.Module): + def __init__(self, smoothing=0.1, num_classes=10): + super(LabelSmoothingCrossEntropy, self).__init__() + self.smoothing = smoothing + self.num_classes = num_classes + + def forward(self, pred, target): + log_probs = F.log_softmax(pred, dim=1) + with torch.no_grad(): + true_dist = torch.zeros_like(log_probs) + true_dist.fill_(self.smoothing / (self.num_classes - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing) + + return torch.mean(torch.sum(-true_dist * log_probs, dim=1)) + + +class WeightedCrossEntropyLoss(nn.Module): + def __init__(self, class_weights=None, num_classes=10): + super(WeightedCrossEntropyLoss, self).__init__() + if class_weights is None: + self.class_weights = torch.ones(num_classes) + else: + self.class_weights = torch.tensor(class_weights, dtype=torch.float32) + + def forward(self, inputs, targets): + if inputs.is_cuda: + self.class_weights = self.class_weights.cuda() + + weights = self.class_weights[targets] + + ce_loss = F.cross_entropy(inputs, targets, reduction='none') + weighted_loss = weights * ce_loss + + return weighted_loss.mean() + + +def get_cat_dog_focused_weights(num_classes=10, cat_idx=3, dog_idx=5, weight=2.0): + weights = torch.ones(num_classes) + weights[cat_idx] = weight + weights[dog_idx] = weight + return weights.tolist() + diff --git a/yeonsu/models/resnet18.py b/yeonsu/models/resnet18.py new file mode 100644 index 0000000..5e1ba04 --- /dev/null +++ b/yeonsu/models/resnet18.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.shortcut = nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet18(nn.Module): + def __init__(self, num_classes=10): + super(ResNet18, self).__init__() + # cifar10 데이터셋(32 x 32 x 3)을 위해서 이거 kernel size 3, stride 1, padding 1 로 변경, 원래 논문에서는 kernel size 7, stride 2, padding 3 으로 되어있음 + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(64, 64, 2, stride=1) + self.layer2 = self._make_layer(64, 128, 2, stride=2) + self.layer3 = self._make_layer(128, 256, 2, stride=2) + self.layer4 = self._make_layer(256, 512, 2, stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512, num_classes) + + def _make_layer(self, in_channels, out_channels, num_blocks, stride): + layers = [] + layers.append(BasicBlock(in_channels, out_channels, stride)) + for _ in range(1, num_blocks): + layers.append(BasicBlock(out_channels, out_channels, stride=1)) + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + +def resnet18(num_classes=10): + return ResNet18(num_classes=num_classes) \ No newline at end of file diff --git a/yeonsu/requirements.txt b/yeonsu/requirements.txt new file mode 100644 index 0000000..3cbcf74 --- /dev/null +++ b/yeonsu/requirements.txt @@ -0,0 +1,8 @@ +torch>=2.0.0 +torchvision>=0.15.0 +tqdm>=4.65.0 +numpy>=1.24.0 +scikit-learn>=1.3.0 +matplotlib>=3.7.0 +seaborn>=0.12.0 + diff --git a/yeonsu/scripts/baseline_train.sh b/yeonsu/scripts/baseline_train.sh new file mode 100755 index 0000000..f31d545 --- /dev/null +++ b/yeonsu/scripts/baseline_train.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Baseline 실험: 기본 설정 (CrossEntropy Loss, 기본 데이터 증강) +# GPU 1번 사용 + +export CUDA_VISIBLE_DEVICES=1 + +python train.py \ + --epochs 150 \ + --lr 0.1 \ + --momentum 0.9 \ + --weight_decay 5e-4 \ + --scheduler step \ + --step_size 30 \ + --gamma 0.1 \ + --batch_size 128 \ + --num_workers 4 \ + --save_freq 10 \ + --loss_type ce \ + --data_dir ./data + diff --git a/yeonsu/scripts/exp1.sh b/yeonsu/scripts/exp1.sh new file mode 100755 index 0000000..16245cd --- /dev/null +++ b/yeonsu/scripts/exp1.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Exp1: Focal Loss + 고급 데이터 증강 + Cutout +# GPU 2번 사용 + +export CUDA_VISIBLE_DEVICES=2 + +python train.py \ + --epochs 150 \ + --lr 0.1 \ + --momentum 0.9 \ + --weight_decay 5e-4 \ + --scheduler step \ + --step_size 30 \ + --gamma 0.1 \ + --batch_size 128 \ + --num_workers 4 \ + --save_freq 10 \ + --loss_type focal \ + --focal_alpha 1.0 \ + --focal_gamma 2.0 \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 \ + --data_dir ./data + diff --git a/yeonsu/scripts/exp2.sh b/yeonsu/scripts/exp2.sh new file mode 100755 index 0000000..18718cd --- /dev/null +++ b/yeonsu/scripts/exp2.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Exp2: Focal Loss + 고급 데이터 증강 + Cutout + Mixup +# GPU 3번 사용 + +export CUDA_VISIBLE_DEVICES=3 + +python train.py \ + --epochs 150 \ + --lr 0.1 \ + --momentum 0.9 \ + --weight_decay 5e-4 \ + --scheduler step \ + --step_size 30 \ + --gamma 0.1 \ + --batch_size 128 \ + --num_workers 4 \ + --save_freq 10 \ + --loss_type focal \ + --focal_alpha 1.0 \ + --focal_gamma 2.0 \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 \ + --use_mixup \ + --mixup_alpha 1.0 \ + --data_dir ./data + diff --git a/yeonsu/scripts/exp3.sh b/yeonsu/scripts/exp3.sh new file mode 100755 index 0000000..a5d2c40 --- /dev/null +++ b/yeonsu/scripts/exp3.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Exp3: Weighted Loss (cat/dog 집중) + 고급 데이터 증강 + Cutout +# GPU 4번 사용 + +export CUDA_VISIBLE_DEVICES=4 + +python train.py \ + --epochs 150 \ + --lr 0.1 \ + --momentum 0.9 \ + --weight_decay 5e-4 \ + --scheduler step \ + --step_size 30 \ + --gamma 0.1 \ + --batch_size 128 \ + --num_workers 4 \ + --save_freq 10 \ + --loss_type weighted \ + --class_weight 2.0 \ + --use_advanced_aug \ + --use_cutout \ + --cutout_length 16 \ + --data_dir ./data + diff --git a/yeonsu/test.py b/yeonsu/test.py new file mode 100644 index 0000000..9470a74 --- /dev/null +++ b/yeonsu/test.py @@ -0,0 +1,233 @@ +""" +ResNet-18 모델 테스트/평가 스크립트 +""" + +import torch +import torch.nn as nn +from tqdm import tqdm +import argparse +import os +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix, classification_report + +from models.resnet18 import resnet18 +from dataset import get_cifar10_dataloaders + + +def test_model(model, test_loader, device, classes=None, save_confusion_matrix=False): + """모델 테스트 및 상세 통계 출력""" + model.eval() + criterion = nn.CrossEntropyLoss() + + test_loss = 0 + correct = 0 + total = 0 + class_correct = list(0. for _ in range(10)) + class_total = list(0. for _ in range(10)) + + # Confusion matrix를 위한 예측값과 실제값 저장 + all_preds = [] + all_targets = [] + + with torch.no_grad(): + for inputs, targets in tqdm(test_loader, desc='Testing'): + inputs, targets = inputs.to(device), targets.to(device) + outputs = model(inputs) + loss = criterion(outputs, targets) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + # Confusion matrix용 데이터 수집 + all_preds.extend(predicted.cpu().numpy()) + all_targets.extend(targets.cpu().numpy()) + + # 클래스별 정확도 계산 + c = (predicted == targets).squeeze() + for i in range(targets.size(0)): + label = targets[i] + class_correct[label] += c[i].item() + class_total[label] += 1 + + # 전체 통계 + test_loss /= len(test_loader) + test_acc = 100. * correct / total + + print('\n' + '=' * 60) + print('테스트 결과') + print('=' * 60) + print(f'전체 테스트 손실: {test_loss:.4f}') + print(f'전체 테스트 정확도: {test_acc:.2f}% ({correct}/{total})') + print('=' * 60) + + # 클래스별 정확도 + if classes: + print('\n클래스별 정확도:') + print('-' * 60) + for i in range(10): + if class_total[i] > 0: + acc = 100 * class_correct[i] / class_total[i] + print(f'{classes[i]:10s}: {acc:6.2f}% ({int(class_correct[i]):5d}/{int(class_total[i]):5d})') + else: + print(f'{classes[i]:10s}: N/A') + print('-' * 60) + + # Confusion Matrix 계산 및 출력 + all_preds = np.array(all_preds) + all_targets = np.array(all_targets) + cm = confusion_matrix(all_targets, all_preds) + + print('\n' + '=' * 60) + print('Confusion Matrix') + print('=' * 60) + print_confusion_matrix(cm, classes) + + # Classification Report + if classes: + print('\n' + '=' * 60) + print('Classification Report') + print('=' * 60) + report = classification_report( + all_targets, all_preds, + target_names=classes, + digits=4 + ) + print(report) + + return test_loss, test_acc, cm + + +def print_confusion_matrix(cm, classes=None): + """Confusion Matrix를 텍스트로 출력""" + if classes is None: + classes = [f'Class {i}' for i in range(len(cm))] + + # 헤더 출력 + print(f'\n{"실제/예측":>12}', end='') + for cls in classes: + print(f'{cls[:6]:>8}', end='') + print() + print('-' * (12 + 8 * len(classes))) + + # 행 출력 + for i, cls in enumerate(classes): + print(f'{cls:>12}', end='') + for j in range(len(classes)): + print(f'{cm[i, j]:>8}', end='') + print(f' (정확도: {100*cm[i,i]/cm[i].sum():.1f}%)') + print('-' * (12 + 8 * len(classes))) + + +def plot_confusion_matrix(cm, classes=None, save_path='confusion_matrix.png', figsize=(10, 8)): + """Confusion Matrix를 시각화하여 저장""" + if classes is None: + classes = [f'Class {i}' for i in range(len(cm))] + + # 정규화된 confusion matrix 계산 (비율로) + cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + + # Figure 생성 + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + # 원본 Confusion Matrix (개수) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=classes, yticklabels=classes, + ax=ax1, cbar_kws={'label': 'Count'}) + ax1.set_title('Confusion Matrix (Count)', fontsize=14, fontweight='bold') + ax1.set_xlabel('Predicted Label', fontsize=12) + ax1.set_ylabel('True Label', fontsize=12) + ax1.tick_params(axis='both', labelsize=9) + + # 정규화된 Confusion Matrix (비율) + sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', + xticklabels=classes, yticklabels=classes, + ax=ax2, cbar_kws={'label': 'Proportion'}) + ax2.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold') + ax2.set_xlabel('Predicted Label', fontsize=12) + ax2.set_ylabel('True Label', fontsize=12) + ax2.tick_params(axis='both', labelsize=9) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f'Confusion Matrix 이미지 저장 완료: {save_path}') + + +def load_model(checkpoint_path, device, num_classes=10): + """체크포인트에서 모델 로드""" + print(f'체크포인트 로드: {checkpoint_path}') + + # 모델 생성 + model = resnet18(num_classes=num_classes).to(device) + + # 체크포인트 로드 + checkpoint = torch.load(checkpoint_path, map_location=device) + + if 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + if 'epoch' in checkpoint and 'accuracy' in checkpoint: + print(f'체크포인트 정보: 에폭 {checkpoint["epoch"]}, 정확도 {checkpoint["accuracy"]:.2f}%') + else: + # state_dict만 저장된 경우 + model.load_state_dict(checkpoint) + + model.eval() + return model + + +def main(): + parser = argparse.ArgumentParser(description='ResNet-18 CIFAR-10 테스트') + parser.add_argument('--checkpoint', type=str, required=True, + help='체크포인트 파일 경로') + parser.add_argument('--data_dir', type=str, default='./data', + help='데이터 디렉토리 경로') + parser.add_argument('--batch_size', type=int, default=128, + help='배치 크기') + parser.add_argument('--num_workers', type=int, default=4, + help='데이터 로딩 워커 수') + parser.add_argument('--save_cm', action='store_true', + help='Confusion Matrix 이미지 저장') + parser.add_argument('--cm_path', type=str, default='confusion_matrix.png', + help='Confusion Matrix 저장 경로') + + args = parser.parse_args() + + # 체크포인트 파일 확인 + if not os.path.exists(args.checkpoint): + print(f'오류: 체크포인트 파일을 찾을 수 없습니다: {args.checkpoint}') + return + + # 디바이스 설정 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f'사용 디바이스: {device}') + + # 모델 로드 + model = load_model(args.checkpoint, device) + + # 데이터 로더 + print('\n데이터 로더 준비 중...') + _, test_loader, classes = get_cifar10_dataloaders( + data_dir=args.data_dir, + batch_size=args.batch_size, + num_workers=args.num_workers + ) + + # 테스트 실행 + test_loss, test_acc, cm = test_model( + model, test_loader, device, classes, + save_confusion_matrix=args.save_cm + ) + + if args.save_cm: + plot_confusion_matrix(cm, classes, save_path=args.cm_path) + + print(f'\n최종 테스트 정확도: {test_acc:.2f}%') + + +if __name__ == '__main__': + main() + diff --git a/yeonsu/train.py b/yeonsu/train.py new file mode 100644 index 0000000..52da3f0 --- /dev/null +++ b/yeonsu/train.py @@ -0,0 +1,461 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR +import os +import time +import json +import csv +import argparse +from datetime import datetime +from tqdm import tqdm +import matplotlib.pyplot as plt + +from models.resnet18 import resnet18 +from dataset import get_cifar10_dataloaders +from losses import FocalLoss, LabelSmoothingCrossEntropy, WeightedCrossEntropyLoss, get_cat_dog_focused_weights +from augmentation import Mixup, mixup_criterion + + +def train_epoch(model, train_loader, criterion, optimizer, device, epoch, use_mixup=False, mixup_alpha=1.0): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + mixup_fn = Mixup(alpha=mixup_alpha) if use_mixup else None + + pbar = tqdm(train_loader, desc=f'Epoch {epoch}') + for batch_idx, (inputs, targets) in enumerate(pbar): + inputs, targets = inputs.to(device), targets.to(device) + + if use_mixup and mixup_fn: + inputs, targets_a, targets_b, lam = mixup_fn((inputs, targets)) + + optimizer.zero_grad() + outputs = model(inputs) + + if use_mixup and mixup_fn: + loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam) + else: + loss = criterion(outputs, targets) + + loss.backward() + optimizer.step() + + running_loss += loss.item() + if use_mixup: + _, predicted = outputs.max(1) + total += targets.size(0) + correct += (lam * predicted.eq(targets_a).sum().item() + + (1 - lam) * predicted.eq(targets_b).sum().item()) + else: + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + pbar.set_postfix({ + 'loss': f'{running_loss/(batch_idx+1):.4f}', + 'acc': f'{100.*correct/total:.2f}%' + }) + + epoch_loss = running_loss / len(train_loader) + epoch_acc = 100. * correct / total + return epoch_loss, epoch_acc + + +def evaluate(model, test_loader, criterion, device): + model.eval() + test_loss = 0 + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, targets in tqdm(test_loader, desc='Evaluating'): + inputs, targets = inputs.to(device), targets.to(device) + outputs = model(inputs) + loss = criterion(outputs, targets) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + test_loss /= len(test_loader) + test_acc = 100. * correct / total + return test_loss, test_acc + + +def save_checkpoint(model, optimizer, scheduler, epoch, acc, filepath): + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'accuracy': acc, + } + torch.save(checkpoint, filepath) + print(f'체크포인트 저장: {filepath}') + + +def load_checkpoint(model, optimizer, scheduler, filepath): + checkpoint = torch.load(filepath) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + epoch = checkpoint['epoch'] + acc = checkpoint['accuracy'] + print(f'체크포인트 로드: {filepath} (에폭 {epoch}, 정확도 {acc:.2f}%)') + return epoch, acc + + +def create_experiment_folder(config, base_dir='./experiments'): + """실험 폴더 생성 및 하이퍼파라미터 저장""" + # 실험 이름 생성 + exp_name_parts = [] + + # 손실 함수 + if config['loss_type'] == 'focal': + exp_name_parts.append(f"focal_g{config['focal_gamma']}") + elif config['loss_type'] == 'weighted': + exp_name_parts.append(f"weighted_w{config['class_weight']}") + elif config['loss_type'] == 'label_smooth': + exp_name_parts.append(f"labelsmooth_s{config['label_smoothing']}") + else: + exp_name_parts.append("ce") + + # 데이터 증강 + aug_parts = [] + if config['use_advanced_aug']: + aug_parts.append("adv_aug") + if config['use_cutout']: + aug_parts.append(f"cutout{config['cutout_length']}") + if config['use_mixup']: + aug_parts.append(f"mixup{config['mixup_alpha']}") + + if aug_parts: + exp_name_parts.append("_".join(aug_parts)) + + # 타임스탬프 추가 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + exp_name = "_".join(exp_name_parts) + f"_{timestamp}" + + # 실험 폴더 생성 + exp_dir = os.path.join(base_dir, exp_name) + os.makedirs(exp_dir, exist_ok=True) + + # 하이퍼파라미터 저장 + config_path = os.path.join(exp_dir, 'config.json') + with open(config_path, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + print(f'실험 폴더 생성: {exp_dir}') + print(f'하이퍼파라미터 저장: {config_path}') + + return exp_dir + + +def save_results(exp_dir, train_losses, train_accs, test_losses, test_accs): + """학습 결과를 CSV로 저장""" + results_path = os.path.join(exp_dir, 'results.csv') + + with open(results_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow(['epoch', 'train_loss', 'train_acc', 'test_loss', 'test_acc']) + + for epoch in range(len(train_losses)): + writer.writerow([ + epoch + 1, + f'{train_losses[epoch]:.6f}', + f'{train_accs[epoch]:.2f}', + f'{test_losses[epoch]:.6f}', + f'{test_accs[epoch]:.2f}' + ]) + + print(f'학습 결과 저장: {results_path}') + + +def save_summary(exp_dir, best_acc, final_acc, total_time): + """최종 결과 요약 저장""" + summary = { + 'best_test_accuracy': best_acc, + 'final_test_accuracy': final_acc, + 'total_training_time_seconds': total_time, + 'completed_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + summary_path = os.path.join(exp_dir, 'summary.json') + with open(summary_path, 'w', encoding='utf-8') as f: + json.dump(summary, f, indent=4, ensure_ascii=False) + + print(f'결과 요약 저장: {summary_path}') + + +def plot_training_curves(exp_dir, train_losses, train_accs, test_losses, test_accs): + """학습 곡선 그래프 저장""" + epochs = range(1, len(train_losses) + 1) + + # Figure 생성 (2x2 subplot) + fig, axes = plt.subplots(2, 2, figsize=(15, 10)) + + # 1. Train/Test Loss + axes[0, 0].plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2) + axes[0, 0].plot(epochs, test_losses, 'r-', label='Test Loss', linewidth=2) + axes[0, 0].set_xlabel('Epoch', fontsize=12) + axes[0, 0].set_ylabel('Loss', fontsize=12) + axes[0, 0].set_title('Training and Test Loss', fontsize=14, fontweight='bold') + axes[0, 0].legend(fontsize=10) + axes[0, 0].grid(True, alpha=0.3) + + # 2. Train/Test Accuracy + axes[0, 1].plot(epochs, train_accs, 'b-', label='Train Accuracy', linewidth=2) + axes[0, 1].plot(epochs, test_accs, 'r-', label='Test Accuracy', linewidth=2) + axes[0, 1].set_xlabel('Epoch', fontsize=12) + axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12) + axes[0, 1].set_title('Training and Test Accuracy', fontsize=14, fontweight='bold') + axes[0, 1].legend(fontsize=10) + axes[0, 1].grid(True, alpha=0.3) + + # 3. Loss 비교 (확대) - 마지막 절반 에폭의 범위로 확대 + axes[1, 0].plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.7) + axes[1, 0].plot(epochs, test_losses, 'r-', label='Test Loss', linewidth=2, alpha=0.7) + axes[1, 0].set_xlabel('Epoch', fontsize=12) + axes[1, 0].set_ylabel('Loss', fontsize=12) + axes[1, 0].set_title('Loss (Zoomed)', fontsize=14, fontweight='bold') + axes[1, 0].legend(fontsize=10) + axes[1, 0].grid(True, alpha=0.3) + # 마지막 절반 에폭의 범위로 확대하여 세밀한 변화 관찰 + if len(train_losses) > 0: + half_point = len(train_losses) // 2 + zoom_losses = train_losses[half_point:] + test_losses[half_point:] + if len(zoom_losses) > 0: + min_loss = min(zoom_losses) + max_loss = max(zoom_losses) + margin = (max_loss - min_loss) * 0.15 + axes[1, 0].set_ylim([max(0, min_loss - margin), max_loss + margin]) + + # 4. Accuracy 비교 (확대) - 마지막 절반 에폭의 범위로 확대 + axes[1, 1].plot(epochs, train_accs, 'b-', label='Train Accuracy', linewidth=2, alpha=0.7) + axes[1, 1].plot(epochs, test_accs, 'r-', label='Test Accuracy', linewidth=2, alpha=0.7) + axes[1, 1].set_xlabel('Epoch', fontsize=12) + axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12) + axes[1, 1].set_title('Accuracy (Zoomed)', fontsize=14, fontweight='bold') + axes[1, 1].legend(fontsize=10) + axes[1, 1].grid(True, alpha=0.3) + # 마지막 절반 에폭의 범위로 확대하여 세밀한 변화 관찰 + if len(train_accs) > 0: + half_point = len(train_accs) // 2 + zoom_accs = train_accs[half_point:] + test_accs[half_point:] + if len(zoom_accs) > 0: + min_acc = min(zoom_accs) + max_acc = max(zoom_accs) + margin = (max_acc - min_acc) * 0.15 + axes[1, 1].set_ylim([max(0, min_acc - margin), min(100, max_acc + margin)]) + + plt.tight_layout() + + # 그래프 저장 + plot_path = os.path.join(exp_dir, 'training_curves.png') + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + plt.close() + + print(f'학습 곡선 그래프 저장: {plot_path}') + + +def parse_args(): + """커맨드라인 인자 파싱""" + parser = argparse.ArgumentParser(description='CIFAR-10 ResNet-18 학습') + + # 데이터 관련 + parser.add_argument('--data_dir', type=str, default='./data', help='데이터 디렉토리') + parser.add_argument('--batch_size', type=int, default=128, help='배치 크기') + parser.add_argument('--num_workers', type=int, default=4, help='데이터 로더 워커 수') + + # 학습 관련 + parser.add_argument('--epochs', type=int, default=150, help='학습 에폭 수') + parser.add_argument('--lr', type=float, default=0.1, help='초기 학습률') + parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') + parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay') + + # 스케줄러 + parser.add_argument('--scheduler', type=str, default='step', choices=['step', 'cosine', 'none'], help='학습률 스케줄러') + parser.add_argument('--step_size', type=int, default=30, help='StepLR step_size') + parser.add_argument('--gamma', type=float, default=0.1, help='StepLR gamma') + + # 저장 관련 + parser.add_argument('--save_freq', type=int, default=10, help='체크포인트 저장 주기') + + # 데이터 증강 + parser.add_argument('--use_advanced_aug', action='store_true', help='고급 데이터 증강 사용') + parser.add_argument('--use_cutout', action='store_true', help='Cutout 사용') + parser.add_argument('--cutout_length', type=int, default=16, help='Cutout 길이') + parser.add_argument('--use_mixup', action='store_true', help='Mixup 사용') + parser.add_argument('--mixup_alpha', type=float, default=1.0, help='Mixup alpha') + + # 손실 함수 + parser.add_argument('--loss_type', type=str, default='ce', choices=['ce', 'focal', 'weighted', 'label_smooth'], help='손실 함수 타입') + parser.add_argument('--focal_alpha', type=float, default=1.0, help='Focal Loss alpha') + parser.add_argument('--focal_gamma', type=float, default=2.0, help='Focal Loss gamma') + parser.add_argument('--class_weight', type=float, default=2.0, help='Weighted Loss 클래스 가중치') + parser.add_argument('--label_smoothing', type=float, default=0.1, help='Label Smoothing 값') + + return parser.parse_args() + + +def main(): + args = parse_args() + + config = { + 'data_dir': args.data_dir, + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'epochs': args.epochs, + 'lr': args.lr, + 'momentum': args.momentum, + 'weight_decay': args.weight_decay, + 'scheduler': args.scheduler, + 'step_size': args.step_size, + 'gamma': args.gamma, + 'save_dir': './checkpoints', + 'save_freq': args.save_freq, + + 'use_advanced_aug': args.use_advanced_aug, + 'use_cutout': args.use_cutout, + 'cutout_length': args.cutout_length, + 'use_mixup': args.use_mixup, + 'mixup_alpha': args.mixup_alpha, + + 'loss_type': args.loss_type, + 'focal_alpha': args.focal_alpha, + 'focal_gamma': args.focal_gamma, + 'class_weight': args.class_weight, + 'label_smoothing': args.label_smoothing, + } + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 실험 폴더 생성 + exp_dir = create_experiment_folder(config, base_dir='./experiments') + config['save_dir'] = os.path.join(exp_dir, 'checkpoints') + + train_loader, test_loader, classes = get_cifar10_dataloaders( + data_dir=config['data_dir'], + batch_size=config['batch_size'], + num_workers=config['num_workers'], + use_advanced_aug=config['use_advanced_aug'], + use_cutout=config['use_cutout'], + cutout_length=config['cutout_length'] + ) + + model = resnet18(num_classes=10).to(device) + + if config['loss_type'] == 'focal': + criterion = FocalLoss( + alpha=config['focal_alpha'], + gamma=config['focal_gamma'] + ).to(device) + elif config['loss_type'] == 'weighted': + class_weights = get_cat_dog_focused_weights( + num_classes=10, + weight=config['class_weight'] + ) + criterion = WeightedCrossEntropyLoss( + class_weights=class_weights + ).to(device) + elif config['loss_type'] == 'label_smooth': + criterion = LabelSmoothingCrossEntropy( + smoothing=config['label_smoothing'], + num_classes=10 + ).to(device) + else: + criterion = nn.CrossEntropyLoss() + + if config['use_mixup']: + print(f'Mixup 사용: alpha={config["mixup_alpha"]}') + if config['use_advanced_aug']: + print(f'고급 데이터 증강 사용: Cutout={config["use_cutout"]}') + optimizer = optim.SGD( + model.parameters(), + lr=config['lr'], + momentum=config['momentum'], + weight_decay=config['weight_decay'] + ) + + if config['scheduler'] == 'step': + scheduler = StepLR(optimizer, step_size=config['step_size'], gamma=config['gamma']) + elif config['scheduler'] == 'cosine': + scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs']) + else: + scheduler = None + + os.makedirs(config['save_dir'], exist_ok=True) + + best_acc = 0.0 + total_start_time = time.time() + + print(f'\n학습 시작 (총 {config["epochs"]} 에폭)...') + print('=' * 60) + + train_losses = [] + train_accs = [] + test_losses = [] + test_accs = [] + + for epoch in range(config['epochs']): + start_time = time.time() + + train_loss, train_acc = train_epoch( + model, train_loader, criterion, optimizer, device, epoch, + use_mixup=config['use_mixup'], + mixup_alpha=config['mixup_alpha'] + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device) + + if scheduler: + scheduler.step() + current_lr = scheduler.get_last_lr()[0] + else: + current_lr = config['lr'] + + train_losses.append(train_loss) + train_accs.append(train_acc) + test_losses.append(test_loss) + test_accs.append(test_acc) + + epoch_time = time.time() - start_time + print(f'\nEpoch {epoch+1}/{config["epochs"]} ({epoch_time:.1f}초)') + print(f'학습 - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%') + print(f'테스트 - Loss: {test_loss:.4f}, Acc: {test_acc:.2f}%') + print(f'Learning Rate: {current_lr:.6f}') + print('-' * 60) + + if test_acc > best_acc: + best_acc = test_acc + best_path = os.path.join(config['save_dir'], 'best_model.pth') + save_checkpoint(model, optimizer, scheduler, epoch, test_acc, best_path) + + if (epoch + 1) % config['save_freq'] == 0: + checkpoint_path = os.path.join(config['save_dir'], f'checkpoint_epoch_{epoch+1}.pth') + save_checkpoint(model, optimizer, scheduler, epoch, test_acc, checkpoint_path) + + total_time = time.time() - total_start_time + final_acc = test_accs[-1] if len(test_accs) > 0 else 0.0 + + # 결과 저장 + save_results(exp_dir, train_losses, train_accs, test_losses, test_accs) + save_summary(exp_dir, best_acc, final_acc, total_time) + plot_training_curves(exp_dir, train_losses, train_accs, test_losses, test_accs) + + print('\n' + '=' * 60) + print('학습 완료!') + print(f'실험 폴더: {exp_dir}') + print(f'최고 테스트 정확도: {best_acc:.2f}%') + print(f'최종 테스트 정확도: {final_acc:.2f}%') + print(f'총 학습 시간: {total_time/60:.1f}분') + print('=' * 60) + + +if __name__ == '__main__': + main() +