圖神經(jīng)網(wǎng)絡(luò)相似度計(jì)算
圖神經(jīng)網(wǎng)絡(luò)相似度計(jì)算
注:大家覺得博客好的話,別忘了點(diǎn)贊收藏呀,本人每周都會(huì)更新關(guān)于人工智能和大數(shù)據(jù)相關(guān)的內(nèi)容,內(nèi)容多為原創(chuàng),Python Java Scala SQL 代碼,CV NLP 推薦系統(tǒng)等,Spark Flink Kafka Hbase Hive Flume等等~寫的都是純干貨,各種頂會(huì)的論文解讀,一起進(jìn)步。
今天和大家分享一篇關(guān)于圖神經(jīng)網(wǎng)絡(luò)相似度計(jì)算的論文
SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
#博學(xué)谷IT學(xué)習(xí)技術(shù)支持#
前言
圖神經(jīng)網(wǎng)絡(luò)是當(dāng)下比較火的模型之一,使用神經(jīng)網(wǎng)絡(luò)來學(xué)習(xí)圖結(jié)構(gòu)數(shù)據(jù),提取和發(fā)掘圖結(jié)構(gòu)數(shù)據(jù)中的特征和模式,滿足聚類、分類、預(yù)測、分割、生成等圖學(xué)習(xí)任務(wù)需求的算法。本文是主要通過圖神經(jīng)網(wǎng)絡(luò)來對(duì)兩個(gè)圖的相似性進(jìn)行快速打分的模型。
一、訓(xùn)練數(shù)據(jù)
本文采用torch內(nèi)置數(shù)據(jù)集GEDDataset,直接調(diào)用就可以了,數(shù)據(jù)集一共有700個(gè)圖,每個(gè)圖最多有10個(gè)點(diǎn)組成,每個(gè)點(diǎn)由29種特征組成
代碼如下(示例):
def process_dataset(self):
"""
Downloading and processing dataset.
"""
print("\nPreparing dataset.\n")
self.training_graphs = GEDDataset(
"datasets/{}".format(self.args.dataset), self.args.dataset, train=True
)
self.testing_graphs = GEDDataset(
"datasets/{}".format(self.args.dataset), self.args.dataset, train=False
)
二、模型的輸入
每次輸入兩幅圖,包含邊的信息了,點(diǎn)的特征
代碼如下(示例):
def forward(self, data):
edge_index_1 = data["g1"].edge_index
edge_index_2 = data["g2"].edge_index
features_1 = data["g1"].x
print(features_1.shape)
features_2 = data["g2"].x
print(features_2.shape)
batch_1 = (
data["g1"].batch
if hasattr(data["g1"], "batch")
else torch.tensor((), dtype=torch.long).new_zeros(data["g1"].num_nodes)
)
batch_2 = (
data["g2"].batch
if hasattr(data["g2"], "batch")
else torch.tensor((), dtype=torch.long).new_zeros(data["g2"].num_nodes)
)
三、圖神經(jīng)網(wǎng)絡(luò)提取更新每個(gè)點(diǎn)的信息
這里運(yùn)用直方圖方式做特征比較新穎。
def convolutional_pass(self, edge_index, features):
"""
Making convolutional pass.
:param edge_index: Edge indices.
:param features: Feature matrix.
:return features: Abstract feature matrix.
"""
features = self.convolution_1(features, edge_index)
features = F.relu(features)
features = F.dropout(features, p=self.args.dropout, training=self.training)
features = self.convolution_2(features, edge_index)
features = F.relu(features)
features = F.dropout(features, p=self.args.dropout, training=self.training)
features = self.convolution_3(features, edge_index)
return features
#每個(gè)點(diǎn)都走三層gcn
abstract_features_1 = self.convolutional_pass(edge_index_1, features_1)
print(abstract_features_1.shape)
abstract_features_2 = self.convolutional_pass(edge_index_2, features_2)
print(abstract_features_2.shape)
四、計(jì)算點(diǎn)和點(diǎn)之間的關(guān)系得到直方圖特征
def calculate_histogram(
self, abstract_features_1, abstract_features_2, batch_1, batch_2
):
abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)
B1, N1, _ = abstract_features_1.size()
B2, N2, _ = abstract_features_2.size()
mask_1 = mask_1.view(B1, N1)
mask_2 = mask_2.view(B2, N2)
num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))
scores = torch.matmul(
abstract_features_1, abstract_features_2.permute([0, 2, 1])
).detach()
hist_list = []
for i, mat in enumerate(scores):
mat = torch.sigmoid(mat[: num_nodes[i], : num_nodes[i]]).view(-1)
hist = torch.histc(mat, bins=self.args.bins)
hist = hist / torch.sum(hist)
hist = hist.view(1, -1)
hist_list.append(hist)
print(torch.stack(hist_list).view(-1, self.args.bins).shape)
return torch.stack(hist_list).view(-1, self.args.bins)
if self.args.histogram:
hist = self.calculate_histogram(
abstract_features_1, abstract_features_2, batch_1, batch_2
)
四、Attention Layer 得到圖的特征
def forward(self, x, batch, size=None):
size = batch[-1].item() + 1 if size is None else size
mean = scatter_mean(x, batch, dim=0, dim_size=size)
transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))
coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))
weighted = coefs.unsqueeze(-1) * x
return scatter_add(weighted, batch, dim=0, dim_size=size)
pooled_features_1 = self.attention(abstract_features_1, batch_1)
pooled_features_2 = self.attention(abstract_features_2, batch_2)
五、運(yùn)用NTN網(wǎng)絡(luò)計(jì)算圖和圖之間的關(guān)系得到特征
def forward(self, embedding_1, embedding_2):
batch_size = len(embedding_1)
scoring = torch.matmul(
embedding_1, self.weight_matrix.view(self.args.filters_3, -1)
)
scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1]) #filters_3可以理解成找多少種關(guān)系
scoring = torch.matmul(
scoring, embedding_2.view(batch_size, self.args.filters_3, 1)
).view(batch_size, -1)
combined_representation = torch.cat((embedding_1, embedding_2), 1)
block_scoring = torch.t(
torch.mm(self.weight_matrix_block, torch.t(combined_representation))
)
scores = F.relu(scoring + block_scoring + self.bias.view(-1))
return scores
六、預(yù)測得到模型的結(jié)果
def process_batch(self, data):
self.optimizer.zero_grad()
data = self.transform(data)
target = data["target"]
prediction = self.model(data)
loss = F.mse_loss(prediction, target, reduction="sum")
loss.backward()
self.optimizer.step()
return loss.item()
總結(jié)
本文通過點(diǎn)和點(diǎn)的比較,加上圖和圖的比較,結(jié)合在一起,最后計(jì)算出兩幅圖的相似度。其中運(yùn)用到GCN ,NTN,ATTENTION,直方圖等方法。較為有創(chuàng)意。