基于tornado的websocket和xterm.js实现webssh
目录:
webssh相关
主函数配置和路由相关
main.py
#!-*-coding:utf8-*-
import os
import uuid
import tornado.web
import tornado.ioloop
from handler import IndexHandler, WsockHandler
base_dir = os.path.dirname(__file__)
def make_handlers(loop):
# 用于请求首页的接口和websocket通信的接口
handlers = [
(r'/', IndexHandler, dict(loop=loop)),
(r'/ws', WsockHandler, dict(loop=loop))
]
return handlers
def make_app(handlers):
settings = dict(
template_path=os.path.join(base_dir, 'templates'),
static_path=os.path.join(base_dir, 'static'),
cookie_secret=uuid.uuid4().hex,
websocket_ping_interval=0,
xsrf_cookies=True,
)
app = tornado.web.Application(handlers, **settings)
return app
def main():
loop = tornado.ioloop.IOLoop.current()
app = make_app(make_handlers(loop))
app.listen(10000, '0.0.0.0')
loop.start()
if __name__ == '__main__':
main()
处理请求的handler
handler.py
#!-*-coding:utf8-*-
import io
import json
import logging
import socket
import struct
import threading
import traceback
import weakref
import paramiko
import ipaddress
import tornado.web
from tornado.ioloop import IOLoop
from tornado.util import basestring_type
from worker import Worker, recycle_worker, workers
# 导入并发相关的包
try:
from concurrent.futures import Future
except ImportError:
from tornado.concurrent import Future
try:
from json.decoder import JSONDecodeError
except ImportError:
JSONDecodeError = ValueError
DELAY = 3
def to_str(s):
if isinstance(s, bytes):
return s.decode('utf-8')
return s
def is_valid_ipv4_address(ipstr):
ipstr = to_str(ipstr)
try:
ipaddress.IPv4Address(ipstr)
except ipaddress.AddressValueError:
return False
return True
def is_valid_ipv6_address(ipstr):
ipstr = to_str(ipstr)
try:
ipaddress.IPv6Address(ipstr)
except ipaddress.AddressValueError:
return False
return True
def is_valid_port(port):
return 0 < port < 65536
def parse_encoding(data):
for line in data.split('\n'):
s = line.split('=')[-1]
if s:
return s.strip('"').split('.')[-1]
class MixinHandler(object):
def get_real_client_addr(self):
ip = self.request.headers.get('X-Real-Ip')
port = self.request.headers.get('X-Real-Port')
if ip is None and port is None: # suppose the server doesn't use nginx
return
if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip):
try:
port = int(port)
except (TypeError, ValueError):
pass
else:
if is_valid_port(port):
return (ip, port)
logging.warning('Bad nginx configuration.')
return False
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
"""
“/”路由对应的hadler
"""
def initialize(self, loop):
self.loop = loop
def get(self):
"""
Get方法 直接返回页面
"""
self.render('index.html')
# 异步
@tornado.gen.coroutine
def post(self):
"""
Post方法
"""
worker_id = None
status = None
encoding = None
# 协程
future = Future()
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
t.setDaemon(True)
t.start()
try:
# 获取协程异步操作返回
worker = yield future
except Exception as exc:
status = str(exc)
else:
worker_id = worker.id
# 将worker放入workers字典
workers[worker_id] = worker
# 未知
self.loop.call_later(DELAY, recycle_worker, worker)
encoding = worker.encoding
self.write(dict(id=worker_id, status=status, encoding=encoding))
def ssh_connect_wrapped(self, future):
"""
创建ssh连接
:param future:
:return:
"""
try:
worker = self.ssh_connect()
except Exception as exc:
logging.error(traceback.format_exc())
future.set_exception(exc)
else:
future.set_result(worker)
def ssh_connect(self):
"""
进行HTTP连接
:return:
"""
ssh = paramiko.SSHClient()
# 自动添加策略,保存服务器的主机名和密钥信息
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# 获取请求参数
args = self.get_args()
dst_addr = (args[0], args[1])
logging.info('Connecting to {}:{}'.format(*dst_addr))
try:
ssh.connect(*args, timeout=6)
except socket.error:
raise ValueError('Unable to connect to {}:{}'.format(*dst_addr))
except paramiko.BadAuthenticationType:
raise ValueError('SSH authentication failed.')
except paramiko.BadHostKeyException:
raise ValueError('Bad host key.')
# 使用shell通道,与exec通道不同
chan = ssh.invoke_shell(term='xterm')
chan.setblocking(0)
worker = Worker(self.loop, ssh, chan, dst_addr)
worker.src_addr = self.get_client_addr()
worker.encoding = self.get_default_encoding(ssh)
return worker
def get_args(self):
hostname = self.get_value('hostname')
port = self.get_port()
username = self.get_value('username')
password = self.get_argument('password')
args = (hostname, port, username, password)
logging.debug(args)
return args
def get_value(self, name):
"""
获取request的数据
:param name:
:return:
"""
value = self.get_argument(name)
if not value:
raise ValueError('Empty {}'.format(name))
return value
def get_port(self):
"""
获取端口,字符串转数字
:return:
"""
value = self.get_value('port')
try:
port = int(value)
except ValueError:
pass
else:
if is_valid_port(port):
return port
raise ValueError('Invalid port {}'.format(value))
print(2)
def get_client_addr(self):
# self.request.connection.stream.socket.getpeername()
return self.get_real_client_addr() or self.request.connection.context.address
def get_default_encoding(self, ssh):
try:
_, stdout, _ = ssh.exec_command('locale')
except paramiko.SSHException:
result = None
else:
data = stdout.read().decode('utf-8')
result = parse_encoding(data)
return result if result else 'utf-8'
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def initialize(self, loop):
self.loop = loop
self.worker_ref = None
def get_client_addr(self):
return self.get_real_client_addr() or self.request.connection.context.address
# 开启ws连接
def open(self):
self.src_addr = self.get_client_addr()
logging.info('Connected from {}:{}'.format(*self.src_addr))
# 请求的时候/ws?id=139747945258512
worker = workers.get(self.get_argument('id'))
if worker and worker.src_addr[0] == self.src_addr[0]:
# 在workers中取出
workers.pop(worker.id)
self.set_nodelay(True)
worker.set_handler(self)
self.worker_ref = weakref.ref(worker)
self.loop.add_handler(worker.fd, worker, IOLoop.READ)
else:
self.close(reason='Websocket authentication failed.')
# 同步消息
def on_message(self, message):
logging.debug('{!r} from {}:{}'.format(message, *self.src_addr))
worker = self.worker_ref()
# 获取前端数据
try:
msg = json.loads(message)
except JSONDecodeError:
return
if not isinstance(msg, dict):
return
resize = msg.get('resize')
if resize and len(resize) == 2:
try:
worker.chan.resize_pty(*resize)
except (TypeError, struct.error, paramiko.SSHException):
pass
data = msg.get('data')
if data and isinstance(data, basestring_type):
worker.data_to_dst.append(data)
worker.on_write()
# 关闭ws连接
def on_close(self):
logging.info('Disconnected from {}:{}'.format(*self.src_addr))
worker = self.worker_ref() if self.worker_ref else None
if worker:
if self.close_reason is None:
self.close_reason = 'client disconnected'
worker.close(reason=self.close_reason)
对于ws操作
#!-*-coding:utf8-*-
import logging
import tornado.websocket
from tornado.ioloop import IOLoop
from tornado.iostream import _ERRNO_CONNRESET
from tornado.util import errno_from_exception
BUF_SIZE = 1024
workers = {}
# 回收work
def recycle_worker(worker):
if worker.handler:
return
logging.warning('Recycling worker {}'.format(worker.id))
workers.pop(worker.id, None)
worker.close(reason='worker recycled')
class Worker(object):
def __init__(self, loop, ssh, chan, dst_addr):
self.loop = loop
self.ssh = ssh
self.chan = chan
self.dst_addr = dst_addr
self.fd = chan.fileno()
self.id = str(id(self))
self.data_to_dst = []
self.handler = None
self.mode = IOLoop.READ
def __call__(self, fd, events):
if events & IOLoop.READ:
self.on_read()
if events & IOLoop.WRITE:
self.on_write()
if events & IOLoop.ERROR:
self.close(reason='error event occurred')
def set_handler(self, handler):
if not self.handler:
self.handler = handler
def update_handler(self, mode):
if self.mode != mode:
self.loop.update_handler(self.fd, mode)
self.mode = mode
def on_read(self):
"""
获取终端返回结果
"""
logging.debug('worker {} on read'.format(self.id))
# 获取终端返回结果
try:
data = self.chan.recv(BUF_SIZE)
except (OSError, IOError) as e:
logging.error(e)
if errno_from_exception(e) in _ERRNO_CONNRESET:
self.close(reason='chan error on reading')
else:
logging.debug('{!r} from {}:{}'.format(data, *self.dst_addr))
# 如果数据为空则不进行返回给ws
if not data:
self.close(reason='chan closed')
return
logging.debug('{!r} to {}:{}'.format(data, *self.handler.src_addr))
# 数据返回给ws
try:
self.handler.write_message(data, binary=True)
except tornado.websocket.WebSocketClosedError:
self.close(reason='websocket closed')
def on_write(self):
"""
获取ws数据发送给终端
"""
logging.debug('worker {} on write'.format(self.id))
# 判断是否有需要发送给终端的数据
if not self.data_to_dst:
return
# 待传输字符
data = ''.join(self.data_to_dst)
logging.debug('{!r} to {}:{}'.format(data, *self.dst_addr))
# 传输数据
try:
sent = self.chan.send(data)
except (OSError, IOError) as e:
logging.error(e)
if errno_from_exception(e) in _ERRNO_CONNRESET:
self.close(reason='chan error on writing')
else:
self.update_handler(IOLoop.WRITE)
else:
# 正常传递完sent=传输字符长度
self.data_to_dst = []
# 获取未传递到终端的字符
data = data[sent:]
# 如果有未传递的进行传递
if data:
self.data_to_dst.append(data)
self.update_handler(IOLoop.WRITE)
# 没有就获取终端返回结果
else:
self.update_handler(IOLoop.READ)
def close(self, reason=None):
logging.info(
'Closing worker {} with reason {}'.format(self.id, reason)
)
if self.handler:
self.loop.remove_handler(self.fd)
self.handler.close()
self.chan.close()
self.ssh.close()
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))