diff --git a/rpc/lux_rpc_client.go b/rpc/lux_rpc_client.go new file mode 100644 index 0000000..d1e2d59 --- /dev/null +++ b/rpc/lux_rpc_client.go @@ -0,0 +1,80 @@ +package rpc + +import ( + "encoding/xml" + "fmt" + "net" +) + +type LuxRpcClient struct { + conn net.Conn + counter int +} + +func LuxDialRpc(network string, address string) (LuxRpcClient, error) { + conn, err := net.Dial(network, address) + if err != nil { + return LuxRpcClient{}, err + } + + return LuxRpcClient{ + conn: conn, + counter: 0, + }, nil +} + +func (rpc *LuxRpcClient) Close() { + rpc.conn.Close() +} + +func (rpc *LuxRpcClient) Execute(request LuxRpcRequest) (LuxRpcResponse, LuxRpcError, error) { + var rpcRes LuxRpcResponse + var rpcErr LuxRpcError + + request.RequestID = rpc.counter + rpc.counter++ + + xmlBytes, err := xml.Marshal(&request) + if err != nil { + log.Errorf("failed to marshal request: %v", err) + return rpcRes, rpcErr, err + } + + _, err = rpc.conn.Write(xmlBytes) + if err != nil { + log.Debugf("rpc client send failed: %v", err) + return rpcRes, rpcErr, err + } + + def := NewLuxRpcDefrag() + part := make([]byte, 1500) + + for { + n, err := rpc.conn.Read(part) + if err != nil { + log.Debugf("rpc client recv failed: %v", err) + return rpcRes, rpcErr, err + } + + if def.Feed(part[n:]) { + // got full data, either its response or error + if def.HasResponse() { + xmlBytes := def.GetAndForget() + if err := xml.Unmarshal(xmlBytes, &rpcRes); err != nil { + return rpcRes, rpcErr, fmt.Errorf("failed to unmarshal rpc response %v: %s", err, string(xmlBytes)) + } + + return rpcRes, rpcErr, nil + } else if def.HasError() { + xmlBytes := def.GetAndForget() + if err := xml.Unmarshal(xmlBytes, &rpcErr); err != nil { + return rpcRes, rpcErr, fmt.Errorf("failed to unmarshal rpc error %v: %s", err, string(xmlBytes)) + } + + return rpcRes, rpcErr, nil + } else { + return rpcRes, rpcErr, fmt.Errorf("got unknown response from rpc server: %s", string(def.GetAndForget())) + } + } + } +} diff --git a/rpc/lux_rpc_server.go b/rpc/lux_rpc_server.go index 69a81f2..03ccd75 100644 --- a/rpc/lux_rpc_server.go +++ b/rpc/lux_rpc_server.go @@ -3,6 +3,7 @@ package rpc import ( "encoding/xml" "net" + "os" "sync" ) @@ -41,6 +42,11 @@ func (rpc *LuxRpcServer) HandleRequest(request LuxRpcRequest, rpcType LuxRpcType } func (rpc *LuxRpcServer) AddEndpoint(network string, listenOn string, rpcType LuxRpcType) error { + // cleanup old socket files + if network == "unix" { + os.Remove(listenOn) + } + listener, err := net.Listen(network, listenOn) if err != nil { return err diff --git a/tests/lux_rpc_test.go b/tests/lux_rpc_data_test.go similarity index 100% rename from tests/lux_rpc_test.go rename to tests/lux_rpc_data_test.go diff --git a/tests/lux_rpc_server_test.go b/tests/lux_rpc_server_test.go new file mode 100644 index 0000000..7acab82 --- /dev/null +++ b/tests/lux_rpc_server_test.go @@ -0,0 +1,46 @@ +package tests + +import ( + "lux/rpc" + "testing" +) + +type DummyController struct { + t *testing.T +} + +func (ctrl *DummyController) GetRpcName() string { + return "dummy" +} + +func (ctrl *DummyController) Register(sv *rpc.LuxRpcServer) {} + +func (ctrl *DummyController) Handle(request rpc.LuxRpcRequest, rpcType rpc.LuxRpcType) (rpc.LuxRpcResponse, rpc.LuxRpcError, bool) { + ctrl.t.Log(request, rpcType) + + return rpc.LuxRpcResponse{}, rpc.LuxRpcError{}, true +} + +func TestRpcServerClient(t *testing.T) { + sv := rpc.NewLuxRpcServer() + sv.RegisterController(&DummyController{t: t}) + err := sv.AddEndpoint("unix", "/tmp/lux-rpc-test.sock", rpc.LuxRpcTypeRoot) + if err != nil { + t.Fatal(err) + } + + cl, err := rpc.LuxDialRpc("unix", "/tmp/lux-rpc-test.sock") + if err != nil { + t.Fatal(err) + } + + rpcRes, rpcErr, err := cl.Execute(rpc.LuxRpcRequest{ + Controller: "dummy", + }) + if err != nil { + t.Fatal(err) + } + + t.Log(rpcRes) + t.Log(rpcErr) +}