[Graph Neural Networks] GraphSAGE
2025.06.08 - [Data & Research] - [Graph Neural Networks] Table of Contents
뭔가 GCN이 큰 고비같은 느낌이었는데 마치고나니 한결 수월하게 다음 내용을 포스팅하도록 하겠습니다. 일단 들어가기 전에 이전 포스팅의 내용 중 다시 리마인드해야할 부분이 있습니다.
2025.06.03 - [Data & Research] - [Graph Neural Networks] Transductive vs. Inductive
[Graph Neural Networks] Transductive vs. Inductive
본격적으로 Graph Neural Network에 들어가기 앞 서 일단 먼저 알아두어야 할 컨셉이 있습니다. Transductive Learning과 Inductive Learning인데요 1. Transductive LearningTransductive Learning 환경에서는 모델이 학습 단
trillionver2.tistory.com
1. Transductive model의 한계
GraphSAGE 이전의 GCN와 같은 GNN 모델들은 기본적으로 transductive learning 세팅을 따르고 있는 모델들이었습니다. 즉, 소셜 네트워크에 새로운 사용자가 가입하거나, 단백질 상호작용 네트워크에서 새로운 단백질이 발견되는 등 그래프가 동적으로 변하거나, 학습 과정에서 보지 못한 노드가 등장하는 현실 세계의 문제에 적용하기 어렵다는 치명적인 단점이 있습니다. 이런 세팅에서는 새로운 노드 하나가 추가될 때마다 전체 그래프 구조가 바뀌므로 모델 전체를 재학습해야 합니다.
2. GraphSAGE의 목적과 원리
GraphSAGE는 inductive learning 맥락에서 이러한 한계를 극복하기 위해 제안되었습니다. 특정 노드의 embedding 자체를 학습하는 것이 아니라, 노드의 특징과 local 구조를 바탕으로 embedding을 생성하는 함수를 학습하는 것입니다(즉, aggregate function이 학습가능한(learnable) 형태).
GraphSAGE의 이름(Graph SAmple and aggreGatE)이 그 핵심 원리를 말해줍니다. 각 노드는 자신의 지역 이웃 정보를 샘플링하고, 샘플링된 이웃들의 정보를 집계(Aggregate)하여 자신의 embedding(hidden representation)을 업데이트합니다. Message Passing 패러다임에 맞추어 구체적으로 한 번 살펴볼까요?
step1) Neighborhood Aggregation
먼저 노드 \(v\)의 이웃집합 \(\mathcal{N}(v)\)에서 고정된 개수\( \mathcal{S}_{\mathcal{N}(v)} \)의 이웃을 uniform sampling합니다. 그러면 샘플링된 이웃들을 대상으로 이전 \(k-1\) layer의 embedding을 가져와서 Aggregator function을 이용하여 집계합니다.
\[ \mathbf{a}_v^{(k)} = \text{AGGREGATE}_k\left(\{\mathbf{h}_u^{(k-1)} : u \in \mathcal{S}_{\mathcal{N}(v)}\}\right) \]
Aggregation이라는 추상적 이름으로 표기한 것은 그 방식이 다양할 수 있기 때문입니다.
- 평균 집계(Mean Aggregator) : 샘플링 된 이웃 embedding의 원소별 평균을 취하는 방식입니다.
\[ \mathbf{a}_v^{(k)} = \frac{1}{|\mathcal{S}_{\mathcal{N}(v)}|} \sum_{u \in \mathcal{S}_{\mathcal{N}(v)}} \mathbf{h}_u^{(k-1)} \] - 풀링 집계(Polling Aggregator) : 각 이웃 embedding을 독립적인 신경망에 통과시킨 후, 순서에 영향을 받지 않는 연산(예: 요소별 최대값(max-pooling) 또는 평균)을 적용하는 방식입니다.
$$ \mathbf{a}_v^{(k)} = \max\left(\{\sigma(W_{\text{pool}}\mathbf{h}_u^{(k-1)} + \mathbf{b}) : u \in \mathcal{S}_{\mathcal{N}(v)}\}\right) $$ - LSTM 집계(LSTM Aggregator) : 이웃들이 순서를 갖지 않으므로, 이웃 임베딩들을 임의의 순서로 나열하여 LSTM의 입력으로 사용합니다. 사실 LSTM은 순차적 정보를 처리하기 위한 구조이므로 이 구조를 GraphSAGE에서 사용하기 위해서는 이점에 유의하여 사용해야 합니다(일반적으로 그래프의 node에는 순서가 존재하지 않으니까요).
step2) Update
앞 단계에서 집계된 이전 layer에서의 embedding과 현재 노드의 embedding을 Concatenate합니다. 이 값을 다시 선형 weight로 변환 후 activation함수를 통과시켜 새로운 layer에서의 embedding을 얻어냅니다. \[ \mathbf{h}_v^{(k)} = \sigma\left(W^{(k)} \cdot \text{CONCAT}(\mathbf{h}_v^{(k-1)}, \mathbf{a}_v^{(k)})\right) \]
3. 목적함수(Objective Function)
사실 GraphSAGE도 사용하는 목적에 따라서 목적함수는 달라질 수 있는데요. 논문의 저자들은 목적이 유사도에 기반한 embedding을 학습(Representation Learning)하는 것 - graph상에서 가까운 노드들은 embedding 공간에서도 가깝게 멀리있는 노드들은 embedding공간에서도 멀게 기억나시죠? - 이었기 때문에
$$ J_G(\mathbf{z}_u) = -\log(\sigma(\mathbf{z}_u^T \mathbf{z}_v)) - Q \cdot \mathbb{E}_{v_n \sim P_n(v)}[\log(\sigma(-\mathbf{z}_{v_n}^T \mathbf{z}_v))] $$
와 같은 Loss를 목적함수로 사용하였습니다.
\( -\log(\sigma(\mathbf{z}_u^T \mathbf{z}_v)) \)는 Positive Sample에 대한 loss로 내적 값을 크게 만들어 두 embedding vector가 유사해지도록 학습시킵니다.
\( -\log(\sigma(-\mathbf{z}_{v_n}^T \mathbf{z}_v)) \)는 Negative Sample에 대한 loss로 embedding vector가 반대방향을 가리키도록 학습시킵니다.
\(Q\)는 Positive Sample 하나 당 사용할 Negative Sample이 개수로 \( \mathbb{E}_{v_n \sim P_n(v)} \)를 사실은 \(Q\)개의 negative sample을 뽑아 평균을 내는 방식으로 계산하게 됩니다.