use std::{ error::Error, fs::File, io::{Read, Write}, net::{SocketAddr, UdpSocket}, path::{Path, PathBuf}, thread, time::Duration, }; use crate::{ packet::{ErrorCode, OptionType, Packet, TransferOption}, Message, }; pub struct Worker; #[derive(Debug, PartialEq, Eq)] pub struct WorkerOptions { blk_size: usize, t_size: usize, timeout: u64, } #[derive(PartialEq, Eq)] enum WorkType { Receive, Send(u64), } const MAX_RETRIES: u32 = 6; const DEFAULT_TIMEOUT_SECS: u64 = 5; const DEFAULT_BLOCK_SIZE: usize = 512; impl Worker { pub fn send( addr: SocketAddr, remote: SocketAddr, file_path: PathBuf, mut options: Vec, ) { thread::spawn(move || { let mut handle_send = || -> Result<(), Box> { let socket = setup_socket(&addr, &remote)?; let work_type = WorkType::Send(Path::new(&file_path).metadata().unwrap().len()); accept_request(&socket, &options, &work_type)?; check_response(&socket)?; send_file(&socket, &file_path, &mut options)?; Ok(()) }; if let Err(err) = handle_send() { eprintln!("{err}"); } }); } pub fn receive( addr: SocketAddr, remote: SocketAddr, file_path: PathBuf, mut options: Vec, ) { thread::spawn(move || { let mut handle_receive = || -> Result<(), Box> { let socket = setup_socket(&addr, &remote)?; let work_type = WorkType::Receive; accept_request(&socket, &options, &work_type)?; receive_file(&socket, &file_path, &mut options)?; Ok(()) }; if let Err(err) = handle_receive() { eprintln!("{err}"); } }); } } fn send_file( socket: &UdpSocket, file_path: &PathBuf, options: &mut Vec, ) -> Result<(), Box> { let mut file = File::open(file_path).unwrap(); let worker_options = parse_options(options, &WorkType::Send(file.metadata().unwrap().len()))?; let mut block_number = 1; loop { let mut chunk = vec![0; worker_options.blk_size]; let size = file.read(&mut chunk)?; let mut retry_cnt = 0; loop { Message::send_data(socket, block_number, chunk[..size].to_vec())?; match Message::recv(socket) { Ok(Packet::Ack(received_block_number)) => { if received_block_number == block_number { block_number = block_number.wrapping_add(1); break; } } Ok(Packet::Error { code, msg }) => { return Err(format!("Received error code {code}, with message {msg}").into()); } _ => { retry_cnt += 1; if retry_cnt == MAX_RETRIES { return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into()); } } } } if size < worker_options.blk_size { break; }; } println!( "Sent {} to {}", file_path.display(), socket.peer_addr().unwrap() ); Ok(()) } fn receive_file( socket: &UdpSocket, file_path: &PathBuf, options: &mut Vec, ) -> Result<(), Box> { let mut file = File::create(file_path).unwrap(); let worker_options = parse_options(options, &WorkType::Receive)?; let mut block_number: u16 = 0; loop { let size; let mut retry_cnt = 0; loop { match Message::recv_data(socket, worker_options.blk_size) { Ok(Packet::Data { block_num: received_block_number, data, }) => { if received_block_number == block_number.wrapping_add(1) { block_number = received_block_number; file.write(&data)?; size = data.len(); break; } } Ok(Packet::Error { code, msg }) => { return Err(format!("Received error code {code}: {msg}").into()); } Err(err) => { retry_cnt += 1; if retry_cnt == MAX_RETRIES { return Err( format!("Transfer timed out after {MAX_RETRIES} tries: {err}").into(), ); } } _ => {} } } Message::send_ack(socket, block_number)?; if size < worker_options.blk_size { break; }; } println!( "Received {} from {}", file_path.display(), socket.peer_addr().unwrap() ); Ok(()) } fn accept_request( socket: &UdpSocket, options: &Vec, work_type: &WorkType, ) -> Result<(), Box> { if options.len() > 0 { Message::send_oack(socket, options.to_vec())?; } else if *work_type == WorkType::Receive { Message::send_ack(socket, 0)? } Ok(()) } fn check_response(socket: &UdpSocket) -> Result<(), Box> { if let Packet::Ack(received_block_number) = Message::recv(&socket)? { if received_block_number != 0 { Message::send_error( &socket, ErrorCode::IllegalOperation, "invalid oack response".to_string(), )?; } } Ok(()) } fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result> { let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; socket.connect(remote)?; socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; socket.set_write_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; Ok(socket) } fn parse_options( options: &mut Vec, work_type: &WorkType, ) -> Result> { let mut worker_options = WorkerOptions { blk_size: DEFAULT_BLOCK_SIZE, t_size: 0, timeout: DEFAULT_TIMEOUT_SECS, }; for option in &mut *options { let TransferOption { option, value } = option; match option { OptionType::BlockSize => worker_options.blk_size = *value, OptionType::TransferSize => match work_type { WorkType::Send(size) => { *value = *size as usize; worker_options.t_size = *size as usize; } WorkType::Receive => { worker_options.t_size = *value; } }, OptionType::Timeout => { if *value == 0 { return Err("Invalid timeout value".into()); } worker_options.timeout = *value as u64; } } } Ok(worker_options) } #[cfg(test)] mod tests { use super::*; #[test] fn parses_send_options() { let mut options = vec![ TransferOption { option: OptionType::BlockSize, value: 1024, }, TransferOption { option: OptionType::TransferSize, value: 0, }, TransferOption { option: OptionType::Timeout, value: 5, }, ]; let work_type = WorkType::Send(12341234); let worker_options = parse_options(&mut options, &work_type).unwrap(); assert_eq!(options[0].value, worker_options.blk_size); assert_eq!(options[1].value, worker_options.t_size); assert_eq!(options[2].value as u64, worker_options.timeout); } #[test] fn parses_receive_options() { let mut options = vec![ TransferOption { option: OptionType::BlockSize, value: 1024, }, TransferOption { option: OptionType::TransferSize, value: 44554455, }, TransferOption { option: OptionType::Timeout, value: 5, }, ]; let work_type = WorkType::Receive; let worker_options = parse_options(&mut options, &work_type).unwrap(); assert_eq!(options[0].value, worker_options.blk_size); assert_eq!(options[1].value, worker_options.t_size); assert_eq!(options[2].value as u64, worker_options.timeout); } #[test] fn parses_default_options() { assert_eq!( parse_options(&mut vec![], &WorkType::Receive).unwrap(), WorkerOptions { blk_size: DEFAULT_BLOCK_SIZE, t_size: 0, timeout: DEFAULT_TIMEOUT_SECS, } ); } }