基于tornado的websocket和xterm.js实现webssh

时间:Jan. 17, 2020 分类:

目录:

webssh相关

主函数配置和路由相关

main.py

#!-*-coding:utf8-*-
import os
import uuid
import tornado.web
import tornado.ioloop

from handler import IndexHandler, WsockHandler

base_dir = os.path.dirname(__file__)


def make_handlers(loop):
    # 用于请求首页的接口和websocket通信的接口
    handlers = [
        (r'/', IndexHandler, dict(loop=loop)),
        (r'/ws', WsockHandler, dict(loop=loop))
    ]
    return handlers


def make_app(handlers):
    settings = dict(
            template_path=os.path.join(base_dir, 'templates'),
            static_path=os.path.join(base_dir, 'static'),
            cookie_secret=uuid.uuid4().hex,
            websocket_ping_interval=0,
            xsrf_cookies=True,
    )
    app = tornado.web.Application(handlers, **settings)
    return app


def main():
    loop = tornado.ioloop.IOLoop.current()
    app = make_app(make_handlers(loop))
    app.listen(10000, '0.0.0.0')
    loop.start()


if __name__ == '__main__':
    main()

处理请求的handler

handler.py

#!-*-coding:utf8-*-
import io
import json
import logging
import socket
import struct
import threading
import traceback
import weakref
import paramiko
import ipaddress
import tornado.web

from tornado.ioloop import IOLoop
from tornado.util import basestring_type
from worker import Worker, recycle_worker, workers

# 导入并发相关的包

try:
    from concurrent.futures import Future
except ImportError:
    from tornado.concurrent import Future

try:
    from json.decoder import JSONDecodeError
except ImportError:
    JSONDecodeError = ValueError


DELAY = 3


def to_str(s):
    if isinstance(s, bytes):
        return s.decode('utf-8')
    return s


def is_valid_ipv4_address(ipstr):
    ipstr = to_str(ipstr)
    try:
        ipaddress.IPv4Address(ipstr)
    except ipaddress.AddressValueError:
        return False
    return True


def is_valid_ipv6_address(ipstr):
    ipstr = to_str(ipstr)
    try:
        ipaddress.IPv6Address(ipstr)
    except ipaddress.AddressValueError:
        return False
    return True


def is_valid_port(port):
    return 0 < port < 65536


def parse_encoding(data):
    for line in data.split('\n'):
        s = line.split('=')[-1]
        if s:
            return s.strip('"').split('.')[-1]


class MixinHandler(object):

    def get_real_client_addr(self):
        ip = self.request.headers.get('X-Real-Ip')
        port = self.request.headers.get('X-Real-Port')

        if ip is None and port is None:  # suppose the server doesn't use nginx
            return 

        if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip):
            try:
                port = int(port)
            except (TypeError, ValueError):
                pass
            else:
                if is_valid_port(port):
                    return (ip, port)

        logging.warning('Bad nginx configuration.')
        return False


class IndexHandler(MixinHandler, tornado.web.RequestHandler):
    """
        “/”路由对应的hadler
    """
    def initialize(self, loop):
        self.loop = loop

    def get(self):
        """
            Get方法 直接返回页面
        """
        self.render('index.html')


    # 异步
    @tornado.gen.coroutine
    def post(self):
        """
            Post方法
        """
        worker_id = None
        status = None
        encoding = None

        # 协程
        future = Future()
        t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
        t.setDaemon(True)
        t.start()

        try:
            # 获取协程异步操作返回
            worker = yield future
        except Exception as exc:
            status = str(exc)
        else:
            worker_id = worker.id
            # 将worker放入workers字典
            workers[worker_id] = worker
            # 未知
            self.loop.call_later(DELAY, recycle_worker, worker)
            encoding = worker.encoding

        self.write(dict(id=worker_id, status=status, encoding=encoding))


    def ssh_connect_wrapped(self, future):
        """
            创建ssh连接
        :param future:
        :return:
        """
        try:
            worker = self.ssh_connect()
        except Exception as exc:
            logging.error(traceback.format_exc())
            future.set_exception(exc)
        else:
            future.set_result(worker)


    def ssh_connect(self):
        """
            进行HTTP连接
        :return:
        """
        ssh = paramiko.SSHClient()
        # 自动添加策略,保存服务器的主机名和密钥信息
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())


        # 获取请求参数
        args = self.get_args()
        dst_addr = (args[0], args[1])
        logging.info('Connecting to {}:{}'.format(*dst_addr))

        try:
            ssh.connect(*args, timeout=6)
        except socket.error:
            raise ValueError('Unable to connect to {}:{}'.format(*dst_addr))
        except paramiko.BadAuthenticationType:
            raise ValueError('SSH authentication failed.')
        except paramiko.BadHostKeyException:
            raise ValueError('Bad host key.')

        # 使用shell通道,与exec通道不同
        chan = ssh.invoke_shell(term='xterm')
        chan.setblocking(0)
        worker = Worker(self.loop, ssh, chan, dst_addr)
        worker.src_addr = self.get_client_addr()
        worker.encoding = self.get_default_encoding(ssh)
        return worker

    def get_args(self):
        hostname = self.get_value('hostname')
        port = self.get_port()
        username = self.get_value('username')
        password = self.get_argument('password')
        args = (hostname, port, username, password)
        logging.debug(args)
        return args

    def get_value(self, name):
        """
            获取request的数据
        :param name:
        :return:
        """
        value = self.get_argument(name)
        if not value:
            raise ValueError('Empty {}'.format(name))
        return value

    def get_port(self):
        """
            获取端口,字符串转数字
        :return:
        """
        value = self.get_value('port')
        try:
            port = int(value)
        except ValueError:
            pass
        else:
            if is_valid_port(port):
                return port

        raise ValueError('Invalid port {}'.format(value))
        print(2)

    def get_client_addr(self):
        # self.request.connection.stream.socket.getpeername()
        return self.get_real_client_addr() or self.request.connection.context.address

    def get_default_encoding(self, ssh):
        try:
            _, stdout, _ = ssh.exec_command('locale')
        except paramiko.SSHException:
            result = None
        else:
            data = stdout.read().decode('utf-8')
            result = parse_encoding(data)

        return result if result else 'utf-8'


