forked from aianaconda/pytorch-GNN-2nd-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcode_27_builder.py
118 lines (108 loc) · 5.12 KB
/
code_27_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 14 10:22:46 2020
@author: 代码医生工作室
@公众号:xiangyuejiqiren (内有更多优秀文章及学习资料)
@来源: <PyTorch深度学习和图神经网络(卷2)——开发应用>配套代码
@配套代码技术支持:bbs.aianaconda.com
"""
"""Graph builder from pandas dataframes"""
from collections import namedtuple
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_categorical
import dgl
__all__ = ['PandasGraphBuilder']
def _series_to_tensor(series):
if is_categorical(series):
return torch.LongTensor(series.cat.codes.values.astype('int64'))
else: # numeric
return torch.FloatTensor(series.values)
class PandasGraphBuilder(object):
"""Creates a heterogeneous graph from multiple pandas dataframes.
Examples
--------
Let's say we have the following three pandas dataframes:
User table ``users``:
=========== =========== =======
``user_id`` ``country`` ``age``
=========== =========== =======
XYZZY U.S. 25
FOO China 24
BAR China 23
=========== =========== =======
Game table ``games``:
=========== ========= ============== ==================
``game_id`` ``title`` ``is_sandbox`` ``is_multiplayer``
=========== ========= ============== ==================
1 Minecraft True True
2 Tetris 99 False True
=========== ========= ============== ==================
Play relationship table ``plays``:
=========== =========== =========
``user_id`` ``game_id`` ``hours``
=========== =========== =========
XYZZY 1 24
FOO 1 20
FOO 2 16
BAR 2 28
=========== =========== =========
One could then create a bidirectional bipartite graph as follows:
>>> builder = PandasGraphBuilder()
>>> builder.add_entities(users, 'user_id', 'user')
>>> builder.add_entities(games, 'game_id', 'game')
>>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')
>>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')
>>> g = builder.build()
>>> g.number_of_nodes('user')
3
>>> g.number_of_edges('plays')
4
"""
def __init__(self):
self.entity_tables = {}
self.relation_tables = {}
self.entity_pk_to_name = {} # mapping from primary key name to entity name
self.entity_pk = {} # mapping from entity name to primary key
self.entity_key_map = {} # mapping from entity names to primary key values
self.num_nodes_per_type = {}
self.edges_per_relation = {}
self.relation_name_to_etype = {}
self.relation_src_key = {} # mapping from relation name to source key
self.relation_dst_key = {} # mapping from relation name to destination key
def add_entities(self, entity_table, primary_key, name):
entities = entity_table[primary_key].astype('category')
if not (entities.value_counts() == 1).all():
raise ValueError('Different entity with the same primary key detected.')
# preserve the category order in the original entity table
entities = entities.cat.reorder_categories(entity_table[primary_key].values)
self.entity_pk_to_name[primary_key] = name
self.entity_pk[name] = primary_key
self.num_nodes_per_type[name] = entity_table.shape[0]
self.entity_key_map[name] = entities
# self.entity_tables[name] = entity_table
def add_binary_relations(self, relation_table, source_key, destination_key, name):
src = relation_table[source_key].astype('category')
src = src.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[source_key]].cat.categories)
dst = relation_table[destination_key].astype('category')
dst = dst.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[destination_key]].cat.categories)
if src.isnull().any():
raise ValueError(
'Some source entities in relation %s do not exist in entity %s.' %
(name, source_key))
if dst.isnull().any():
raise ValueError(
'Some destination entities in relation %s do not exist in entity %s.' %
(name, destination_key))
srctype = self.entity_pk_to_name[source_key]
dsttype = self.entity_pk_to_name[destination_key]
etype = (srctype, name, dsttype)
self.relation_name_to_etype[name] = etype
self.edges_per_relation[etype] = (src.cat.codes.values, dst.cat.codes.values)
self.relation_tables[name] = relation_table
self.relation_src_key[name] = source_key
self.relation_dst_key[name] = destination_key
def build(self):
# Create heterograph
graph = dgl.heterograph(self.edges_per_relation, self.num_nodes_per_type)
return graph