lux/net/lux_router.go

142 lines
2.9 KiB
Go

package net
import (
"errors"
"lux/crypto"
"lux/proto"
"net"
"sync"
)
type LuxRoute struct {
Key crypto.LuxKey
Destination *net.UDPAddr
Associated *LuxChannel
}
type LuxRouter struct {
thisKey crypto.LuxKey
keyStore crypto.LuxKeyStore
routes []LuxRoute
channelLock sync.RWMutex
outbound []LuxChannel
inbound []LuxChannel
dgramChan chan LuxDatagram
}
func NewLuxRouter(key crypto.LuxKey, ks crypto.LuxKeyStore) LuxRouter {
return LuxRouter{
thisKey: key,
keyStore: ks,
routes: make([]LuxRoute, 0),
outbound: make([]LuxChannel, 0),
inbound: make([]LuxChannel, 0),
dgramChan: make(chan LuxDatagram),
}
}
func (r *LuxRouter) addOutboundChannel(ch LuxChannel) *LuxChannel {
r.channelLock.Lock()
r.outbound = append(r.outbound, ch)
channel := &r.outbound[len(r.outbound)-1]
r.channelLock.Unlock()
return channel
}
func (r *LuxRouter) addInboundChannel(ch LuxChannel) *LuxChannel {
r.channelLock.Lock()
r.inbound = append(r.inbound, ch)
channel := &r.inbound[len(r.inbound)-1]
r.channelLock.Unlock()
return channel
}
func (r *LuxRouter) CreateOutboundRoute(id proto.LuxID, chType LuxChannelType, udpAddr string) error {
// we gonna look up key by id from key store
key, ok := r.keyStore.Get(id)
if !ok {
return errors.New("key not found")
}
// create outbound channel
channel, err := NewLuxOutboundChannel(udpAddr, chType)
if err != nil {
return err
}
r.routes = append(r.routes, LuxRoute{
Key: key,
Destination: channel.Address,
Associated: r.addOutboundChannel(channel),
})
return nil
}
func (r *LuxRouter) CreateInboundChannel(chType LuxChannelType, udpAddr string) error {
channel, err := NewLuxInboundChannel(udpAddr, chType)
if err != nil {
return err
}
r.routes = append(r.routes, LuxRoute{
Key: r.thisKey,
Destination: channel.Address,
Associated: r.addInboundChannel(channel),
})
return nil
}
// close channel when error happened
func (r *LuxRouter) CloseChannel(channel *LuxChannel) {
r.channelLock.Lock()
for i, ch := range r.outbound {
if &ch == channel {
r.outbound = append(r.outbound[:i], r.outbound[i+1:]...)
}
}
for i, ch := range r.inbound {
if &ch == channel {
r.inbound = append(r.inbound[:i], r.inbound[:i+1]...)
}
}
r.channelLock.Unlock()
}
func (r *LuxRouter) GetDgramChannel() chan<- LuxDatagram {
return r.dgramChan
}
// goroutine to receive datagrams and send them to router over channel
func channelReceiver(r *LuxRouter, channel *LuxChannel) {
dgramChan := r.GetDgramChannel()
var dgram LuxDatagram
var err error
for err == nil {
dgram, err = channel.Recv()
dgramChan <- dgram
}
r.CloseChannel(channel)
}
func (r *LuxRouter) Start() {
r.channelLock.RLock()
for _, inbound := range r.inbound {
go channelReceiver(r, &inbound)
}
r.channelLock.RUnlock()
}
func (r *LuxRouter) RecvDgram() LuxDatagram {
return <-r.dgramChan
}