Medical AI & Article Review

PlexusNet: A neural network architectural concept for medical image classification

Kimhj 2023. 12. 1. 10:47
  • 최신 합성곱 신경망 모델은 의료 영상에서 광범위하게 사용되어 다양한 임상 문제에 대응하고 있지만, 이러한 모델의 복잡성과 규모는 의료 영상에서 정당화되지 않을 수 있으며, 사용 가능한 자원 예산에 따라 다름
  • Feature map의 수를 증가시키면 분류 작업에 대한 모델 설명력이 감소하고, 현재의 데이터 정규화 방법은 모델 개발 전에 고정되어 있으며 데이터 도메인의 명시를 무시하는 문제가 있음.
  • 이러한 문제를 고려하여, PlexusNet이라는 새로운 확장 가능한 모델을 제안함. PlexusNet의 구조는 네트워크의 깊이, 너비 및 분기에 따라 조절되는 블록 아키텍처와 모델 스케일링을 포함하고 있음.
  • PlexusNet은 더 나은 데이터 일반화를 위한 새로운 학습 가능한 데이터 정규화 알고리즘을 포함하고 있음. 본 논문에서는 다섯 가지 임상 분류 문제에 대한 PlexusNet을 설계하기 위해 간단하면서 효과적인 신경 아키텍처 탐색을 적용했봤고, 이는 현 SOTA 모델인  ResNet-18 및 EfficientNet B0/1에 미치지 않는 성능을 달성했음.
  • 그리고 SOTA 모델과 유사한 성능을 가진 PlexusNet은 훨씬 낮은 매개변수 용량과 대표적인 feature map을 가지고 있으며, PlexusNet이 생성한 잠재적인 특징을 기반으로 범주에 연관된 구별 가능한 클러스터를 시각화하여 결과를 보여주었음. (실제 논문 결과에는 포함되어 있지 않음)

 

  • 모델 코드
class PatchConvNet(keras.Model):
    def __init__(
        self,
        stem,
        trunk,
        attention_pooling,
        preprocessing_model,
        train_augmentation_model,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.stem = stem
        self.trunk = trunk
        self.attention_pooling = attention_pooling
        self.train_augmentation_model = train_augmentation_model
        self.preprocessing_model = preprocessing_model

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "stem": self.stem,
                "trunk": self.trunk,
                "attention_pooling": self.attention_pooling,
                "train_augmentation_model": self.train_augmentation_model,
                "preprocessing_model": self.preprocessing_model,
            }
        )
        return config

    def _calculate_loss(self, inputs, test=False):
        images, labels = inputs
        # Augment the input images.
        if test:
            augmented_images = self.preprocessing_model(images)
        else:
            augmented_images = self.train_augmentation_model(images)
        # Pass through the stem.
        x = self.stem(augmented_images)
        # Pass through the trunk.
        x = self.trunk(x)
        # Pass through the attention pooling block.
        logits, _ = self.attention_pooling(x)
        # Compute the total loss.
        total_loss = self.compiled_loss(labels, logits)
        return total_loss, logits

    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            total_loss, logits = self._calculate_loss(inputs)
        # Apply gradients.
        train_vars = [
            self.stem.trainable_variables,
            self.trunk.trainable_variables,
            self.attention_pooling.trainable_variables,
        ]
        grads = tape.gradient(total_loss, train_vars)
        trainable_variable_list = []
        for (grad, var) in zip(grads, train_vars):
            for g, v in zip(grad, var):
                trainable_variable_list.append((g, v))
        self.optimizer.apply_gradients(trainable_variable_list)
        # Report progress.
        _, labels = inputs
        self.compiled_metrics.update_state(labels, logits)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, inputs):
        total_loss, logits = self._calculate_loss(inputs, test=True)
        # Report progress.
        _, labels = inputs
        self.compiled_metrics.update_state(labels, logits)
        return {m.name: m.result() for m in self.metrics}

    def call(self, images):
        # Augment the input images.
        augmented_images = self.preprocessing_model(images)
        # Pass through the stem.
        x = self.stem(augmented_images)
        # Pass through the trunk.
        x = self.trunk(x)
        # Pass through the attention pooling block.
        logits, viz_weights = self.attention_pooling(x)
        return logits, viz_weights

'Medical AI & Article Review' 카테고리의 다른 글

Image artifact generate 설명  (0) 2023.12.18
상관분석(Correlation)  (1) 2023.12.11
BI tool  (0) 2023.10.11
Ad-hoc 데이터 분석  (0) 2023.10.11
Airflow, Kafka, dbt, Presto  (0) 2023.10.11