Browse Source

fix: formula error

RegMs If 3 years ago
parent
commit
ce4e2bf49c
3 changed files with 101 additions and 72 deletions
  1. 83 52
      gnn.go
  2. 11 1
      graph.go
  3. 7 19
      main.go

+ 83 - 52
gnn.go

@@ -11,23 +11,22 @@ const (
 	Input   int     = 1433
 	Input   int     = 1433
 	Hidden  int     = 64
 	Hidden  int     = 64
 	Output  int     = 7
 	Output  int     = 7
-	Sample  int     = 10000000
+	Sample  int     = 10000
 	Batch   int     = 5
 	Batch   int     = 5
 	Dropout float64 = 0.5
 	Dropout float64 = 0.5
-	Rate    float64 = 0.1
+	Rate    float64 = 0.01
 )
 )
 
 
 type (
 type (
 	Parameter struct {
 	Parameter struct {
-		W, B, A Matrix
+		W, B Matrix
 	}
 	}
 
 
 	Layer struct {
 	Layer struct {
-		d    int
-		f    func(Matrix) Matrix
-		p    Parameter
-		D    Vector
-		A, E Matrix
+		d int
+		f func(Matrix) Matrix
+		p Parameter
+		D Vector
 	}
 	}
 )
 )
 
 
@@ -55,47 +54,49 @@ func Softmax(A Matrix) Matrix {
 	return A
 	return A
 }
 }
 
 
-func GetEmbedding(G Graph, u, k int, l []Layer, train bool) Matrix {
-	if k == 0 {
-		l[k].E = MakeMatrix(1, l[k].d).Add(Matrix{G.X[u]})
-		if train && l[k].D != nil {
-			l[k].E.Dropout(l[k].D)
-		}
-		return l[k].E
+func GetAggregation(G Graph, u, k int, l []Layer) Matrix {
+	if len(G.A[u]) == 0 {
+		return MakeMatrix(1, l[k].d)
 	}
 	}
-	l[k-1].A, l[k].E = MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k].d)
 	// GCN
 	// GCN
-	// deg := 0
+	A := MakeMatrix(1, l[k].d)
+	for v := range G.A[u] {
+		A.Add(Matrix{G.E[k][v]})
+	}
+	return A.Divide(float64(len(G.A[u])))
+	// GAT
+	// A := MakeMatrix(0, l[k].d)
 	// for v := range G.A[u] {
 	// for v := range G.A[u] {
-	// 	l[k-1].A.Add(GetEmbedding(G, v, k-1, l, train))
-	// 	deg++
+	// 	A = append(A, G.E[k][v])
 	// }
 	// }
-	// l[k].E.Add(GetEmbedding(G, u, k-1, l, train).Multiply(l[k].p.B))
-	// if deg > 0 {
-	// 	l[k].E.Add(l[k-1].A.Divide(float64(deg)).Multiply(l[k].p.W))
+	// C := MakeMatrix(1, A.N())
+	// Me := G.E[k][u].Modulus()
+	// for i := 0; i < A.N(); i++ {
+	// 	Ma := A[i].Modulus()
+	// 	if Me > 0 && Ma > 0 {
+	// 		C[0][i] = G.E[k][u].Dot(A[i]) / Me / Ma
+	// 	}
 	// }
 	// }
-	// GAT
-	A := MakeMatrix(0, l[k-1].d)
-	for v := range G.A[u] {
-		A = append(A, GetEmbedding(G, v, k-1, l, train)[0])
-	}
-	l[k].E.Add(GetEmbedding(G, u, k-1, l, train).Multiply(l[k].p.B))
-	if A.N() > 0 {
-		C := MakeMatrix(1, A.N())
-		Me := l[k-1].E[0].Modulus()
-		for i := 0; i < A.N(); i++ {
-			Ma := A[i].Modulus()
-			if Me != 0 && Ma != 0 {
-				C[0][i] = l[k-1].E[0].Dot(A[i]) / Me / Ma
-			}
+	// return Softmax(C).Multiply(A)
+}
+
+func GetEmbedding(G Graph, u, k int, l []Layer, train bool) Matrix {
+	E := MakeMatrix(1, l[k].d)
+	if k == 0 {
+		E.Add(Matrix{G.X[u]})
+	} else {
+		for v := range G.A[u] {
+			GetEmbedding(G, v, k-1, l, train)
 		}
 		}
-		l[k-1].A.Add(Softmax(C).Multiply(A))
-		l[k].E.Add(l[k-1].A.Multiply(l[k].p.W))
+		E.Add(GetAggregation(G, u, k-1, l).Multiply(l[k].p.W))
+		E.Add(GetEmbedding(G, u, k-1, l, train).Multiply(l[k].p.B))
+		l[k].f(E)
 	}
 	}
 	if train && l[k].D != nil {
 	if train && l[k].D != nil {
-		l[k].E.Dropout(l[k].D)
+		E.Dropout(l[k].D)
 	}
 	}
-	return l[k].f(l[k].E)
+	G.E[k][u] = E[0]
+	return E
 }
 }
 
 
 // A += B * C
 // A += B * C
@@ -117,12 +118,12 @@ func StartRefine(wg *sync.WaitGroup, A, B Matrix, c float64) {
 }
 }
 
 
 func Train(G Graph) []Layer {
 func Train(G Graph) []Layer {
-	p0 := Parameter{A: MakeRandomMatrix(Input*2, 1)}
-	p1 := Parameter{W: MakeRandomMatrix(Input, Hidden), B: MakeRandomMatrix(Input, Hidden), A: MakeRandomMatrix(Hidden*2, 1)}
-	p2 := Parameter{W: MakeRandomMatrix(Hidden, Output), B: MakeRandomMatrix(Hidden, Output)}
-	l := []Layer{{d: Input, p: p0}, {d: Hidden, f: ReLU, p: p1}, {d: Output, f: Softmax, p: p2}}
+	p1 := Parameter{MakeRandomMatrix(Input, Hidden), MakeRandomMatrix(Input, Hidden)}
+	p2 := Parameter{MakeRandomMatrix(Hidden, Output), MakeRandomMatrix(Hidden, Output)}
+	l := []Layer{{d: Input}, {d: Hidden, f: ReLU, p: p1}, {d: Output, f: Softmax, p: p2}}
 	for i := 0; i < Sample; i++ {
 	for i := 0; i < Sample; i++ {
 		if i%1000 == 0 {
 		if i%1000 == 0 {
+			Test(G, l, false)
 			fmt.Println("sampling", i)
 			fmt.Println("sampling", i)
 		}
 		}
 		var wg sync.WaitGroup
 		var wg sync.WaitGroup
@@ -134,17 +135,28 @@ func Train(G Graph) []Layer {
 			GetEmbedding(G, u, 2, l, true)
 			GetEmbedding(G, u, 2, l, true)
 			delta := MakeMatrix(1, Output)
 			delta := MakeMatrix(1, Output)
 			delta[0][G.L[u]] = 1
 			delta[0][G.L[u]] = 1
-			delta.Sub(l[2].E).Divide(float64(Batch))
-			StartCalc(&wg, DW2, l[1].A, delta)
-			StartCalc(&wg, DB2, l[1].E, delta)
-			delta = delta.Multiply(l[2].p.B.Transpose())
+			delta.Sub(Matrix{G.E[2][u]}).Divide(float64(Batch))
+			StartCalc(&wg, DW2, GetAggregation(G, u, 1, l), delta)
+			StartCalc(&wg, DB2, Matrix{G.E[1][u]}, delta)
+			deltaB := delta.Multiply(l[2].p.B.Transpose())
 			for k := 0; k < Hidden; k++ {
 			for k := 0; k < Hidden; k++ {
-				if l[1].E[0][k] == 0 {
-					delta[0][k] = 0
+				if G.E[1][u][k] == 0 {
+					deltaB[0][k] = 0
 				}
 				}
 			}
 			}
-			StartCalc(&wg, DW1, l[0].A, delta)
-			StartCalc(&wg, DB1, l[0].E, delta)
+			StartCalc(&wg, DW1, GetAggregation(G, u, 0, l), deltaB)
+			StartCalc(&wg, DB1, Matrix{G.E[0][u]}, deltaB)
+			deltaW := delta.Multiply(l[2].p.W.Transpose())
+			for v := range G.A[u] {
+				delta = MakeMatrix(1, Hidden).Add(deltaW).Divide(float64(len(G.A[u])))
+				for k := 0; k < Hidden; k++ {
+					if G.E[1][v][k] == 0 {
+						delta[0][k] = 0
+					}
+				}
+				StartCalc(&wg, DW1, GetAggregation(G, v, 0, l), delta)
+				StartCalc(&wg, DB1, Matrix{G.E[0][v]}, delta)
+			}
 			wg.Wait()
 			wg.Wait()
 		}
 		}
 		StartRefine(&wg, l[2].p.W, DW2, 1/Rate)
 		StartRefine(&wg, l[2].p.W, DW2, 1/Rate)
@@ -155,3 +167,22 @@ func Train(G Graph) []Layer {
 	}
 	}
 	return l
 	return l
 }
 }
