์ธ๊ณต์ง€๋Šฅ(AI) ๐Ÿ“š

๊ฐ„๋‹จํ•œ CNN์„ ์ด์šฉํ•œ ์‚ฌ์ž์™€ ํ˜ธ๋ž‘์ด ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๊ธฐ ๊ตฌํ˜„

leejaejae 2024. 6. 19. 14:33

์ด๋ฒˆ ํฌ์ŠคํŠธ์—์„œ๋Š” PyTorch๋ฅผ ์ด์šฉํ•ด ๊ฐ„๋‹จํ•œ CNN ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜๊ณ , ์ด๋ฅผ ํ†ตํ•ด ์‚ฌ์ž์™€ ํ˜ธ๋ž‘์ด ์ด๋ฏธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ๋ถ€ํ„ฐ ๋ชจ๋ธ ํ•™์Šต, ํ‰๊ฐ€๊นŒ์ง€์˜ ์ „์ฒด ๊ณผ์ •์„ ์ž์„ธํžˆ ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

< ๋ชฉ์ฐจ >

1. ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„
2. ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜
3. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์„ค์ •
4. ๋ฐ์ดํ„ฐ์…‹ ๋ถ„ํ•  ๋ฐ ๋ฐ์ดํ„ฐ ๋กœ๋” ์ƒ์„ฑ
5. ๊ฐ„๋‹จํ•œ CNN ๋ชจ๋ธ ์ •์˜
6. ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๊ฒ€์ฆ
7. ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ํ‰๊ฐ€
8. ํ•™์Šต ๋ฐ ๊ฒ€์ฆ ์†์‹ค, ์ •ํ™•๋„ ์‹œ๊ฐํ™”
9. ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
10. ๊ฒฐ๊ณผ


1. ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„

๋จผ์ €, ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ , ๊ฐ ์ด๋ฏธ์ง€๋ฅผ ๋ ˆ์ด๋ธ”๋งํ•ฉ๋‹ˆ๋‹ค. ์‚ฌ์ž๋Š” 0, ํ˜ธ๋ž‘์ด๋Š” 1๋กœ ๋ผ๋ฒจ์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.

import glob

lion_image_paths = glob.glob('/mnt/lion/*.jpg')
tiger_image_paths = glob.glob('/mnt/tiger/*.jpg')

image_paths = lion_image_paths + tiger_image_paths
labels = [0] * len(lion_image_paths) + [1] * len(tiger_image_paths)  # 0: ์‚ฌ์ž, 1: ํ˜ธ๋ž‘์ด

print(f'Total images: {len(image_paths)}, Total labels: {len(labels)}')

 

2. ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜

PyTorch์˜ Dataset ํด๋ž˜์Šค๋ฅผ ์ƒ์†๋ฐ›์•„ ์ปค์Šคํ…€ ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋ณ€ํ™˜(transform)์€ ํ•™์Šต๊ณผ ํ…Œ์ŠคํŠธ ๋‹จ๊ณ„์—์„œ ๋‹ค๋ฅด๊ฒŒ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

 

3. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์„ค์ •

ํ•™์Šต ๋ฐ์ดํ„ฐ๋Š” ๋‹ค์–‘ํ•œ ์ด๋ฏธ์ง€ ์ฆ๊ฐ• ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•ด ๋ฐ์ดํ„ฐ์…‹์„ ํ’๋ถ€ํ•˜๊ฒŒ ํ•˜๊ณ , ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋Š” ๊ธฐ๋ณธ์ ์ธ ํฌ๊ธฐ ์กฐ์ •๋งŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

 

4. ๋ฐ์ดํ„ฐ์…‹ ๋ถ„ํ•  ๋ฐ ๋ฐ์ดํ„ฐ ๋กœ๋” ์ƒ์„ฑ

๋ฐ์ดํ„ฐ์…‹์„ ํ•™์Šต, ๊ฒ€์ฆ, ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋‚˜๋ˆ„๊ณ , ๊ฐ๊ฐ์˜ ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

