Passed
Push — main ( 9ee7a5...30953c )
by Eran
01:37
created

graphinate.renderers.graphql._graphql_app()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import contextlib
2
import webbrowser
3
4
import strawberry
5
from starlette.applications import Starlette
6
from starlette.requests import Request
7
from starlette.responses import RedirectResponse
8
from starlette.schemas import SchemaGenerator
9
from starlette.types import ASGIApp
10
from strawberry.asgi import GraphQL
11
from strawberry.extensions.tracing import OpenTelemetryExtension
12
13
from graphinate.server.starlette import routes
14
15
DEFAULT_PORT: int = 8072
16
17
GRAPHQL_ROUTE_PATH = "/graphql"
18
19
20
def _openapi_schema(request: Request) -> ASGIApp:
21
    """
22
    Generates an OpenAPI schema for the GraphQL API and other routes.
23
24
    Args:
25
        request (Request): The HTTP request object.
26
27
    Returns:
28
        ASGIApp: An OpenAPI response containing the schema for the specified routes.
29
    """
30
    schema_data = {
31
        'openapi': '3.0.0',
32
        'info': {'title': 'Graphinate API', 'version': '0.8.2'},
33
        'paths': {
34
            '/graphql': {'get': {'responses': {200: {'description': 'GraphQL'}}}},
35
            '/graphiql': {'get': {'responses': {200: {'description': 'GraphiQL UI.'}}}},
36
            '/metrics': {'get': {'responses': {200: {'description': 'Prometheus metrics.'}}}},
37
            '/viewer': {'get': {'responses': {200: {'description': '3D Force-Directed Graph Viewer'}}}},
38
            '/voyager': {'get': {'responses': {200: {'description': 'Voyager GraphQL Schema Viewer'}}}}
39
        }
40
    }
41
42
    schema = SchemaGenerator(schema_data)
43
    return schema.OpenAPIResponse(request=request)
44
45
46
def _graphql_app(graphql_schema: strawberry.Schema) -> strawberry.asgi.GraphQL:
47
    graphql_schema.extensions.append(OpenTelemetryExtension)
48
    graphql_app = GraphQL(graphql_schema, graphiql=True)
49
    return graphql_app
50
51
52
def _starlette_app(graphql_app: strawberry.asgi.GraphQL | None = None, port: int = DEFAULT_PORT, **kwargs) -> Starlette:
53
    def open_url(endpoint):
54
        webbrowser.open(f'http://localhost:{port}/{endpoint}')
55
56
    @contextlib.asynccontextmanager
57
    async def lifespan(app: Starlette):  # pragma: no cover
58
        if kwargs.get('browse'):
59
            open_url('viewer')
60
        yield
61
62
    app = Starlette(
63
        lifespan=lifespan,
64
        routes=routes()
65
    )
66
67
    from starlette_prometheus import PrometheusMiddleware, metrics
68
    app.add_middleware(PrometheusMiddleware)
69
    app.add_route("/metrics", metrics)
70
71
    if graphql_app:
72
        app.add_route(GRAPHQL_ROUTE_PATH, graphql_app)
73
        app.add_websocket_route(GRAPHQL_ROUTE_PATH, graphql_app)
74
        app.add_route("/schema", route=_openapi_schema, include_in_schema=False)
75
        app.add_route("/openapi.json", route=_openapi_schema, include_in_schema=False)
76
77
    async def redirect_to_viewer(request):
78
        return RedirectResponse(url='/viewer')
79
80
    app.add_route('/', redirect_to_viewer)
81
82
    return app
83
84
85
def server(graphql_schema: strawberry.Schema, port: int = DEFAULT_PORT, **kwargs):
86
    """
87
    Args:
88
        graphql_schema: The Strawberry GraphQL schema.
89
        port: The port number to run the server on. Defaults to 8072.
90
91
    Returns:
92
    """
93
94
    graphql_app = _graphql_app(graphql_schema)
95
96
    app = _starlette_app(graphql_app, port=port, **kwargs)
97
98
    import uvicorn
99
    uvicorn.run(app, host='0.0.0.0', port=port)
100
101
102
__all__ = ['server']
103