Skip to content

Commit 2980463

Browse files
committed
Merge branch '302-amr' into 'main'
Implement AMR claim Closes #302 See merge request yaal/canaille!307
2 parents 1e34a98 + 79cc154 commit 2980463

File tree

5 files changed

+301
-8
lines changed

5 files changed

+301
-8
lines changed

CHANGES.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
[0.x.x] - Unreleased
2+
--------------------
3+
4+
Added
5+
^^^^^
6+
- OIDC ``amr`` claim support. :issue:`302`
7+
18
[0.1.0] - 2025-11-13
29
--------------------
310

canaille/app/session.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
class UserSession:
2222
user: User | None = None
2323
last_login_datetime: datetime.datetime | None = None
24+
authentication_methods: list[str] | None = None
2425

2526
@classmethod
2627
def deserialize(cls, payload):
@@ -38,13 +39,17 @@ def deserialize(cls, payload):
3839
)
3940
if payload.get("last_login_datetime")
4041
else None,
42+
authentication_methods=payload.get("authentication_methods"),
4143
)
4244

4345
def serialize(self):
44-
return {
46+
payload = {
4547
"user": self.user.id,
4648
"last_login_datetime": self.last_login_datetime.isoformat(),
4749
}
50+
if self.authentication_methods is not None:
51+
payload["authentication_methods"] = self.authentication_methods
52+
return payload
4853

4954

5055
def current_user_session():
@@ -67,7 +72,12 @@ def save_user_session() -> None:
6772
def login_user(user, remember: bool = True) -> None:
6873
"""Open a session for the user."""
6974
now = datetime.datetime.now(datetime.timezone.utc)
70-
obj = UserSession(user=user, last_login_datetime=now)
75+
authentication_methods = g.auth.achieved if hasattr(g, "auth") and g.auth else None
76+
obj = UserSession(
77+
user=user,
78+
last_login_datetime=now,
79+
authentication_methods=authentication_methods,
80+
)
7181
g.session = obj
7282
try:
7383
previous = (

canaille/oidc/provider.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,35 @@
4444
AUTHORIZATION_CODE_LIFETIME = 84400
4545
JWT_JTI_CACHE_LIFETIME = 3600
4646

47+
AMR_MAPPING = {
48+
"password": ["pwd"],
49+
"otp": ["otp"],
50+
"sms": ["sms", "mca"],
51+
"email": ["mca"],
52+
}
53+
54+
55+
def compute_amr_values(authentication_methods):
56+
"""Convert internal authentication methods to AMR values (RFC 8176).
57+
58+
Returns a list of AMR values, automatically adding 'mfa' if multiple
59+
factors were used.
60+
"""
61+
if not authentication_methods:
62+
return None
63+
64+
amr_values = []
65+
for method in authentication_methods:
66+
if method in AMR_MAPPING:
67+
amr_values.extend(AMR_MAPPING[method])
68+
69+
amr_values = list(dict.fromkeys(amr_values))
70+
71+
if len(authentication_methods) > 1:
72+
amr_values.append("mfa")
73+
74+
return amr_values if amr_values else None
75+
4776

4877
def get_bearer_token(request):
4978
"""Get the Bearer token from the request headers."""
@@ -86,6 +115,12 @@ def save_authorization_code(code, request):
86115
nonce = request.payload.data.get("nonce")
87116
now = datetime.datetime.now(datetime.timezone.utc)
88117
scope = request.client.get_allowed_scope(request.payload.scope)
118+
authentication_methods = (
119+
g.session.authentication_methods
120+
if hasattr(g, "session") and g.session
121+
else None
122+
)
123+
amr = compute_amr_values(authentication_methods)
89124
code = models.AuthorizationCode(
90125
authorization_code_id=gen_salt(48),
91126
code=code,
@@ -99,6 +134,7 @@ def save_authorization_code(code, request):
99134
challenge=request.payload.data.get("code_challenge"),
100135
challenge_method=request.payload.data.get("code_challenge_method"),
101136
auth_time=g.session.last_login_datetime,
137+
amr=amr,
102138
)
103139
Backend.instance.save(code)
104140
return code.code

tests/core/test_auth_password.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ def test_signin_and_out(testclient, user, caplog):
6161
res = res.follow(status=200)
6262

6363
with testclient.session_transaction() as session:
64-
assert [{"user": user.id, "last_login_datetime": mock.ANY}] == session.get(
65-
"sessions"
66-
)
64+
assert [
65+
{
66+
"user": user.id,
67+
"last_login_datetime": mock.ANY,
68+
"authentication_methods": ["password"],
69+
}
70+
] == session.get("sessions")
6771
assert "auth" not in session
6872

6973
res = testclient.get("/login", status=200)
@@ -149,9 +153,13 @@ def test_signin_with_alternate_attribute(testclient, user):
149153
res = res.follow(status=200)
150154

151155
with testclient.session_transaction() as session:
152-
assert [{"user": user.id, "last_login_datetime": mock.ANY}] == session.get(
153-
"sessions"
154-
)
156+
assert [
157+
{
158+
"user": user.id,
159+
"last_login_datetime": mock.ANY,
160+
"authentication_methods": ["password"],
161+
}
162+
] == session.get("sessions")
155163

156164

157165
def test_password_page_without_signin_in_redirects_to_login_page(testclient, user):

tests/oidc/test_amr.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from urllib.parse import parse_qs
2+
from urllib.parse import urlsplit
3+
4+
from joserfc import jwt
5+
6+
from canaille.app import models
7+
from canaille.oidc.jose import registry
8+
from canaille.oidc.provider import compute_amr_values
9+
from tests.core.test_auth_otp import generate_otp
10+
11+
from . import client_credentials
12+
13+
14+
def test_amr_password_only(testclient, user, client, server_jwk, backend):
15+
"""Test that AMR contains only 'pwd' for password-only authentication."""
16+
testclient.app.config["CANAILLE"]["AUTHENTICATION_FACTORS"] = ["password"]
17+
18+
res = testclient.get("/login")
19+
res.form["login"] = "user"
20+
res = res.form.submit().follow()
21+
res.form["password"] = "correct horse battery staple"
22+
res = res.form.submit()
23+
24+
res = testclient.get(
25+
"/oauth/authorize",
26+
params=dict(
27+
response_type="code",
28+
client_id=client.client_id,
29+
scope="openid",
30+
nonce="somenonce",
31+
redirect_uri="https://client.test/redirect1",
32+
),
33+
)
34+
res = res.form.submit(name="answer", value="accept", status=302)
35+
36+
params = parse_qs(urlsplit(res.location).query)
37+
code = params["code"][0]
38+
39+
res = testclient.post(
40+
"/oauth/token",
41+
params=dict(
42+
grant_type="authorization_code",
43+
code=code,
44+
scope="openid",
45+
redirect_uri=client.redirect_uris[0],
46+
),
47+
headers={"Authorization": f"Basic {client_credentials(client)}"},
48+
)
49+
50+
id_token = res.json["id_token"]
51+
claims = jwt.decode(id_token, server_jwk, registry=registry)
52+
assert claims.claims["amr"] == ["pwd"]
53+
54+
for consent in backend.query(models.Consent, client=client, subject=user):
55+
backend.delete(consent)
56+
57+
58+
def test_amr_password_and_otp(testclient, user, client, server_jwk, backend):
59+
"""Test that AMR contains 'pwd', 'otp', and 'mfa' for password + OTP authentication."""
60+
testclient.app.config["CANAILLE"]["AUTHENTICATION_FACTORS"] = ["password", "otp"]
61+
62+
res = testclient.get("/login")
63+
res.form["login"] = "user"
64+
res = res.form.submit().follow()
65+
res.form["password"] = "correct horse battery staple"
66+
res = res.form.submit().follow()
67+
68+
totp_period = int(
69+
testclient.app.config["CANAILLE"]["TOTP_LIFETIME"].total_seconds()
70+
)
71+
res.form["otp"] = generate_otp("TOTP", user.secret_token, totp_period=totp_period)
72+
res = res.form.submit()
73+
74+
res = testclient.get(
75+
"/oauth/authorize",
76+
params=dict(
77+
response_type="code",
78+
client_id=client.client_id,
79+
scope="openid",
80+
nonce="somenonce",
81+
redirect_uri="https://client.test/redirect1",
82+
),
83+
)
84+
res = res.form.submit(name="answer", value="accept", status=302)
85+
86+
params = parse_qs(urlsplit(res.location).query)
87+
code = params["code"][0]
88+
89+
res = testclient.post(
90+
"/oauth/token",
91+
params=dict(
92+
grant_type="authorization_code",
93+
code=code,
94+
scope="openid",
95+
redirect_uri=client.redirect_uris[0],
96+
),
97+
headers={"Authorization": f"Basic {client_credentials(client)}"},
98+
)
99+
100+
id_token = res.json["id_token"]
101+
claims = jwt.decode(id_token, server_jwk, registry=registry)
102+
assert set(claims.claims["amr"]) == {"pwd", "otp", "mfa"}
103+
104+
for consent in backend.query(models.Consent, client=client, subject=user):
105+
backend.delete(consent)
106+
107+
108+
def test_amr_password_and_sms(smtpd, testclient, user, client, server_jwk, backend):
109+
"""Test that AMR contains 'pwd', 'sms', and 'mfa' for password + SMS authentication."""
110+
testclient.app.config["CANAILLE"]["AUTHENTICATION_FACTORS"] = ["password", "sms"]
111+
112+
res = testclient.get("/login")
113+
res.form["login"] = "user"
114+
res = res.form.submit().follow()
115+
res.form["password"] = "correct horse battery staple"
116+
res = res.form.submit().follow()
117+
118+
backend.reload(user)
119+
res.form["otp"] = user.one_time_password
120+
res = res.form.submit()
121+
122+
res = testclient.get(
123+
"/oauth/authorize",
124+
params=dict(
125+
response_type="code",
126+
client_id=client.client_id,
127+
scope="openid",
128+
nonce="somenonce",
129+
redirect_uri="https://client.test/redirect1",
130+
),
131+
)
132+
res = res.form.submit(name="answer", value="accept", status=302)
133+
134+
params = parse_qs(urlsplit(res.location).query)
135+
code = params["code"][0]
136+
137+
res = testclient.post(
138+
"/oauth/token",
139+
params=dict(
140+
grant_type="authorization_code",
141+
code=code,
142+
scope="openid",
143+
redirect_uri=client.redirect_uris[0],
144+
),
145+
headers={"Authorization": f"Basic {client_credentials(client)}"},
146+
)
147+
148+
id_token = res.json["id_token"]
149+
claims = jwt.decode(id_token, server_jwk, registry=registry)
150+
assert set(claims.claims["amr"]) == {"pwd", "sms", "mca", "mfa"}
151+
152+
for consent in backend.query(models.Consent, client=client, subject=user):
153+
backend.delete(consent)
154+
155+
156+
def test_amr_password_and_email(smtpd, testclient, user, client, server_jwk, backend):
157+
"""Test that AMR contains 'pwd', 'mca', and 'mfa' for password + email authentication."""
158+
testclient.app.config["CANAILLE"]["AUTHENTICATION_FACTORS"] = [
159+
"password",
160+
"email",
161+
]
162+
163+
res = testclient.get("/login")
164+
res.form["login"] = "user"
165+
res = res.form.submit().follow()
166+
res.form["password"] = "correct horse battery staple"
167+
res = res.form.submit().follow()
168+
169+
backend.reload(user)
170+
res.form["otp"] = user.one_time_password
171+
res = res.form.submit()
172+
173+
res = testclient.get(
174+
"/oauth/authorize",
175+
params=dict(
176+
response_type="code",
177+
client_id=client.client_id,
178+
scope="openid",
179+
nonce="somenonce",
180+
redirect_uri="https://client.test/redirect1",
181+
),
182+
)
183+
res = res.form.submit(name="answer", value="accept", status=302)
184+
185+
params = parse_qs(urlsplit(res.location).query)
186+
code = params["code"][0]
187+
188+
res = testclient.post(
189+
"/oauth/token",
190+
params=dict(
191+
grant_type="authorization_code",
192+
code=code,
193+
scope="openid",
194+
redirect_uri=client.redirect_uris[0],
195+
),
196+
headers={"Authorization": f"Basic {client_credentials(client)}"},
197+
)
198+
199+
id_token = res.json["id_token"]
200+
claims = jwt.decode(id_token, server_jwk, registry=registry)
201+
assert set(claims.claims["amr"]) == {"pwd", "mca", "mfa"}
202+
203+
for consent in backend.query(models.Consent, client=client, subject=user):
204+
backend.delete(consent)
205+
206+
207+
def test_compute_amr_values_none():
208+
"""Test compute_amr_values with None returns None."""
209+
assert compute_amr_values(None) is None
210+
211+
212+
def test_compute_amr_values_empty_list():
213+
"""Test compute_amr_values with empty list returns None."""
214+
assert compute_amr_values([]) is None
215+
216+
217+
def test_compute_amr_values_unknown_method():
218+
"""Test compute_amr_values with unknown method returns None."""
219+
assert compute_amr_values(["unknown_method"]) is None
220+
221+
222+
def test_compute_amr_values_mixed_known_unknown():
223+
"""Test compute_amr_values with mixed known and unknown methods."""
224+
result = compute_amr_values(["password", "unknown_method", "otp"])
225+
assert set(result) == {"pwd", "otp", "mfa"}
226+
227+
228+
def test_compute_amr_values_deduplication():
229+
"""Test that AMR values are deduplicated (email and sms both give mca)."""
230+
result = compute_amr_values(["email", "sms"])
231+
assert result.count("mca") == 1
232+
assert set(result) == {"mca", "sms", "mfa"}

0 commit comments

Comments
 (0)