Add experimental single port support

This commit is contained in:
Altuğ Bakan 2023-05-07 21:04:56 +03:00
parent 2239b7ac9f
commit 8925922ef5
9 changed files with 706 additions and 379 deletions

2
Cargo.lock generated
View file

@ -4,4 +4,4 @@ version = 3
[[package]] [[package]]
name = "tftpd" name = "tftpd"
version = "0.2.2" version = "0.2.4"

View file

@ -1,6 +1,6 @@
[package] [package]
name = "tftpd" name = "tftpd"
version = "0.2.2" version = "0.2.4"
authors = ["Altuğ Bakan <mail@alt.ug>"] authors = ["Altuğ Bakan <mail@alt.ug>"]
edition = "2021" edition = "2021"
description = "Multithreaded TFTP server daemon" description = "Multithreaded TFTP server daemon"

View file

@ -25,6 +25,8 @@ pub struct Config {
pub port: u16, pub port: u16,
/// Default directory of the TFTP Server. (default: current working directory) /// Default directory of the TFTP Server. (default: current working directory)
pub directory: PathBuf, pub directory: PathBuf,
/// Use a single port for both sending and receiving. (default: false)
pub single_port: bool,
} }
impl Config { impl Config {
@ -38,6 +40,7 @@ impl Config {
ip_address: Ipv4Addr::new(127, 0, 0, 1), ip_address: Ipv4Addr::new(127, 0, 0, 1),
port: 69, port: 69,
directory: env::current_dir().unwrap_or_else(|_| env::temp_dir()), directory: env::current_dir().unwrap_or_else(|_| env::temp_dir()),
single_port: false,
}; };
args.next(); args.next();
@ -68,6 +71,9 @@ impl Config {
return Err("Missing directory after flag".into()); return Err("Missing directory after flag".into());
} }
} }
"-s" | "--single-port" => {
config.single_port = true;
}
"-h" | "--help" => { "-h" | "--help" => {
println!("TFTP Server Daemon\n"); println!("TFTP Server Daemon\n");
println!("Usage: tftpd [OPTIONS]\n"); println!("Usage: tftpd [OPTIONS]\n");
@ -77,6 +83,7 @@ impl Config {
" -p, --port <PORT>\t\tSet the listening port of the server (default: 69)" " -p, --port <PORT>\t\tSet the listening port of the server (default: 69)"
); );
println!(" -d, --directory <DIRECTORY>\tSet the listening port of the server (default: Current Working Directory)"); println!(" -d, --directory <DIRECTORY>\tSet the listening port of the server (default: Current Working Directory)");
println!(" -s, --single-port\t\tUse a single port for both sending and receiving (default: false)");
println!(" -h, --help\t\t\tPrint help information"); println!(" -h, --help\t\t\tPrint help information");
process::exit(0); process::exit(0);
} }
@ -97,7 +104,7 @@ mod tests {
#[test] #[test]
fn parses_full_config() { fn parses_full_config() {
let config = Config::new( let config = Config::new(
vec!["/", "-i", "0.0.0.0", "-p", "1234", "-d", "/"] vec!["/", "-i", "0.0.0.0", "-p", "1234", "-d", "/", "-s"]
.iter() .iter()
.map(|s| s.to_string()), .map(|s| s.to_string()),
) )
@ -106,6 +113,7 @@ mod tests {
assert_eq!(config.ip_address, Ipv4Addr::new(0, 0, 0, 0)); assert_eq!(config.ip_address, Ipv4Addr::new(0, 0, 0, 0));
assert_eq!(config.port, 1234); assert_eq!(config.port, 1234);
assert_eq!(config.directory, PathBuf::from_str("/").unwrap()); assert_eq!(config.directory, PathBuf::from_str("/").unwrap());
assert!(config.single_port);
} }
#[test] #[test]

View file

@ -20,6 +20,7 @@ mod convert;
mod message; mod message;
mod packet; mod packet;
mod server; mod server;
mod socket;
mod window; mod window;
mod worker; mod worker;
@ -32,5 +33,7 @@ pub use packet::OptionType;
pub use packet::Packet; pub use packet::Packet;
pub use packet::TransferOption; pub use packet::TransferOption;
pub use server::Server; pub use server::Server;
pub use socket::ServerSocket;
pub use socket::Socket;
pub use window::Window; pub use window::Window;
pub use worker::Worker; pub use worker::Worker;

View file

@ -7,7 +7,7 @@ fn main() {
process::exit(1) process::exit(1)
}); });
let server = Server::new(&config).unwrap_or_else(|err| { let mut server = Server::new(&config).unwrap_or_else(|err| {
eprintln!( eprintln!(
"Problem creating server on {}:{}: {err}", "Problem creating server on {}:{}: {err}",
config.ip_address, config.port config.ip_address, config.port

View file

@ -1,9 +1,6 @@
use std::{ use std::{error::Error, net::SocketAddr};
error::Error,
net::{SocketAddr, UdpSocket},
};
use crate::{ErrorCode, Packet, TransferOption}; use crate::{ErrorCode, Packet, Socket, TransferOption};
/// Message `struct` is used for easy message transmission of common TFTP /// Message `struct` is used for easy message transmission of common TFTP
/// message types. /// message types.
@ -30,41 +27,34 @@ impl Message {
/// Sends a data packet to the socket's connected remote. See /// Sends a data packet to the socket's connected remote. See
/// [`UdpSocket`] for more information about connected /// [`UdpSocket`] for more information about connected
/// sockets. /// sockets.
pub fn send_data( pub fn send_data<T: Socket>(
socket: &UdpSocket, socket: &T,
block_num: u16, block_num: u16,
data: Vec<u8>, data: Vec<u8>,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
socket.send(&Packet::Data { block_num, data }.serialize()?)?; socket.send(&Packet::Data { block_num, data })
Ok(())
} }
/// Sends an acknowledgement packet to the socket's connected remote. See /// Sends an acknowledgement packet to the socket's connected remote. See
/// [`UdpSocket`] for more information about connected /// [`UdpSocket`] for more information about connected
/// sockets. /// sockets.
pub fn send_ack(socket: &UdpSocket, block_number: u16) -> Result<(), Box<dyn Error>> { pub fn send_ack<T: Socket>(socket: &T, block_number: u16) -> Result<(), Box<dyn Error>> {
socket.send(&Packet::Ack(block_number).serialize()?)?; socket.send(&Packet::Ack(block_number))
Ok(())
} }
/// Sends an error packet to the socket's connected remote. See /// Sends an error packet to the socket's connected remote. See
/// [`UdpSocket`] for more information about connected /// [`UdpSocket`] for more information about connected
/// sockets. /// sockets.
pub fn send_error( pub fn send_error<T: Socket>(
socket: &UdpSocket, socket: &T,
code: ErrorCode, code: ErrorCode,
msg: &str, msg: &str,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
if socket if socket
.send( .send(&Packet::Error {
&Packet::Error {
code, code,
msg: msg.to_string(), msg: msg.to_string(),
} })
.serialize()?,
)
.is_err() .is_err()
{ {
eprintln!("could not send an error message"); eprintln!("could not send an error message");
@ -74,8 +64,8 @@ impl Message {
} }
/// Sends an error packet to the supplied [`SocketAddr`]. /// Sends an error packet to the supplied [`SocketAddr`].
pub fn send_error_to( pub fn send_error_to<T: Socket>(
socket: &UdpSocket, socket: &T,
to: &SocketAddr, to: &SocketAddr,
code: ErrorCode, code: ErrorCode,
msg: &str, msg: &str,
@ -85,8 +75,7 @@ impl Message {
&Packet::Error { &Packet::Error {
code, code,
msg: msg.to_string(), msg: msg.to_string(),
} },
.serialize()?,
to, to,
) )
.is_err() .is_err()
@ -98,46 +87,47 @@ impl Message {
/// Sends an option acknowledgement packet to the socket's connected remote. /// Sends an option acknowledgement packet to the socket's connected remote.
/// See [`UdpSocket`] for more information about connected sockets. /// See [`UdpSocket`] for more information about connected sockets.
pub fn send_oack( pub fn send_oack<T: Socket>(
socket: &UdpSocket, socket: &T,
options: Vec<TransferOption>, options: Vec<TransferOption>,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
socket.send(&Packet::Oack(options).serialize()?)?; socket.send(&Packet::Oack(options))
Ok(())
} }
/// Receives a packet from the socket's connected remote, and returns the /// Receives a packet from the socket's connected remote, and returns the
/// parsed [`Packet`]. This function cannot handle large data packets due to /// parsed [`Packet`]. This function cannot handle large data packets due to
/// the limited buffer size. For handling data packets, see [`Message::recv_with_size()`]. /// the limited buffer size. For handling data packets, see [`Message::recv_with_size()`].
pub fn recv(socket: &UdpSocket) -> Result<Packet, Box<dyn Error>> { pub fn recv<T: Socket>(socket: &T) -> Result<Packet, Box<dyn Error>> {
let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; let mut buf = [0; MAX_REQUEST_PACKET_SIZE];
let number_of_bytes = socket.recv(&mut buf)?; socket.recv(&mut buf)
let packet = Packet::deserialize(&buf[..number_of_bytes])?;
Ok(packet)
} }
/// Receives a packet from any incoming remote request, and returns the /// Receives a packet from any incoming remote request, and returns the
/// parsed [`Packet`] and the requesting [`SocketAddr`]. This function cannot handle /// parsed [`Packet`] and the requesting [`SocketAddr`]. This function cannot handle
/// large data packets due to the limited buffer size, so it is intended for /// large data packets due to the limited buffer size, so it is intended for
/// only accepting incoming requests. For handling data packets, see [`Message::recv_with_size()`]. /// only accepting incoming requests. For handling data packets, see [`Message::recv_with_size()`].
pub fn recv_from(socket: &UdpSocket) -> Result<(Packet, SocketAddr), Box<dyn Error>> { pub fn recv_from<T: Socket>(socket: &T) -> Result<(Packet, SocketAddr), Box<dyn Error>> {
let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; let mut buf = [0; MAX_REQUEST_PACKET_SIZE];
let (number_of_bytes, from) = socket.recv_from(&mut buf)?; socket.recv_from(&mut buf)
let packet = Packet::deserialize(&buf[..number_of_bytes])?;
Ok((packet, from))
} }
/// Receives a data packet from the socket's connected remote, and returns the /// Receives a data packet from the socket's connected remote, and returns the
/// parsed [`Packet`]. The received packet can actually be of any type, however, /// parsed [`Packet`]. The received packet can actually be of any type, however,
/// this function also allows supplying the buffer size for an incoming request. /// this function also allows supplying the buffer size for an incoming request.
pub fn recv_with_size(socket: &UdpSocket, size: usize) -> Result<Packet, Box<dyn Error>> { pub fn recv_with_size<T: Socket>(socket: &T, size: usize) -> Result<Packet, Box<dyn Error>> {
let mut buf = vec![0; size + 4]; let mut buf = vec![0; size + 4];
let number_of_bytes = socket.recv(&mut buf)?; socket.recv(&mut buf)
let packet = Packet::deserialize(&buf[..number_of_bytes])?; }
Ok(packet) /// Receives a data packet from any incoming remote request, and returns the
/// parsed [`Packet`] and the requesting [`SocketAddr`]. The received packet can
/// actually be of any type, however, this function also allows supplying the
/// buffer size for an incoming request.
pub fn recv_from_with_size<T: Socket>(
socket: &T,
size: usize,
) -> Result<(Packet, SocketAddr), Box<dyn Error>> {
let mut buf = vec![0; size + 4];
socket.recv_from(&mut buf)
} }
} }

View file

@ -1,8 +1,16 @@
use crate::{Config, Message, Worker}; use crate::{Config, Message, OptionType, ServerSocket, Socket, Worker};
use crate::{ErrorCode, Packet, TransferOption}; use crate::{ErrorCode, Packet, TransferOption};
use std::cmp::max;
use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::net::{SocketAddr, UdpSocket}; use std::net::{SocketAddr, UdpSocket};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::mpsc::Sender;
use std::time::Duration;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_BLOCK_SIZE: usize = 512;
const DEFAULT_WINDOW_SIZE: u16 = 1;
/// Server `struct` is used for handling incoming TFTP requests. /// Server `struct` is used for handling incoming TFTP requests.
/// ///
@ -22,6 +30,9 @@ use std::path::{Path, PathBuf};
pub struct Server { pub struct Server {
socket: UdpSocket, socket: UdpSocket,
directory: PathBuf, directory: PathBuf,
single_port: bool,
largest_block_size: usize,
clients: HashMap<SocketAddr, Sender<Packet>>,
} }
impl Server { impl Server {
@ -32,15 +43,24 @@ impl Server {
let server = Server { let server = Server {
socket, socket,
directory: config.directory.clone(), directory: config.directory.clone(),
single_port: config.single_port,
largest_block_size: DEFAULT_BLOCK_SIZE,
clients: HashMap::new(),
}; };
Ok(server) Ok(server)
} }
/// Starts listening for connections. Note that this function does not finish running until termination. /// Starts listening for connections. Note that this function does not finish running until termination.
pub fn listen(&self) { pub fn listen(&mut self) {
loop { loop {
if let Ok((packet, from)) = Message::recv_from(&self.socket) { let received = if self.single_port {
Message::recv_from_with_size(&self.socket, self.largest_block_size)
} else {
Message::recv_from(&self.socket)
};
if let Ok((packet, from)) = received {
match packet { match packet {
Packet::Rrq { Packet::Rrq {
filename, filename,
@ -63,6 +83,7 @@ impl Server {
} }
} }
_ => { _ => {
if self.route_packet(packet, &from).is_err() {
Message::send_error_to( Message::send_error_to(
&self.socket, &self.socket,
&from, &from,
@ -71,13 +92,14 @@ impl Server {
) )
.unwrap_or_else(|_| eprintln!("Received invalid request")); .unwrap_or_else(|_| eprintln!("Received invalid request"));
} }
}
}; };
} }
} }
} }
fn handle_rrq( fn handle_rrq(
&self, &mut self,
filename: String, filename: String,
options: &mut [TransferOption], options: &mut [TransferOption],
to: &SocketAddr, to: &SocketAddr,
@ -97,25 +119,63 @@ impl Server {
"file access violation", "file access violation",
), ),
ErrorCode::FileExists => { ErrorCode::FileExists => {
Worker::send( let worker_options =
self.socket.local_addr()?, parse_options(options, RequestType::Read(file_path.metadata()?.len()))?;
*to,
file_path.to_path_buf(), if self.single_port {
options.to_vec(), let mut socket = create_single_socket(&self.socket, to)?;
socket.set_read_timeout(worker_options.timeout)?;
socket.set_write_timeout(worker_options.timeout)?;
self.clients.insert(*to, socket.sender());
self.largest_block_size =
max(self.largest_block_size, worker_options.block_size);
accept_request(
&socket,
options,
RequestType::Read(file_path.metadata()?.len()),
)?;
let worker = Worker::new(
socket,
file_path.clone(),
worker_options.block_size,
worker_options.timeout,
worker_options.window_size,
); );
Ok(()) worker.send()
} else {
let socket = create_multi_socket(&self.socket.local_addr()?, to)?;
socket.set_read_timeout(Some(worker_options.timeout))?;
socket.set_write_timeout(Some(worker_options.timeout))?;
accept_request(
&socket,
options,
RequestType::Read(file_path.metadata()?.len()),
)?;
let worker = Worker::new(
socket,
file_path.clone(),
worker_options.block_size,
worker_options.timeout,
worker_options.window_size,
);
worker.send()
} }
_ => Err("unexpected error code when checking file".into()), }
_ => Err("Unexpected error code when checking file".into()),
} }
} }
fn handle_wrq( fn handle_wrq(
&self, &mut self,
filename: String, file_name: String,
options: &mut [TransferOption], options: &mut [TransferOption],
to: &SocketAddr, to: &SocketAddr,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let file_path = &self.directory.join(filename); let file_path = &self.directory.join(file_name);
match check_file_exists(file_path, &self.directory) { match check_file_exists(file_path, &self.directory) {
ErrorCode::FileExists => Message::send_error_to( ErrorCode::FileExists => Message::send_error_to(
&self.socket, &self.socket,
@ -130,19 +190,161 @@ impl Server {
"file access violation", "file access violation",
), ),
ErrorCode::FileNotFound => { ErrorCode::FileNotFound => {
Worker::receive( let worker_options = parse_options(options, RequestType::Write)?;
self.socket.local_addr()?,
*to, if self.single_port {
file_path.to_path_buf(), let mut socket = create_single_socket(&self.socket, to)?;
options.to_vec(), socket.set_read_timeout(worker_options.timeout)?;
socket.set_write_timeout(worker_options.timeout)?;
self.clients.insert(*to, socket.sender());
self.largest_block_size =
max(self.largest_block_size, worker_options.block_size);
accept_request(&socket, options, RequestType::Write)?;
let worker = Worker::new(
socket,
file_path.clone(),
worker_options.block_size,
worker_options.timeout,
worker_options.window_size,
); );
worker.receive()
} else {
let socket = create_multi_socket(&self.socket.local_addr()?, to)?;
socket.set_read_timeout(Some(worker_options.timeout))?;
socket.set_write_timeout(Some(worker_options.timeout))?;
accept_request(&socket, options, RequestType::Write)?;
let worker = Worker::new(
socket,
file_path.clone(),
worker_options.block_size,
worker_options.timeout,
worker_options.window_size,
);
worker.receive()
}
}
_ => Err("Unexpected error code when checking file".into()),
}
}
fn route_packet(&self, packet: Packet, to: &SocketAddr) -> Result<(), Box<dyn Error>> {
if self.clients.contains_key(to) {
self.clients[to].send(packet)?;
Ok(()) Ok(())
} } else {
_ => Err("unexpected error code when checking file".into()), Err("No client found for packet".into())
} }
} }
} }
#[derive(Debug, PartialEq)]
struct WorkerOptions {
block_size: usize,
transfer_size: u64,
timeout: Duration,
window_size: u16,
}
#[derive(Debug, PartialEq)]
enum RequestType {
Read(u64),
Write,
}
fn parse_options(
options: &mut [TransferOption],
request_type: RequestType,
) -> Result<WorkerOptions, &'static str> {
let mut worker_options = WorkerOptions {
block_size: DEFAULT_BLOCK_SIZE,
transfer_size: 0,
timeout: DEFAULT_TIMEOUT,
window_size: DEFAULT_WINDOW_SIZE,
};
for option in options {
let TransferOption {
option: option_type,
value,
} = option;
match option_type {
OptionType::BlockSize => worker_options.block_size = *value,
OptionType::TransferSize => match request_type {
RequestType::Read(size) => {
*value = size as usize;
worker_options.transfer_size = size;
}
RequestType::Write => worker_options.transfer_size = *value as u64,
},
OptionType::Timeout => {
if *value == 0 {
return Err("Invalid timeout value");
}
worker_options.timeout = Duration::from_secs(*value as u64);
}
OptionType::Windowsize => {
if *value == 0 || *value > u16::MAX as usize {
return Err("Invalid windowsize value");
}
worker_options.window_size = *value as u16;
}
}
}
Ok(worker_options)
}
fn create_single_socket(
socket: &UdpSocket,
remote: &SocketAddr,
) -> Result<ServerSocket, Box<dyn Error>> {
let socket = ServerSocket::new(socket.try_clone()?, *remote);
Ok(socket)
}
fn create_multi_socket(
addr: &SocketAddr,
remote: &SocketAddr,
) -> Result<UdpSocket, Box<dyn Error>> {
let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?;
socket.connect(remote)?;
Ok(socket)
}
fn accept_request<T: Socket>(
socket: &T,
options: &[TransferOption],
request_type: RequestType,
) -> Result<(), Box<dyn Error>> {
if !options.is_empty() {
Message::send_oack(socket, options.to_vec())?;
if let RequestType::Read(_) = request_type {
check_response(socket)?;
}
} else if request_type == RequestType::Write {
Message::send_ack(socket, 0)?
}
Ok(())
}
fn check_response<T: Socket>(socket: &T) -> Result<(), Box<dyn Error>> {
if let Packet::Ack(received_block_number) = Message::recv(socket)? {
if received_block_number != 0 {
Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?;
}
}
Ok(())
}
fn check_file_exists(file: &Path, directory: &PathBuf) -> ErrorCode { fn check_file_exists(file: &Path, directory: &PathBuf) -> ErrorCode {
if !validate_file_path(file, directory) { if !validate_file_path(file, directory) {
return ErrorCode::AccessViolation; return ErrorCode::AccessViolation;
@ -185,4 +387,69 @@ mod tests {
&PathBuf::from("/dir/test") &PathBuf::from("/dir/test")
)); ));
} }
#[test]
fn parses_write_options() {
let mut options = vec![
TransferOption {
option: OptionType::BlockSize,
value: 1024,
},
TransferOption {
option: OptionType::TransferSize,
value: 0,
},
TransferOption {
option: OptionType::Timeout,
value: 5,
},
];
let work_type = RequestType::Read(12341234);
let worker_options = parse_options(&mut options, work_type).unwrap();
assert_eq!(options[0].value, worker_options.block_size);
assert_eq!(options[1].value, worker_options.transfer_size as usize);
assert_eq!(options[2].value as u64, worker_options.timeout.as_secs());
}
#[test]
fn parses_read_options() {
let mut options = vec![
TransferOption {
option: OptionType::BlockSize,
value: 1024,
},
TransferOption {
option: OptionType::TransferSize,
value: 44554455,
},
TransferOption {
option: OptionType::Timeout,
value: 5,
},
];
let work_type = RequestType::Write;
let worker_options = parse_options(&mut options, work_type).unwrap();
assert_eq!(options[0].value, worker_options.block_size);
assert_eq!(options[1].value, worker_options.transfer_size as usize);
assert_eq!(options[2].value as u64, worker_options.timeout.as_secs());
}
#[test]
fn parses_default_options() {
assert_eq!(
parse_options(&mut [], RequestType::Write).unwrap(),
WorkerOptions {
block_size: DEFAULT_BLOCK_SIZE,
transfer_size: 0,
timeout: DEFAULT_TIMEOUT,
window_size: DEFAULT_WINDOW_SIZE,
}
);
}
} }

207
src/socket.rs Normal file
View file

@ -0,0 +1,207 @@
use std::{
error::Error,
net::{SocketAddr, UdpSocket},
sync::{
mpsc::{self, Receiver, Sender},
Mutex,
},
time::Duration,
};
use crate::Packet;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
/// Socket `trait` is used for easy message transmission of common TFTP
/// message types. This `trait` is implemented for [`UdpSocket`] and used
/// for abstraction of single socket communication.
pub trait Socket: Send + Sync + 'static {
/// Sends a [`Packet`] to the socket's connected remote [`Socket`].
fn send(&self, packet: &Packet) -> Result<(), Box<dyn Error>>;
/// Sends a [`Packet`] to the specified remote [`Socket`].
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box<dyn Error>>;
/// Receives a [`Packet`] from the socket's connected remote [`Socket`].
fn recv(&self, buf: &mut [u8]) -> Result<Packet, Box<dyn Error>>;
/// Receives a [`Packet`] from any remote [`Socket`] and returns the [`SocketAddr`]
/// of the remote [`Socket`].
fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box<dyn Error>>;
/// Returns the remote [`SocketAddr`] if it exists.
fn remote_addr(&self) -> Result<SocketAddr, Box<dyn Error>>;
/// Sets the read timeout for the [`Socket`].
fn set_read_timeout(&mut self, dur: Duration) -> Result<(), Box<dyn Error>>;
/// Sets the write timeout for the [`Socket`].
fn set_write_timeout(&mut self, dur: Duration) -> Result<(), Box<dyn Error>>;
}
impl Socket for UdpSocket {
fn send(&self, packet: &Packet) -> Result<(), Box<dyn Error>> {
self.send(&packet.serialize()?)?;
Ok(())
}
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box<dyn Error>> {
self.send_to(&packet.serialize()?, to)?;
Ok(())
}
fn recv(&self, buf: &mut [u8]) -> Result<Packet, Box<dyn Error>> {
let amt = self.recv(buf)?;
let packet = Packet::deserialize(&buf[..amt])?;
Ok(packet)
}
fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box<dyn Error>> {
let (amt, addr) = self.recv_from(buf)?;
let packet = Packet::deserialize(&buf[..amt])?;
Ok((packet, addr))
}
fn remote_addr(&self) -> Result<SocketAddr, Box<dyn Error>> {
Ok(self.peer_addr()?)
}
fn set_read_timeout(&mut self, dur: Duration) -> Result<(), Box<dyn Error>> {
UdpSocket::set_read_timeout(self, Some(dur))?;
Ok(())
}
fn set_write_timeout(&mut self, dur: Duration) -> Result<(), Box<dyn Error>> {
UdpSocket::set_write_timeout(self, Some(dur))?;
Ok(())
}
}
/// ServerSocket `struct` is used as an abstraction layer for a server
/// [`Socket`]. This `struct` is used for abstraction of single socket
/// communication.
///
/// # Example
///
/// ```rust
/// use std::net::{SocketAddr, UdpSocket};
/// use std::str::FromStr;
/// use tftpd::{Socket, ServerSocket, Packet};
///
/// let socket = ServerSocket::new(
/// UdpSocket::bind("127.0.0.1:0").unwrap(),
/// SocketAddr::from_str("127.0.0.1:50000").unwrap(),
/// );
/// socket.send(&Packet::Ack(1)).unwrap();
/// ```
pub struct ServerSocket {
socket: UdpSocket,
remote: SocketAddr,
sender: Mutex<Sender<Packet>>,
receiver: Mutex<Receiver<Packet>>,
timeout: Duration,
}
impl Socket for ServerSocket {
fn send(&self, packet: &Packet) -> Result<(), Box<dyn Error>> {
self.send_to(packet, &self.remote)
}
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box<dyn Error>> {
self.socket.send_to(&packet.serialize()?, to)?;
Ok(())
}
fn recv(&self, _buf: &mut [u8]) -> Result<Packet, Box<dyn Error>> {
if let Ok(receiver) = self.receiver.lock() {
if let Ok(packet) = receiver.recv_timeout(self.timeout) {
Ok(packet)
} else {
Err("Failed to receive".into())
}
} else {
Err("Failed to lock mutex".into())
}
}
fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box<dyn Error>> {
Ok((self.recv(buf)?, self.remote))
}
fn remote_addr(&self) -> Result<SocketAddr, Box<dyn Error>> {
Ok(self.remote)
}
fn set_read_timeout(&mut self, dur: Duration) -> Result<(), Box<dyn Error>> {
self.timeout = dur;
Ok(())
}
fn set_write_timeout(&mut self, dur: Duration) -> Result<(), Box<dyn Error>> {
self.socket.set_write_timeout(Some(dur))?;
Ok(())
}
}
impl ServerSocket {
/// Creates a new [`ServerSocket`] from a [`UdpSocket`] and a remote [`SocketAddr`].
pub fn new(socket: UdpSocket, remote: SocketAddr) -> Self {
let (sender, receiver) = mpsc::channel();
Self {
socket,
remote,
sender: Mutex::new(sender),
receiver: Mutex::new(receiver),
timeout: DEFAULT_TIMEOUT,
}
}
/// Returns a [`Sender`] for sending [`Packet`]s to the remote [`Socket`].
pub fn sender(&self) -> Sender<Packet> {
self.sender.lock().unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_recv() {
let socket = ServerSocket::new(
UdpSocket::bind("127.0.0.1:0").unwrap(),
SocketAddr::from_str("127.0.0.1:50000").unwrap(),
);
socket.sender.lock().unwrap().send(Packet::Ack(1)).unwrap();
let packet = socket.recv(&mut []).unwrap();
assert_eq!(packet, Packet::Ack(1));
socket
.sender
.lock()
.unwrap()
.send(Packet::Data {
block_num: 15,
data: vec![0x01, 0x02, 0x03],
})
.unwrap();
let packet = socket.recv(&mut []).unwrap();
assert_eq!(
packet,
Packet::Data {
block_num: 15,
data: vec![0x01, 0x02, 0x03]
}
);
}
}

View file

@ -1,13 +1,15 @@
use std::{ use std::{
error::Error, error::Error,
fs::{self, File}, fs::{self, File},
net::{SocketAddr, UdpSocket},
path::PathBuf, path::PathBuf,
thread, thread,
time::{Duration, SystemTime}, time::{Duration, Instant},
}; };
use crate::{ErrorCode, Message, OptionType, Packet, TransferOption, Window}; use crate::{Message, Packet, Socket, Window};
const MAX_RETRIES: u32 = 6;
const TIMEOUT_BUFFER: Duration = Duration::from_secs(1);
/// Worker `struct` is used for multithreaded file sending and receiving. /// Worker `struct` is used for multithreaded file sending and receiving.
/// It creates a new socket using the Server's IP and a random port /// It creates a new socket using the Server's IP and a random port
@ -18,55 +20,64 @@ use crate::{ErrorCode, Message, OptionType, Packet, TransferOption, Window};
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use std::{net::SocketAddr, path::PathBuf, str::FromStr}; /// use std::{net::{UdpSocket, SocketAddr}, path::PathBuf, str::FromStr, time::Duration};
/// use tftpd::Worker; /// use tftpd::Worker;
/// ///
/// // Send a file, responding to a read request. /// // Send a file, responding to a read request.
/// Worker::send( /// let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
/// SocketAddr::from_str("127.0.0.1:1234").unwrap(), /// socket.connect(SocketAddr::from_str("127.0.0.1:12345").unwrap()).unwrap();
/// SocketAddr::from_str("127.0.0.1:4321").unwrap(), ///
/// PathBuf::from_str("/home/rust/test.txt").unwrap(), /// let worker = Worker::new(
/// vec![] /// socket,
/// PathBuf::from_str("Cargo.toml").unwrap(),
/// 512,
/// Duration::from_secs(1),
/// 1,
/// ); /// );
///
/// worker.send().unwrap();
/// ``` /// ```
pub struct Worker; pub struct Worker<T>
where
#[derive(Debug, PartialEq, Eq)] T: Socket,
struct WorkerOptions { {
socket: T,
file_name: PathBuf,
blk_size: usize, blk_size: usize,
t_size: usize, timeout: Duration,
timeout: u64,
windowsize: u16, windowsize: u16,
} }
#[derive(PartialEq, Eq)] impl<T> Worker<T>
enum WorkType { where
Receive, T: Socket,
Send(u64), {
} /// Creates a new [`Worker`] with the supplied options.
pub fn new(
socket: T,
file_name: PathBuf,
blk_size: usize,
timeout: Duration,
windowsize: u16,
) -> Worker<T> {
Worker {
socket,
file_name,
blk_size,
timeout,
windowsize,
}
}
const MAX_RETRIES: u32 = 6;
const DEFAULT_TIMEOUT_SECS: u64 = 5;
const TIMEOUT_BUFFER_SECS: u64 = 1;
const DEFAULT_BLOCK_SIZE: usize = 512;
impl Worker {
/// Sends a file to the remote [`SocketAddr`] that has sent a read request using /// Sends a file to the remote [`SocketAddr`] that has sent a read request using
/// a random port, asynchronously. /// a random port, asynchronously.
pub fn send( pub fn send(self) -> Result<(), Box<dyn Error>> {
addr: SocketAddr, let file_name = self.file_name.clone();
remote: SocketAddr, let remote_addr = self.socket.remote_addr().unwrap();
file_path: PathBuf,
mut options: Vec<TransferOption>,
) {
thread::spawn(move || {
let mut handle_send = || -> Result<(), Box<dyn Error>> {
let socket = setup_socket(&addr, &remote)?;
let work_type = WorkType::Send(file_path.metadata()?.len());
let worker_options = parse_options(&mut options, &work_type)?;
accept_request(&socket, &options, &work_type)?; thread::spawn(move || {
send_file(&socket, File::open(&file_path)?, &worker_options)?; let handle_send = || -> Result<(), Box<dyn Error>> {
self.send_file(File::open(&file_name)?)?;
Ok(()) Ok(())
}; };
@ -75,8 +86,8 @@ impl Worker {
Ok(_) => { Ok(_) => {
println!( println!(
"Sent {} to {}", "Sent {} to {}",
file_path.file_name().unwrap().to_str().unwrap(), &file_name.file_name().unwrap().to_string_lossy(),
remote &remote_addr
); );
} }
Err(err) => { Err(err) => {
@ -84,24 +95,19 @@ impl Worker {
} }
} }
}); });
Ok(())
} }
/// Receives a file from the remote [`SocketAddr`] that has sent a write request using /// Receives a file from the remote [`SocketAddr`] that has sent a write request using
/// a random port, asynchronously. /// the supplied socket, asynchronously.
pub fn receive( pub fn receive(self) -> Result<(), Box<dyn Error>> {
addr: SocketAddr, let file_name = self.file_name.clone();
remote: SocketAddr, let remote_addr = self.socket.remote_addr().unwrap();
file_path: PathBuf,
mut options: Vec<TransferOption>,
) {
thread::spawn(move || {
let mut handle_receive = || -> Result<(), Box<dyn Error>> {
let socket = setup_socket(&addr, &remote)?;
let work_type = WorkType::Receive;
let worker_options = parse_options(&mut options, &work_type)?;
accept_request(&socket, &options, &work_type)?; thread::spawn(move || {
receive_file(&socket, File::create(&file_path)?, &worker_options)?; let handle_receive = || -> Result<(), Box<dyn Error>> {
self.receive_file(File::create(&file_name)?)?;
Ok(()) Ok(())
}; };
@ -110,48 +116,41 @@ impl Worker {
Ok(_) => { Ok(_) => {
println!( println!(
"Received {} from {}", "Received {} from {}",
file_path.file_name().unwrap().to_str().unwrap(), &file_name.file_name().unwrap().to_string_lossy(),
remote remote_addr
); );
} }
Err(err) => { Err(err) => {
eprintln!("{err}"); eprintln!("{err}");
if fs::remove_file(&file_path).is_err() { if fs::remove_file(&file_name).is_err() {
eprintln!( eprintln!("Error while cleaning {}", &file_name.to_str().unwrap());
"Error while cleaning {}",
file_path.file_name().unwrap().to_str().unwrap()
);
} }
} }
} }
}); });
}
}
fn send_file( Ok(())
socket: &UdpSocket, }
file: File,
worker_options: &WorkerOptions, fn send_file(self, file: File) -> Result<(), Box<dyn Error>> {
) -> Result<(), Box<dyn Error>> {
let mut block_number = 1; let mut block_number = 1;
let mut window = Window::new(worker_options.windowsize, worker_options.blk_size, file); let mut window = Window::new(self.windowsize, self.blk_size, file);
loop { loop {
let filled = window.fill()?; let filled = window.fill()?;
let mut retry_cnt = 0; let mut retry_cnt = 0;
let mut time = let mut time = Instant::now() - (self.timeout + TIMEOUT_BUFFER);
SystemTime::now() - Duration::from_secs(DEFAULT_TIMEOUT_SECS + TIMEOUT_BUFFER_SECS);
loop { loop {
if time.elapsed()? >= Duration::from_secs(DEFAULT_TIMEOUT_SECS) { if time.elapsed() >= self.timeout {
send_window(socket, &window, block_number)?; send_window(&self.socket, &window, block_number)?;
time = SystemTime::now(); time = Instant::now();
} }
match Message::recv(socket) { match Message::recv(&self.socket) {
Ok(Packet::Ack(received_block_number)) => { Ok(Packet::Ack(received_block_number)) => {
let diff = received_block_number.wrapping_sub(block_number); let diff = received_block_number.wrapping_sub(block_number);
if diff <= worker_options.windowsize { if diff <= self.windowsize {
block_number = received_block_number.wrapping_add(1); block_number = received_block_number.wrapping_add(1);
window.remove(diff + 1)?; window.remove(diff + 1)?;
break; break;
@ -163,7 +162,9 @@ fn send_file(
_ => { _ => {
retry_cnt += 1; retry_cnt += 1;
if retry_cnt == MAX_RETRIES { if retry_cnt == MAX_RETRIES {
return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into()); return Err(
format!("Transfer timed out after {MAX_RETRIES} tries").into()
);
} }
} }
} }
@ -175,22 +176,18 @@ fn send_file(
} }
Ok(()) Ok(())
} }
fn receive_file( fn receive_file(self, file: File) -> Result<(), Box<dyn Error>> {
socket: &UdpSocket,
file: File,
worker_options: &WorkerOptions,
) -> Result<(), Box<dyn Error>> {
let mut block_number: u16 = 0; let mut block_number: u16 = 0;
let mut window = Window::new(worker_options.windowsize, worker_options.blk_size, file); let mut window = Window::new(self.windowsize, self.blk_size, file);
loop { loop {
let mut size; let mut size;
let mut retry_cnt = 0; let mut retry_cnt = 0;
loop { loop {
match Message::recv_with_size(socket, worker_options.blk_size) { match Message::recv_with_size(&self.socket, self.blk_size) {
Ok(Packet::Data { Ok(Packet::Data {
block_num: received_block_number, block_num: received_block_number,
data, data,
@ -200,7 +197,7 @@ fn receive_file(
size = data.len(); size = data.len();
window.add(data)?; window.add(data)?;
if size < worker_options.blk_size { if size < self.blk_size {
break; break;
} }
@ -215,24 +212,27 @@ fn receive_file(
_ => { _ => {
retry_cnt += 1; retry_cnt += 1;
if retry_cnt == MAX_RETRIES { if retry_cnt == MAX_RETRIES {
return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into()); return Err(
format!("Transfer timed out after {MAX_RETRIES} tries").into()
);
} }
} }
} }
} }
window.empty()?; window.empty()?;
Message::send_ack(socket, block_number)?; Message::send_ack(&self.socket, block_number)?;
if size < worker_options.blk_size { if size < self.blk_size {
break; break;
}; };
} }
Ok(()) Ok(())
}
} }
fn send_window( fn send_window<T: Socket>(
socket: &UdpSocket, socket: &T,
window: &Window, window: &Window,
mut block_num: u16, mut block_num: u16,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
@ -243,151 +243,3 @@ fn send_window(
Ok(()) Ok(())
} }
fn accept_request(
socket: &UdpSocket,
options: &Vec<TransferOption>,
work_type: &WorkType,
) -> Result<(), Box<dyn Error>> {
if !options.is_empty() {
Message::send_oack(socket, options.to_vec())?;
if let WorkType::Send(_) = work_type {
check_response(socket)?;
}
} else if *work_type == WorkType::Receive {
Message::send_ack(socket, 0)?
}
Ok(())
}
fn check_response(socket: &UdpSocket) -> Result<(), Box<dyn Error>> {
if let Packet::Ack(received_block_number) = Message::recv(socket)? {
if received_block_number != 0 {
Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?;
}
}
Ok(())
}
fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result<UdpSocket, Box<dyn Error>> {
let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?;
socket.connect(remote)?;
socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?;
socket.set_write_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?;
Ok(socket)
}
fn parse_options(
options: &mut Vec<TransferOption>,
work_type: &WorkType,
) -> Result<WorkerOptions, Box<dyn Error>> {
let mut worker_options = WorkerOptions {
blk_size: DEFAULT_BLOCK_SIZE,
t_size: 0,
timeout: DEFAULT_TIMEOUT_SECS,
windowsize: 1,
};
for option in &mut *options {
let TransferOption { option, value } = option;
match option {
OptionType::BlockSize => worker_options.blk_size = *value,
OptionType::TransferSize => match work_type {
WorkType::Send(size) => {
*value = *size as usize;
worker_options.t_size = *size as usize;
}
WorkType::Receive => {
worker_options.t_size = *value;
}
},
OptionType::Timeout => {
if *value == 0 {
return Err("Invalid timeout value".into());
}
worker_options.timeout = *value as u64;
}
OptionType::Windowsize => {
if *value == 0 || *value > u16::MAX as usize {
return Err("Invalid windowsize value".into());
}
worker_options.windowsize = *value as u16;
}
}
}
Ok(worker_options)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_send_options() {
let mut options = vec![
TransferOption {
option: OptionType::BlockSize,
value: 1024,
},
TransferOption {
option: OptionType::TransferSize,
value: 0,
},
TransferOption {
option: OptionType::Timeout,
value: 5,
},
];
let work_type = WorkType::Send(12341234);
let worker_options = parse_options(&mut options, &work_type).unwrap();
assert_eq!(options[0].value, worker_options.blk_size);
assert_eq!(options[1].value, worker_options.t_size);
assert_eq!(options[2].value as u64, worker_options.timeout);
}
#[test]
fn parses_receive_options() {
let mut options = vec![
TransferOption {
option: OptionType::BlockSize,
value: 1024,
},
TransferOption {
option: OptionType::TransferSize,
value: 44554455,
},
TransferOption {
option: OptionType::Timeout,
value: 5,
},
];
let work_type = WorkType::Receive;
let worker_options = parse_options(&mut options, &work_type).unwrap();
assert_eq!(options[0].value, worker_options.blk_size);
assert_eq!(options[1].value, worker_options.t_size);
assert_eq!(options[2].value as u64, worker_options.timeout);
}
#[test]
fn parses_default_options() {
assert_eq!(
parse_options(&mut vec![], &WorkType::Receive).unwrap(),
WorkerOptions {
blk_size: DEFAULT_BLOCK_SIZE,
t_size: 0,
timeout: DEFAULT_TIMEOUT_SECS,
windowsize: 1,
}
);
}
}