在正式开始Web开发前,我们需要编写一个Web框架。
由于aiohttp相对比较底层,所以我们需要基于aiohttp自己封装一个处理url的Web框架。
## 定义add_route函数,来注册一个URL处理函数 def add_route(app, fn): method = getattr(fn, '__method__', None) path = getattr(fn, '__route__', None) if path is None or method is None: raise ValueError('@get or @post not defined in %s.' % str(fn)) if not asyncio.iscoroutinefunction(fn) and not inspect.isgeneratorfunction(fn): fn = asyncio.coroutine(fn) logging.info('add route %s %s => %s(%s)' % (method, path, fn.__name__, ', '.join(inspect.signature(fn).parameters.keys()))) app.router.add_route(method, path, RequestHandler(app, fn))
get和post
## 编写装饰函数 @get() def get(path): ## Define decorator @get('/path') def decorator(func): @functools.wraps(func) def wrapper(*args, **kw): return func(*args, **kw) wrapper.__method__ = 'GET' wrapper.__route__ = path return wrapper return decorator ## 编写装饰函数 @post() def post(path): ## Define decorator @post('/path') def decorator(func): @functools.wraps(func) def wrapper(*args, **kw): return func(*args, **kw) wrapper.__method__ = 'POST' wrapper.__route__ = path return wrapper return decorator
在www目录新建coroweb.py
#!/usr/bin/env python3 # -*- coding: utf-8 -*- import asyncio, os, inspect, logging, functools from urllib import parse from aiohttp import web ## apis是处理分页的模块,后面会编写 ## APIError 是指API调用时发生逻辑错误 from apis import APIError ## 编写装饰函数 @get() def get(path): ## Define decorator @get('/path') def decorator(func): @functools.wraps(func) def wrapper(*args, **kw): return func(*args, **kw) wrapper.__method__ = 'GET' wrapper.__route__ = path return wrapper return decorator ## 编写装饰函数 @post() def post(path): ## Define decorator @post('/path') def decorator(func): @functools.wraps(func) def wrapper(*args, **kw): return func(*args, **kw) wrapper.__method__ = 'POST' wrapper.__route__ = path return wrapper return decorator ## 以下是RequestHandler需要定义的一些函数 def get_required_kw_args(fn): args = [] params = inspect.signature(fn).parameters for name, param in params.items(): if param.kind == inspect.Parameter.KEYWORD_ONLY and param.default == inspect.Parameter.empty: args.append(name) return tuple(args) def get_named_kw_args(fn): args = [] params = inspect.signature(fn).parameters for name, param in params.items(): if param.kind == inspect.Parameter.KEYWORD_ONLY: args.append(name) return tuple(args) def has_named_kw_args(fn): params = inspect.signature(fn).parameters for name, param in params.items(): if param.kind == inspect.Parameter.KEYWORD_ONLY: return True def has_var_kw_arg(fn): params = inspect.signature(fn).parameters for name, param in params.items(): if param.kind == inspect.Parameter.VAR_KEYWORD: return True def has_request_arg(fn): sig = inspect.signature(fn) params = sig.parameters found = False for name, param in params.items(): if name == 'request': found = True continue if found and (param.kind != inspect.Parameter.VAR_POSITIONAL and param.kind != inspect.Parameter.KEYWORD_ONLY and param.kind != inspect.Parameter.VAR_KEYWORD): raise ValueError('request parameter must be the last named parameter in function: %s%s' % (fn.__name__, str(sig))) return found ## 定义RequestHandler从URL函数中分析其需要接受的参数 class RequestHandler(object): def __init__(self, app, fn): self._app = app self._func = fn self._has_request_arg = has_request_arg(fn) self._has_var_kw_arg = has_var_kw_arg(fn) self._has_named_kw_args = has_named_kw_args(fn) self._named_kw_args = get_named_kw_args(fn) self._required_kw_args = get_required_kw_args(fn) async def __call__(self, request): kw = None if self._has_var_kw_arg or self._has_named_kw_args or self._required_kw_args: if request.method == 'POST': if not request.content_type: return web.HTTPBadRequest(text='Missing Content-Type.') ct = request.content_type.lower() if ct.startswith('application/json'): params = await request.json() if not isinstance(params, dict): return web.HTTPBadRequest(text='JSON body must be object.') kw = params elif ct.startswith('application/x-www-form-urlencoded') or ct.startswith('multipart/form-data'): params = await request.post() kw = dict(**params) else: return web.HTTPBadRequest(text='Unsupported Content-Type: %s' % request.content_type) if request.method == 'GET': qs = request.query_string if qs: kw = dict() for k, v in parse.parse_qs(qs, True).items(): kw[k] = v[0] if kw is None: kw = dict(**request.match_info) else: if not self._has_var_kw_arg and self._named_kw_args: # remove all unamed kw: copy = dict() for name in self._named_kw_args: if name in kw: copy[name] = kw[name] kw = copy # check named arg: for k, v in request.match_info.items(): if k in kw: logging.warning('Duplicate arg name in named arg and kw args: %s' % k) kw[k] = v if self._has_request_arg: kw['request'] = request # check required kw: if self._required_kw_args: for name in self._required_kw_args: if not name in kw: return web.HTTPBadRequest(text='Missing argument: %s' % name) logging.info('call with args: %s' % str(kw)) try: r = await self._func(**kw) return r except APIError as e: return dict(error=e.error, data=e.data, message=e.message) ## 定义add_static函数,来注册static文件夹下的文件 def add_static(app): path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'static') app.router.add_static('/static/', path) logging.info('add static %s => %s' % ('/static/', path)) ## 定义add_route函数,来注册一个URL处理函数 def add_route(app, fn): method = getattr(fn, '__method__', None) path = getattr(fn, '__route__', None) if path is None or method is None: raise ValueError('@get or @post not defined in %s.' % str(fn)) if not asyncio.iscoroutinefunction(fn) and not inspect.isgeneratorfunction(fn): fn = asyncio.coroutine(fn) logging.info('add route %s %s => %s(%s)' % (method, path, fn.__name__, ', '.join(inspect.signature(fn).parameters.keys()))) app.router.add_route(method, path, RequestHandler(app, fn)) ## 定义add_routes函数,自动把handler模块的所有符合条件的URL函数注册了 def add_routes(app, module_name): n = module_name.rfind('.') if n == (-1): mod = __import__(module_name, globals(), locals()) else: name = module_name[n+1:] mod = getattr(__import__(module_name[:n], globals(), locals(), [name]), name) for attr in dir(mod): if attr.startswith('_'): continue fn = getattr(mod, attr) if callable(fn): method = getattr(fn, '__method__', None) path = getattr(fn, '__route__', None) if method and path: add_route(app, fn)
最后,在app.py中加入middleware、jinja2模板和自注册的支持。
#!/usr/bin/env python3 # -*- coding: utf-8 -*- import logging; logging.basicConfig(level=logging.INFO) import asyncio, os, json, time from datetime import datetime from aiohttp import web from jinja2 import Environment, FileSystemLoader ## config 配置代码在后面会创建添加, 可先从github下载到www下,以防报错 from config import configs import orm from coroweb import add_routes, add_static ## handlers 是url处理模块在后面会创建编写, 可先从github下载到www下,以防报错 from handlers import cookie2user, COOKIE_NAME ## 初始化jinja2的函数 def init_jinja2(app, **kw): logging.info('init jinja2...') options = dict( autoescape = kw.get('autoescape', True), block_start_string = kw.get('block_start_string', '{%'), block_end_string = kw.get('block_end_string', '%}'), variable_start_string = kw.get('variable_start_string', '{{'), variable_end_string = kw.get('variable_end_string', '}}'), auto_reload = kw.get('auto_reload', True) ) path = kw.get('path', None) if path is None: path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') logging.info('set jinja2 template path: %s' % path) env = Environment(loader=FileSystemLoader(path), **options) filters = kw.get('filters', None) if filters is not None: for name, f in filters.items(): env.filters[name] = f app['__templating__'] = env ## 以下是middleware,可以把通用的功能从每个URL处理函数中拿出来集中放到一个地方 ## URL处理日志工厂 async def logger_factory(app, handler): async def logger(request): logging.info('Request: %s %s' % (request.method, request.path)) return (await handler(request)) return logger ## 认证处理工厂--把当前用户绑定到request上,并对URL/manage/进行拦截,检查当前用户是否是管理员身份 async def auth_factory(app, handler): async def auth(request): logging.info('check user: %s %s' % (request.method, request.path)) request.__user__ = None cookie_str = request.cookies.get(COOKIE_NAME) if cookie_str: user = await cookie2user(cookie_str) if user: logging.info('set current user: %s' % user.email) request.__user__ = user if request.path.startswith('/manage/') and (request.__user__ is None or not request.__user__.admin): return web.HTTPFound('/signin') return (await handler(request)) return auth ## 数据处理工厂 async def data_factory(app, handler): async def parse_data(request): if request.method == 'POST': if request.content_type.startswith('application/json'): request.__data__ = await request.json() logging.info('request json: %s' % str(request.__data__)) elif request.content_type.startswith('application/x-www-form-urlencoded'): request.__data__ = await request.post() logging.info('request form: %s' % str(request.__data__)) return (await handler(request)) return parse_data ## 响应返回处理工厂 async def response_factory(app, handler): async def response(request): logging.info('Response handler...') r = await handler(request) if isinstance(r, web.StreamResponse): return r if isinstance(r, bytes): resp = web.Response(body=r) resp.content_type = 'application/octet-stream' return resp if isinstance(r, str): if r.startswith('redirect:'): return web.HTTPFound(r[9:]) resp = web.Response(body=r.encode('utf-8')) resp.content_type = 'text/html;charset=utf-8' return resp if isinstance(r, dict): template = r.get('__template__') if template is None: resp = web.Response(body=json.dumps(r, ensure_ascii=False, default=lambda o: o.__dict__).encode('utf-8')) resp.content_type = 'application/json;charset=utf-8' return resp else: r['__user__'] = request.__user__ resp = web.Response(body=app['__templating__'].get_template(template).render(**r).encode('utf-8')) resp.content_type = 'text/html;charset=utf-8' return resp if isinstance(r, int) and r >= 100 and r < 600: return web.Response(r) if isinstance(r, tuple) and len(r) == 2: t, m = r if isinstance(t, int) and t >= 100 and t < 600: return web.Response(t, str(m)) # default: resp = web.Response(body=str(r).encode('utf-8')) resp.content_type = 'text/plain;charset=utf-8' return resp return response ## 时间转换 def datetime_filter(t): delta = int(time.time() - t) if delta < 60: return u'1分钟前' if delta < 3600: return u'%s分钟前' % (delta // 60) if delta < 86400: return u'%s小时前' % (delta // 3600) if delta < 604800: return u'%s天前' % (delta // 86400) dt = datetime.fromtimestamp(t) return u'%s年%s月%s日' % (dt.year, dt.month, dt.day) async def init(loop): await orm.create_pool(loop=loop, **configs.db) app = web.Application(loop=loop, middlewares=[ logger_factory, auth_factory, response_factory ]) init_jinja2(app, filters=dict(datetime=datetime_filter)) add_routes(app, 'handlers') add_static(app) srv = await loop.create_server(app.make_handler(), '127.0.0.1', 9000) logging.info('server started at http://127.0.0.1:9000...') return srv loop = asyncio.get_event_loop() loop.run_until_complete(init(loop)) loop.run_forever()