-
Notifications
You must be signed in to change notification settings - Fork 3
/
blueprint.py
322 lines (253 loc) · 10.3 KB
/
blueprint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# Copyright 2024 Canonical Ltd.
# See LICENSE file for licensing details.
"""Provides API blueprint for flask to run the policy checks.
Note that this blueprint requires the application to be run with a single worker due to the use of
an in-memory set to store the one time tokens. This is done to reduce the complexity of deployments
as the alternative would be to require a database.
"""
import http
import json
import logging
import os
import secrets
import tempfile
from enum import Enum
from hmac import compare_digest
from pathlib import Path
from typing import cast
from flask import Blueprint, Response, request
from flask_httpauth import HTTPTokenAuth
from flask_pydantic import validate
from github import GithubException
from repo_policy_compliance import (
PullRequestInput,
PushInput,
ScheduleInput,
UsedPolicy,
WorkflowDispatchInput,
database,
exceptions,
github_client,
policy,
pull_request,
push,
schedule,
workflow_dispatch,
)
from repo_policy_compliance.check import Result
repo_policy_compliance = Blueprint("repo_policy_compliance", __name__)
auth = HTTPTokenAuth(scheme="Bearer")
# Need temporary file to persist policy document so better not wrap the entire module in a with
# statement
policy_document_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
policy_document_path = Path(policy_document_file.name)
# Bandit thinks this is the token value when it is the name of the environment variable with the
# token value
CHARM_TOKEN_ENV_NAME = "CHARM_TOKEN" # nosec
PULL_REQUEST_DISALLOW_FORK_ENV_NAME = "PULL_REQUEST_DISALLOW_FORK"
# Bandit thinks this is the token value when it is the name of the endpoint to get a one time token
ONE_TIME_TOKEN_ENDPOINT = "/one-time-token" # nosec
POLICY_ENDPOINT = "/policy"
CHECK_RUN_ENDPOINT = "/check-run"
PULL_REQUEST_CHECK_RUN_ENDPOINT = "/pull_request/check-run"
WORKFLOW_DISPATCH_CHECK_RUN_ENDPOINT = "/workflow_dispatch/check-run"
PUSH_CHECK_RUN_ENDPOINT = "/push/check-run"
SCHEDULE_CHECK_RUN_ENDPOINT = "/schedule/check-run"
DEFAULT_CHECK_RUN_ENDPOINT = "/default/check-run"
ALWAYS_FAIL_CHECK_RUN_ENDPOINT = "/always-fail/check-run"
HEALTH_ENDPOINT = "/health"
AUTH_HEALTH_ENDPOINT = "/auth-health"
class Users(str, Enum):
"""The possible users.
Attributes:
CHARM: The charm user that can request one time tokens.
RUNNER: The runner user that can check whether a run should proceed.
"""
CHARM = "charm"
RUNNER = "runner"
CHARM_ROLE = Users.CHARM
RUNNER_ROLE = Users.RUNNER
@auth.verify_token
def verify_token(token: str) -> str | None:
"""Verify the authentication token.
Args:
token: The token to check.
Returns:
The identity associated with the token or None if no token matches.
"""
charm_token = os.getenv(CHARM_TOKEN_ENV_NAME) or os.getenv(f"FLASK_{CHARM_TOKEN_ENV_NAME}")
if not charm_token:
logging.error(
(
"%s environment variable is required for generating one time tokens, it is not "
"defined or empty"
),
CHARM_TOKEN_ENV_NAME,
)
return None
if compare_digest(token, charm_token):
return Users.CHARM
if database.check_token(token=token):
return Users.RUNNER
return None
@auth.get_user_roles
def get_user_roles(user: str) -> str | None:
"""Get the roles of a user.
Args:
user: The name of the user.
Returns:
The role of the user if they have one, else None.
"""
match user:
case Users.CHARM:
return CHARM_ROLE
case Users.RUNNER:
return RUNNER_ROLE
# It shouldn't be possible to get here since each valid token should be associated with
# a user
case _: # pragma: no cover
return None
@repo_policy_compliance.route(ONE_TIME_TOKEN_ENDPOINT)
@auth.login_required(role=CHARM_ROLE)
def one_time_token() -> str:
"""Generate a one time token for a runner.
Returns:
The one time token.
"""
token = secrets.token_hex(32)
database.add_token(token)
return token
@repo_policy_compliance.route(POLICY_ENDPOINT, methods=["POST"])
@auth.login_required(role=CHARM_ROLE)
def policy_endpoint() -> Response:
"""Generate a one time token for a runner.
Returns:
Either that the policy was updated or an error if the policy is invalid.
"""
data = cast(dict, request.json)
if not (policy_report := policy.check(document=data)).result:
return Response(response=policy_report.reason, status=400)
policy_document_path.write_text(json.dumps(data), encoding="utf-8")
return Response(status=http.HTTPStatus.NO_CONTENT)
def _get_policy_document() -> dict | UsedPolicy:
"""Get the current policy document.
Returns:
The current policy document if set or that all policies should be used.
"""
if stored_policy_document_contents := policy_document_path.read_text(encoding="utf-8"):
return cast(dict, json.loads(stored_policy_document_contents))
pull_request_disallow_fork = (
os.getenv(PULL_REQUEST_DISALLOW_FORK_ENV_NAME, "")
or os.getenv(f"FLASK_{PULL_REQUEST_DISALLOW_FORK_ENV_NAME}", "")
).lower() == "true"
if not pull_request_disallow_fork:
return UsedPolicy.PULL_REQUEST_ALLOW_FORK
return UsedPolicy.ALL
# Keeping /check-run pointing to this for backwards compatibility
@repo_policy_compliance.route(CHECK_RUN_ENDPOINT, methods=["POST"])
@repo_policy_compliance.route(PULL_REQUEST_CHECK_RUN_ENDPOINT, methods=["POST"])
@auth.login_required(role=RUNNER_ROLE)
@validate()
def pull_request_check_run(body: PullRequestInput) -> Response:
"""Check whether a pull request run should proceed.
Args:
body: The request body after it is validated.
Returns:
Either to proceed with the run or an error not to proceed with a reason why.
"""
policy_document = _get_policy_document()
if (
report := pull_request(input_=body, policy_document=policy_document)
).result == Result.FAIL:
return Response(response=report.reason, status=http.HTTPStatus.FORBIDDEN)
if report.result == Result.ERROR:
return Response(response=report.reason, status=http.HTTPStatus.INTERNAL_SERVER_ERROR)
return Response(status=http.HTTPStatus.NO_CONTENT)
@repo_policy_compliance.route(WORKFLOW_DISPATCH_CHECK_RUN_ENDPOINT, methods=["POST"])
@auth.login_required(role=RUNNER_ROLE)
@validate()
def workflow_dispatch_check_run(body: WorkflowDispatchInput) -> Response:
"""Check whether a workflow dispatch run should proceed.
Args:
body: The request body after it is validated.
Returns:
Either to proceed with the run or an error not to proceed with a reason why.
"""
policy_document = _get_policy_document()
if (
report := workflow_dispatch(input_=body, policy_document=policy_document)
).result == Result.FAIL:
return Response(response=report.reason, status=http.HTTPStatus.FORBIDDEN)
if report.result == Result.ERROR:
return Response(response=report.reason, status=http.HTTPStatus.INTERNAL_SERVER_ERROR)
return Response(status=http.HTTPStatus.NO_CONTENT)
# Include a default endpoint that works the same as push to be used for other events
@repo_policy_compliance.route(DEFAULT_CHECK_RUN_ENDPOINT, methods=["POST"])
@repo_policy_compliance.route(PUSH_CHECK_RUN_ENDPOINT, methods=["POST"])
@auth.login_required(role=RUNNER_ROLE)
@validate()
def push_check_run(body: PushInput) -> Response:
"""Check whether a push run should proceed.
Args:
body: The request body after it is validated.
Returns:
Either to proceed with the run or an error not to proceed with a reason why.
"""
policy_document = _get_policy_document()
if (report := push(input_=body, policy_document=policy_document)).result == Result.FAIL:
return Response(response=report.reason, status=http.HTTPStatus.FORBIDDEN)
if report.result == Result.ERROR:
return Response(response=report.reason, status=http.HTTPStatus.INTERNAL_SERVER_ERROR)
return Response(status=http.HTTPStatus.NO_CONTENT)
@repo_policy_compliance.route(SCHEDULE_CHECK_RUN_ENDPOINT, methods=["POST"])
@auth.login_required(role=RUNNER_ROLE)
@validate()
def schedule_check_run(body: ScheduleInput) -> Response:
"""Check whether a schedule run should proceed.
Args:
body: The request body after it is validated.
Returns:
Either to proceed with the run or an error not to proceed with a reason why.
"""
policy_document = _get_policy_document()
if (report := schedule(input_=body, policy_document=policy_document)).result == Result.FAIL:
return Response(response=report.reason, status=http.HTTPStatus.FORBIDDEN)
if report.result == Result.ERROR:
return Response(response=report.reason, status=http.HTTPStatus.INTERNAL_SERVER_ERROR)
return Response(status=http.HTTPStatus.NO_CONTENT)
@repo_policy_compliance.route(HEALTH_ENDPOINT, methods=["GET"])
def health() -> Response:
"""Health check endpoint.
Returns:
500 response if GitHGub connectivity is not correctly configured, 204 response otherwise.
"""
try:
client = github_client.get()
client.get_repo("canonical/repo-policy-compliance")
except exceptions.ConfigurationError as exc:
return Response(response=str(exc), status=http.HTTPStatus.INTERNAL_SERVER_ERROR)
except GithubException as exc:
return Response(
response=f"could not communicate with GitHub, {exc}",
status=http.HTTPStatus.INTERNAL_SERVER_ERROR,
)
return Response(status=http.HTTPStatus.NO_CONTENT)
@repo_policy_compliance.route(AUTH_HEALTH_ENDPOINT, methods=["GET"])
@auth.login_required(role=RUNNER_ROLE)
def auth_health() -> Response:
"""Health check for authenticated requests.
Returns:
204 response.
"""
return Response(status=http.HTTPStatus.NO_CONTENT)
@repo_policy_compliance.route(ALWAYS_FAIL_CHECK_RUN_ENDPOINT, methods=["POST"])
@auth.login_required(role=RUNNER_ROLE)
def always_fail_check_run() -> Response:
"""Return failure to be used during testing.
Returns:
Always returns a failure response.
"""
return Response(
response="Endpoint designed for testing that always fails",
status=http.HTTPStatus.FORBIDDEN,
)