From f7c1827e85c8ff9e542bd8593055a65a8d2bc050 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Thu, 17 Nov 2022 13:49:44 +0000
Subject: [PATCH] Refactor a bit

---
 calparrot/__main__.py | 18 +++-------
 calparrot/db.py       |  6 ++++
 calparrot/proxy.py    | 79 +++++++++++++++++++++++++++++--------------
 3 files changed, 64 insertions(+), 39 deletions(-)

diff --git a/calparrot/__main__.py b/calparrot/__main__.py
index 22d87da..51e5752 100644
--- a/calparrot/__main__.py
+++ b/calparrot/__main__.py
@@ -4,10 +4,7 @@ import logging
 import os
 import sys
 
-from tornado.httpserver import HTTPServer
-from tornado.netutil import bind_sockets
-
-from .proxy import make_app, listen
+from .proxy import ProxyApp
 
 log = logging.getLogger(__name__)
 
@@ -30,19 +27,14 @@ async def amain():
     calcat_cfg.setdefault('use_oauth2', False)
     log.info("Upstream is %s", calcat_cfg['base_url'])
 
-    app = make_app(calcat_cfg, args.db)
-    sockets = bind_sockets(0, '127.0.0.1')
-    server = HTTPServer(app)
-    server.add_sockets(sockets)
-
-    port = sockets[0].getsockname()[1]
-    log.info(f"CalParrot serving constant queries on http://127.0.0.1:%d", port)
+    app = ProxyApp(calcat_cfg, args.db)
+    log.info("CalParrot serving constant queries on http://127.0.0.1:%d", app.port)
     if args.port_file:
         with open(args.port_file, 'w') as f:
-            f.write(str(port))
+            f.write(str(app.port))
 
     if os.fork() == 0:
-        await listen(app)
+        await app.serve()
     else:
         # Parent process - exit now
         return 0
diff --git a/calparrot/db.py b/calparrot/db.py
index 003f058..507978c 100644
--- a/calparrot/db.py
+++ b/calparrot/db.py
@@ -18,6 +18,10 @@ class ResponsesDB:
         self.path = path
         self.conn = sqlite3.connect(path)
         self.conn.executescript(SCHEMA)
+        self.stats = {'stored': 0, 'retrieved': 0}
+
+    def close(self):
+        self.conn.close()
 
     def insert_pending(self, url, req_body) -> bool:
         try:
@@ -36,6 +40,7 @@ class ResponsesDB:
                 "UPDATE responses SET status=?, reason=?, headers=?, body=? WHERE url=? AND req_body=?",
                 (status, reason, headers, body, url, req_body)
             )
+        self.stats['stored'] += 1
 
     def get(self, url, req_body):
         row = self.conn.execute(
@@ -49,6 +54,7 @@ class ResponsesDB:
         while True:
             status, reason, headers, body = self.get(url, req_body)
             if status != -1:
+                self.stats['retrieved'] += 1
                 return status, reason, headers, body
 
             await asyncio.sleep(poll_interval)
diff --git a/calparrot/proxy.py b/calparrot/proxy.py
index 602500d..a6c250e 100644
--- a/calparrot/proxy.py
+++ b/calparrot/proxy.py
@@ -11,7 +11,9 @@ from urllib.parse import parse_qs, urlencode
 
 from oauthlib.oauth2 import BackendApplicationClient, TokenExpiredError
 from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPResponse
+from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders
+from tornado.netutil import bind_sockets
 from tornado.web import RequestHandler, Application
 
 from .db import ResponsesDB
@@ -166,33 +168,58 @@ class ProxyHandler(RequestHandler):
 
 
 class ShutdownHandler(RequestHandler):
-    def post(self):
-        self.application.settings['calparrot_quit_event'].set()
-
-
-def make_app(creds, db_path='calparrot.sqlite'):
-    base_url = creds['base_url'].rstrip('/')  # e.g. https://in.xfel.eu/calibration
-
-    if creds['use_oauth2']:
-        client = XFELOauthClient(
-            creds['client_id'],
-            creds['client_secret'],
-            creds['user_email'],
-            scope=None,
-            token_url=base_url + '/oauth/token'
-        )
-    else:
-        client = AsyncHTTPClient()
-
-    return Application([
-        ('/.calparrot/stop', ShutdownHandler),
-        ('(/.*)', ProxyHandler, {
-            'base_url': base_url,
-            'upstream_client': client,
-            'response_store': ResponsesDB(db_path),
-        }),
-    ], calparrot_quit_event=asyncio.Event())
+    def initialize(self, proxy_app):
+        self.proxy_app = proxy_app
 
+    def post(self):
+        self.proxy_app.shutdown()
+
+
+class ProxyApp:
+    def __init__(self, creds, db_path='calparrot.sqlite'):
+        self.quit_event = asyncio.Event()
+        self.response_store = ResponsesDB(db_path)
+        base_url = creds['base_url'].rstrip('/')  # e.g. https://in.xfel.eu/calibration
+
+        if creds['use_oauth2']:
+            self.client = XFELOauthClient(
+                creds['client_id'],
+                creds['client_secret'],
+                creds['user_email'],
+                scope=None,
+                token_url=base_url + '/oauth/token'
+            )
+        else:
+            self.client = AsyncHTTPClient()
+
+        self.tornado_app = Application([
+            ('/.calparrot/stop', ShutdownHandler, {'proxy_app': self}),
+            ('(/.*)', ProxyHandler, {
+                'base_url': base_url,
+                'upstream_client': self.client,
+                'response_store': self.response_store,
+            }),
+        ])
+
+        # Bind a random port to listen on
+        sockets = bind_sockets(0, '127.0.0.1')
+        self.port = sockets[0].getsockname()[1]
+        self.server = HTTPServer(self.tornado_app)
+        self.server.add_sockets(sockets)
+
+    async def serve(self):
+        # The event loop is running the server (set up by add_sockets() above),
+        # so there's nothing to do until we're asked to shut down.
+        await self.quit_event.wait()
+
+        log.info("CalParrot shutting down")
+        self.server.stop()  # Stop accepting new connections
+        await self.server.close_all_connections()  # Close existing connections
+        self.response_store.close()
+        log.info("Query stats: %s", self.response_store.stats)
+
+    def shutdown(self):
+        self.quit_event.set()
 
 async def listen(app):
     quit_evt = app.settings['calparrot_quit_event']
-- 
GitLab