Commit 9480693a authored by Thomas Gatzweiler's avatar Thomas Gatzweiler
Browse files

Async IO Hack

parent 46e4bfab
......@@ -2,6 +2,7 @@ extern crate ssh;
extern crate log;
use std::env;
use std::error::Error;
use std::fs::File;
use std::io::{self, Write};
use std::process;
......
use std::fs::{File, OpenOptions};
use std::io::{self, Read, Write};
use std::os::unix::io::{FromRawFd, IntoRawFd, RawFd};
use std::fs::OpenOptions;
use std::io;
use std::os::unix::io::{FromRawFd, IntoRawFd};
use std::os::unix::process::CommandExt;
use std::path::PathBuf;
use std::process::{self, Stdio};
use std::thread::{self, JoinHandle};
use std::sync::mpsc;
use sys;
use connection::ConnectionEvent;
pub type ChannelId = u32;
#[derive(Debug)]
pub struct Channel {
id: ChannelId,
peer_id: ChannelId,
process: Option<process::Child>,
pty: Option<(RawFd, PathBuf)>,
master: Option<File>,
pty: Option<sys::Pty>,
window_size: u32,
peer_window_size: u32,
max_packet_size: u32,
read_thread: Option<JoinHandle<()>>,
events: mpsc::Sender<ConnectionEvent>,
}
#[derive(Debug)]
......@@ -38,18 +37,17 @@ pub enum ChannelRequest {
impl Channel {
pub fn new(
id: ChannelId, peer_id: ChannelId, peer_window_size: u32,
max_packet_size: u32
max_packet_size: u32, events: mpsc::Sender<ConnectionEvent>
) -> Channel {
Channel {
id: id,
peer_id: peer_id,
process: None,
master: None,
pty: None,
window_size: peer_window_size,
peer_window_size: peer_window_size,
max_packet_size: max_packet_size,
read_thread: None,
events: events,
}
}
......@@ -57,6 +55,10 @@ impl Channel {
self.id
}
pub fn peer_id(&self) -> ChannelId {
self.peer_id
}
pub fn window_size(&self) -> u32 {
self.window_size
}
......@@ -75,83 +77,78 @@ impl Channel {
pixel_height,
..
} => {
let (master_fd, tty_path) = sys::getpty();
sys::set_winsize(
master_fd,
chars,
rows,
pixel_width,
pixel_height,
);
self.read_thread = Some(thread::spawn(move || {
use libc::dup;
let master2 = unsafe { dup(master_fd) };
println!("dup result: {}", dup as u32);
let mut master = unsafe { File::from_raw_fd(master2) };
loop {
use std::str::from_utf8_unchecked;
let mut buf = [0; 4096];
let count = master.read(&mut buf).unwrap();
if count == 0 {
break;
}
println!("Read: {}", unsafe {
from_utf8_unchecked(&buf[0..count])
});
}
println!("Quitting read thread.");
}));
self.pty = Some((master_fd, tty_path));
self.master = Some(unsafe { File::from_raw_fd(master_fd) });
if let Ok(mut pty) = sys::Pty::get() {
pty.set_winsize(chars, rows, pixel_width, pixel_height);
let events = self.events.clone();
let id = self.id;
pty.subscribe(move || {
events.send(ConnectionEvent::ChannelData(id)).map_err(
|_| (),
)
});
self.pty = Some(pty);
}
}
ChannelRequest::Shell => {
if let Some(&(_, ref tty_path)) = self.pty.as_ref() {
if let Some(ref pty) = self.pty {
let stdin = OpenOptions::new()
.read(true)
.write(true)
.open(&tty_path)
.open(pty.path())
.unwrap()
.into_raw_fd();
let stdout = OpenOptions::new()
.read(true)
.write(true)
.open(&tty_path)
.open(pty.path())
.unwrap()
.into_raw_fd();
let stderr = OpenOptions::new()
.read(true)
.write(true)
.open(&tty_path)
.open(pty.path())
.unwrap()
.into_raw_fd();
process::Command::new("login")
.stdin(unsafe { Stdio::from_raw_fd(stdin) })
.stdout(unsafe { Stdio::from_raw_fd(stdout) })
.stderr(unsafe { Stdio::from_raw_fd(stderr) })
.before_exec(|| sys::before_exec())
.spawn()
.unwrap();
self.process = Some(
process::Command::new("login")
.stdin(unsafe { Stdio::from_raw_fd(stdin) })
.stdout(unsafe { Stdio::from_raw_fd(stdout) })
.stderr(unsafe { Stdio::from_raw_fd(stderr) })
.before_exec(|| sys::before_exec())
.spawn()
.unwrap(),
);
}
}
}
debug!("Channel Request: {:?}", request);
}
pub fn data(&mut self, data: &[u8]) -> io::Result<()> {
if let Some(ref mut master) = self.master {
master.write_all(data)?;
master.flush()
pub fn write(&mut self, data: &[u8]) -> io::Result<()> {
match self.pty
{
Some(ref mut pty) => pty.write(data),
_ => Ok(()),
}
else {
Ok(())
}
pub fn read(&mut self, data: &mut [u8]) -> io::Result<usize> {
match self.pty
{
Some(ref mut pty) => pty.read(data),
_ => Ok(0),
}
}
}
impl Drop for Channel {
fn drop(&mut self) {
self.process.take().map(|mut p| p.kill());
}
}
use std::collections::{BTreeMap, VecDeque};
use std::io::{self, BufReader, Read, Write};
use std::sync::Arc;
use std::sync::mpsc;
use channel::{Channel, ChannelId, ChannelRequest};
use encryption::{AesCtr, Decryptor, Encryption};
......@@ -18,6 +19,11 @@ enum ConnectionState {
Established,
}
pub enum ConnectionEvent {
ChannelData(ChannelId),
StreamData,
}
#[derive(Clone)]
pub enum ConnectionType {
Server(Arc<ServerConfig>),
......@@ -42,10 +48,14 @@ pub struct Connection {
seq: (u32, u32),
tx_queue: VecDeque<Packet>,
channels: BTreeMap<ChannelId, Channel>,
pub events_tx: mpsc::Sender<ConnectionEvent>,
events_rx: mpsc::Receiver<ConnectionEvent>,
}
impl<'a> Connection {
pub fn new(conn_type: ConnectionType) -> Connection {
let (events_tx, events_rx) = mpsc::channel();
Connection {
conn_type: conn_type,
hash_data: HashData::default(),
......@@ -57,6 +67,8 @@ impl<'a> Connection {
seq: (0, 0),
tx_queue: VecDeque::new(),
channels: BTreeMap::new(),
events_rx: events_rx,
events_tx: events_tx,
}
}
......@@ -67,23 +79,43 @@ impl<'a> Connection {
let mut reader = BufReader::new(stream);
loop {
let packet = self.recv(&mut reader)?;
let response = self.process(packet)?;
match self.events_rx.recv()
{
Ok(ConnectionEvent::ChannelData(id)) => {
if let Some(ref mut channel) = self.channels.get_mut(&id) {
let mut buf = [0; 4096];
let count = channel.read(&mut buf)?;
if count > 0 {
let mut res = Packet::new(MessageType::ChannelData);
res.write_uint32(channel.peer_id())?;
res.write_bytes(&buf[..count])?;
}
}
}
Ok(ConnectionEvent::StreamData) => {
let packet = self.recv(&mut reader)?;
let response = self.process(packet)?;
let mut stream = reader.get_mut();
let mut stream = reader.get_mut();
if let Some(packet) = response {
self.send(&mut stream, packet)?;
if let Some(packet) = response {
self.send(&mut stream, packet)?;
}
}
Err(_) => {}
}
// Send additional packets from the queue
let mut packets: Vec<Packet> = self.tx_queue.drain(..).collect();
let mut stream = reader.get_mut();
for packet in packets.drain(..) {
self.send(&mut stream, packet)?;
}
}
}
fn recv(&mut self, mut stream: &mut Read) -> Result<Packet> {
let packet = if let Some((ref mut c2s, _)) = self.encryption {
let mut decryptor = Decryptor::new(&mut **c2s, &mut stream);
......@@ -296,7 +328,13 @@ impl<'a> Connection {
0
};
let channel = Channel::new(id, peer_id, window_size, max_packet_size);
let channel = Channel::new(
id,
peer_id,
window_size,
max_packet_size,
self.events_tx.clone(),
);
let mut res = Packet::new(MessageType::ChannelOpenConfirmation);
res.write_uint32(peer_id)?;
......@@ -304,7 +342,7 @@ impl<'a> Connection {
res.write_uint32(channel.window_size())?;
res.write_uint32(channel.max_packet_size())?;
debug!("Open {:?}", channel);
debug!("Open Channel {}", id);
self.channels.insert(id, channel);
......@@ -357,7 +395,7 @@ impl<'a> Connection {
let data = reader.read_string()?;
let mut channel = self.channels.get_mut(&channel_id).unwrap();
channel.data(data.as_slice())?;
channel.write(data.as_slice())?;
Ok(None)
}
......
......@@ -9,8 +9,6 @@ extern crate syscall;
#[cfg(not(target_os = "redox"))]
extern crate libc;
mod error;
mod algorithm;
mod packet;
mod message;
mod connection;
......@@ -19,15 +17,17 @@ mod encryption;
mod mac;
mod channel;
pub mod error;
pub mod algorithm;
pub mod public_key;
pub mod server;
pub use self::server::{Server, ServerConfig};
#[cfg(target_os = "redox")]
#[path = "sys/redox.rs"]
#[path = "sys/redox/mod.rs"]
pub mod sys;
#[cfg(not(target_os = "redox"))]
#[path = "sys/unix.rs"]
#[path = "sys/linux/mod.rs"]
pub mod sys;
pub use self::server::{Server, ServerConfig};
use std::io;
use std::net::TcpListener;
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use std::thread;
use connection::{Connection, ConnectionType};
use connection::{Connection, ConnectionEvent, ConnectionType};
use public_key::KeyPair;
use sys;
pub struct ServerConfig {
pub host: String,
......@@ -38,7 +40,7 @@ impl Server {
let result = connection.run(&mut stream);
if let Some(error) = result.err() {
println!("sshd: {}", error)
println!("sshd: {}", error);
}
});
}
......
use std::io::Result;
use std::os::unix::io::RawFd;
mod pty;
pub use self::pty::Pty;
pub fn before_exec() -> Result<()> {
use libc;
unsafe {
libc::setsid();
libc::ioctl(0, libc::TIOCSCTTY, 1);
}
Ok(())
}
pub fn fork() -> usize {
use libc;
unsafe { libc::fork() as usize }
}
use libc;
use std::ffi::CStr;
use std::fs::{File, OpenOptions};
use std::io::{self, Read, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use std::path::PathBuf;
use std::sync::mpsc;
use std::thread::{self, JoinHandle};
pub struct Pty {
master: File,
path: PathBuf,
sub_thread: Option<JoinHandle<()>>,
sub_thread_tx: Option<mpsc::Sender<ThreadCommand>>,
}
enum ThreadCommand {
WaitForData,
Stop,
}
impl Pty {
pub fn get() -> Result<Pty, ()> {
const TIOCPKT: libc::c_ulong = 0x5420;
let master_fd = OpenOptions::new()
.read(true)
.write(true)
.custom_flags(libc::O_NONBLOCK)
.open("/dev/ptmx")
.unwrap()
.into_raw_fd();
unsafe {
use std::io::Error;
let mut flag: libc::c_int = 1;
if libc::ioctl(
master_fd,
TIOCPKT,
&mut flag as *mut libc::c_int,
) < 0
{
error!("ioctl: {:?}", Error::last_os_error());
return Err(());
}
if libc::grantpt(master_fd) < 0 {
error!("grantpt: {:?}", Error::last_os_error());
return Err(());
}
if libc::unlockpt(master_fd) < 0 {
error!("unlockpt: {:?}", Error::last_os_error());
return Err(());
}
}
let tty_path = unsafe {
PathBuf::from(
CStr::from_ptr(libc::ptsname(master_fd))
.to_string_lossy()
.into_owned(),
)
};
let master = unsafe { File::from_raw_fd(master_fd) };
Ok(Pty {
master: master,
path: tty_path,
sub_thread: None,
sub_thread_tx: None,
})
}
pub fn subscribe<F>(&mut self, callback: F)
where
F: Fn() -> Result<(), ()> + Send + 'static,
{
let (thread_tx, thread_rx) = mpsc::channel();
let mut pollfd = libc::pollfd {
fd: self.master.as_raw_fd(),
events: libc::POLLIN,
revents: 0,
};
self.sub_thread = Some(thread::spawn(move || {
loop {
match thread_rx.recv()
{
Ok(ThreadCommand::WaitForData) => {}
Ok(ThreadCommand::Stop) => return,
Err(_) => return,
}
// Clear receive queue
while !thread_rx.try_recv().is_err() {}
unsafe { libc::poll(&mut pollfd as *mut libc::pollfd, 1, -1) };
if callback().is_err() {
return;
}
}
}));
self.sub_thread_tx = Some(thread_tx);
}
pub fn path<'a>(&'a self) -> &'a PathBuf {
&self.path
}
pub fn set_winsize(&self, row: u16, col: u16, xpixel: u16, ypixel: u16) {
let size = libc::winsize {
ws_row: row,
ws_col: col,
ws_xpixel: xpixel,
ws_ypixel: ypixel,
};
unsafe {
let fd = self.master.as_raw_fd();
libc::ioctl(fd, libc::TIOCSWINSZ, &size as *const libc::winsize);
}
}
pub fn write(&mut self, data: &[u8]) -> io::Result<()> {
self.master.write_all(data)?;
self.master.flush()
}
pub fn read(&mut self, data: &mut [u8]) -> io::Result<usize> {
match self.master.read(data)
{
Ok(count) => {
self.sub_thread_tx.as_ref().map(|tx| {
tx.send(ThreadCommand::WaitForData)
});
Ok(count)
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.sub_thread_tx.as_ref().map(|tx| {
tx.send(ThreadCommand::WaitForData)
});
Ok(0)
}
Err(e) => Err(e),
}
}
}
impl Drop for Pty {
fn drop(&mut self) {
self.sub_thread_tx.take().map(
|tx| tx.send(ThreadCommand::Stop),
);
self.sub_thread = None;
}
}
use std::io::Result;
use std::os::unix::io::RawFd;
use std::path::PathBuf;
pub mod pty;
pub fn before_exec() -> Result<()> {
Ok(())
......@@ -12,20 +12,3 @@ pub fn fork() -> usize {
}
pub fn set_winsize(fd: RawFd, row: u16, col: u16, xpixel: u16, ypixel: u16) {}
pub fn getpty() -> (RawFd, PathBuf) {
use syscall;
let master = syscall::open("pty:", syscall::O_RDWR | syscall::O_CREAT)
.unwrap();
let mut buf: [u8; 4096] = [0; 4096];
let count = syscall::fpath(master, &mut buf).unwrap();
(
master,
PathBuf::from(unsafe {
String::from_utf8_unchecked(Vec::from(&buf[..count]))
}),
)
}
use std::io::Result;
use std::os::unix::io::RawFd;
use std::path::PathBuf;
pub fn getpty() -> (RawFd, PathBuf) {
use syscall;
let master = syscall::open(
"pty:",
syscall::O_RDWR | syscall::O_CREAT | syscall::O_NONBLOCK,
).unwrap();
let mut buf: [u8; 4096] = [0; 4096];
let count = syscall::fpath(master, &mut buf).unwrap();
let path = String::from_utf8(Vec::from(&buf[..count]).or(())).unwrap();
(master, PathBuf::from(path))
}
use std::io::Result;
use std::os::unix::io::RawFd;
use std::path::PathBuf;
pub fn before_exec() -> Result<()> {
use libc;
unsafe {
libc::setsid();
libc::ioctl(0, libc::TIOCSCTTY, 1);