class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):

    def initialize(self, loop):
        self.loop = loop
        self.worker_ref = None

    def get_client_addr(self):
        return self.get_real_client_addr() or self.request.connection.context.address

    # 开启ws连接
    def open(self):
        self.src_addr = self.get_client_addr()
        logging.info('Connected from {}:{}'.format(*self.src_addr))
        # 请求的时候/ws?id=139747945258512
        worker = workers.get(self.get_argument('id'))
        if worker and worker.src_addr[0] == self.src_addr[0]:
            # 在workers中取出
            workers.pop(worker.id)
            self.set_nodelay(True)
            worker.set_handler(self)
            self.worker_ref = weakref.ref(worker)
            self.loop.add_handler(worker.fd, worker, IOLoop.READ)
        else:
            self.close(reason='Websocket authentication failed.')

    # 同步消息
    def on_message(self, message):
        logging.debug('{!r} from {}:{}'.format(message, *self.src_addr))
        worker = self.worker_ref()
        # 获取前端数据
        try:
            msg = json.loads(message)
        except JSONDecodeError:
            return

        if not isinstance(msg, dict):
            return

        resize = msg.get('resize')
        if resize and len(resize) == 2:
            try:
                worker.chan.resize_pty(*resize)
            except (TypeError, struct.error, paramiko.SSHException):
                pass

        data = msg.get('data')
        if data and isinstance(data, basestring_type):
            worker.data_to_dst.append(data)
            worker.on_write()

    # 关闭ws连接
    def on_close(self):
        logging.info('Disconnected from {}:{}'.format(*self.src_addr))
        worker = self.worker_ref() if self.worker_ref else None
        if worker:
            if self.close_reason is None:
                self.close_reason = 'client disconnected'
            worker.close(reason=self.close_reason)

对于ws操作

#!-*-coding:utf8-*-
import logging
import tornado.websocket

from tornado.ioloop import IOLoop
from tornado.iostream import _ERRNO_CONNRESET
from tornado.util import errno_from_exception


BUF_SIZE = 1024
workers = {}

# 回收work
def recycle_worker(worker):
    if worker.handler:
        return
    logging.warning('Recycling worker {}'.format(worker.id))
    workers.pop(worker.id, None)
    worker.close(reason='worker recycled')


class Worker(object):
    def __init__(self, loop, ssh, chan, dst_addr):
        self.loop = loop
        self.ssh = ssh
        self.chan = chan
        self.dst_addr = dst_addr
        self.fd = chan.fileno()
        self.id = str(id(self))
        self.data_to_dst = []
        self.handler = None
        self.mode = IOLoop.READ

    def __call__(self, fd, events):
        if events & IOLoop.READ:
            self.on_read()
        if events & IOLoop.WRITE:
            self.on_write()
        if events & IOLoop.ERROR:
            self.close(reason='error event occurred')

    def set_handler(self, handler):
        if not self.handler:
            self.handler = handler

    def update_handler(self, mode):
        if self.mode != mode:
            self.loop.update_handler(self.fd, mode)
            self.mode = mode

    def on_read(self):
        """
            获取终端返回结果
        """
        logging.debug('worker {} on read'.format(self.id))
        # 获取终端返回结果
        try:
            data = self.chan.recv(BUF_SIZE)
        except (OSError, IOError) as e:
            logging.error(e)
            if errno_from_exception(e) in _ERRNO_CONNRESET:
                self.close(reason='chan error on reading')
        else:
            logging.debug('{!r} from {}:{}'.format(data, *self.dst_addr))
            # 如果数据为空则不进行返回给ws
            if not data:
                self.close(reason='chan closed')
                return

            logging.debug('{!r} to {}:{}'.format(data, *self.handler.src_addr))
            # 数据返回给ws
            try:
                self.handler.write_message(data, binary=True)
            except tornado.websocket.WebSocketClosedError:
                self.close(reason='websocket closed')

    def on_write(self):
        """
            获取ws数据发送给终端
        """
        logging.debug('worker {} on write'.format(self.id))
        # 判断是否有需要发送给终端的数据
        if not self.data_to_dst:
            return

        # 待传输字符
        data = ''.join(self.data_to_dst)
        logging.debug('{!r} to {}:{}'.format(data, *self.dst_addr))
        # 传输数据
        try:
            sent = self.chan.send(data)
        except (OSError, IOError) as e:
            logging.error(e)
            if errno_from_exception(e) in _ERRNO_CONNRESET:
                self.close(reason='chan error on writing')
            else:
                self.update_handler(IOLoop.WRITE)
        else:
            # 正常传递完sent=传输字符长度
            self.data_to_dst = []
            # 获取未传递到终端的字符
            data = data[sent:]
            # 如果有未传递的进行传递
            if data:
                self.data_to_dst.append(data)
                self.update_handler(IOLoop.WRITE)
            # 没有就获取终端返回结果
            else:
                self.update_handler(IOLoop.READ)

    def close(self, reason=None):
        logging.info(
            'Closing worker {} with reason {}'.format(self.id, reason)
        )
        if self.handler:
            self.loop.remove_handler(self.fd)
            self.handler.close()
        self.chan.close()
        self.ssh.close()
        logging.info('Connection to {}:{} lost'.format(*self.dst_addr))