diff --git a/net/lux_packet.go b/net/lux_packet.go new file mode 100644 index 0000000..41b8a1e --- /dev/null +++ b/net/lux_packet.go @@ -0,0 +1,87 @@ +package net + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "lux/crypto" + "lux/proto" + "net" +) + +const LUX_PROTO_PACKET_HDRLEN = proto.LUX_PROTO_ID_SIZE + 2 + +type LuxPacket struct { + Target proto.LuxID + Type uint + Buffer proto.LuxBuffer +} + +func DecryptLuxPacket(dgram LuxDatagram, key crypto.LuxKey) (LuxPacket, error) { + packet := LuxPacket{} + + cipherBlock, err := aes.NewCipher(key.Key) + if err != nil { + return packet, err + } + decrypter := cipher.NewCBCDecrypter(cipherBlock, key.IV) + + payloadLen := len(dgram.Payload) + if payloadLen == 0 || payloadLen%cipherBlock.BlockSize() != 0 { + return packet, errors.New("payload is not block aligned") + } + + payload := make([]byte, payloadLen) + decrypter.CryptBlocks(payload, dgram.Payload) + + packet.Buffer = proto.FromSlice(payload) + if err := packet.Target.Read(&packet.Buffer); err != nil { + return packet, err + + } + + if pType, err := packet.Buffer.ReadUint16(); err == nil { + packet.Type = uint(pType) + } else { + return packet, err + } + + return packet, nil +} + +func EncryptLuxPacket(packet LuxPacket, key crypto.LuxKey, target *net.UDPAddr) (LuxDatagram, error) { + dgram := LuxDatagram{} + + cipherBlock, err := aes.NewCipher(key.Key) + if err != nil { + return dgram, err + } + encrypter := cipher.NewCBCEncrypter(cipherBlock, key.IV) + + var paddingLen int + packetLen := packet.Buffer.Length() + if packetLen%encrypter.BlockSize() != 0 { + paddingLen = ((packetLen / encrypter.BlockSize()) + 1) * encrypter.BlockSize() + } else { + paddingLen = 0 + } + + padding := make([]byte, paddingLen) + _, err = rand.Read(padding) + if err != nil { + return dgram, err + } + + wd := proto.AllocLuxBuffer(LUX_PROTO_PACKET_HDRLEN + packetLen + paddingLen) + key.Id.Write(&wd) + wd.WriteUint16(uint16(packet.Type)) + wd.WriteBytes(packet.Buffer.AllBytes()) + wd.WriteBytes(padding) + + dgram.Target = target + dgram.Payload = make([]byte, wd.Length()) + + encrypter.CryptBlocks(dgram.Payload, wd.AllBytes()) + return dgram, nil +} diff --git a/tests/lux_packet_test.go b/tests/lux_packet_test.go new file mode 100644 index 0000000..3a0a073 --- /dev/null +++ b/tests/lux_packet_test.go @@ -0,0 +1,46 @@ +package tests + +import ( + "bytes" + "lux/crypto" + "lux/net" + "lux/proto" + "testing" +) + +func TestEncryptDecrypt(t *testing.T) { + key, _ := crypto.NewLuxKey(proto.LuxTypeHost) + packet := net.LuxPacket{ + Target: key.Id, + Type: 0x1234, + Buffer: proto.NewLuxBuffer(), + } + packet.Buffer.WriteString("hello world! very unaligned") + + dgram, err := net.EncryptLuxPacket(packet, key, nil) + if err != nil { + t.Fatal(err) + } + + new, err := net.DecryptLuxPacket(dgram, key) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(key.Id.UUID[:], new.Target.UUID[:]) { + t.Fatal(key.Id) + t.Fatal(new.Target.UUID) + t.Fatal("LuxIDs are not equal") + } + if new.Type != 0x1234 { + t.Fatalf("%x new lux packet type instead of 0x1234", new.Type) + } + + str, err := new.Buffer.ReadString() + if err != nil { + t.Fatal(err) + } + if str != "hello world! very unaligned" { + t.Fatalf("wrong string: %s", str) + } +}