RegMs If 3 лет назад
Родитель
Сommit
0c20b14336
3 измененных файлов с 60 добавлено и 35 удалено
  1. 47 22
      gnn.go
  2. 4 2
      graph.go
  3. 9 11
      main.go

+ 47 - 22
gnn.go

@@ -1,19 +1,21 @@
 package main
 
 import (
+	"fmt"
 	"math"
 	"math/rand"
+	"sync"
 )
 
 const (
 	Input  int     = 1433
 	Hidden int     = 50
 	Output int     = 7
-	Sample int     = 5000
-	Batch  int     = 20
-	RateWo float64 = 0.02
-	RateWi float64 = 0.04
-	RateB  float64 = 0.05
+	Sample int     = 1000
+	Batch  int     = 10
+	RateWo float64 = 0.06
+	RateWi float64 = 0.08
+	RateB  float64 = 0.1
 )
 
 type (
@@ -59,8 +61,7 @@ func GetEmbedding(G Graph, u, k int, l []Layer) Vector {
 	if k == 0 {
 		return G.X[u]
 	}
-	l[k].E = Multiply(Matrix{GetEmbedding(G, u, k-1, l)}, l[k].p.B)
-	l[k-1].O, l[k-1].I = MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k-1].d)
+	l[k-1].O, l[k-1].I, l[k].E = MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k].d)
 	Do, Di := 0, 0
 	for v, w := range G.A[u] {
 		if w == 1 {
@@ -77,46 +78,70 @@ func GetEmbedding(G Graph, u, k int, l []Layer) Vector {
 	if Di > 0 {
 		l[k].E.Add(Multiply(l[k-1].I.Divide(float64(Di)), l[k].p.Wi))
 	}
+	l[k].E.Add(Multiply(Matrix{GetEmbedding(G, u, k-1, l)}, l[k].p.B))
 	return l[k].f(l[k].E)[0]
 }
 
+// A += B * C / d
+func StartCalc(wg *sync.WaitGroup, A, B, C Matrix, d float64) {
+	wg.Add(1)
+	go func() {
+		A.Add(Multiply(B.Transpose(), C).Divide(d))
+		wg.Done()
+	}()
+}
+
+// A += B
+func StartRefine(wg *sync.WaitGroup, A, B Matrix) {
+	wg.Add(1)
+	go func() {
+		A.Add(B)
+		wg.Done()
+	}()
+}
+
 func Train(G Graph) []Layer {
 	p1 := Parameter{MakeRandomMatrix(Input, Hidden), MakeRandomMatrix(Input, Hidden), MakeRandomMatrix(Input, Hidden)}
 	p2 := Parameter{MakeRandomMatrix(Hidden, Output), MakeRandomMatrix(Hidden, Output), MakeRandomMatrix(Hidden, Output)}
 	l := []Layer{{d: Input}, {d: Hidden, f: ReLU, p: p1}, {d: Output, f: Softmax, p: p2}}
 	for i := 0; i < Sample; i++ {
+		if i%10 == 0 {
+			fmt.Println("sampling", i)
+		}
+		var wg sync.WaitGroup
 		DWo2, DWi2, DB2 := MakeMatrix(Hidden, Output), MakeMatrix(Hidden, Output), MakeMatrix(Hidden, Output)
 		DWo1, DWi1, DB1 := MakeMatrix(Input, Hidden), MakeMatrix(Input, Hidden), MakeMatrix(Input, Hidden)
-		u := nodeId[rand.Intn(len(nodeId))]
-		for j := 0; j < Batch; j++ {
+		for u, j := nodeId[rand.Intn(len(nodeId))], 0; j < Batch; j++ {
 			GetEmbedding(G, u, 2, l)
 			delta := MakeMatrix(1, Output)
-			delta[0][nodeLabel[u]] = 1
+			delta[0][G.L[u]] = 1
 			delta.Sub(l[2].E)
-			DWo2.Add(Multiply(l[1].O.Transpose(), delta).Divide(float64(Batch) / RateWo))
-			DWi2.Add(Multiply(l[1].I.Transpose(), delta).Divide(float64(Batch) / RateWi))
-			DB2.Add(Multiply(l[1].E.Transpose(), delta).Divide(float64(Batch) / RateB))
+			StartCalc(&wg, DWo2, l[1].O, delta, float64(Batch)/RateWo)
+			StartCalc(&wg, DWi2, l[1].I, delta, float64(Batch)/RateWi)
+			StartCalc(&wg, DB2, l[1].E, delta, float64(Batch)/RateB)
 			delta = Multiply(delta, l[2].p.B.Transpose())
 			for k := 0; k < Hidden; k++ {
 				if l[1].E[0][k] == 0 {
 					delta[0][k] = 0
 				}
 			}
-			DWo1.Add(Multiply(l[0].O.Transpose(), delta).Divide(float64(Batch) / RateWo))
-			DWi1.Add(Multiply(l[0].I.Transpose(), delta).Divide(float64(Batch) / RateWi))
-			DB1.Add(Multiply(Matrix{G.X[u]}.Transpose(), delta).Divide(float64(Batch) / RateB))
+			StartCalc(&wg, DWo1, l[0].O, delta, float64(Batch)/RateWo)
+			StartCalc(&wg, DWi1, l[0].I, delta, float64(Batch)/RateWi)
+			StartCalc(&wg, DB1, Matrix{G.X[u]}, delta, float64(Batch)/RateB)
 			neighbor := make([]int, 0)
 			for v := range G.A[u] {
 				neighbor = append(neighbor, v)
 			}
 			u = neighbor[rand.Intn(len(neighbor))]
+			wg.Wait()
 		}
-		l[2].p.Wo.Add(DWo2)
-		l[2].p.Wi.Add(DWi2)
-		l[2].p.B.Add(DB2)
-		l[1].p.Wo.Add(DWo1)
-		l[1].p.Wi.Add(DWi1)
-		l[1].p.B.Add(DB1)
+		StartRefine(&wg, l[2].p.Wo, DWo2)
+		StartRefine(&wg, l[2].p.Wi, DWi2)
+		StartRefine(&wg, l[2].p.B, DB2)
+		StartRefine(&wg, l[1].p.Wo, DWo1)
+		StartRefine(&wg, l[1].p.Wi, DWi1)
+		StartRefine(&wg, l[1].p.B, DB1)
+		wg.Wait()
 	}
 	return l
 }

+ 4 - 2
graph.go

@@ -5,6 +5,7 @@ type (
 
 	Graph struct {
 		X map[int]Vector
+		L map[int]int
 		A AdjacencyMatrix
 	}
 )
@@ -17,11 +18,12 @@ func (A AdjacencyMatrix) Modify(u, v, w int) {
 }
 
 func MakeGraph() Graph {
-	return Graph{make(map[int]Vector), make(AdjacencyMatrix)}
+	return Graph{make(map[int]Vector), make(map[int]int), make(AdjacencyMatrix)}
 }
 
-func (G Graph) AddNode(u int, V Vector) {
+func (G Graph) AddNode(u int, V Vector, l int) {
 	G.X[u] = V
+	G.L[u] = l
 }
 
 func (G Graph) AddEdge(u, v int) {

+ 9 - 11
main.go

@@ -21,8 +21,7 @@ var (
 		"Theory",
 	}
 
-	nodeId    = make([]int, 0)
-	nodeLabel = make(map[int]int)
+	nodeId = make([]int, 0)
 )
 
 func main() {
@@ -32,6 +31,9 @@ func main() {
 	}
 	G := MakeGraph()
 	for i := 0; i < Node; i++ {
+		if i%100 == 0 {
+			fmt.Println("reading node", i)
+		}
 		var u int
 		fmt.Fscan(file, &u)
 		nodeId = append(nodeId, u)
@@ -39,18 +41,14 @@ func main() {
 		for j := 0; j < Input; j++ {
 			fmt.Fscan(file, &V[j])
 		}
-		G.AddNode(u, V)
 		var label string
 		fmt.Fscan(file, &label)
 		for j := 0; j < Output; j++ {
 			if labelName[j] == label {
-				nodeLabel[u] = j
+				G.AddNode(u, V, j)
 				break
 			}
 		}
-		if i%100 == 0 {
-			fmt.Println("reading node", i)
-		}
 	}
 	file.Close()
 	file, err = os.Open("cora/cora.cites")
@@ -64,7 +62,7 @@ func main() {
 	}
 	file.Close()
 	l := Train(G)
-	cnt, tot := 0, 0
+	cnt := 0
 	for u := range G.X {
 		GetEmbedding(G, u, 2, l)
 		id, max := 0, 0.
@@ -73,10 +71,10 @@ func main() {
 				id, max = i, l[2].E[0][i]
 			}
 		}
-		if nodeLabel[u] == id {
+		fmt.Println(u, id)
+		if G.L[u] == id {
 			cnt++
 		}
-		tot++
 	}
-	fmt.Println(cnt, "/", tot)
+	fmt.Println(cnt, "/", Node)
 }