diff --git a/net/lux_router.go b/net/lux_router.go index af5d975..a38a00d 100644 --- a/net/lux_router.go +++ b/net/lux_router.go @@ -5,6 +5,7 @@ import ( "lux/crypto" "lux/proto" "net" + "sync" ) type LuxRoute struct { @@ -19,28 +20,42 @@ type LuxRouter struct { routes []LuxRoute - outbound []LuxChannel - inbound []LuxChannel + channelLock sync.RWMutex + outbound []LuxChannel + inbound []LuxChannel + + dgramChan chan LuxDatagram } func NewLuxRouter(key crypto.LuxKey, ks crypto.LuxKeyStore) LuxRouter { return LuxRouter{ - thisKey: key, - keyStore: ks, - routes: make([]LuxRoute, 0), - outbound: make([]LuxChannel, 0), - inbound: make([]LuxChannel, 0), + thisKey: key, + keyStore: ks, + routes: make([]LuxRoute, 0), + outbound: make([]LuxChannel, 0), + inbound: make([]LuxChannel, 0), + dgramChan: make(chan LuxDatagram), } } func (r *LuxRouter) addOutboundChannel(ch LuxChannel) *LuxChannel { + r.channelLock.Lock() + r.outbound = append(r.outbound, ch) - return &r.outbound[len(r.outbound)-1] + channel := &r.outbound[len(r.outbound)-1] + + r.channelLock.Unlock() + return channel } func (r *LuxRouter) addInboundChannel(ch LuxChannel) *LuxChannel { + r.channelLock.Lock() + r.inbound = append(r.inbound, ch) - return &r.inbound[len(r.inbound)-1] + channel := &r.inbound[len(r.inbound)-1] + + r.channelLock.Unlock() + return channel } func (r *LuxRouter) CreateOutboundRoute(id proto.LuxID, chType LuxChannelType, udpAddr string) error { @@ -77,3 +92,51 @@ func (r *LuxRouter) CreateInboundChannel(chType LuxChannelType, udpAddr string) }) return nil } + +// close channel when error happened +func (r *LuxRouter) CloseChannel(channel *LuxChannel) { + r.channelLock.Lock() + for i, ch := range r.outbound { + if &ch == channel { + r.outbound = append(r.outbound[:i], r.outbound[i+1:]...) + } + } + + for i, ch := range r.inbound { + if &ch == channel { + r.inbound = append(r.inbound[:i], r.inbound[:i+1]...) + } + } + r.channelLock.Unlock() +} + +func (r *LuxRouter) GetDgramChannel() chan<- LuxDatagram { + return r.dgramChan +} + +// goroutine to receive datagrams and send them to router over channel +func channelReceiver(r *LuxRouter, channel *LuxChannel) { + dgramChan := r.GetDgramChannel() + + var dgram LuxDatagram + var err error + + for err == nil { + dgram, err = channel.Recv() + dgramChan <- dgram + } + + r.CloseChannel(channel) +} + +func (r *LuxRouter) Start() { + r.channelLock.RLock() + for _, inbound := range r.inbound { + go channelReceiver(r, &inbound) + } + r.channelLock.RUnlock() +} + +func (r *LuxRouter) RecvDgram() LuxDatagram { + return <-r.dgramChan +} diff --git a/tests/lux_router_test.go b/tests/lux_router_test.go new file mode 100644 index 0000000..2d31e61 --- /dev/null +++ b/tests/lux_router_test.go @@ -0,0 +1,40 @@ +package tests + +import ( + "bytes" + "lux/crypto" + "lux/net" + "lux/proto" + "testing" +) + +func TestDgramChannel(t *testing.T) { + ks := crypto.NewLuxKeyStore("/tmp/keystore.dat") + keyA, _ := crypto.NewLuxKey(proto.LuxTypeHost) + + r := net.NewLuxRouter(keyA, ks) + err := r.CreateInboundChannel(net.LuxChannelInterior, "127.0.0.2:9979") + if err != nil { + t.Fatal(err) + } + + outbound, err := net.NewLuxOutboundChannel("127.0.0.2:9979", net.LuxChannelInterior) + if err != nil { + t.Fatal(err) + } + + dgram := net.LuxDatagram{ + Target: outbound.Address, + Payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + } + + r.Start() + outbound.Send(dgram) + + recv := r.RecvDgram() + if !bytes.Equal(dgram.Payload, recv.Payload) { + t.Log(dgram) + t.Log(recv) + t.Log("payloads are not equal!") + } +}