Browse Source

fix: refine hyperparameters

RegMs If 3 years ago
parent
commit
3874efd72f
1 changed files with 14 additions and 11 deletions
  1. 14 11
      gnn.go

+ 14 - 11
gnn.go

@@ -5,7 +5,10 @@ import (
 )
 
 const (
-	HIDDEN int = 200
+	Hidden int     = 50
+	RateWo float64 = 0.02
+	RateWi float64 = 0.04
+	RateB  float64 = 0.05
 )
 
 type (
@@ -75,28 +78,28 @@ func GetEmbedding(G Graph, u, k int, l []Layer) Vector {
 }
 
 func Train(G Graph) []Layer {
-	p1 := Parameter{MakeRandomMatrix(1433, HIDDEN), MakeRandomMatrix(1433, HIDDEN), MakeRandomMatrix(1433, HIDDEN)}
-	p2 := Parameter{MakeRandomMatrix(HIDDEN, 7), MakeRandomMatrix(HIDDEN, 7), MakeRandomMatrix(HIDDEN, 7)}
-	l := []Layer{{d: 1433}, {d: HIDDEN, f: ReLU, p: p1}, {d: 7, f: Softmax, p: p2}}
+	p1 := Parameter{MakeRandomMatrix(1433, Hidden), MakeRandomMatrix(1433, Hidden), MakeRandomMatrix(1433, Hidden)}
+	p2 := Parameter{MakeRandomMatrix(Hidden, 7), MakeRandomMatrix(Hidden, 7), MakeRandomMatrix(Hidden, 7)}
+	l := []Layer{{d: 1433}, {d: Hidden, f: ReLU, p: p1}, {d: 7, f: Softmax, p: p2}}
 	for u, X := range G.X {
 		GetEmbedding(G, u, 2, l)
 		delta := MakeMatrix(1, 7)
 		delta[0][nodeLabel[u]] = 1
 		delta.Sub(l[2].E)
 		DWo2, DWi2, DB2 := Multiply(l[1].O.Transpose(), delta), Multiply(l[1].I.Transpose(), delta), Multiply(l[1].E.Transpose(), delta)
-		DWo2.Divide(10)
-		DWi2.Divide(10)
-		DB2.Divide(10)
+		DWo2.Divide(1 / RateWo)
+		DWi2.Divide(1 / RateWi)
+		DB2.Divide(1 / RateB)
 		delta = Multiply(delta, l[2].p.B.Transpose())
-		for i := 0; i < HIDDEN; i++ {
+		for i := 0; i < Hidden; i++ {
 			if l[1].E[0][i] == 0 {
 				delta[0][i] = 0
 			}
 		}
 		DWo1, DWi1, DB1 := Multiply(l[0].O.Transpose(), delta), Multiply(l[0].I.Transpose(), delta), Multiply(Matrix{X}.Transpose(), delta)
-		DWo1.Divide(10)
-		DWi1.Divide(10)
-		DB1.Divide(10)
+		DWo1.Divide(1 / RateWo)
+		DWi1.Divide(1 / RateWi)
+		DB1.Divide(1 / RateB)
 		l[2].p.Wo.Add(DWo2)
 		l[2].p.Wi.Add(DWi2)
 		l[2].p.B.Add(DB2)