fix return types

This commit is contained in:
Waylon S. Walker 2024-10-16 21:13:53 -05:00
parent f64e488ab1
commit 7c1b153020
3 changed files with 63 additions and 14 deletions

View file

@ -20,13 +20,13 @@ get-rtf:
http GET :8000/example Content-Type:application/rtf http GET :8000/example Content-Type:application/rtf
get-json: get-json:
http GET :8000 Content-Type:application/json http GET :8000/example Content-Type:application/json
get-html: get-html:
http GET :8000 Content-Type:text/html http GET :8000/example Content-Type:text/html
get-md: get-md:
http GET :8000 Content-Type:application/markdown http GET :8000/example Content-Type:application/markdown
livez: livez:

View file

@ -3,14 +3,18 @@ ACCEPT_TYPES = {
"text/html": "html", "text/html": "html",
"application/html": "html", "application/html": "html",
"text/html-partial": "html", "text/html-partial": "html",
"text/html-fragment": "html", "application/html-partial": "html",
"text/rich": "rtf", "text/rich": "rtf",
"application/rtf": "rtf", "application/rtf": "rtf",
"text/rtf": "rtf", "text/rtf": "rtf",
"text/rich": "rtf",
"text/plain": "text", "text/plain": "text",
"application/text": "text", "application/text": "text",
"application/plain": "text",
"application/markdown": "markdown", "application/markdown": "markdown",
"application/md": "markdown",
"text/markdown": "markdown", "text/markdown": "markdown",
"text/md": "markdown",
"text/x-markdown": "markdown", "text/x-markdown": "markdown",
"image/png": "png", "image/png": "png",
"application/pdf": "pdf", "application/pdf": "pdf",

View file

@ -119,6 +119,32 @@ def set_prefers(
content_type = None content_type = None
hx_request_header = request.headers.get("hx-request") hx_request_header = request.headers.get("hx-request")
user_agent = request.headers.get("user-agent", "").lower() user_agent = request.headers.get("user-agent", "").lower()
referer = request.headers.get("referer", "")
if "," in content_type:
content_type = content_type.split(",")[0]
request.state.bound_logger.info(
"content_type set",
content_type=content_type,
hx_request_header=hx_request_header,
user_agent=user_agent,
referer=referer,
)
if content_type == "*/*":
content_type = None
if ("/docs" in referer or "/redoc" in referer) and content_type is None:
content_type = "application/json"
elif is_browser_request(user_agent) and content_type is None:
request.state.bound_logger.info("browser agent request")
content_type = "text/html"
elif is_rtf_request(user_agent) and content_type is None:
request.state.bound_logger.info("rtf agent request")
content_type = "application/rtf"
elif content_type is None:
request.state.bound_logger.info("no content type request")
content_type = content_type or "application/json"
if hx_request_header == "true": if hx_request_header == "true":
content_type = "text/html-partial" content_type = "text/html-partial"
@ -131,14 +157,19 @@ def set_prefers(
elif is_rtf_request(user_agent) and content_type is None: elif is_rtf_request(user_agent) and content_type is None:
content_type = "text/rtf" content_type = "text/rtf"
else: # else:
content_type = "application/json" # content_type = "application/json"
partial = "partial" in content_type partial = "partial" in content_type
# if content_type in ACCEPT_TYPES: # if content_type in ACCEPT_TYPES:
for accept_type, accept_value in ACCEPT_TYPES.items(): # for accept_type, accept_value in ACCEPT_TYPES.items():
if accept_type in content_type: # if accept_type in content_type:
request.state.prefers = Prefers(**{accept_value: True}, partial=partial) if content_type in ACCEPT_TYPES:
request.state.prefers = Prefers(
**{ACCEPT_TYPES[content_type]: True}, partial=partial
)
else:
request.state.prefers = Prefers(JSON=True, partial=partial)
request.state.content_type = content_type request.state.content_type = content_type
request.state.bound_logger = request.state.bound_logger.bind( request.state.bound_logger = request.state.bound_logger.bind(
@ -314,6 +345,9 @@ def handle_not_found(request: Request, call_next, data: str):
async def respond_based_on_content_type(request: Request, call_next): async def respond_based_on_content_type(request: Request, call_next):
requested_path = request.url.path requested_path = request.url.path
if requested_path in ["/docs", "/redoc", "/openapi.json", "/static/app.css"]: if requested_path in ["/docs", "/redoc", "/openapi.json", "/static/app.css"]:
request.state.bound_logger.info(
"protected route returning non-dynamic response"
)
return await call_next(request) return await call_next(request)
try: try:
@ -325,8 +359,8 @@ async def respond_based_on_content_type(request: Request, call_next):
"content_type", "content_type",
request.headers.get("content-type", request.headers.get("Accept")), request.headers.get("content-type", request.headers.get("Accept")),
) )
if "raw" in content_type: # if "raw" in content_type:
return await call_next(request) # return await call_next(request)
if content_type == "*/*": if content_type == "*/*":
content_type = None content_type = None
if ("/docs" in referer or "/redoc" in referer) and content_type is None: if ("/docs" in referer or "/redoc" in referer) and content_type is None:
@ -343,27 +377,30 @@ async def respond_based_on_content_type(request: Request, call_next):
data = body.decode("utf-8") data = body.decode("utf-8")
if response.status_code == 404: if response.status_code == 404:
request.state.bound_logger.info("404 not found")
return handle_not_found( return handle_not_found(
request=request, request=request,
call_next=call_next, call_next=call_next,
data=data, data=data,
) )
if response.status_code == 422:
return response
if str(response.status_code)[0] not in "123": if str(response.status_code)[0] not in "123":
request.state.bound_logger.info("non-200 response")
return response return response
return await handle_response(request, response, data) return await handle_response(request, response, data)
# except TemplateNotFound: # except TemplateNotFound:
# return HTMLResponse(content="Template Not Found ", status_code=404) # return HTMLResponse(content="Template Not Found ", status_code=404)
except StarletteHTTPException as exc: except StarletteHTTPException as exc:
request.state.bound_logger.info("starlette exception")
return HTMLResponse( return HTMLResponse(
content=f"Error {exc.status_code}: {exc.detail}", content=f"Error {exc.status_code}: {exc.detail}",
status_code=exc.status_code, status_code=exc.status_code,
) )
except RequestValidationError as exc: except RequestValidationError as exc:
request.state.bound_logger.info("request validation error")
return JSONResponse(status_code=422, content={"detail": exc.errors()}) return JSONResponse(status_code=422, content={"detail": exc.errors()})
except Exception as e: except Exception as e:
request.state.bound_logger.info("internal server error")
print(traceback.format_exc()) print(traceback.format_exc())
return HTMLResponse(content=f"Internal Server Error: {e!s}", status_code=500) return HTMLResponse(content=f"Internal Server Error: {e!s}", status_code=500)
@ -373,21 +410,24 @@ async def handle_response(request: Request, response: Response, data: str):
template_name = getattr(request.state, "template_name", "default_template.html") template_name = getattr(request.state, "template_name", "default_template.html")
if request.state.prefers.partial: if request.state.prefers.partial:
request.state.bound_logger = logger.bind(template_name=template_name)
template_name = "partial_" + template_name template_name = "partial_" + template_name
if request.state.prefers.JSON: if request.state.prefers.JSON:
request.state.bound_logger.info("returning JSON")
return JSONResponse( return JSONResponse(
content=json_data, content=json_data,
) )
if request.state.prefers.html: if request.state.prefers.html:
request.state.bound_logger.info("returning html")
return templates.TemplateResponse( return templates.TemplateResponse(
template_name, template_name,
{"request": request, "data": json_data}, {"request": request, "data": json_data},
headers=response.headers,
) )
if request.state.prefers.markdown: if request.state.prefers.markdown:
request.state.bound_logger.info("returning markdown")
import html2text import html2text
template = templates.get_template(template_name) template = templates.get_template(template_name)
@ -396,18 +436,21 @@ async def handle_response(request: Request, response: Response, data: str):
return PlainTextResponse(content=markdown_content, headers=response.headers) return PlainTextResponse(content=markdown_content, headers=response.headers)
if request.state.prefers.text: if request.state.prefers.text:
request.state.bound_logger.info("returning plain text")
plain_text_content = format_json_as_plain_text(json_data) plain_text_content = format_json_as_plain_text(json_data)
return PlainTextResponse( return PlainTextResponse(
content=plain_text_content, content=plain_text_content,
) )
if request.state.prefers.rtf: if request.state.prefers.rtf:
request.state.bound_logger.info("returning rich text")
rich_text_content = format_json_as_rich_text(json_data, template_name) rich_text_content = format_json_as_rich_text(json_data, template_name)
return PlainTextResponse( return PlainTextResponse(
content=rich_text_content, content=rich_text_content,
) )
if request.state.prefers.png: if request.state.prefers.png:
request.state.bound_logger.info("returning PNG")
template = templates.get_template(template_name) template = templates.get_template(template_name)
html_content = template.render(data=json_data) html_content = template.render(data=json_data)
screenshot = get_screenshot(html_content) screenshot = get_screenshot(html_content)
@ -417,6 +460,7 @@ async def handle_response(request: Request, response: Response, data: str):
) )
if request.state.prefers.pdf: if request.state.prefers.pdf:
request.state.bound_logger.info("returning PDF")
template = templates.get_template(template_name) template = templates.get_template(template_name)
html_content = template.render(data=json_data) html_content = template.render(data=json_data)
scale = float( scale = float(
@ -430,6 +474,7 @@ async def handle_response(request: Request, response: Response, data: str):
media_type="application/pdf", media_type="application/pdf",
) )
request.state.bound_logger.info("returning DEFAULT JSON")
return JSONResponse( return JSONResponse(
content=json_data, content=json_data,
) )