diff --git a/Cargo.lock b/Cargo.lock index e84a412..26d86b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,4 +4,4 @@ version = 3 [[package]] name = "tftpd" -version = "0.0.0" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 7f1a563..d16e5a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tftpd" -version = "0.0.0" +version = "0.1.0" authors = ["Altuğ Bakan "] edition = "2021" description = "TFTP Server Daemon implemented in Rust" diff --git a/src/main.rs b/src/main.rs index 3b6827d..076cea6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,14 +3,14 @@ use tftpd::{Config, Server}; fn main() { let config = Config::new(env::args()).unwrap_or_else(|err| { - eprintln!("Problem parsing arguments: {}", err); + eprintln!("Problem parsing arguments: {err}"); process::exit(1) }); let server = Server::new(&config).unwrap_or_else(|err| { eprintln!( - "Problem creating server on {}:{}: {}", - config.ip_address, config.port, err + "Problem creating server on {}:{}: {err}", + config.ip_address, config.port ); process::exit(1) }); diff --git a/src/message.rs b/src/message.rs index 32014db..5ae4c4b 100644 --- a/src/message.rs +++ b/src/message.rs @@ -7,17 +7,28 @@ use crate::packet::{ErrorCode, Opcode, Packet, TransferOption}; pub struct Message; +const MAX_REQUEST_PACKET_SIZE: usize = 512; + impl Message { - pub fn send_data(socket: &UdpSocket, data: &[u8]) -> Result<(), Box> { - let buf = [&Opcode::Data.as_bytes()[..], data].concat(); + pub fn send_data( + socket: &UdpSocket, + block_number: u16, + data: &[u8], + ) -> Result<(), Box> { + let buf = [ + &Opcode::Data.as_bytes()[..], + &block_number.to_be_bytes(), + data, + ] + .concat(); socket.send(&buf)?; Ok(()) } - pub fn send_ack(socket: &UdpSocket, block: u16) -> Result<(), Box> { - let buf = [Opcode::Ack.as_bytes(), block.to_be_bytes()].concat(); + pub fn send_ack(socket: &UdpSocket, block_number: u16) -> Result<(), Box> { + let buf = [Opcode::Ack.as_bytes(), block_number.to_be_bytes()].concat(); socket.send(&buf)?; @@ -29,16 +40,21 @@ impl Message { code: ErrorCode, msg: &str, ) -> Result<(), Box> { - socket.send(&get_error_buf(code, msg))?; + socket.send(&build_error_buf(code, msg))?; Ok(()) } - pub fn send_error_to(socket: &UdpSocket, to: &SocketAddr, code: ErrorCode, msg: &str) { - eprintln!("{msg}"); - if socket.send_to(&get_error_buf(code, msg), to).is_err() { + pub fn send_error_to<'a>( + socket: &UdpSocket, + to: &SocketAddr, + code: ErrorCode, + msg: &'a str, + ) -> Result<(), Box> { + if socket.send_to(&build_error_buf(code, msg), to).is_err() { eprintln!("could not send an error message"); } + Err(msg.into()) } pub fn send_oack( @@ -56,19 +72,32 @@ impl Message { Ok(()) } - pub fn receive_ack(socket: &UdpSocket) -> Result> { - let mut buf = [0; 4]; - socket.recv(&mut buf)?; + pub fn recv(socket: &UdpSocket) -> Result> { + let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; + let number_of_bytes = socket.recv(&mut buf)?; + let packet = Packet::deserialize(&buf[..number_of_bytes])?; - if let Ok(Packet::Ack(block)) = Packet::deserialize(&buf) { - Ok(block) - } else { - Err("invalid ack".into()) - } + Ok(packet) + } + + pub fn recv_data(socket: &UdpSocket, size: usize) -> Result> { + let mut buf = vec![0; size + 4]; + let number_of_bytes = socket.recv(&mut buf)?; + let packet = Packet::deserialize(&buf[..number_of_bytes])?; + + Ok(packet) + } + + pub fn recv_from(socket: &UdpSocket) -> Result<(Packet, SocketAddr), Box> { + let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; + let (number_of_bytes, from) = socket.recv_from(&mut buf)?; + let packet = Packet::deserialize(&buf[..number_of_bytes])?; + + Ok((packet, from)) } } -fn get_error_buf(code: ErrorCode, msg: &str) -> Vec { +fn build_error_buf(code: ErrorCode, msg: &str) -> Vec { [ &Opcode::Error.as_bytes()[..], &code.as_bytes()[..], diff --git a/src/packet.rs b/src/packet.rs index f377103..95bd61b 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -1,7 +1,7 @@ use crate::Convert; -use std::error::Error; +use std::{error::Error, fmt}; -pub enum Packet<'a> { +pub enum Packet { Rrq { filename: String, mode: String, @@ -14,7 +14,7 @@ pub enum Packet<'a> { }, Data { block_num: u16, - data: &'a [u8], + data: Vec, }, Ack(u16), Error { @@ -23,9 +23,9 @@ pub enum Packet<'a> { }, } -impl<'a> Packet<'a> { - pub fn deserialize(buf: &'a [u8]) -> Result> { - let opcode = Opcode::from_u16(Convert::to_u16(&buf[0..1])?)?; +impl Packet { + pub fn deserialize(buf: &[u8]) -> Result> { + let opcode = Opcode::from_u16(Convert::to_u16(&buf[0..=1])?)?; match opcode { Opcode::Rrq | Opcode::Wrq => parse_rq(buf, opcode), @@ -66,7 +66,7 @@ impl Opcode { } } -#[derive(Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] pub struct TransferOption { pub option: OptionType, pub value: usize, @@ -84,7 +84,7 @@ impl TransferOption { } } -#[derive(Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum OptionType { BlockSize, TransferSize, @@ -100,10 +100,6 @@ impl OptionType { } } - fn as_bytes(&self) -> &[u8] { - self.as_str().as_bytes() - } - fn from_str(value: &str) -> Result { match value { "blksize" => Ok(OptionType::BlockSize), @@ -147,6 +143,21 @@ impl ErrorCode { } } +impl fmt::Display for ErrorCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ErrorCode::NotDefined => write!(f, "Not Defined"), + ErrorCode::FileNotFound => write!(f, "File Not Found"), + ErrorCode::AccessViolation => write!(f, "Access Violation"), + ErrorCode::DiskFull => write!(f, "Disk Full"), + ErrorCode::IllegalOperation => write!(f, "Illegal Operation"), + ErrorCode::UnknownId => write!(f, "Unknown ID"), + ErrorCode::FileExists => write!(f, "File Exists"), + ErrorCode::NoSuchUser => write!(f, "No Such User"), + } + } +} + fn parse_rq(buf: &[u8], opcode: Opcode) -> Result> { let mut options = vec![]; let filename: String; @@ -188,7 +199,7 @@ fn parse_rq(buf: &[u8], opcode: Opcode) -> Result> { fn parse_data(buf: &[u8]) -> Result> { Ok(Packet::Data { block_num: Convert::to_u16(&buf[2..])?, - data: &buf[4..], + data: buf[4..].to_vec(), }) } @@ -239,11 +250,11 @@ mod tests { &[0x00], &"octet".as_bytes(), &[0x00], - &OptionType::TransferSize.as_bytes(), + &OptionType::TransferSize.as_str().as_bytes(), &[0x00], &"0".as_bytes(), &[0x00], - &OptionType::Timeout.as_bytes(), + &OptionType::Timeout.as_str().as_bytes(), &[0x00], &"5".as_bytes(), &[0x00], @@ -311,11 +322,11 @@ mod tests { &[0x00], &"octet".as_bytes(), &[0x00], - &OptionType::TransferSize.as_bytes(), + &OptionType::TransferSize.as_str().as_bytes(), &[0x00], &"12341234".as_bytes(), &[0x00], - &OptionType::BlockSize.as_bytes(), + &OptionType::BlockSize.as_str().as_bytes(), &[0x00], &"1024".as_bytes(), &[0x00], diff --git a/src/server.rs b/src/server.rs index a396c62..d8fc6e5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,9 +2,7 @@ use crate::packet::{ErrorCode, Packet, TransferOption}; use crate::{Config, Message, Worker}; use std::error::Error; use std::net::{SocketAddr, UdpSocket}; -use std::path::{Path, PathBuf}; - -const MAX_REQUEST_PACKET_SIZE: usize = 512; +use std::path::PathBuf; pub struct Server { socket: UdpSocket, @@ -25,63 +23,86 @@ impl Server { pub fn listen(&self) { loop { - let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; - if let Ok((number_of_bytes, from)) = self.socket.recv_from(&mut buf) { - if let Ok(packet) = Packet::deserialize(&buf[..number_of_bytes]) { - self.handle_packet(&packet, &from) - } + if let Ok((packet, from)) = Message::recv_from(&self.socket) { + match packet { + Packet::Rrq { + filename, + mut options, + .. + } => match self.handle_rrq(filename.clone(), &mut options, &from) { + Ok(_) => { + println!("Sending {filename} to {from}"); + } + Err(err) => eprintln!("{err}"), + }, + Packet::Wrq { + filename, + mut options, + .. + } => match self.handle_wrq(filename.clone(), &mut options, &from) { + Ok(_) => { + println!("Receiving {filename} from {from}"); + } + Err(err) => eprintln!("{err}"), + }, + _ => { + Message::send_error_to( + &self.socket, + &from, + ErrorCode::IllegalOperation, + "invalid request", + ) + .unwrap_or_else(|err| eprintln!("{err}")); + } + }; } } } - fn handle_packet(&self, packet: &Packet, from: &SocketAddr) { - match &packet { - Packet::Rrq { - filename, options, .. - } => self.validate_rrq(filename, options, from), - Packet::Wrq { - filename, options, .. - } => self.validate_wrq(filename, options, from), - _ => { - Message::send_error_to( - &self.socket, - from, - ErrorCode::IllegalOperation, - "invalid request", - ); - } - } - } - - fn validate_rrq(&self, filename: &String, options: &Vec, to: &SocketAddr) { - match self.check_file_exists(&Path::new(&filename)) { + fn handle_rrq( + &self, + filename: String, + options: &mut Vec, + to: &SocketAddr, + ) -> Result<(), Box> { + match check_file_exists(&get_full_path(&filename, &self.directory), &self.directory) { ErrorCode::FileNotFound => { - Message::send_error_to( + return Message::send_error_to( &self.socket, to, ErrorCode::FileNotFound, - "requested file does not exist", + "file does not exist", ); } ErrorCode::AccessViolation => { - Message::send_error_to( + return Message::send_error_to( &self.socket, to, ErrorCode::AccessViolation, - "requested file is not in the directory", + "file access violation", ); } - ErrorCode::FileExists => self - .handle_rrq(filename, options, to) - .unwrap_or_else(|err| eprintln!("could not handle read request: {err}")), + ErrorCode::FileExists => Worker::send( + self.socket.local_addr().unwrap(), + *to, + filename, + options.to_vec(), + ), _ => {} } + + Ok(()) } - fn validate_wrq(&self, filename: &String, options: &Vec, to: &SocketAddr) { - match self.check_file_exists(&Path::new(&filename)) { + fn handle_wrq( + &self, + filename: String, + options: &mut Vec, + to: &SocketAddr, + ) -> Result<(), Box> { + match check_file_exists(&get_full_path(&filename, &self.directory), &self.directory) { ErrorCode::FileExists => { - Message::send_error_to( + return Message::send_error_to( &self.socket, to, ErrorCode::FileExists, @@ -89,53 +110,85 @@ impl Server { ); } ErrorCode::AccessViolation => { - Message::send_error_to( + return Message::send_error_to( &self.socket, to, ErrorCode::AccessViolation, - "requested file is not in the directory", + "file access violation", ); } - ErrorCode::FileNotFound => self - .handle_wrq(filename, options, to) - .unwrap_or_else(|err| eprintln!("could not handle write request: {err}")), + ErrorCode::FileNotFound => Worker::receive( + self.socket.local_addr().unwrap(), + *to, + filename, + options.to_vec(), + ), _ => {} }; - } - - fn handle_rrq( - &self, - filename: &String, - options: &Vec, - to: &SocketAddr, - ) -> Result<(), Box> { - let mut worker = Worker::new(&self.socket.local_addr().unwrap(), to)?; - worker.send_file(Path::new(&filename), options)?; Ok(()) } - - fn handle_wrq( - &self, - filename: &String, - options: &Vec, - to: &SocketAddr, - ) -> Result<(), Box> { - let mut worker = Worker::new(&self.socket.local_addr().unwrap(), to)?; - worker.receive_file(Path::new(&filename), options)?; - - Ok(()) - } - - fn check_file_exists(&self, file: &Path) -> ErrorCode { - if !file.ancestors().any(|a| a == &self.directory) { - return ErrorCode::AccessViolation; - } - - if !file.exists() { - return ErrorCode::FileNotFound; - } - - ErrorCode::FileExists - } +} + +fn check_file_exists(file: &PathBuf, directory: &PathBuf) -> ErrorCode { + if !validate_file_path(file, directory) { + return ErrorCode::AccessViolation; + } + + if !file.exists() { + return ErrorCode::FileNotFound; + } + + ErrorCode::FileExists +} + +fn validate_file_path(file: &PathBuf, directory: &PathBuf) -> bool { + !file.to_str().unwrap().contains("..") && file.ancestors().any(|a| a == directory) +} + +fn get_full_path(filename: &str, directory: &PathBuf) -> PathBuf { + let mut file = directory.clone(); + file.push(PathBuf::from(filename)); + file +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn gets_full_path() { + assert_eq!( + get_full_path("test.txt", &PathBuf::from("/dir/test")), + PathBuf::from("/dir/test/test.txt") + ); + + assert_eq!( + get_full_path("some_dir/test.txt", &PathBuf::from("/dir/test")), + PathBuf::from("/dir/test/some_dir/test.txt") + ); + } + + #[test] + fn validates_file_path() { + assert!(validate_file_path( + &PathBuf::from("/dir/test/file"), + &PathBuf::from("/dir/test") + )); + + assert!(!validate_file_path( + &PathBuf::from("/system/data.txt"), + &PathBuf::from("/dir/test") + )); + + assert!(!validate_file_path( + &PathBuf::from("~/some_data.txt"), + &PathBuf::from("/dir/test") + )); + + assert!(!validate_file_path( + &PathBuf::from("/dir/test/../file"), + &PathBuf::from("/dir/test") + )); + } } diff --git a/src/worker.rs b/src/worker.rs index 7e2e00e..6a2aaf5 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,111 +1,242 @@ use std::{ error::Error, fs::File, - io::Read, + io::{Read, Write}, net::{SocketAddr, UdpSocket}, path::Path, + thread, time::Duration, }; use crate::{ - packet::{OptionType, TransferOption}, + packet::{ErrorCode, OptionType, Packet, TransferOption}, Message, }; -pub struct Worker { - socket: UdpSocket, +pub struct Worker; + +pub struct WorkerOptions { blk_size: usize, t_size: usize, - timeout: usize, + timeout: u64, +} + +#[derive(PartialEq, Eq)] +enum WorkType { + Receive, + Send(u64), } const MAX_RETRIES: u32 = 6; +const DEFAULT_TIMEOUT_SECS: u64 = 5; +const DEFAULT_BLOCK_SIZE: usize = 512; impl Worker { - pub fn new(addr: &SocketAddr, remote: &SocketAddr) -> Result> { - let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; - socket.connect(remote)?; - Ok(Worker { - socket, - blk_size: 512, - t_size: 0, - timeout: 5, - }) + pub fn send( + addr: SocketAddr, + remote: SocketAddr, + filename: String, + mut options: Vec, + ) { + thread::spawn(move || { + let mut handle_send = || -> Result<(), Box> { + let socket = setup_socket(&addr, &remote)?; + let work_type = WorkType::Send(Path::new(&filename).metadata().unwrap().len()); + let worker_options = parse_options(&mut options, &work_type)?; + send_oack(&socket, &options, &work_type)?; + send_file(&socket, &worker_options, &filename, &mut options)?; + + Ok(()) + }; + + if let Err(err) = handle_send() { + eprintln!("{err}"); + } + }); } - pub fn send_file( - &mut self, - file: &Path, - options: &Vec, - ) -> Result<(), Box> { - let mut file = File::open(file).unwrap(); + pub fn receive( + addr: SocketAddr, + remote: SocketAddr, + filename: String, + mut options: Vec, + ) { + thread::spawn(move || { + let mut handle_receive = || -> Result<(), Box> { + let socket = setup_socket(&addr, &remote)?; + let work_type = WorkType::Receive; + let worker_options = parse_options(&mut options, &work_type)?; + send_oack(&socket, &options, &work_type)?; + receive_file(&socket, &worker_options, &filename, &mut options)?; - self.parse_options(options, Some(&file)); - Message::send_oack(&self.socket, options)?; + Ok(()) + }; - self.socket - .set_write_timeout(Some(Duration::from_secs(self.timeout as u64)))?; + if let Err(err) = handle_receive() { + eprintln!("{err}"); + } + }); + } +} + +fn send_file( + socket: &UdpSocket, + worker_options: &WorkerOptions, + filename: &String, + options: &mut Vec, +) -> Result<(), Box> { + let mut file = File::open(filename).unwrap(); + + parse_options(options, &WorkType::Send(file.metadata().unwrap().len()))?; + + let mut block_number = 1; + loop { + let mut chunk = vec![0; worker_options.blk_size]; + let size = file.read(&mut chunk)?; let mut retry_cnt = 0; loop { - let mut chunk = Vec::with_capacity(self.blk_size); - let size = file - .by_ref() - .take(self.blk_size as u64) - .read_to_end(&mut chunk)?; + Message::send_data(socket, block_number, &chunk[..size])?; - loop { - if Message::send_data(&self.socket, &chunk).is_err() { - return Err(format!("failed to send data").into()); + match Message::recv(socket) { + Ok(Packet::Ack(received_block_number)) => { + if received_block_number == block_number { + block_number = block_number.wrapping_add(1); + break; + } } - - if let Ok(block) = Message::receive_ack(&self.socket) { - todo!("handle block number"); - } else { + Ok(Packet::Error { code, msg }) => { + return Err(format!("received error code {code}, with message {msg}").into()); + } + _ => { retry_cnt += 1; if retry_cnt == MAX_RETRIES { return Err(format!("transfer timed out after {MAX_RETRIES} tries").into()); } } } - - if size < self.blk_size { - break; - }; } - Ok(()) + if size < worker_options.blk_size { + break; + }; } - pub fn receive_file( - &mut self, - file: &Path, - options: &Vec, - ) -> Result<(), Box> { - let mut file = File::open(file).unwrap(); + println!("Sent {filename} to {}", socket.peer_addr().unwrap()); + Ok(()) +} - self.parse_options(options, Some(&file)); - Message::send_oack(&self.socket, options)?; +fn receive_file( + socket: &UdpSocket, + worker_options: &WorkerOptions, + filename: &String, + options: &mut Vec, +) -> Result<(), Box> { + let mut file = File::create(filename).unwrap(); - todo!("file receiving"); + parse_options(options, &WorkType::Receive)?; - Ok(()) - } + let mut block_number: u16 = 0; + loop { + let size; - fn parse_options(&mut self, options: &Vec, file: Option<&File>) { - for option in options { - let TransferOption { option, value } = option; - - match option { - OptionType::BlockSize => self.blk_size = *value, - OptionType::TransferSize => { - self.t_size = match file { - Some(file) => file.metadata().unwrap().len() as usize, - None => *value, + let mut retry_cnt = 0; + loop { + match Message::recv_data(socket, worker_options.blk_size) { + Ok(Packet::Data { + block_num: received_block_number, + data, + }) => { + if received_block_number == block_number.wrapping_add(1) { + block_number = received_block_number; + file.write(&data)?; + size = data.len(); + break; } } - OptionType::Timeout => self.timeout = *value, + Ok(Packet::Error { code, msg }) => { + return Err(format!("received error code {code}: {msg}").into()); + } + Err(err) => { + retry_cnt += 1; + if retry_cnt == MAX_RETRIES { + return Err( + format!("transfer timed out after {MAX_RETRIES} tries: {err}").into(), + ); + } + } + _ => {} + } + } + + Message::send_ack(socket, block_number)?; + if size < worker_options.blk_size { + break; + }; + } + + println!("Received {filename} from {}", socket.peer_addr().unwrap()); + Ok(()) +} + +fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result> { + let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; + socket.connect(remote)?; + socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; + socket.set_write_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; + Ok(socket) +} + +fn parse_options( + options: &mut Vec, + work_type: &WorkType, +) -> Result> { + let mut worker_options = WorkerOptions { + blk_size: DEFAULT_BLOCK_SIZE, + t_size: 0, + timeout: DEFAULT_TIMEOUT_SECS, + }; + + for option in &mut *options { + let TransferOption { option, value } = option; + + match option { + OptionType::BlockSize => worker_options.blk_size = *value, + OptionType::TransferSize => match work_type { + WorkType::Send(size) => { + *value = *size as usize; + } + WorkType::Receive => { + worker_options.t_size = *value; + } + }, + OptionType::Timeout => { + if *value == 0 { + return Err("invalid timeout value".into()); + } + worker_options.timeout = *value as u64; } } } + + Ok(worker_options) +} + +fn send_oack( + socket: &UdpSocket, + options: &Vec, + work_type: &WorkType, +) -> Result<(), Box> { + if options.len() > 0 { + Message::send_oack(socket, options)?; + if let Packet::Ack(received_block_number) = Message::recv(socket)? { + if received_block_number != 0 { + Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?; + } + } + } else if *work_type == WorkType::Receive { + Message::send_ack(socket, 0)? + } + + Ok(()) }