Fix options on write request

This commit is contained in:
altugbakan 2023-03-11 19:23:23 +03:00
parent 0df85e156d
commit 181b1107a3
8 changed files with 138 additions and 103 deletions

2
Cargo.lock generated
View file

@ -4,4 +4,4 @@ version = 3
[[package]] [[package]]
name = "tftpd" name = "tftpd"
version = "0.1.0" version = "0.1.1"

View file

@ -1,9 +1,9 @@
[package] [package]
name = "tftpd" name = "tftpd"
version = "0.1.0" version = "0.1.1"
authors = ["Altuğ Bakan <mail@alt.ug>"] authors = ["Altuğ Bakan <mail@alt.ug>"]
edition = "2021" edition = "2021"
description = "TFTP Server Daemon implemented in Rust" description = "Multithreaded TFTP server daemon"
repository = "https://github.com/altugbakan/rs-tftpd" repository = "https://github.com/altugbakan/rs-tftpd"
license = "MIT" license = "MIT"
keywords = ["tftp", "server"] keywords = ["tftp", "server"]

View file

@ -28,14 +28,14 @@ impl Config {
if let Some(ip_str) = args.next() { if let Some(ip_str) = args.next() {
config.ip_address = ip_str.parse::<Ipv4Addr>()?; config.ip_address = ip_str.parse::<Ipv4Addr>()?;
} else { } else {
return Err("missing ip address after flag".into()); return Err("Missing ip address after flag".into());
} }
} }
"-p" | "--port" => { "-p" | "--port" => {
if let Some(port_str) = args.next() { if let Some(port_str) = args.next() {
config.port = port_str.parse::<u16>()?; config.port = port_str.parse::<u16>()?;
} else { } else {
return Err("missing port number after flag".into()); return Err("Missing port number after flag".into());
} }
} }
"-d" | "--directory" => { "-d" | "--directory" => {
@ -45,7 +45,7 @@ impl Config {
} }
config.directory = PathBuf::from(dir_str); config.directory = PathBuf::from(dir_str);
} else { } else {
return Err("missing directory after flag".into()); return Err("Missing directory after flag".into());
} }
} }
"-h" | "--help" => { "-h" | "--help" => {
@ -60,7 +60,7 @@ impl Config {
println!(" -h, --help\t\t\tPrint help information"); println!(" -h, --help\t\t\tPrint help information");
process::exit(0); process::exit(0);
} }
invalid => return Err(format!("invalid flag: {invalid}").into()), invalid => return Err(format!("Invalid flag: {invalid}").into()),
} }
} }

View file

@ -5,7 +5,7 @@ pub struct Convert;
impl Convert { impl Convert {
pub fn to_u16(buf: &[u8]) -> Result<u16, &'static str> { pub fn to_u16(buf: &[u8]) -> Result<u16, &'static str> {
if buf.len() < 2 { if buf.len() < 2 {
Err("error when converting to u16") Err("Error when converting to u16")
} else { } else {
Ok(((buf[0] as u16) << 8) + buf[1] as u16) Ok(((buf[0] as u16) << 8) + buf[1] as u16)
} }
@ -17,7 +17,7 @@ impl Convert {
String::from_utf8(buf[start..start + index].to_vec())?, String::from_utf8(buf[start..start + index].to_vec())?,
index + start, index + start,
)), )),
None => return Err("invalid string".into()), None => return Err("Invalid string".into()),
} }
} }
} }

View file

