From 5a08b0c7aa64b3761051341e37b4f5ff1746acd7 Mon Sep 17 00:00:00 2001 From: mykola2312 <49044616+mykola2312@users.noreply.github.com> Date: Fri, 10 Jan 2025 06:09:47 +0200 Subject: [PATCH] working on router unit test, fixed bug when sockets weren't closed in tests --- net/lux_router.go | 43 +++++++++++++++++++++++++++--- tests/lux_router_test.go | 56 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 4 deletions(-) diff --git a/net/lux_router.go b/net/lux_router.go index 634c7eb..7ba8e12 100644 --- a/net/lux_router.go +++ b/net/lux_router.go @@ -141,6 +141,23 @@ func (r *LuxRouter) Start() { r.channelLock.RUnlock() } +func (r *LuxRouter) Stop() { + // close all channels + r.channelLock.Lock() + for _, inbound := range r.inbound { + inbound.Close() + } + r.inbound = r.inbound[0:] + + for _, outbound := range r.outbound { + outbound.Close() + } + r.outbound = r.outbound[0:] + r.channelLock.Unlock() + + r.routes = r.routes[0:] +} + func (r *LuxRouter) RecvDgram() LuxDatagram { return <-r.dgramChan } @@ -158,14 +175,18 @@ func (r *LuxRouter) getRouteIndex(udpAddr *net.UDPAddr) (LuxRoute, int) { return LuxRoute{}, -1 } -func (r *LuxRouter) getRouteIndexByID(id proto.LuxID) int { +func (r *LuxRouter) getRouteIndexByID(id proto.LuxID) (LuxRoute, int) { for idx, route := range r.routes { if bytes.Equal(route.Key.Id.UUID[:], id.UUID[:]) { - return idx + return route, idx } } - return -1 + return LuxRoute{}, -1 +} + +func (r *LuxRouter) GetRoutes() []LuxRoute { + return r.routes } func (r *LuxRouter) GetRoute(udpAddr *net.UDPAddr) (LuxRoute, bool) { @@ -223,7 +244,7 @@ func (r *LuxRouter) Recv() (LuxPacket, error) { if bytes.Equal(packet.Target.UUID[:], key.Id.UUID[:]) { // key UUID and decrypted UUID matching - create OR update route and return packet - if idx := r.getRouteIndexByID(packet.Target); idx != -1 { + if _, idx := r.getRouteIndexByID(packet.Target); idx != -1 { route := &r.routes[idx] route.Destination = dgram.Target route.Associated = dgram.Channel @@ -243,3 +264,17 @@ func (r *LuxRouter) Recv() (LuxPacket, error) { return packet, fmt.Errorf("non-peer packet from %s", dgram.Target.String()) } + +func (r *LuxRouter) Send(packet LuxPacket) error { + route, idx := r.getRouteIndexByID(packet.Target) + if idx == -1 { + return errors.New("no route to peer") + } + + dgram, err := EncryptLuxPacket(packet, route.Key, route.Destination) + if err != nil { + return err + } + + return route.Associated.Send(dgram) +} diff --git a/tests/lux_router_test.go b/tests/lux_router_test.go index 2d31e61..2e16634 100644 --- a/tests/lux_router_test.go +++ b/tests/lux_router_test.go @@ -10,7 +10,9 @@ import ( func TestDgramChannel(t *testing.T) { ks := crypto.NewLuxKeyStore("/tmp/keystore.dat") + keyA, _ := crypto.NewLuxKey(proto.LuxTypeHost) + ks.Put(keyA) r := net.NewLuxRouter(keyA, ks) err := r.CreateInboundChannel(net.LuxChannelInterior, "127.0.0.2:9979") @@ -29,6 +31,8 @@ func TestDgramChannel(t *testing.T) { } r.Start() + defer r.Stop() + outbound.Send(dgram) recv := r.RecvDgram() @@ -38,3 +42,55 @@ func TestDgramChannel(t *testing.T) { t.Log("payloads are not equal!") } } + +func TestRouterSendRecv(t *testing.T) { + ks := crypto.NewLuxKeyStore("/tmp/keystore.dat") + + keyA, _ := crypto.NewLuxKey(proto.LuxTypeHost) + ks.Put(keyA) + + routerA := net.NewLuxRouter(keyA, ks) + if err := routerA.CreateInboundChannel(net.LuxChannelInterior, "127.0.0.2:9979"); err != nil { + t.Fatal(err) + } + routerA.Start() + defer routerA.Stop() + + t.Log("routerA routing table:") + t.Log(routerA.GetRoutes()) + + routerB := net.NewLuxRouter(keyA, ks) + if err := routerB.CreateOutboundRoute(keyA.Id, net.LuxChannelInterior, "127.0.0.2:9979"); err != nil { + t.Fatal(err) + } + routerB.Start() + defer routerB.Stop() + + t.Log("routerB routing table:") + t.Log(routerB.GetRoutes()) + + payload := []byte{1, 2, 3, 4, 5, 6, 7, 8} + packet := net.LuxPacket{ + Target: keyA.Id, + Type: 0x1234, + Buffer: proto.FromSlice(payload), + } + + routerB.Send(packet) + newPacket, err := routerA.Recv() + if err != nil { + t.Fatal(err) + } + + if newPacket.ChannelType != net.LuxChannelInterior { + t.Fatal("newPacket.ChannelType != net.LuxChannelInterior") + } + if newPacket.Type != packet.Type { + t.Fatal("newPacket.Type != packet.Type") + } + if !bytes.Equal(newPacket.Buffer.AllBytes(), payload) { + t.Log(payload) + t.Log(newPacket.Buffer.AllBytes()) + t.Fatal("payloads aren't equal!") + } +}