์ด๋ฒ ํฌ์คํธ์์๋ PyTorch๋ฅผ ์ด์ฉํด ๊ฐ๋จํ CNN ๋ชจ๋ธ์ ๊ตฌ์ถํ๊ณ , ์ด๋ฅผ ํตํด ์ฌ์์ ํธ๋์ด ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฅํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด๊ฒ ์ต๋๋ค. ๋ฐ์ดํฐ์ ๋ก๋๋ถํฐ ๋ชจ๋ธ ํ์ต, ํ๊ฐ๊น์ง์ ์ ์ฒด ๊ณผ์ ์ ์์ธํ ์ค๋ช ํ๊ฒ ์ต๋๋ค.
< ๋ชฉ์ฐจ >
1. ๋ฐ์ดํฐ์
์ค๋น
2. ๋ฐ์ดํฐ์
ํด๋์ค ์ ์
3. ๋ฐ์ดํฐ ๋ณํ ์ค์
4. ๋ฐ์ดํฐ์
๋ถํ ๋ฐ ๋ฐ์ดํฐ ๋ก๋ ์์ฑ
5. ๊ฐ๋จํ CNN ๋ชจ๋ธ ์ ์
6. ๋ชจ๋ธ ํ์ต ๋ฐ ๊ฒ์ฆ
7. ํ
์คํธ ๋ฐ์ดํฐ ํ๊ฐ
8. ํ์ต ๋ฐ ๊ฒ์ฆ ์์ค, ์ ํ๋ ์๊ฐํ
9. ํ
์คํธ ๊ฒฐ๊ณผ ์๊ฐํ
10. ๊ฒฐ๊ณผ
1. ๋ฐ์ดํฐ์ ์ค๋น
๋จผ์ , ์ด๋ฏธ์ง ๊ฒฝ๋ก๋ฅผ ๋ถ๋ฌ์ค๊ณ , ๊ฐ ์ด๋ฏธ์ง๋ฅผ ๋ ์ด๋ธ๋งํฉ๋๋ค. ์ฌ์๋ 0, ํธ๋์ด๋ 1๋ก ๋ผ๋ฒจ์ ์ง์ ํฉ๋๋ค.
import glob
lion_image_paths = glob.glob('/mnt/lion/*.jpg')
tiger_image_paths = glob.glob('/mnt/tiger/*.jpg')
image_paths = lion_image_paths + tiger_image_paths
labels = [0] * len(lion_image_paths) + [1] * len(tiger_image_paths) # 0: ์ฌ์, 1: ํธ๋์ด
print(f'Total images: {len(image_paths)}, Total labels: {len(labels)}')
2. ๋ฐ์ดํฐ์ ํด๋์ค ์ ์
PyTorch์ Dataset ํด๋์ค๋ฅผ ์์๋ฐ์ ์ปค์คํ ๋ฐ์ดํฐ์ ํด๋์ค๋ฅผ ์ ์ํฉ๋๋ค. ์ด๋ฏธ์ง ๋ณํ(transform)์ ํ์ต๊ณผ ํ ์คํธ ๋จ๊ณ์์ ๋ค๋ฅด๊ฒ ์ค์ ํฉ๋๋ค.
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
label = self.labels[idx]
return image, label
3. ๋ฐ์ดํฐ ๋ณํ ์ค์
ํ์ต ๋ฐ์ดํฐ๋ ๋ค์ํ ์ด๋ฏธ์ง ์ฆ๊ฐ ๊ธฐ๋ฒ์ ์ฌ์ฉํด ๋ฐ์ดํฐ์ ์ ํ๋ถํ๊ฒ ํ๊ณ , ํ ์คํธ ๋ฐ์ดํฐ๋ ๊ธฐ๋ณธ์ ์ธ ํฌ๊ธฐ ์กฐ์ ๋ง ์ํํฉ๋๋ค.
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
transforms.RandomGrayscale(p=0.1),
transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
4. ๋ฐ์ดํฐ์ ๋ถํ ๋ฐ ๋ฐ์ดํฐ ๋ก๋ ์์ฑ
๋ฐ์ดํฐ์ ์ ํ์ต, ๊ฒ์ฆ, ํ ์คํธ ์ธํธ๋ก ๋๋๊ณ , ๊ฐ๊ฐ์ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค.
from torch.utils.data import DataLoader, random_split
dataset = CustomDataset(image_paths=image_paths, labels=labels, transform=train_transform)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {test_size}')
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
val_dataset.dataset.transform = test_transform
test_dataset.dataset.transform = test_transform
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
5. ๊ฐ๋จํ CNN ๋ชจ๋ธ ์ ์
๊ฐ๋จํ CNN ๋ชจ๋ธ์ ์ ์ํฉ๋๋ค. 2๊ฐ์ ์ปจ๋ณผ๋ฃจ์ ๋ ์ด์ด์ 2๊ฐ์ FC ๋ ์ด์ด๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 56 * 56, 512)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 32 * 56 * 56)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
6. ๋ชจ๋ธ ํ์ต ๋ฐ ๊ฒ์ฆ
๋ชจ๋ธ์ ํ์ตํ๊ณ , ๊ฒ์ฆ ์ธํธ์์์ ์ฑ๋ฅ์ ํ๊ฐํฉ๋๋ค. ๋ํ, Early Stopping์ ์ ์ฉํด ๊ณผ์ ํฉ์ ๋ฐฉ์งํฉ๋๋ค.
train_losses = []
val_losses = []
val_accuracies = []
best_val_accuracy = 0
early_stop_patience = 5
early_stop_counter = 0
# training
num_epochs = 50
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
train_losses.append(epoch_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')
# Validate
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_accuracy = 100 * correct / total
val_losses.append(val_loss / len(val_loader))
val_accuracies.append(val_accuracy)
print(f'Validation Loss: {val_loss/len(val_loader)}, Validation Accuracy: {val_accuracy}%')
# Early stopping
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
early_stop_counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
early_stop_counter += 1
if early_stop_counter >= early_stop_patience:
print("Early stopping triggered.")
break
model.load_state_dict(torch.load('best_model.pth'))
7. ํ ์คํธ ๋ฐ์ดํฐ ํ๊ฐ
์ต์ข ๋ชจ๋ธ์ ํ ์คํธ ๋ฐ์ดํฐ์์ ํ๊ฐํ๊ณ , ํ ์คํธ ์ ํ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.
model.eval()
correct = 0
total = 0
all_images = []
all_labels = []
all_preds = []
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_images.extend(images.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_preds.extend(predicted.cpu().numpy())
test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy}%')
8. ํ์ต ๋ฐ ๊ฒ์ฆ ์์ค, ์ ํ๋ ์๊ฐํ
ํ์ต ๋ฐ ๊ฒ์ฆ ์์ค, ๊ฒ์ฆ ์ ํ๋๋ฅผ ๊ทธ๋ํ๋ก ์๊ฐํํฉ๋๋ค.
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()
plt.figure(figsize=(10, 5))
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Validation Accuracy')
plt.show()
9. ํ ์คํธ ๊ฒฐ๊ณผ ์๊ฐํ
ํ ์คํธ ๋ฐ์ดํฐ์์ ์์ธก๋ ๊ฒฐ๊ณผ๋ฅผ ์๊ฐํํฉ๋๋ค.
import numpy as np
classes = ['lion', 'tiger']
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()
for i in range(10):
image = all_images[i].transpose((1, 2, 0))
image = (image * 255).astype(np.uint8)
label = all_labels[i]
pred = all_preds[i]
axes[i].imshow(image)
axes[i].add_patch(plt.Rectangle((0, 0), image.shape[1], image.shape[0], fill=False, edgecolor='red', linewidth=2))
axes[i].text(5, 25, f'Pred: {classes[pred]}', bbox=dict(facecolor='white', alpha=0.75))
axes[i].text(5, 50, f'True: {classes[label]}', bbox=dict(facecolor='white', alpha=0.75))
axes[i].axis('off')
plt.tight_layout()
plt.show()
10. ๊ฒฐ๊ณผ
ํ์ต ๋ฐ ๊ฒ์ฆ ์์ค
๊ฒ์ฆ ์ ํ๋
ํ
์คํธ ๊ฒฐ๊ณผ
์ด์์ผ๋ก ์ฌ์์ ํธ๋์ด ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฅํ๋ CNN ๋ชจ๋ธ์ ๊ตฌ์ถํ๋ ์ ์ฒด ๊ณผ์ ์ ๋ง์ณค์ต๋๋ค. ์ด๋ฒ ํฌ์คํธ๊ฐ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ์ ๊ตฌ์ถํ๋ ๋ฐ ๋์์ด ๋๊ธธ ๋ฐ๋๋๋ค. ์ถ๊ฐ์ ์ธ ์ง๋ฌธ์ด ์์ผ์๋ฉด ์ธ์ ๋ ์ง ๋๊ธ๋ก ๋จ๊ฒจ์ฃผ์ธ์!