@ -3,7 +3,7 @@ use std::{
net::{SocketAddr, UdpSocket}, net::{SocketAddr, UdpSocket},
}; };
use crate::packet::{ErrorCode, Opcode, Packet, TransferOption}; use crate::packet::{ErrorCode, Packet, TransferOption};
pub struct Message; pub struct Message;
@ -12,25 +12,16 @@ const MAX_REQUEST_PACKET_SIZE: usize = 512;
impl Message { impl Message {
pub fn send_data( pub fn send_data(
socket: &UdpSocket, socket: &UdpSocket,
block_number: u16, block_num: u16,
data: &[u8], data: Vec<u8>,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let buf = [ socket.send(&Packet::Data { block_num, data }.serialize()?)?;
&Opcode::Data.as_bytes()[..],
&block_number.to_be_bytes(),
data,
]
.concat();
socket.send(&buf)?;
Ok(()) Ok(())
} }
pub fn send_ack(socket: &UdpSocket, block_number: u16) -> Result<(), Box<dyn Error>> { pub fn send_ack(socket: &UdpSocket, block_number: u16) -> Result<(), Box<dyn Error>> {
let buf = [Opcode::Ack.as_bytes(), block_number.to_be_bytes()].concat(); socket.send(&Packet::Ack(block_number).serialize()?)?;
socket.send(&buf)?;
Ok(()) Ok(())
} }
@ -38,9 +29,9 @@ impl Message {
pub fn send_error( pub fn send_error(
socket: &UdpSocket, socket: &UdpSocket,
code: ErrorCode, code: ErrorCode,
msg: &str, msg: String,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
socket.send(&build_error_buf(code, msg))?; socket.send(&Packet::Error { code, msg }.serialize()?)?;
Ok(()) Ok(())
} }
@ -49,9 +40,19 @@ impl Message {
socket: &UdpSocket, socket: &UdpSocket,
to: &SocketAddr, to: &SocketAddr,
code: ErrorCode, code: ErrorCode,
msg: &'a str, msg: String,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
if socket.send_to(&build_error_buf(code, msg), to).is_err() { if socket
.send_to(
&Packet::Error {
code,
msg: msg.clone(),
}
.serialize()?,
to,
)
.is_err()
{
eprintln!("could not send an error message"); eprintln!("could not send an error message");
} }
Err(msg.into()) Err(msg.into())
@ -59,15 +60,9 @@ impl Message {
pub fn send_oack( pub fn send_oack(
socket: &UdpSocket, socket: &UdpSocket,
options: &Vec<TransferOption>, options: Vec<TransferOption>,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let mut buf = Opcode::Oack.as_bytes().to_vec(); socket.send(&Packet::Oack { options }.serialize()?)?;
for option in options {
buf = [buf, option.as_bytes()].concat();
}
socket.send(&buf)?;
Ok(()) Ok(())
} }
@ -96,13 +91,3 @@ impl Message {
Ok((packet, from)) Ok((packet, from))
} }
} }
fn build_error_buf(code: ErrorCode, msg: &str) -> Vec<u8> {
[
&Opcode::Error.as_bytes()[..],
&code.as_bytes()[..],
&msg.as_bytes()[..],
&[0x00],
]
.concat()
}

View file

@ -21,6 +21,9 @@ pub enum Packet {
code: ErrorCode, code: ErrorCode,
msg: String, msg: String,
}, },
Oack {
options: Vec<TransferOption>,
},
} }
impl Packet { impl Packet {
@ -32,7 +35,17 @@ impl Packet {
Opcode::Data => parse_data(buf), Opcode::Data => parse_data(buf),
Opcode::Ack => parse_ack(buf), Opcode::Ack => parse_ack(buf),
Opcode::Error => parse_error(buf), Opcode::Error => parse_error(buf),
_ => Err("invalid packet".into()), _ => Err("Invalid packet".into()),
}
}
pub fn serialize(&self) -> Result<Vec<u8>, Box<dyn Error>> {
match self {
Packet::Data { block_num, data } => Ok(serialize_data(block_num, data)),
Packet::Ack(block_num) => Ok(serialize_ack(block_num)),
Packet::Error { code, msg } => Ok(serialize_error(code, msg)),
Packet::Oack { options } => Ok(serialize_oack(options)),
_ => Err("Invalid packet".into()),
} }
} }
} }
@ -57,7 +70,7 @@ impl Opcode {
0x0004 => Ok(Opcode::Ack), 0x0004 => Ok(Opcode::Ack),
0x0005 => Ok(Opcode::Error), 0x0005 => Ok(Opcode::Error),
0x0006 => Ok(Opcode::Oack), 0x0006 => Ok(Opcode::Oack),
_ => Err("invalid opcode"), _ => Err("Invalid opcode"),
} }
} }
@ -105,13 +118,13 @@ impl OptionType {
"blksize" => Ok(OptionType::BlockSize), "blksize" => Ok(OptionType::BlockSize),
"tsize" => Ok(OptionType::TransferSize), "tsize" => Ok(OptionType::TransferSize),
"timeout" => Ok(OptionType::Timeout), "timeout" => Ok(OptionType::Timeout),
_ => Err("invalid option type".into()), _ => Err("Invalid option type".into()),
} }
} }
} }
#[repr(u16)] #[repr(u16)]
#[derive(PartialEq, Debug)] #[derive(Clone, Copy, PartialEq, Debug)]
pub enum ErrorCode { pub enum ErrorCode {
NotDefined = 0, NotDefined = 0,
FileNotFound = 1, FileNotFound = 1,
@ -134,7 +147,7 @@ impl ErrorCode {
5 => Ok(ErrorCode::UnknownId), 5 => Ok(ErrorCode::UnknownId),
6 => Ok(ErrorCode::FileExists), 6 => Ok(ErrorCode::FileExists),
7 => Ok(ErrorCode::NoSuchUser), 7 => Ok(ErrorCode::NoSuchUser),
_ => Err("invalid error code"), _ => Err("Invalid error code"),
} }
} }
@ -192,7 +205,7 @@ fn parse_rq(buf: &[u8], opcode: Opcode) -> Result<Packet, Box<dyn Error>> {
mode, mode,
options, options,
}), }),
_ => Err("non request opcode".into()), _ => Err("Non request opcode".into()),
} }
} }
@ -213,6 +226,39 @@ fn parse_error(buf: &[u8]) -> Result<Packet, Box<dyn Error>> {
Ok(Packet::Error { code, msg }) Ok(Packet::Error { code, msg })
} }
fn serialize_data(block_num: &u16, data: &Vec<u8>) -> Vec<u8> {
[
&Opcode::Data.as_bytes(),
&block_num.to_be_bytes(),
data.as_slice(),
]
.concat()
}
fn serialize_ack(block_num: &u16) -> Vec<u8> {
[Opcode::Ack.as_bytes(), block_num.to_be_bytes()].concat()
}
fn serialize_error(code: &ErrorCode, msg: &String) -> Vec<u8> {
[
&Opcode::Error.as_bytes()[..],
&code.as_bytes()[..],
&msg.as_bytes()[..],
&[0x00],
]
.concat()
}
fn serialize_oack(options: &Vec<TransferOption>) -> Vec<u8> {
let mut buf = Opcode::Oack.as_bytes().to_vec();
for option in options {
buf = [buf, option.as_bytes()].concat();
}
buf
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -29,28 +29,28 @@ impl Server {
filename, filename,
mut options, mut options,
.. ..
} => match self.handle_rrq(filename.clone(), &mut options, &from) { } => {
Ok(_) => { println!("Sending {filename} to {from}");
println!("Sending {filename} to {from}"); if let Err(err) = self.handle_rrq(filename.clone(), &mut options, &from) {
eprintln!("{err}")
} }
Err(err) => eprintln!("{err}"), }
},
Packet::Wrq { Packet::Wrq {
filename, filename,
mut options, mut options,
.. ..
} => match self.handle_wrq(filename.clone(), &mut options, &from) { } => {
Ok(_) => { println!("Receiving {filename} from {from}");
println!("Receiving {filename} from {from}"); if let Err(err) = self.handle_wrq(filename.clone(), &mut options, &from) {
eprintln!("{err}")
} }
Err(err) => eprintln!("{err}"), }
},
_ => { _ => {
Message::send_error_to( Message::send_error_to(
&self.socket, &self.socket,
&from, &from,
ErrorCode::IllegalOperation, ErrorCode::IllegalOperation,
"invalid request", "invalid request".to_string(),
) )
.unwrap_or_else(|err| eprintln!("{err}")); .unwrap_or_else(|err| eprintln!("{err}"));
} }
@ -71,7 +71,7 @@ impl Server {
&self.socket, &self.socket,
to, to,
ErrorCode::FileNotFound, ErrorCode::FileNotFound,
"file does not exist", "file does not exist".to_string(),
); );
} }
ErrorCode::AccessViolation => { ErrorCode::AccessViolation => {
@ -79,7 +79,7 @@ impl Server {
&self.socket, &self.socket,
to, to,
ErrorCode::AccessViolation, ErrorCode::AccessViolation,
"file access violation", "file access violation".to_string(),
); );
} }
ErrorCode::FileExists => Worker::send( ErrorCode::FileExists => Worker::send(
@ -106,7 +106,7 @@ impl Server {
&self.socket, &self.socket,
to, to,
ErrorCode::FileExists, ErrorCode::FileExists,
"requested file already exists", "requested file already exists".to_string(),
); );
} }
ErrorCode::AccessViolation => { ErrorCode::AccessViolation => {
@ -114,7 +114,7 @@ impl Server {
&self.socket, &self.socket,
to, to,
ErrorCode::AccessViolation, ErrorCode::AccessViolation,
"file access violation", "file access violation".to_string(),
); );
} }
ErrorCode::FileNotFound => Worker::receive( ErrorCode::FileNotFound => Worker::receive(

View file

@ -42,9 +42,9 @@ impl Worker {
let mut handle_send = || -> Result<(), Box<dyn Error>> { let mut handle_send = || -> Result<(), Box<dyn Error>> {
let socket = setup_socket(&addr, &remote)?; let socket = setup_socket(&addr, &remote)?;
let work_type = WorkType::Send(Path::new(&filename).metadata().unwrap().len()); let work_type = WorkType::Send(Path::new(&filename).metadata().unwrap().len());
let worker_options = parse_options(&mut options, &work_type)?; accept_request(&socket, &options, &work_type)?;
send_oack(&socket, &options, &work_type)?; check_response(&socket)?;
send_file(&socket, &worker_options, &filename, &mut options)?; send_file(&socket, &filename, &mut options)?;
Ok(()) Ok(())
}; };
@ -65,9 +65,8 @@ impl Worker {
let mut handle_receive = || -> Result<(), Box<dyn Error>> { let mut handle_receive = || -> Result<(), Box<dyn Error>> {
let socket = setup_socket(&addr, &remote)?; let socket = setup_socket(&addr, &remote)?;
let work_type = WorkType::Receive; let work_type = WorkType::Receive;
let worker_options = parse_options(&mut options, &work_type)?; accept_request(&socket, &options, &work_type)?;
send_oack(&socket, &options, &work_type)?; receive_file(&socket, &filename, &mut options)?;
receive_file(&socket, &worker_options, &filename, &mut options)?;
Ok(()) Ok(())
}; };
@ -81,13 +80,11 @@ impl Worker {
fn send_file( fn send_file(
socket: &UdpSocket, socket: &UdpSocket,
worker_options: &WorkerOptions,
filename: &String, filename: &String,
options: &mut Vec<TransferOption>, options: &mut Vec<TransferOption>,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let mut file = File::open(filename).unwrap(); let mut file = File::open(filename).unwrap();
let worker_options = parse_options(options, &WorkType::Send(file.metadata().unwrap().len()))?;
parse_options(options, &WorkType::Send(file.metadata().unwrap().len()))?;
let mut block_number = 1; let mut block_number = 1;
loop { loop {
@ -96,7 +93,7 @@ fn send_file(
let mut retry_cnt = 0; let mut retry_cnt = 0;
loop { loop {
Message::send_data(socket, block_number, &chunk[..size])?; Message::send_data(socket, block_number, chunk[..size].to_vec())?;
match Message::recv(socket) { match Message::recv(socket) {
Ok(Packet::Ack(received_block_number)) => { Ok(Packet::Ack(received_block_number)) => {
@ -106,12 +103,12 @@ fn send_file(
} }
} }
Ok(Packet::Error { code, msg }) => { Ok(Packet::Error { code, msg }) => {
return Err(format!("received error code {code}, with message {msg}").into()); return Err(format!("Received error code {code}, with message {msg}").into());
} }
_ => { _ => {
retry_cnt += 1; retry_cnt += 1;
if retry_cnt == MAX_RETRIES { if retry_cnt == MAX_RETRIES {
return Err(format!("transfer timed out after {MAX_RETRIES} tries").into()); return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into());
} }
} }
} }
@ -128,13 +125,11 @@ fn send_file(
fn receive_file( fn receive_file(
socket: &UdpSocket, socket: &UdpSocket,
worker_options: &WorkerOptions,
filename: &String, filename: &String,
options: &mut Vec<TransferOption>, options: &mut Vec<TransferOption>,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let mut file = File::create(filename).unwrap(); let mut file = File::create(filename).unwrap();
let worker_options = parse_options(options, &WorkType::Receive)?;
parse_options(options, &WorkType::Receive)?;
let mut block_number: u16 = 0; let mut block_number: u16 = 0;
loop { loop {
@ -155,13 +150,13 @@ fn receive_file(
} }
} }
Ok(Packet::Error { code, msg }) => { Ok(Packet::Error { code, msg }) => {
return Err(format!("received error code {code}: {msg}").into()); return Err(format!("Received error code {code}: {msg}").into());
} }
Err(err) => { Err(err) => {
retry_cnt += 1; retry_cnt += 1;
if retry_cnt == MAX_RETRIES { if retry_cnt == MAX_RETRIES {
return Err( return Err(
format!("transfer timed out after {MAX_RETRIES} tries: {err}").into(), format!("Transfer timed out after {MAX_RETRIES} tries: {err}").into(),
); );
} }
} }
@ -179,6 +174,34 @@ fn receive_file(
Ok(()) Ok(())
} }
fn accept_request(
socket: &UdpSocket,
options: &Vec<TransferOption>,
work_type: &WorkType,
) -> Result<(), Box<dyn Error>> {
if options.len() > 0 {
Message::send_oack(socket, options.to_vec())?;
} else if *work_type == WorkType::Receive {
Message::send_ack(socket, 0)?
}
Ok(())
}
fn check_response(socket: &UdpSocket) -> Result<(), Box<dyn Error>> {
if let Packet::Ack(received_block_number) = Message::recv(&socket)? {
if received_block_number != 0 {
Message::send_error(
&socket,
ErrorCode::IllegalOperation,
"invalid oack response".to_string(),
)?;
}
}
Ok(())
}
fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result<UdpSocket, Box<dyn Error>> { fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result<UdpSocket, Box<dyn Error>> {
let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?;
socket.connect(remote)?; socket.connect(remote)?;
@ -212,7 +235,7 @@ fn parse_options(
}, },
OptionType::Timeout => { OptionType::Timeout => {
if *value == 0 { if *value == 0 {
return Err("invalid timeout value".into()); return Err("Invalid timeout value".into());
} }
worker_options.timeout = *value as u64; worker_options.timeout = *value as u64;
} }
@ -221,22 +244,3 @@ fn parse_options(
Ok(worker_options) Ok(worker_options)
} }
fn send_oack(
socket: &UdpSocket,
options: &Vec<TransferOption>,
work_type: &WorkType,
) -> Result<(), Box<dyn Error>> {
if options.len() > 0 {
Message::send_oack(socket, options)?;
if let Packet::Ack(received_block_number) = Message::recv(socket)? {
if received_block_number != 0 {
Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?;
}
}
} else if *work_type == WorkType::Receive {
Message::send_ack(socket, 0)?
}
Ok(())
}