From f25a9f596d7d6ed07262291fbc69ce5eaa1bfd37 Mon Sep 17 00:00:00 2001 From: GnoCiYeH Date: Mon, 15 Apr 2024 22:01:32 +0800 Subject: [PATCH] =?UTF-8?q?socket=E7=BB=9F=E4=B8=80=E6=94=B9=E7=94=A8`Glob?= =?UTF-8?q?alSocketHandle`,=E5=B9=B6=E4=B8=94=E4=BF=AE=E5=A4=8Dfcntl=20SET?= =?UTF-8?q?FD=E7=9A=84=E9=94=99=E8=AF=AF=20(#730)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * socket统一改用`GlobalSocketHandle`,并且修复fcntl SETFD的错误 --------- Co-authored-by: longjin --- kernel/src/arch/x86_64/syscall/mod.rs | 2 +- kernel/src/filesystem/vfs/syscall.rs | 38 +++- kernel/src/net/net_core.rs | 11 +- kernel/src/net/socket/handle.rs | 39 +++++ kernel/src/net/socket/inet.rs | 238 +++++++++++++++----------- kernel/src/net/socket/mod.rs | 50 ++---- kernel/src/net/socket/unix.rs | 20 ++- kernel/src/net/syscall.rs | 9 +- 8 files changed, 253 insertions(+), 154 deletions(-) create mode 100644 kernel/src/net/socket/handle.rs diff --git a/kernel/src/arch/x86_64/syscall/mod.rs b/kernel/src/arch/x86_64/syscall/mod.rs index fb9fb05f5..599779aec 100644 --- a/kernel/src/arch/x86_64/syscall/mod.rs +++ b/kernel/src/arch/x86_64/syscall/mod.rs @@ -88,7 +88,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) { mfence(); let pid = ProcessManager::current_pcb().pid(); let show = false; - // let show = if syscall_num != SYS_SCHED && pid.data() > 3 { + // let show = if syscall_num != SYS_SCHED && pid.data() >= 7 { // true // } else { // false diff --git a/kernel/src/filesystem/vfs/syscall.rs b/kernel/src/filesystem/vfs/syscall.rs index c7a24033f..bd885b848 100644 --- a/kernel/src/filesystem/vfs/syscall.rs +++ b/kernel/src/filesystem/vfs/syscall.rs @@ -1024,6 +1024,15 @@ impl Syscall { oldfd: i32, newfd: i32, fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>, + ) -> Result { + Self::do_dup3(oldfd, newfd, FileMode::empty(), fd_table_guard) + } + + fn do_dup3( + oldfd: i32, + newfd: i32, + flags: FileMode, + fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>, ) -> Result { // 确认oldfd, newid是否有效 if !(FileDescriptorVec::validate_fd(oldfd) && FileDescriptorVec::validate_fd(newfd)) { @@ -1047,8 +1056,12 @@ impl Syscall { .get_file_by_fd(oldfd) .ok_or(SystemError::EBADF)?; let new_file = old_file.try_clone().ok_or(SystemError::EBADF)?; - // dup2默认非cloexec - new_file.set_close_on_exec(false); + + if flags.contains(FileMode::O_CLOEXEC) { + new_file.set_close_on_exec(true); + } else { + new_file.set_close_on_exec(false); + } // 申请文件描述符,并把文件对象存入其中 let res = fd_table_guard .alloc_fd(new_file, Some(newfd)) @@ -1064,8 +1077,9 @@ impl Syscall { /// - `cmd`:命令 /// - `arg`:参数 pub fn fcntl(fd: i32, cmd: FcntlCommand, arg: i32) -> Result { + // kdebug!("fcntl ({cmd:?}) fd: {fd}, arg={arg}"); match cmd { - FcntlCommand::DupFd => { + FcntlCommand::DupFd | FcntlCommand::DupFdCloexec => { if arg < 0 || arg as usize >= FileDescriptorVec::PROCESS_MAX_FD { return Err(SystemError::EBADF); } @@ -1074,7 +1088,16 @@ impl Syscall { let binding = ProcessManager::current_pcb().fd_table(); let mut fd_table_guard = binding.write(); if fd_table_guard.get_file_by_fd(i as i32).is_none() { - return Self::do_dup2(fd, i as i32, &mut fd_table_guard); + if cmd == FcntlCommand::DupFd { + return Self::do_dup2(fd, i as i32, &mut fd_table_guard); + } else { + return Self::do_dup3( + fd, + i as i32, + FileMode::O_CLOEXEC, + &mut fd_table_guard, + ); + } } } return Err(SystemError::EMFILE); @@ -1083,12 +1106,15 @@ impl Syscall { // Get file descriptor flags. let binding = ProcessManager::current_pcb().fd_table(); let fd_table_guard = binding.read(); + if let Some(file) = fd_table_guard.get_file_by_fd(fd) { // drop guard 以避免无法调度的问题 drop(fd_table_guard); if file.close_on_exec() { return Ok(FD_CLOEXEC as usize); + } else { + return Ok(0); } } return Err(SystemError::EBADF); @@ -1145,8 +1171,8 @@ impl Syscall { // TODO: unimplemented // 未实现的命令,返回0,不报错。 - // kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd); - return Ok(0); + kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd); + return Err(SystemError::ENOSYS); } } } diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 144f7ec18..6e727effe 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -12,7 +12,7 @@ use crate::{ use super::{ event_poll::{EPollEventType, EventPoll}, - socket::{inet::TcpSocket, HANDLE_MAP, SOCKET_SET}, + socket::{handle::GlobalSocketHandle, inet::TcpSocket, HANDLE_MAP, SOCKET_SET}, }; /// The network poll function, which will be called by timer. @@ -188,7 +188,8 @@ pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> { fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { for (handle, socket_type) in sockets.iter() { let handle_guard = HANDLE_MAP.read_irqsave(); - let item = handle_guard.get(&handle); + let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle); + let item = handle_guard.get(&global_handle); if item.is_none() { continue; } @@ -203,7 +204,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { match socket_type { smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => { handle_guard - .get(&handle) + .get(&global_handle) .unwrap() .wait_queue .wakeup_any(events); @@ -217,7 +218,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { events |= TcpSocket::CAN_CONNECT; } handle_guard - .get(&handle) + .get(&global_handle) .unwrap() .wait_queue .wakeup_any(events); @@ -227,7 +228,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { } drop(handle_guard); let mut handle_guard = HANDLE_MAP.write_irqsave(); - let handle_item = handle_guard.get_mut(&handle).unwrap(); + let handle_item = handle_guard.get_mut(&global_handle).unwrap(); EventPoll::wakeup_epoll( &handle_item.epitems, EPollEventType::from_bits_truncate(events as u32), diff --git a/kernel/src/net/socket/handle.rs b/kernel/src/net/socket/handle.rs new file mode 100644 index 000000000..97701ace0 --- /dev/null +++ b/kernel/src/net/socket/handle.rs @@ -0,0 +1,39 @@ +use ida::IdAllocator; +use smoltcp::iface::SocketHandle; + +int_like!(KernelHandle, usize); + +/// # socket的句柄管理组件 +/// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。 +/// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。 +#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] +pub enum GlobalSocketHandle { + Smoltcp(SocketHandle), + Kernel(KernelHandle), +} + +static KERNEL_HANDLE_IDA: IdAllocator = IdAllocator::new(0, usize::MAX); + +impl GlobalSocketHandle { + pub fn new_smoltcp_handle(handle: SocketHandle) -> Self { + return Self::Smoltcp(handle); + } + + pub fn new_kernel_handle() -> Self { + return Self::Kernel(KernelHandle::new(KERNEL_HANDLE_IDA.alloc().unwrap())); + } + + pub fn smoltcp_handle(&self) -> Option { + if let Self::Smoltcp(sh) = *self { + return Some(sh); + } + None + } + + pub fn kernel_handle(&self) -> Option { + if let Self::Kernel(kh) = *self { + return Some(kh); + } + None + } +} diff --git a/kernel/src/net/socket/inet.rs b/kernel/src/net/socket/inet.rs index f2822a88a..2e4e25962 100644 --- a/kernel/src/net/socket/inet.rs +++ b/kernel/src/net/socket/inet.rs @@ -1,7 +1,6 @@ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use smoltcp::{ - iface::SocketHandle, - socket::{raw, tcp, udp}, + socket::{raw, tcp, udp, AnySocket}, wire, }; use system_error::SystemError; @@ -18,8 +17,8 @@ use crate::{ }; use super::{ - GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod, - SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, + handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, + SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, }; /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 @@ -27,7 +26,7 @@ use super::{ /// ref: https://man7.org/linux/man-pages/man7/raw.7.html #[derive(Debug, Clone)] pub struct RawSocket { - handle: Arc, + handle: GlobalSocketHandle, /// 用户发送的数据包是否包含了IP头. /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) @@ -68,8 +67,7 @@ impl RawSocket { ); // 把socket添加到socket集合中,并得到socket的句柄 - let handle: Arc = - GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket)); + let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); let metadata = SocketMetadata::new( SocketType::Raw, @@ -88,12 +86,20 @@ impl RawSocket { } impl Socket for RawSocket { + fn close(&mut self) { + let mut socket_set_guard = SOCKET_SET.lock_irqsave(); + socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? + drop(socket_set_guard); + poll_ifaces(); + } + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { poll_ifaces(); loop { // 如何优化这里? let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handle.0); + let socket = + socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); match socket.recv_slice(buf) { Ok(len) => { @@ -126,7 +132,8 @@ impl Socket for RawSocket { // 如果用户发送的数据包,包含IP头,则直接发送 if self.header_included { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handle.0); + let socket = + socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); match socket.send_slice(buf) { Ok(_) => { return Ok(buf.len()); @@ -141,7 +148,7 @@ impl Socket for RawSocket { if let Some(Endpoint::Ip(Some(endpoint))) = to { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket: &mut raw::Socket = - socket_set_guard.get_mut::(self.handle.0); + socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone(); @@ -209,8 +216,8 @@ impl Socket for RawSocket { Box::new(self.clone()) } - fn socket_handle(&self) -> SocketHandle { - self.handle.0 + fn socket_handle(&self) -> GlobalSocketHandle { + self.handle } fn as_any_ref(&self) -> &dyn core::any::Any { @@ -227,7 +234,7 @@ impl Socket for RawSocket { /// https://man7.org/linux/man-pages/man7/udp.7.html #[derive(Debug, Clone)] pub struct UdpSocket { - pub handle: Arc, + pub handle: GlobalSocketHandle, remote_endpoint: Option, // 记录远程endpoint提供给connect(), 应该使用IP地址。 metadata: SocketMetadata, } @@ -257,8 +264,8 @@ impl UdpSocket { let socket = udp::Socket::new(rx_buffer, tx_buffer); // 把socket添加到socket集合中,并得到socket的句柄 - let handle: Arc = - GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket)); + let handle: GlobalSocketHandle = + GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); let metadata = SocketMetadata::new( SocketType::Udp, @@ -301,13 +308,21 @@ impl UdpSocket { } impl Socket for UdpSocket { + fn close(&mut self) { + let mut socket_set_guard = SOCKET_SET.lock_irqsave(); + socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? + drop(socket_set_guard); + poll_ifaces(); + } + /// @brief 在read函数执行之前,请先bind到本地的指定端口 fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { loop { // kdebug!("Wait22 to Read"); poll_ifaces(); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handle.0); + let socket = + socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); // kdebug!("Wait to Read"); @@ -344,7 +359,7 @@ impl Socket for UdpSocket { // kdebug!("udp write: remote = {:?}", remote_endpoint); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handle.0); + let socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); // kdebug!("is open()={}", socket.is_open()); // kdebug!("socket endpoint={:?}", socket.endpoint()); if socket.can_send() { @@ -369,14 +384,14 @@ impl Socket for UdpSocket { fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { let mut sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get_mut::(self.handle.0); + let socket = sockets.get_mut::(self.handle.smoltcp_handle().unwrap()); // kdebug!("UDP Bind to {:?}", endpoint); return self.do_bind(socket, endpoint); } fn poll(&self) -> EPollEventType { let sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get::(self.handle.0); + let socket = sockets.get::(self.handle.smoltcp_handle().unwrap()); return SocketPollMethod::udp_poll( socket, @@ -417,7 +432,7 @@ impl Socket for UdpSocket { fn endpoint(&self) -> Option { let sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get::(self.handle.0); + let socket = sockets.get::(self.handle.smoltcp_handle().unwrap()); let listen_endpoint = socket.endpoint(); if listen_endpoint.port == 0 { @@ -440,8 +455,8 @@ impl Socket for UdpSocket { return self.remote_endpoint.clone(); } - fn socket_handle(&self) -> SocketHandle { - self.handle.0 + fn socket_handle(&self) -> GlobalSocketHandle { + self.handle } fn as_any_ref(&self) -> &dyn core::any::Any { @@ -458,7 +473,7 @@ impl Socket for UdpSocket { /// https://man7.org/linux/man-pages/man7/tcp.7.html #[derive(Debug, Clone)] pub struct TcpSocket { - handles: Vec>, + handles: Vec, local_endpoint: Option, // save local endpoint for bind() is_listening: bool, metadata: SocketMetadata, @@ -483,7 +498,7 @@ impl TcpSocket { /// @return 返回创建的tcp的socket pub fn new(options: SocketOptions) -> Self { // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄 - let handles: Vec> = vec![GlobalSocketHandle::new( + let handles: Vec = vec![GlobalSocketHandle::new_smoltcp_handle( SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), )]; @@ -542,6 +557,15 @@ impl TcpSocket { } impl Socket for TcpSocket { + fn close(&mut self) { + for handle in self.handles.iter() { + let mut socket_set_guard = SOCKET_SET.lock_irqsave(); + socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? + drop(socket_set_guard); + } + poll_ifaces(); + } + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { if HANDLE_MAP .read_irqsave() @@ -558,7 +582,8 @@ impl Socket for TcpSocket { poll_ifaces(); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handles.get(0).unwrap().0); + let socket = socket_set_guard + .get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); // 如果socket已经关闭,返回错误 if !socket.is_active() { @@ -626,7 +651,8 @@ impl Socket for TcpSocket { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handles.get(0).unwrap().0); + let socket = socket_set_guard + .get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); if socket.is_open() { if socket.can_send() { @@ -653,7 +679,8 @@ impl Socket for TcpSocket { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); // kdebug!("tcp socket:poll, socket'len={}",self.handle.len()); - let socket = socket_set_guard.get_mut::(self.handles.get(0).unwrap().0); + let socket = socket_set_guard + .get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); return SocketPollMethod::tcp_poll( socket, HANDLE_MAP @@ -668,7 +695,8 @@ impl Socket for TcpSocket { let mut sockets = SOCKET_SET.lock_irqsave(); // kdebug!("tcp socket:connect, socket'len={}",self.handle.len()); - let socket = sockets.get_mut::(self.handles.get(0).unwrap().0); + let socket = + sockets.get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); if let Endpoint::Ip(Some(ip)) = endpoint { let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; @@ -689,7 +717,9 @@ impl Socket for TcpSocket { loop { poll_ifaces(); let mut sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get_mut::(self.handles.get(0).unwrap().0); + let socket = sockets.get_mut::( + self.handles.get(0).unwrap().smoltcp_handle().unwrap(), + ); match socket.state() { tcp::State::Established => { @@ -741,9 +771,9 @@ impl Socket for TcpSocket { let mut handle_guard = HANDLE_MAP.write_irqsave(); self.handles.extend((handlen..backlog).map(|_| { let socket = Self::create_new_socket(); - let handle = GlobalSocketHandle::new(sockets.add(socket)); + let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); let handle_item = SocketHandleItem::new(); - handle_guard.insert(handle.0, handle_item); + handle_guard.insert(handle, handle_item); handle })); // kdebug!("tcp socket:listen, socket'len={}",self.handle.len()); @@ -753,7 +783,7 @@ impl Socket for TcpSocket { for i in 0..backlog { let handle = self.handles.get(i).unwrap(); - let socket = sockets.get_mut::(handle.0); + let socket = sockets.get_mut::(handle.smoltcp_handle().unwrap()); if !socket.is_listening() { // kdebug!("Tcp Socket is already listening on {local_endpoint}"); @@ -803,79 +833,89 @@ impl Socket for TcpSocket { // 随机获取访问的socket的handle let index: usize = rand() % self.handles.len(); let handle = self.handles.get(index).unwrap(); - let socket = sockets.get_mut::(handle.0); - - if socket.is_active() { - // kdebug!("tcp accept: socket.is_active()"); - let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; - - let new_socket = { - // The new TCP socket used for sending and receiving data. - let mut tcp_socket = Self::create_new_socket(); - self.do_listen(&mut tcp_socket, endpoint) - .expect("do_listen failed"); - - // tcp_socket.listen(endpoint).unwrap(); - - // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 - // 因此需要再为当前的socket分配一个新的handle - let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket)); - let old_handle = ::core::mem::replace( - &mut *self.handles.get_mut(index).unwrap(), - new_handle.clone(), - ); - - let metadata = SocketMetadata::new( - SocketType::Tcp, - Self::DEFAULT_TX_BUF_SIZE, - Self::DEFAULT_RX_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - self.metadata.options, - ); - let new_socket = Box::new(TcpSocket { - handles: vec![old_handle.clone()], - local_endpoint: self.local_endpoint, - is_listening: false, - metadata, - }); - // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len()); - - // 更新端口与 socket 的绑定 - if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() { - PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?; - PORT_MANAGER.bind_port( - self.metadata.socket_type, - ip.port, - *new_socket.clone(), - )?; - } + let socket = sockets + .iter_mut() + .find(|y| { + tcp::Socket::downcast(y.1) + .map(|y| y.is_active()) + .unwrap_or(false) + }) + .map(|y| tcp::Socket::downcast_mut(y.1).unwrap()); + if let Some(socket) = socket { + if socket.is_active() { + // kdebug!("tcp accept: socket.is_active()"); + let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; + + let new_socket = { + // The new TCP socket used for sending and receiving data. + let mut tcp_socket = Self::create_new_socket(); + self.do_listen(&mut tcp_socket, endpoint) + .expect("do_listen failed"); + + // tcp_socket.listen(endpoint).unwrap(); + + // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 + // 因此需要再为当前的socket分配一个新的handle + let new_handle = + GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket)); + let old_handle = ::core::mem::replace( + &mut *self.handles.get_mut(index).unwrap(), + new_handle, + ); + + let metadata = SocketMetadata::new( + SocketType::Tcp, + Self::DEFAULT_TX_BUF_SIZE, + Self::DEFAULT_RX_BUF_SIZE, + Self::DEFAULT_METADATA_BUF_SIZE, + self.metadata.options, + ); + + let new_socket = Box::new(TcpSocket { + handles: vec![old_handle], + local_endpoint: self.local_endpoint, + is_listening: false, + metadata, + }); + // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len()); + + // 更新端口与 socket 的绑定 + if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() { + PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?; + PORT_MANAGER.bind_port( + self.metadata.socket_type, + ip.port, + *new_socket.clone(), + )?; + } - // 更新handle表 - let mut handle_guard = HANDLE_MAP.write_irqsave(); - // 先删除原来的 + // 更新handle表 + let mut handle_guard = HANDLE_MAP.write_irqsave(); + // 先删除原来的 - let item = handle_guard.remove(&old_handle.0).unwrap(); + let item = handle_guard.remove(&old_handle).unwrap(); - // 按照smoltcp行为,将新的handle绑定到原来的item - handle_guard.insert(new_handle.0, item); - let new_item = SocketHandleItem::new(); + // 按照smoltcp行为,将新的handle绑定到原来的item + handle_guard.insert(new_handle, item); + let new_item = SocketHandleItem::new(); - // 插入新的item - handle_guard.insert(old_handle.0, new_item); + // 插入新的item + handle_guard.insert(old_handle, new_item); - new_socket - }; - // kdebug!("tcp accept: new socket: {:?}", new_socket); - drop(sockets); - poll_ifaces(); + new_socket + }; + // kdebug!("tcp accept: new socket: {:?}", new_socket); + drop(sockets); + poll_ifaces(); - return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); + return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); + } } // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); drop(sockets); - SocketHandleItem::sleep(handle.0, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave()); + SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave()); // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); } } @@ -887,7 +927,8 @@ impl Socket for TcpSocket { let sockets = SOCKET_SET.lock_irqsave(); // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len()); - let socket = sockets.get::(self.handles.get(0).unwrap().0); + let socket = + sockets.get::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); if let Some(ep) = socket.local_endpoint() { result = Some(Endpoint::Ip(Some(ep))); } @@ -899,7 +940,8 @@ impl Socket for TcpSocket { let sockets = SOCKET_SET.lock_irqsave(); // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); - let socket = sockets.get::(self.handles.get(0).unwrap().0); + let socket = + sockets.get::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); } @@ -911,10 +953,10 @@ impl Socket for TcpSocket { Box::new(self.clone()) } - fn socket_handle(&self) -> SocketHandle { + fn socket_handle(&self) -> GlobalSocketHandle { // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); - self.handles.get(0).unwrap().0 + *self.handles.get(0).unwrap() } fn as_any_ref(&self) -> &dyn core::any::Any { diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index c5163381f..9dd73c1ef 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -9,7 +9,7 @@ use alloc::{ }; use hashbrown::HashMap; use smoltcp::{ - iface::{SocketHandle, SocketSet}, + iface::SocketSet, socket::{self, tcp, udp}, }; use system_error::SystemError; @@ -29,16 +29,17 @@ use crate::{ }; use self::{ + handle::GlobalSocketHandle, inet::{RawSocket, TcpSocket, UdpSocket}, unix::{SeqpacketSocket, StreamSocket}, }; use super::{ event_poll::{EPollEventType, EPollItem, EventPoll}, - net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, }; +pub mod handle; pub mod inet; pub mod unix; @@ -48,7 +49,7 @@ lazy_static! { pub static ref SOCKET_SET: SpinLock> = SpinLock::new(SocketSet::new(vec![])); /// SocketHandle表,每个SocketHandle对应一个SocketHandleItem, /// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁 - pub static ref HANDLE_MAP: RwLock> = RwLock::new(HashMap::new()); + pub static ref HANDLE_MAP: RwLock> = RwLock::new(HashMap::new()); /// 端口管理器 pub static ref PORT_MANAGER: PortManager = PortManager::new(); } @@ -83,6 +84,11 @@ pub(super) fn new_socket( return Err(SystemError::EAFNOSUPPORT); } }; + + let handle_item = SocketHandleItem::new(); + HANDLE_MAP + .write_irqsave() + .insert(socket.socket_handle(), handle_item); Ok(socket) } @@ -224,9 +230,7 @@ pub trait Socket: Sync + Send + Debug + Any { Ok(()) } - fn socket_handle(&self) -> SocketHandle { - todo!() - } + fn socket_handle(&self) -> GlobalSocketHandle; fn write_buffer(&self, _buf: &[u8]) -> Result { todo!() @@ -272,6 +276,8 @@ pub trait Socket: Sync + Send + Debug + Any { Ok(()) } + + fn close(&mut self); } impl Clone for Box { @@ -329,6 +335,7 @@ impl IndexNode for SocketInode { .write_irqsave() .remove(&socket.socket_handle()) .unwrap(); + socket.close(); } Ok(()) @@ -409,9 +416,9 @@ impl SocketHandleItem { /// ## 在socket的等待队列上睡眠 pub fn sleep( - socket_handle: SocketHandle, + socket_handle: GlobalSocketHandle, events: u64, - handle_map_guard: RwLockReadGuard<'_, HashMap>, + handle_map_guard: RwLockReadGuard<'_, HashMap>, ) { unsafe { handle_map_guard @@ -544,33 +551,6 @@ impl PortManager { } } -/// # socket的句柄管理组件 -/// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。 -/// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。 -#[derive(Debug)] -pub struct GlobalSocketHandle(SocketHandle); - -impl GlobalSocketHandle { - pub fn new(handle: SocketHandle) -> Arc { - return Arc::new(Self(handle)); - } -} - -impl Clone for GlobalSocketHandle { - fn clone(&self) -> Self { - Self(self.0) - } -} - -impl Drop for GlobalSocketHandle { - fn drop(&mut self) { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - socket_set_guard.remove(self.0); // 删除的时候,会发送一条FINISH的信息? - drop(socket_set_guard); - poll_ifaces(); - } -} - /// @brief socket的类型 #[derive(Debug, Clone, Copy, PartialEq)] pub enum SocketType { diff --git a/kernel/src/net/socket/unix.rs b/kernel/src/net/socket/unix.rs index 0082d1ede..83cd0c5d2 100644 --- a/kernel/src/net/socket/unix.rs +++ b/kernel/src/net/socket/unix.rs @@ -3,13 +3,16 @@ use system_error::SystemError; use crate::{libs::spinlock::SpinLock, net::Endpoint}; -use super::{Socket, SocketInode, SocketMetadata, SocketOptions, SocketType}; +use super::{ + handle::GlobalSocketHandle, Socket, SocketInode, SocketMetadata, SocketOptions, SocketType, +}; #[derive(Debug, Clone)] pub struct StreamSocket { metadata: SocketMetadata, buffer: Arc>>, peer_inode: Option>, + handle: GlobalSocketHandle, } impl StreamSocket { @@ -37,11 +40,18 @@ impl StreamSocket { metadata, buffer, peer_inode: None, + handle: GlobalSocketHandle::new_kernel_handle(), } } } impl Socket for StreamSocket { + fn socket_handle(&self) -> GlobalSocketHandle { + self.handle + } + + fn close(&mut self) {} + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { let mut buffer = self.buffer.lock_irqsave(); @@ -110,6 +120,7 @@ pub struct SeqpacketSocket { metadata: SocketMetadata, buffer: Arc>>, peer_inode: Option>, + handle: GlobalSocketHandle, } impl SeqpacketSocket { @@ -137,11 +148,14 @@ impl SeqpacketSocket { metadata, buffer, peer_inode: None, + handle: GlobalSocketHandle::new_kernel_handle(), } } } impl Socket for SeqpacketSocket { + fn close(&mut self) {} + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { let mut buffer = self.buffer.lock_irqsave(); @@ -188,6 +202,10 @@ impl Socket for SeqpacketSocket { Ok(len) } + fn socket_handle(&self) -> GlobalSocketHandle { + self.handle + } + fn metadata(&self) -> SocketMetadata { self.metadata.clone() } diff --git a/kernel/src/net/syscall.rs b/kernel/src/net/syscall.rs index 05f08713a..cc9bf5696 100644 --- a/kernel/src/net/syscall.rs +++ b/kernel/src/net/syscall.rs @@ -19,7 +19,7 @@ use crate::{ }; use super::{ - socket::{new_socket, PosixSocketType, Socket, SocketHandleItem, SocketInode, HANDLE_MAP}, + socket::{new_socket, PosixSocketType, Socket, SocketInode}, Endpoint, Protocol, ShutdownType, }; @@ -44,13 +44,6 @@ impl Syscall { let socket = new_socket(address_family, socket_type, protocol)?; - if address_family != AddressFamily::Unix { - let handle_item = SocketHandleItem::new(); - HANDLE_MAP - .write_irqsave() - .insert(socket.socket_handle(), handle_item); - } - let socketinode: Arc = SocketInode::new(socket); let f = File::new(socketinode, FileMode::O_RDWR)?; // 把socket添加到当前进程的文件描述符表中