+
+func Test(G Graph, l []Layer, detail bool) {
+	cnt1, cnt2 := 0, 0
+	for u := range G.X {
+		GetEmbedding(G, u, 2, l, false)
+		id, _ := G.E[2][u].Max()
+		if detail {
+			fmt.Println(u, id)
+		}
+		if G.L[u] == id {
+			cnt1++
+		}
+		if G.L[u] == id+Output {
+			cnt2++
+		}
+	}
+	fmt.Println(cnt1, "/", len(nodeId), ",", cnt2, "/", Node-len(nodeId))
+	fmt.Println(100.*cnt1/len(nodeId), ",", 100.*cnt2/(Node-len(nodeId)))
+}

+ 11 - 1
graph.go

@@ -5,6 +5,7 @@ type (
 
 
 	Graph struct {
 	Graph struct {
 		X map[int]Vector
 		X map[int]Vector
+		E []map[int]Vector
 		L map[int]int
 		L map[int]int
 		A AdjacencyMatrix
 		A AdjacencyMatrix
 	}
 	}
@@ -18,7 +19,16 @@ func (A AdjacencyMatrix) Modify(u, v, w int) {
 }
 }
 
 
 func MakeGraph() Graph {
 func MakeGraph() Graph {
-	return Graph{make(map[int]Vector), make(map[int]int), make(AdjacencyMatrix)}
+	G := Graph{
+		make(map[int]Vector),
+		make([]map[int]Vector, 3),
+		make(map[int]int),
+		make(AdjacencyMatrix),
+	}
+	for i := 0; i < 3; i++ {
+		G.E[i] = make(map[int]Vector)
+	}
+	return G
 }
 }
 
 
 func (G Graph) AddNode(u int, V Vector, l int) {
 func (G Graph) AddNode(u int, V Vector, l int) {

+ 7 - 19
main.go

@@ -2,14 +2,12 @@ package main
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"math/rand"
 	"os"
 	"os"
 )
 )
 
 
 const (
 const (
-	Node    int     = 2708
-	Edge    int     = 5429
-	Labeled float64 = 0.05
+	Node int = 2708
+	Edge int = 5429
 )
 )
 
 
 var (
 var (
@@ -22,6 +20,7 @@ var (
 		"Rule_Learning",
 		"Rule_Learning",
 		"Theory",
 		"Theory",
 	}
 	}
+	labelCnt = make([]int, len(labelName))
 
 
 	nodeId = make([]int, 0)
 	nodeId = make([]int, 0)
 )
 )
@@ -46,17 +45,19 @@ func main() {
 		fmt.Fscan(file, &label)
 		fmt.Fscan(file, &label)
 		for j := 0; j < Output; j++ {
 		for j := 0; j < Output; j++ {
 			if labelName[j] == label {
 			if labelName[j] == label {
-				if rand.Float64() < Labeled {
+				if labelCnt[j] < 20 {
 					G.AddNode(u, V, j)
 					G.AddNode(u, V, j)
 					nodeId = append(nodeId, u)
 					nodeId = append(nodeId, u)
 				} else {
 				} else {
 					G.AddNode(u, V, j+Output)
 					G.AddNode(u, V, j+Output)
 				}
 				}
+				labelCnt[j]++
 				break
 				break
 			}
 			}
 		}
 		}
 	}
 	}
 	file.Close()
 	file.Close()
+	fmt.Println(labelCnt)
 	file, err = os.Open("cora/cora.cites")
 	file, err = os.Open("cora/cora.cites")
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
@@ -68,18 +69,5 @@ func main() {
 	}
 	}
 	file.Close()
 	file.Close()
 	l := Train(G)
 	l := Train(G)
-	cnt1, cnt2 := 0, 0
-	for u := range G.X {
-		GetEmbedding(G, u, 2, l, false)
-		id, _ := l[2].E[0].Max()
-		fmt.Println(u, id)
-		if G.L[u] == id {
-			cnt1++
-		}
-		if G.L[u] == id+Output {
-			cnt2++
-		}
-	}
-	fmt.Println(cnt1, "/", len(nodeId), ",", cnt2, "/", Node-len(nodeId))
-	fmt.Println(100.*cnt1/len(nodeId), ",", 100.*cnt2/(Node-len(nodeId)))
+	Test(G, l, true)
 }
 }