|
|
@@ -1,19 +1,21 @@
|
|
|
package main
|
|
|
|
|
|
import (
|
|
|
+ "fmt"
|
|
|
"math"
|
|
|
"math/rand"
|
|
|
+ "sync"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
Input int = 1433
|
|
|
Hidden int = 50
|
|
|
Output int = 7
|
|
|
- Sample int = 5000
|
|
|
- Batch int = 20
|
|
|
- RateWo float64 = 0.02
|
|
|
- RateWi float64 = 0.04
|
|
|
- RateB float64 = 0.05
|
|
|
+ Sample int = 1000
|
|
|
+ Batch int = 10
|
|
|
+ RateWo float64 = 0.06
|
|
|
+ RateWi float64 = 0.08
|
|
|
+ RateB float64 = 0.1
|
|
|
)
|
|
|
|
|
|
type (
|
|
|
@@ -59,8 +61,7 @@ func GetEmbedding(G Graph, u, k int, l []Layer) Vector {
|
|
|
if k == 0 {
|
|
|
return G.X[u]
|
|
|
}
|
|
|
- l[k].E = Multiply(Matrix{GetEmbedding(G, u, k-1, l)}, l[k].p.B)
|
|
|
- l[k-1].O, l[k-1].I = MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k-1].d)
|
|
|
+ l[k-1].O, l[k-1].I, l[k].E = MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k-1].d), MakeMatrix(1, l[k].d)
|
|
|
Do, Di := 0, 0
|
|
|
for v, w := range G.A[u] {
|
|
|
if w == 1 {
|
|
|
@@ -77,46 +78,70 @@ func GetEmbedding(G Graph, u, k int, l []Layer) Vector {
|
|
|
if Di > 0 {
|
|
|
l[k].E.Add(Multiply(l[k-1].I.Divide(float64(Di)), l[k].p.Wi))
|
|
|
}
|
|
|
+ l[k].E.Add(Multiply(Matrix{GetEmbedding(G, u, k-1, l)}, l[k].p.B))
|
|
|
return l[k].f(l[k].E)[0]
|
|
|
}
|
|
|
|
|
|
+// A += B * C / d
|
|
|
+func StartCalc(wg *sync.WaitGroup, A, B, C Matrix, d float64) {
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ A.Add(Multiply(B.Transpose(), C).Divide(d))
|
|
|
+ wg.Done()
|
|
|
+ }()
|
|
|
+}
|
|
|
+
|
|
|
+// A += B
|
|
|
+func StartRefine(wg *sync.WaitGroup, A, B Matrix) {
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ A.Add(B)
|
|
|
+ wg.Done()
|
|
|
+ }()
|
|
|
+}
|
|
|
+
|
|
|
func Train(G Graph) []Layer {
|
|
|
p1 := Parameter{MakeRandomMatrix(Input, Hidden), MakeRandomMatrix(Input, Hidden), MakeRandomMatrix(Input, Hidden)}
|
|
|
p2 := Parameter{MakeRandomMatrix(Hidden, Output), MakeRandomMatrix(Hidden, Output), MakeRandomMatrix(Hidden, Output)}
|
|
|
l := []Layer{{d: Input}, {d: Hidden, f: ReLU, p: p1}, {d: Output, f: Softmax, p: p2}}
|
|
|
for i := 0; i < Sample; i++ {
|
|
|
+ if i%10 == 0 {
|
|
|
+ fmt.Println("sampling", i)
|
|
|
+ }
|
|
|
+ var wg sync.WaitGroup
|
|
|
DWo2, DWi2, DB2 := MakeMatrix(Hidden, Output), MakeMatrix(Hidden, Output), MakeMatrix(Hidden, Output)
|
|
|
DWo1, DWi1, DB1 := MakeMatrix(Input, Hidden), MakeMatrix(Input, Hidden), MakeMatrix(Input, Hidden)
|
|
|
- u := nodeId[rand.Intn(len(nodeId))]
|
|
|
- for j := 0; j < Batch; j++ {
|
|
|
+ for u, j := nodeId[rand.Intn(len(nodeId))], 0; j < Batch; j++ {
|
|
|
GetEmbedding(G, u, 2, l)
|
|
|
delta := MakeMatrix(1, Output)
|
|
|
- delta[0][nodeLabel[u]] = 1
|
|
|
+ delta[0][G.L[u]] = 1
|
|
|
delta.Sub(l[2].E)
|
|
|
- DWo2.Add(Multiply(l[1].O.Transpose(), delta).Divide(float64(Batch) / RateWo))
|
|
|
- DWi2.Add(Multiply(l[1].I.Transpose(), delta).Divide(float64(Batch) / RateWi))
|
|
|
- DB2.Add(Multiply(l[1].E.Transpose(), delta).Divide(float64(Batch) / RateB))
|
|
|
+ StartCalc(&wg, DWo2, l[1].O, delta, float64(Batch)/RateWo)
|
|
|
+ StartCalc(&wg, DWi2, l[1].I, delta, float64(Batch)/RateWi)
|
|
|
+ StartCalc(&wg, DB2, l[1].E, delta, float64(Batch)/RateB)
|
|
|
delta = Multiply(delta, l[2].p.B.Transpose())
|
|
|
for k := 0; k < Hidden; k++ {
|
|
|
if l[1].E[0][k] == 0 {
|
|
|
delta[0][k] = 0
|
|
|
}
|
|
|
}
|
|
|
- DWo1.Add(Multiply(l[0].O.Transpose(), delta).Divide(float64(Batch) / RateWo))
|
|
|
- DWi1.Add(Multiply(l[0].I.Transpose(), delta).Divide(float64(Batch) / RateWi))
|
|
|
- DB1.Add(Multiply(Matrix{G.X[u]}.Transpose(), delta).Divide(float64(Batch) / RateB))
|
|
|
+ StartCalc(&wg, DWo1, l[0].O, delta, float64(Batch)/RateWo)
|
|
|
+ StartCalc(&wg, DWi1, l[0].I, delta, float64(Batch)/RateWi)
|
|
|
+ StartCalc(&wg, DB1, Matrix{G.X[u]}, delta, float64(Batch)/RateB)
|
|
|
neighbor := make([]int, 0)
|
|
|
for v := range G.A[u] {
|
|
|
neighbor = append(neighbor, v)
|
|
|
}
|
|
|
u = neighbor[rand.Intn(len(neighbor))]
|
|
|
+ wg.Wait()
|
|
|
}
|
|
|
- l[2].p.Wo.Add(DWo2)
|
|
|
- l[2].p.Wi.Add(DWi2)
|
|
|
- l[2].p.B.Add(DB2)
|
|
|
- l[1].p.Wo.Add(DWo1)
|
|
|
- l[1].p.Wi.Add(DWi1)
|
|
|
- l[1].p.B.Add(DB1)
|
|
|
+ StartRefine(&wg, l[2].p.Wo, DWo2)
|
|
|
+ StartRefine(&wg, l[2].p.Wi, DWi2)
|
|
|
+ StartRefine(&wg, l[2].p.B, DB2)
|
|
|
+ StartRefine(&wg, l[1].p.Wo, DWo1)
|
|
|
+ StartRefine(&wg, l[1].p.Wi, DWi1)
|
|
|
+ StartRefine(&wg, l[1].p.B, DB1)
|
|
|
+ wg.Wait()
|
|
|
}
|
|
|
return l
|
|
|
}
|