Skip to content
Snippets Groups Projects
Commit 67ad1d86 authored by Thomas Kluyver's avatar Thomas Kluyver
Browse files

Initial work on reproducibility proxy

parents
No related branches found
No related tags found
No related merge requests found
__pycache__/
/dist/
*.sqlite
import argparse
import asyncio
import logging
import os
import sys
from tornado.httpserver import HTTPServer
from tornado.netutil import bind_sockets
from .proxy import make_app, listen
log = logging.getLogger(__name__)
try:
from cal_tools.restful_config import restful_config
except ImportError:
restful_config = {}
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--debug', action='store_true')
ap.add_argument('--db', default='calparrot.sqlite')
ap.add_argument()
args = ap.parse_args()
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
calcat_cfg = restful_config.get('calcat', {})
calcat_cfg.setdefault('base_url', 'http://exflcalproxy:8080/')
calcat_cfg.setdefault('use_oauth2', False)
log.info("Upstream is %s (user %s)", calcat_cfg['base_url'], calcat_cfg['user_email'])
app = make_app(calcat_cfg, args.db)
sockets = bind_sockets(0, '127.0.0.1')
server = HTTPServer(app)
server.add_sockets(sockets)
port = sockets[0].getsockname()[1]
print(f"http://127.0.0.1:{port}")
if os.fork() == 0:
asyncio.run(listen(app))
else:
# Parent process - exit now
return 0
if __name__ == '__main__':
sys.exit(main())
import asyncio
import sqlite3
SCHEMA = """
CREATE TABLE IF NOT EXISTS responses(
url STRING,
req_body BLOB,
status INTEGER,
reason TEXT,
headers TEXT,
body BLOB
);
CREATE UNIQUE INDEX IF NOT EXISTS request_keys ON responses(url, req_body);
"""
class ResponsesDB:
def __init__(self, path='calparrot.sqlite'):
self.path = path
self.conn = sqlite3.connect(path)
self.conn.executescript(SCHEMA)
def insert_pending(self, url, req_body) -> bool:
try:
with self.conn:
self.conn.execute(
"INSERT INTO responses(url, req_body, status) VALUES(?, ?, -1)",
(url, req_body)
)
return True
except sqlite3.IntegrityError:
return False
def add_result(self, url, req_body, status, reason, headers, body):
with self.conn:
self.conn.execute(
"UPDATE responses SET status=?, reason=?, headers=?, body=? WHERE url=?",
(status, reason, headers, body, url)
)
def get(self, url, req_body):
row = self.conn.execute(
"SELECT status, reason, headers, body FROM responses"
" WHERE url=? AND req_body=?", (url, req_body)
).fetchone()
if row is None:
raise KeyError(url)
async def wait_get_response(self, url, req_body, poll_interval=0.5):
while True:
status, reason, headers, body = self.get(url, req_body)
if status != -1:
return status, reason, headers, body
await asyncio.sleep(poll_interval)
if __name__ == '__main__':
ResponsesDB()
"""Allow internal GET requests without Oauth to an XFEL service like CalCat
Accept requests on the internal network, add an Oauth token and forward them to
a server which requires Oauth. This is to simplify getting access to calibration
constants.
"""
import asyncio
import logging
from dataclasses import dataclass
from urllib.parse import parse_qs, urlencode
from oauthlib.oauth2 import BackendApplicationClient, TokenExpiredError
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPResponse
from tornado.httputil import HTTPHeaders
from tornado.web import RequestHandler, Application
from .db import ResponsesDB
__version__ = "0.1"
log = logging.getLogger(__name__)
REMOVE_RESPONSE_HEADERS = {
'Content-Length', 'Transfer-Encoding', 'Connection'
}
class XFELOauthClient:
"""HTTP client using Oauth client credientials flow"""
token = None
def __init__(self, client_id, client_secret, user_email, scope, token_url):
self.client_id = client_id
self.client_secret = client_secret
self.user_email = user_email
self.scope = scope
self.token_url = token_url
self.http_client = AsyncHTTPClient()
self.oauth_client = BackendApplicationClient(client_id, scope=scope)
async def get_token(self):
response = await self.http_client.fetch(HTTPRequest(
self.token_url,
method="POST",
headers={
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8",
},
body=self.oauth_client.prepare_request_body(
client_secret=self.client_secret, include_client_id=True,
)
))
self.oauth_client.parse_request_body_response(response.body)
log.info("Obtained access token, expires in %d seconds",
self.oauth_client.expires_in)
async def fetch(self, url, *, headers=None, **kwargs):
if self.oauth_client.access_token is None:
await self.get_token()
try:
_, headers, _ = self.oauth_client.add_token(url, headers=headers)
except TokenExpiredError:
log.info("Oauth token expired, will request new token")
await self.get_token()
_, headers, _ = self.oauth_client.add_token(url, headers=headers)
headers['X-User-Email'] = self.user_email
return await self.http_client.fetch(url, headers=headers, **kwargs)
@dataclass
class ProxyResponse:
status: int
reason: str
headers: HTTPHeaders
body: bytes
@classmethod
def from_client_response(cls, resp: HTTPResponse):
headers = HTTPHeaders()
for k, v in resp.headers.get_all():
if k not in REMOVE_RESPONSE_HEADERS:
headers.add(k, v)
return cls(
status=resp.code, reason=resp.reason, headers=headers, body=resp.body
)
@classmethod
def from_stored(cls, status: int, reason: str, headers: str, body: bytes):
headers = HTTPHeaders.parse(headers)
return cls(status=status, reason=reason, headers=headers, body=body)
class ProxyHandler(RequestHandler):
def initialize(self, base_url, upstream_client, response_store: ResponsesDB):
self.base_url = base_url
self.upstream_client = upstream_client
self.response_store = response_store
async def _get_upstream(self, req_path) -> ProxyResponse:
url = self.base_url + req_path
headers = self.request.headers.copy()
del headers['Host']
body = self.request.body
log.info("Forwarding request for %s to %s", req_path, url)
# Send the request on with an Oauth token
response = await self.upstream_client.fetch(
url, headers=headers, body=body, raise_error=False, follow_redirects=False,
# Some GET requests use a body, which requires this:
allow_nonstandard_methods=True,
)
log.info("Got response status %d (%s)", response.code, response.reason)
return ProxyResponse.from_client_response(response)
def _record_response(self, req_path, resp: ProxyResponse):
if resp.status < 500:
# Don't store server errors
self.response_store.add_result(
req_path, self.request.body,
status=resp.status,
reason=resp.reason,
headers=str(resp.headers),
body=resp.body
)
async def _get_with_cache(self, req_path) -> ProxyResponse:
if self.response_store.insert_pending(req_path, self.request.body):
# We put a pending response in the database: now query upstream
resp = await self._get_upstream(req_path)
self._record_response(req_path, resp)
else:
# This URL is already in the database
try:
status, reason, headers, body = await asyncio.wait_for(
self.response_store.wait_get_response(req_path, self.request.body),
timeout=15
)
resp = ProxyResponse.from_stored(status, reason, headers, body)
except TimeoutError:
# Still pending - the process fetching it may have crashed.
resp = await self._get_upstream(req_path)
self._record_response(req_path, resp)
return resp
def _normalize_query(self, query: str) -> str:
# Sort parameters by key name in a query string, so we don't store
# separate entries based on random differences in order.
d = parse_qs(query, keep_blank_values=True)
return urlencode(sorted(d.items()), doseq=True)
async def get(self, path):
req_path = path
if self.request.query is not None:
req_path += '?' + self._normalize_query(self.request.query)
response = await self._get_with_cache(req_path)
# Forward the response to the client
self.set_status(response.status, response.reason)
for header, value in response.headers.get_all():
self.add_header(header, value)
if response.body:
await self.finish(response.body)
class ShutdownHandler(RequestHandler):
def post(self):
self.application.settings['calparrot_quit_event'].set()
def make_app(creds, db_path='calparrot.sqlite'):
base_url = creds['base_url'].rstrip('/') # e.g. https://in.xfel.eu/calibration
if creds['use_oauth2']:
client = XFELOauthClient(
creds['client_id'],
creds['client_secret'],
creds['user_email'],
scope=None,
token_url=base_url + '/oauth/token'
)
else:
client = AsyncHTTPClient()
return Application([
('/.calparrot/stop', ShutdownHandler),
('(/.*)', ProxyHandler, {
'base_url': base_url,
'upstream_client': client,
'response_store': ResponsesDB(db_path),
}),
], calparrot_quit_event=asyncio.Event())
async def listen(app):
quit_evt = app.settings['calparrot_quit_event']
await quit_evt
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment