gnn.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package main
  2. import (
  3. "fmt"
  4. "math"
  5. "math/rand"
  6. "sync"
  7. )
  8. const (
  9. Input int = 1433
  10. Hidden int = 64
  11. Output int = 7
  12. Sample int = 10000
  13. Batch int = 16
  14. Dropout float64 = 0.5
  15. Rate float64 = 0.01
  16. )
  17. type (
  18. Parameter struct {
  19. W, B Matrix
  20. }
  21. Layer struct {
  22. d int
  23. f func(Matrix) Matrix
  24. p Parameter
  25. D Vector
  26. }
  27. )
  28. func ReLU(A Matrix) Matrix {
  29. for i := 0; i < A.N(); i++ {
  30. for j := 0; j < A.M(); j++ {
  31. A[i][j] = math.Max(0, A[i][j])
  32. }
  33. }
  34. return A
  35. }
  36. func Softmax(A Matrix) Matrix {
  37. for i := 0; i < A.N(); i++ {
  38. _, max := A[i].Max()
  39. sum := 0.
  40. for j := 0; j < A.M(); j++ {
  41. A[i][j] = math.Exp(A[i][j] - max)
  42. sum += A[i][j]
  43. }
  44. for j := 0; j < A.M(); j++ {
  45. A[i][j] /= sum
  46. }
  47. }
  48. return A
  49. }
  50. func GetAggregation(G Graph, u, k int, l []Layer) Matrix {
  51. if len(G.A[u]) == 0 {
  52. return MakeMatrix(1, l[k].d)
  53. }
  54. // GCN
  55. A := MakeMatrix(1, l[k].d)
  56. for v := range G.A[u] {
  57. A.Add(Matrix{G.E[k][v]})
  58. }
  59. return A.Divide(float64(len(G.A[u])))
  60. // GAT
  61. // A := MakeMatrix(0, l[k].d)
  62. // for v := range G.A[u] {
  63. // A = append(A, G.E[k][v])
  64. // }
  65. // C := MakeMatrix(1, A.N())
  66. // Me := G.E[k][u].Modulus()
  67. // for i := 0; i < A.N(); i++ {
  68. // Ma := A[i].Modulus()
  69. // if Me > 0 && Ma > 0 {
  70. // C[0][i] = G.E[k][u].Dot(A[i]) / Me / Ma
  71. // }
  72. // }
  73. // return Softmax(C).Multiply(A)
  74. }
  75. func GetEmbedding(G Graph, u, k int, l []Layer, train bool) Matrix {
  76. E := MakeMatrix(1, l[k].d)
  77. if k == 0 {
  78. E.Add(Matrix{G.X[u]})
  79. } else {
  80. for v := range G.A[u] {
  81. GetEmbedding(G, v, k-1, l, train)
  82. }
  83. E.Add(GetAggregation(G, u, k-1, l).Multiply(l[k].p.W))
  84. E.Add(GetEmbedding(G, u, k-1, l, train).Multiply(l[k].p.B))
  85. l[k].f(E)
  86. }
  87. if train && l[k].D != nil {
  88. E.Dropout(l[k].D)
  89. }
  90. G.E[k][u] = E[0]
  91. return E
  92. }
  93. // A += B * C
  94. func StartCalc(wg *sync.WaitGroup, A, B, C Matrix) {
  95. wg.Add(1)
  96. go func() {
  97. A.Add(B.Transpose().Multiply(C))
  98. wg.Done()
  99. }()
  100. }
  101. // A += B / c
  102. func StartRefine(wg *sync.WaitGroup, A, B Matrix, c float64) {
  103. wg.Add(1)
  104. go func() {
  105. A.Add(B.Divide(c))
  106. wg.Done()
  107. }()
  108. }
  109. func Train(G Graph) []Layer {
  110. p1 := Parameter{MakeRandomMatrix(Input, Hidden), MakeRandomMatrix(Input, Hidden)}
  111. p2 := Parameter{MakeRandomMatrix(Hidden, Output), MakeRandomMatrix(Hidden, Output)}
  112. l := []Layer{{d: Input}, {d: Hidden, f: ReLU, p: p1}, {d: Output, f: Softmax, p: p2}}
  113. for i := 0; i < Sample; i++ {
  114. if i%100 == 0 {
  115. Test(G, l, false)
  116. // fmt.Println("sampling", i)
  117. }
  118. var wg sync.WaitGroup
  119. l[0].D, l[1].D = MakeDropoutVector(Input), MakeDropoutVector(Hidden)
  120. DW2, DB2 := MakeMatrix(Hidden, Output), MakeMatrix(Hidden, Output)
  121. DW1, DB1 := MakeMatrix(Input, Hidden), MakeMatrix(Input, Hidden)
  122. for j := 0; j < Batch; j++ {
  123. u := nodeId[rand.Intn(len(nodeId))]
  124. GetEmbedding(G, u, 2, l, true)
  125. delta := MakeMatrix(1, Output)
  126. delta[0][G.L[u]] = 1
  127. delta.Sub(Matrix{G.E[2][u]}).Divide(float64(Batch))
  128. StartCalc(&wg, DW2, GetAggregation(G, u, 1, l), delta)
  129. StartCalc(&wg, DB2, Matrix{G.E[1][u]}, delta)
  130. deltaB := delta.Multiply(l[2].p.B.Transpose())
  131. for k := 0; k < Hidden; k++ {
  132. if G.E[1][u][k] == 0 {
  133. deltaB[0][k] = 0
  134. }
  135. }
  136. StartCalc(&wg, DW1, GetAggregation(G, u, 0, l), deltaB)
  137. StartCalc(&wg, DB1, Matrix{G.E[0][u]}, deltaB)
  138. deltaW := delta.Multiply(l[2].p.W.Transpose())
  139. for v := range G.A[u] {
  140. delta = MakeMatrix(1, Hidden).Add(deltaW)
  141. for k := 0; k < Hidden; k++ {
  142. if G.E[1][v][k] == 0 {
  143. delta[0][k] = 0
  144. }
  145. }
  146. StartCalc(&wg, DW1, GetAggregation(G, v, 0, l), delta)
  147. StartCalc(&wg, DB1, Matrix{G.E[0][v]}, delta)
  148. }
  149. wg.Wait()
  150. }
  151. Rate := 0.2 * math.Exp(-float64(i)/1000)
  152. StartRefine(&wg, l[2].p.W, DW2, 1/Rate)
  153. StartRefine(&wg, l[2].p.B, DB2, 1/Rate)
  154. StartRefine(&wg, l[1].p.W, DW1, 1/Rate)
  155. StartRefine(&wg, l[1].p.B, DB1, 1/Rate)
  156. wg.Wait()
  157. }
  158. return l
  159. }
  160. func Test(G Graph, l []Layer, detail bool) {
  161. cnt1, cnt2, loss := 0, 0, 0.
  162. for u := range G.X {
  163. GetEmbedding(G, u, 2, l, false)
  164. id, _ := G.E[2][u].Max()
  165. if detail {
  166. fmt.Println(u, id)
  167. }
  168. if G.L[u] == id {
  169. cnt1++
  170. }
  171. if G.L[u] == id+Output {
  172. cnt2++
  173. }
  174. loss -= math.Log(G.E[2][u][G.L[u]%Output])
  175. }
  176. if detail {
  177. fmt.Println(cnt1, "/", len(nodeId), ",", cnt2, "/", Node-len(nodeId))
  178. fmt.Println(
  179. 100*float64(cnt1)/float64(len(nodeId)), ",",
  180. 100*float64(cnt2)/float64((Node-len(nodeId))), ",",
  181. loss/float64(len(G.X)),
  182. )
  183. } else {
  184. fmt.Println(100*float64(cnt2)/float64((Node-len(nodeId))), loss/float64(len(G.X)))
  185. }
  186. }