better errors

This commit is contained in:
Waylon S. Walker 2024-10-17 08:20:37 -05:00
parent c8afba360b
commit 19db26b0cb
5 changed files with 84 additions and 82 deletions

View file

@ -1,6 +1,4 @@
from fastapi import APIRouter from fastapi import APIRouter, Depends, Request
from fastapi import Depends
from fastapi import Request
from fastapi_dynamic_response.base.schema import Message from fastapi_dynamic_response.base.schema import Message
from fastapi_dynamic_response.dependencies import get_content_type from fastapi_dynamic_response.dependencies import get_content_type
@ -17,6 +15,16 @@ async def get_example(
return {"message": "Hello, this is an example", "data": [1, 2, 3, 4]} return {"message": "Hello, this is an example", "data": [1, 2, 3, 4]}
@router.get("/error")
async def get_error(
request: Request,
content_type: str = Depends(get_content_type),
):
request.state.template_name = "example.html"
0 / 0
return {"message": "Hello, this is an example", "data": [1, 2, 3, 4]}
@router.get("/another-example") @router.get("/another-example")
async def another_example( async def another_example(
request: Request, request: Request,
@ -30,6 +38,16 @@ async def another_example(
} }
@router.get("/message")
async def message(
request: Request,
message_id: int,
content_type: str = Depends(get_content_type),
):
request.state.template_name = "post_message.html"
return {"message": message.message}
@router.post("/message") @router.post("/message")
async def message( async def message(
request: Request, request: Request,

View file

@ -3,15 +3,13 @@
import logging import logging
from fastapi_dynamic_response.settings import Settings from fastapi_dynamic_response.settings import settings
import structlog import structlog
logger = structlog.get_logger() logger = structlog.get_logger()
def configure_logging_two(): def configure_logging():
settings = Settings()
# Clear existing loggers # Clear existing loggers
logging.config.dictConfig( logging.config.dictConfig(
{ {
@ -73,6 +71,3 @@ def configure_logging_two():
logger.info("Logging configured") logger.info("Logging configured")
logger.info(f"Environment: {settings.ENV}") logger.info(f"Environment: {settings.ENV}")
configure_logging = configure_logging_two

View file

@ -5,19 +5,21 @@ from fastapi_dynamic_response import globals
from fastapi_dynamic_response.__about__ import __version__ from fastapi_dynamic_response.__about__ import __version__
from fastapi_dynamic_response.base.router import router as base_router from fastapi_dynamic_response.base.router import router as base_router
from fastapi_dynamic_response.dependencies import get_content_type from fastapi_dynamic_response.dependencies import get_content_type
from fastapi_dynamic_response.zpages.router import router as zpages_router
from fastapi_dynamic_response.settings import settings
from fastapi_dynamic_response.logging_config import configure_logging
from fastapi_dynamic_response.middleware import ( from fastapi_dynamic_response.middleware import (
Sitemap, Sitemap,
add_process_time_header, add_process_time_header,
catch_exceptions_middleware, catch_exceptions_middleware,
log_requests, log_requests,
respond_based_on_content_type, respond_based_on_content_type,
set_bound_logger,
set_prefers, set_prefers,
set_span_id, set_span_id,
) )
from fastapi_dynamic_response.zpages.router import router as zpages_router
from fastapi_dynamic_response.logging_config import configure_logging
configure_logging() configure_logging()
app = FastAPI( app = FastAPI(
@ -28,7 +30,7 @@ app = FastAPI(
openapi_url=None, openapi_url=None,
# openapi_tags=tags_metadata, # openapi_tags=tags_metadata,
# exception_handlers=exception_handlers, # exception_handlers=exception_handlers,
debug=True, debug=settings.DEBUG,
dependencies=[ dependencies=[
# Depends(set_prefers), # Depends(set_prefers),
# Depends(set_span_id), # Depends(set_span_id),
@ -47,6 +49,7 @@ app.middleware("http")(Sitemap(app))
app.middleware("http")(set_prefers) app.middleware("http")(set_prefers)
app.middleware("http")(set_span_id) app.middleware("http")(set_span_id)
app.middleware("http")(catch_exceptions_middleware) app.middleware("http")(catch_exceptions_middleware)
app.middleware("http")(set_bound_logger)
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")

View file

@ -1,4 +1,6 @@
from difflib import get_close_matches from difflib import get_close_matches
from fastapi_dynamic_response.settings import settings
from io import BytesIO from io import BytesIO
import json import json
import time import time
@ -7,10 +9,6 @@ from typing import Any, Dict
from uuid import uuid4 from uuid import uuid4
from fastapi import Request, Response from fastapi import Request, Response
from fastapi.exceptions import (
HTTPException as StarletteHTTPException,
RequestValidationError,
)
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
import html2text import html2text
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -79,10 +77,16 @@ async def add_process_time_header(request: Request, call_next):
start_time = time.perf_counter() start_time = time.perf_counter()
response = await call_next(request) response = await call_next(request)
process_time = time.perf_counter() - start_time process_time = time.perf_counter() - start_time
response.headers["X-Process-Time"] = str(process_time) if str(response.status_code)[0] in "123":
response.headers["X-Process-Time"] = str(process_time)
return response return response
def set_bound_logger(request: Request, call_next):
request.state.bound_logger = logger.bind()
return call_next(request)
async def set_span_id(request: Request, call_next): async def set_span_id(request: Request, call_next):
span_id = uuid4() span_id = uuid4()
request.state.span_id = span_id request.state.span_id = span_id
@ -90,8 +94,9 @@ async def set_span_id(request: Request, call_next):
response = await call_next(request) response = await call_next(request)
response.headers["x-request-id"] = str(span_id) if str(response.status_code)[0] in "123":
response.headers["x-span-id"] = str(span_id) response.headers["x-request-id"] = str(span_id)
response.headers["x-span-id"] = str(span_id)
return response return response
@ -121,7 +126,7 @@ def set_prefers(
user_agent = request.headers.get("user-agent", "").lower() user_agent = request.headers.get("user-agent", "").lower()
referer = request.headers.get("referer", "") referer = request.headers.get("referer", "")
if "," in content_type: if content_type and "," in content_type:
content_type = content_type.split(",")[0] content_type = content_type.split(",")[0]
request.state.bound_logger.info( request.state.bound_logger.info(
@ -308,19 +313,6 @@ def format_json_as_rich_text(data: dict, template_name: str) -> str:
return capture.get() return capture.get()
async def respond_based_on_content_type(
request: Request,
call_next,
content_type: str,
data: str,
):
requested_path = request.url.path
if requested_path in ["/docs", "/redoc", "/openapi.json"]:
return await call_next(request)
return await call_next(request)
def handle_not_found(request: Request, call_next, data: str): def handle_not_found(request: Request, call_next, data: str):
requested_path = request.url.path requested_path = request.url.path
# available_routes = [route.path for route in app.router.routes if route.path] # available_routes = [route.path for route in app.router.routes if route.path]
@ -353,59 +345,44 @@ async def respond_based_on_content_type(request: Request, call_next):
try: try:
response = await call_next(request) response = await call_next(request)
user_agent = request.headers.get("user-agent", "").lower()
referer = request.headers.get("referer", "")
content_type = request.query_params.get(
"content_type",
request.headers.get("content-type", request.headers.get("Accept")),
)
# if "raw" in content_type:
# return await call_next(request)
if content_type == "*/*":
content_type = None
if ("/docs" in referer or "/redoc" in referer) and content_type is None:
content_type = "application/json"
elif is_browser_request(user_agent) and content_type is None:
content_type = "text/html"
elif is_rtf_request(user_agent) and content_type is None:
content_type = "application/rtf"
elif content_type is None:
content_type = content_type or "application/json"
body = b"".join([chunk async for chunk in response.body_iterator])
data = body.decode("utf-8")
if response.status_code == 404: if response.status_code == 404:
request.state.bound_logger.info("404 not found") request.state.bound_logger.info("404 not found")
return handle_not_found( body = b"".join([chunk async for chunk in response.body_iterator])
data = body.decode("utf-8")
response = handle_not_found(
request=request, request=request,
call_next=call_next, call_next=call_next,
data=data, data=data,
) )
if str(response.status_code)[0] not in "123": elif str(response.status_code)[0] not in "123":
request.state.bound_logger.info("non-200 response") request.state.bound_logger.info(f"non-200 response {response.status_code}")
# return await handle_response(request, response, data)
return response return response
else:
body = b"".join([chunk async for chunk in response.body_iterator])
data = body.decode("utf-8")
return await handle_response(request, response, data) return await handle_response(request, response, data)
# except TemplateNotFound:
# return HTMLResponse(content="Template Not Found ", status_code=404)
except StarletteHTTPException as exc:
request.state.bound_logger.info("starlette exception")
return HTMLResponse(
content=f"Error {exc.status_code}: {exc.detail}",
status_code=exc.status_code,
)
except RequestValidationError as exc:
request.state.bound_logger.info("request validation error")
return JSONResponse(status_code=422, content={"detail": exc.errors()})
except Exception as e: except Exception as e:
request.state.bound_logger.info("internal server error") request.state.bound_logger.info("internal server error")
print(traceback.format_exc()) # print(traceback.format_exc())
return HTMLResponse(content=f"Internal Server Error: {e!s}", status_code=500) raise e
if settings.ENV == "local":
return HTMLResponse(
content=f"Internal Server Error: {e!s} {traceback.format_exc()}",
status_code=500,
)
else:
return HTMLResponse(
content=f"Internal Server Error: {e!s}", status_code=500
)
async def handle_response(request: Request, response: Response, data: str): async def handle_response(
request: Request,
response: Response,
data: str,
):
json_data = json.loads(data) json_data = json.loads(data)
template_name = getattr(request.state, "template_name", "default_template.html") template_name = getattr(request.state, "template_name", "default_template.html")
@ -433,7 +410,7 @@ async def handle_response(request: Request, response: Response, data: str):
template = templates.get_template(template_name) template = templates.get_template(template_name)
html_content = template.render(data=json_data) html_content = template.render(data=json_data)
markdown_content = html2text.html2text(html_content) markdown_content = html2text.html2text(html_content)
return PlainTextResponse(content=markdown_content, headers=response.headers) return PlainTextResponse(content=markdown_content)
if request.state.prefers.text: if request.state.prefers.text:
request.state.bound_logger.info("returning plain text") request.state.bound_logger.info("returning plain text")
@ -483,13 +460,11 @@ async def handle_response(request: Request, response: Response, data: str):
# Initialize the logger # Initialize the logger
async def log_requests(request: Request, call_next): async def log_requests(request: Request, call_next):
# Log request details # Log request details
request.state.bound_logger = logger.bind(
method=request.method, path=request.url.path
)
request.state.bound_logger.info( request.state.bound_logger.info(
"Request received", "Request received",
# span_id=request.state.span_id,
method=request.method,
path=request.url.path,
# headers=dict(request.headers),
# prefers=request.state.prefers,
) )
# logger.info( # logger.info(
# headers=dict(request.headers), # headers=dict(request.headers),

View file

@ -1,8 +1,19 @@
from pydantic import model_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
class Settings(BaseSettings): class Settings(BaseSettings):
ENV: str = "local" ENV: str = "local"
DEBUG: bool = False
class Config: class Config:
env_file = "config.env" env_file = "config.env"
@model_validator(mode="after")
def validate_debug(self):
if self.ENV == "local" and self.DEBUG is False:
self.DEBUG = True
return self
settings = Settings()