from torch.utils.data import DataLoader, random_split

dataset = CustomDataset(image_paths=image_paths, labels=labels, transform=train_transform)

train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {test_size}')

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

val_dataset.dataset.transform = test_transform
test_dataset.dataset.transform = test_transform

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

 

5. ๊ฐ„๋‹จํ•œ CNN ๋ชจ๋ธ ์ •์˜

๊ฐ„๋‹จํ•œ CNN ๋ชจ๋ธ์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. 2๊ฐœ์˜ ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด์™€ 2๊ฐœ์˜ FC ๋ ˆ์ด์–ด๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 56 * 56, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

 

6. ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๊ฒ€์ฆ

๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ณ , ๊ฒ€์ฆ ์„ธํŠธ์—์„œ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ, Early Stopping์„ ์ ์šฉํ•ด ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค.

train_losses = []
val_losses = []
val_accuracies = []

best_val_accuracy = 0
early_stop_patience = 5
early_stop_counter = 0

# training
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')

    # Validate
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = 100 * correct / total
    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(val_accuracy)
    print(f'Validation Loss: {val_loss/len(val_loader)}, Validation Accuracy: {val_accuracy}%')

    # Early stopping
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        early_stop_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        early_stop_counter += 1

    if early_stop_counter >= early_stop_patience:
        print("Early stopping triggered.")
        break

model.load_state_dict(torch.load('best_model.pth'))

 

7. ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ํ‰๊ฐ€

์ตœ์ข… ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์—์„œ ํ‰๊ฐ€ํ•˜๊ณ , ํ…Œ์ŠคํŠธ ์ •ํ™•๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

model.eval()
correct = 0
total = 0
all_images = []
all_labels = []
all_preds = []
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        all_images.extend(images.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy}%')

 

8. ํ•™์Šต ๋ฐ ๊ฒ€์ฆ ์†์‹ค, ์ •ํ™•๋„ ์‹œ๊ฐํ™”

ํ•™์Šต ๋ฐ ๊ฒ€์ฆ ์†์‹ค, ๊ฒ€์ฆ ์ •ํ™•๋„๋ฅผ ๊ทธ๋ž˜ํ”„๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Validation Accuracy')
plt.show()

 

9. ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์—์„œ ์˜ˆ์ธก๋œ ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

import numpy as np

classes = ['lion', 'tiger']
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()

for i in range(10):
    image = all_images[i].transpose((1, 2, 0))
    image = (image * 255).astype(np.uint8)
    label = all_labels[i]
    pred = all_preds[i]

    axes[i].imshow(image)
    axes[i].add_patch(plt.Rectangle((0, 0), image.shape[1], image.shape[0], fill=False, edgecolor='red', linewidth=2))
    axes[i].text(5, 25, f'Pred: {classes[pred]}', bbox=dict(facecolor='white', alpha=0.75))
    axes[i].text(5, 50, f'True: {classes[label]}', bbox=dict(facecolor='white', alpha=0.75))
    axes[i].axis('off')

plt.tight_layout()
plt.show()

 

10. ๊ฒฐ๊ณผ

ํ•™์Šต ๋ฐ ๊ฒ€์ฆ ์†์‹ค


๊ฒ€์ฆ ์ •ํ™•๋„


ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ

 

์ด์ƒ์œผ๋กœ ์‚ฌ์ž์™€ ํ˜ธ๋ž‘์ด ์ด๋ฏธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” CNN ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜๋Š” ์ „์ฒด ๊ณผ์ •์„ ๋งˆ์ณค์Šต๋‹ˆ๋‹ค. ์ด๋ฒˆ ํฌ์ŠคํŠธ๊ฐ€ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋˜๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค. ์ถ”๊ฐ€์ ์ธ ์งˆ๋ฌธ์ด ์žˆ์œผ์‹œ๋ฉด ์–ธ์ œ๋“ ์ง€ ๋Œ“๊ธ€๋กœ ๋‚จ๊ฒจ์ฃผ์„ธ์š”!