8000 feat(users): add id from idp to db table if possible (also update exi… by Avantol13 · Pull Request #975 · uc-cdis/fence · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat(users): add id from idp to db table if possible (also update exi… #975

New issue

Have a questio 8000 n about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions fence/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def build_redirect_url(hostname, path):
return redirect_base + path


def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
def login_user(
username, provider, fence_idp=None, shib_idp=None, email=None, id_from_idp=None
):
"""
Login a user with the given username and provider. Set values in Flask
session to indicate the user being logged in. In addition, commit the user
Expand All @@ -70,6 +72,8 @@ def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
shib_idp (str, optional): Downstreawm shibboleth IdP
email (str, optional): email of user (may or may not match username depending
on the IdP)
id_from_idp (str, optional): id from the IDP (which may be different than
the username)
"""

def set_flask_session_values(user):
Expand All @@ -93,6 +97,7 @@ def set_flask_session_values(user):
user = query_for_user(session=current_session, username=username)
if user:
_update_users_email(user, email)
_update_users_id_from_idp(user, id_from_idp)

# This expression is relevant to those users who already have user and
# idp info persisted to the database. We return early to avoid
Expand All @@ -101,11 +106,16 @@ def set_flask_session_values(user):
set_flask_session_values(user)
return
else:
# we need a new user
user = User(username=username)

if email:
user = User(username=username, email=email)
else:
user = User(username=username)
user.email = email

if id_from_idp:
user.id_from_idp = id_from_idp

# setup idp connection for new user (or existing user w/o it setup)
idp = (
current_session.query(IdentityProvider)
.filter(IdentityProvider.name == provider)
Expand Down Expand Up @@ -271,3 +281,17 @@ def _update_users_email(user, email):

current_session.add(user)
current_session.commit()


def _update_users_id_from_idp(user, id_from_idp):
"""
Update id_from_idp if provided and doesn't match db entry.
"""
if id_from_idp and user.id_from_idp != id_from_idp:
logger.info(
f"Updating username {user.username}'s id_from_idp from {user.id_from_idp} to {id_from_idp}"
)
user.id_from_idp = id_from_idp

current_session.add(user)
current_session.commit()
34 changes: 24 additions & 10 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def get(self):


class DefaultOAuth2Callback(Resource):
def __init__(self, idp_name, client, username_field="email", email_field="email"):
def __init__(
self,
idp_name,
client,
username_field="email",
email_field="email",
id_from_idp_field="sub",
):
"""
Construct a resource for a login callback endpoint

Expand All @@ -65,14 +72,18 @@ def __init__(self, idp_name, client, username_field="email", email_field="email"
client (fence.resources.openid.idp_oauth2.Oauth2ClientBase):
Some instaniation of this base client class or a child class
username_field (str, optional): default field from response to
retrieve the username
retrieve the unique username
email_field (str, optional): default field from response to
retrieve the email (if available)
id_from_idp_field (str, optional): default field from response to
retrieve the idp-specific ID for this user (could be the same
as username_field)
"""
self.idp_name = idp_name
self.client = client
self.username_field = username_field
self.email_field = email_field
self.id_from_idp_field = id_from_idp_field

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -101,33 +112,36 @@ def get(self):
result = self.client.get_user_id(code)
10000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure that RASOauth2Client.get_user_id (and get_user_id of any other clients for which we want to store sub) will need to be updated to also return sub.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, really good catch. Yup, just pushed an update

username = result.get(self.username_field)
email = result.get(self.email_field)
id_from_idp = result.get(self.id_from_idp_field)
if username:
resp = _login(username, self.idp_name, email=email)
self.post_login(flask.g.user, result)
resp = _login(username, self.idp_name, email=email, id_from_idp=id_from_idp)
self.post_login(
user=flask.g.user, token_result=result, id_from_idp=id_from_idp
)
return resp
raise UserError(result)

def post_login(self, user=None, token_result=None):
prepare_login_log(self.idp_name)
def post_login(self, user=None, token_result=None, id_from_idp=None):
prepare_login_log(self.idp_name, id_from_idp=id_from_idp)


def prepare_login_log(idp_name):
def prepare_login_log(idp_name, id_from_idp=None):
flask.g.audit_data = {
"username": flask.g.user.username,
"sub": flask.g.user.id,
"sub": id_from_idp,
"idp": idp_name,
"fence_idp": flask.session.get("fence_idp"),
"shib_idp": flask.session.get("shib_idp"),
"client_id": flask.session.get("client_id"),
}


def _login(username, idp_name, email=None):
def _login(username, idp_name, email=None, id_from_idp=None):
"""
Login user with given username, then redirect if session has a saved
redirect.
"""
login_user(username, idp_name, email=email)
login_user(username, idp_name, email=email, id_from_idp=id_from_idp)

if config["REGISTER_USERS_ON"]:
if not flask.g.user.additional_info.get("registration_info"):
Expand Down
4 changes: 2 additions & 2 deletions fence/blueprints/login/ras.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
6D4E username_field="username",
)

def post_login(self, user=None, token_result=None):
def post_login(self, user=None, token_result=None, id_from_idp=None):
# TODO: I'm not convinced this code should be in post_login.
# Just putting it in here for now, but might refactor later.
# This saves us a call to RAS /userinfo, but will not make sense
Expand Down Expand Up @@ -187,4 +187,4 @@ def post_login(self, user=None, token_result=None):
)
sync.sync_single_user_visas(user, current_session)

super(RASCallback, self).post_login()
super(RASCallback, self).post_login(id_from_idp=id_from_idp)
6 changes: 3 additions & 3 deletions fence/blueprints/login/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __init__(self):
idp_name=IdentityProvider.synapse,
client=flask.current_app.synapse_client,
username_field="fence_username",
id_from_idp_field="sub",
)

def post_login(self, user=None, token_result=None):
user.id_from_idp = token_result["sub"]
def post_login(self, user=None, token_result=None, id_from_idp=None):
user.email = token_result["email"]
user.display_name = "{given_name} {family_name}".format(**token_result)
info = {}
Expand Down Expand Up @@ -53,4 +53,4 @@ def post_login(self, user=None, token_result=None):
user.username, config["DREAM_CHALLENGE_GROUP"]
)

super(SynapseCallback, self).post_login()
super(SynapseCallback, self).post_login(id_from_idp=id_from_idp)
6 changes: 4 additions & 2 deletions fence/job/visa_update_cronjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ async def update_tokens(self, db_session):
Initialize a producer-consumer workflow.

Producer: Collects users from db and feeds it to the workers
Worker: Takes in the users from the Producer and passes it to the Updater to update the tokens and passes those updated tokens for JWT validation
Updater: Updates refresh_tokens and visas by calling the update_user_visas from the correct client
Worker: Takes in the users from the Producer and passes it to the Updater to
update the tokens and passes those updated tokens for JWT validation
Updater: Updates refresh_tokens and visas by calling the update_user_visas from
the correct client

"""
start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/cognito_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_user_id(self, code):
if claims["email"] and (
claims["email_verified"] or self.settings["assume_emails_verified"]
):
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
elif claims["email"]:
return {"error": "Email is not verified"}
else:
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/google_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims["email"] and claims["email_verified"]:
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
elif claims["email"]:
return {"error": "Email is not verified"}
else:
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/microsoft_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims.get("email"):
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
return {"error": "Can't get user's Microsoft email!"}
except Exception as exception:
self.logger.exception("Can't get user info")
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/okta_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims["email"]:
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
else:
return {"error": "Can't get user's email!"}
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/orcid_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims["sub"]:
return {"orcid": claims["sub"]}
return {"orcid": claims["sub"], "sub": claims["sub"]}
else:
return {"error": "Can't get user's orcid"}
except Exception as e:
Expand Down
6 changes: 5 additions & 1 deletion fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def get_user_id(self, code):
self.logger.exception("{}: {}".format(err_msg, e))
return {"error": err_msg}

return {"username": username, "email": userinfo.get("email")}
return {
"username": username,
"email": userinfo.get("email"),
"sub": userinfo.get("sub"),
}

def refresh_cronjob_pkey_cache(self, issuer, kid, pkey_cache):
"""
Expand Down
24 changes: 19 additions & 5 deletions tests/login/test_login_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,25 @@
def test_login_user_already_in_db(db_session):
"""
Test that if a user is already in the database and logs in, the session will contain
the user's information.
the user's information (including additional information that may have been provided
during the login like email and id_from_idp)
"""
email = "testuser@gmail.com"
provider = "Test Provider"
id_from_idp = "Provider_ID_0001"

test_user = User(username=email, is_admin=False)
db_session.add(test_user)
db_session.commit()
user_id = str(test_user.id)
assert not test_user.email
assert not test_user.id_from_idp

login_user(email, provider)
login_user(email, provider, email=email, id_from_idp=id_from_idp)

assert test_user.identity_provider.name == provider
assert test_user.id_from_idp == id_from_idp
assert test_user.email == email
assert flask.session["username"] == email
assert flask.session["provider"] == provider
assert flask.session["user_id"] == user_id
Expand All @@ -33,18 +39,23 @@ def test_login_user_with_idp_already_in_db(db_session):
"""
email = "testuser@gmail.com"
provider = "Test Provider"
id_from_idp = "Provider_ID_0001"

test_user = User(username=email, is_admin=False)
test_user = User(
username=email, email=email, id_from_idp=id_from_idp, is_admin=False
)
test_idp = IdentityProvider(name=provider)
test_user.identity_provider = test_idp

db_session.add(test_user)
db_session.commit()
user_id = str(test_user.id)

login_user(email, provider)
login_user(email, provider, email=email, id_from_idp=id_from_idp)

assert test_user.identity_provider.name == provider
assert test_user.id_from_idp == id_from_idp
assert test_user.email == email
assert flask.session["username"] == email
assert flask.session["provider"] == provider
assert flask.session["user_id"] == user_id
Expand All @@ -58,12 +69,15 @@ def test_login_new_user(db_session):
"""
email = "testuser@gmail.com"
provider = "Test Provider"
id_from_idp = "Provider_ID_0001"

login_user(email, provider)
login_user(email, provider, email=email, id_from_idp=id_from_idp)

test_user = db_session.query(User).filter(User.username == email.lower()).first()

assert test_user.identity_provider.name == provider
assert test_user.id_from_idp == id_from_idp
assert test_user.email == email
assert flask.session["username"] == email
assert flask.session["provider"] == provider
assert flask.session["user_id"] == str(test_user.id)
Expand Down
3 changes: 2 additions & 1 deletion tests/login/test_microsoft_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_get_user_id(microsoft_oauth2_client):
return_value=return_value,
):
user_id = microsoft_oauth2_client.get_user_id(code="123")
assert user_id == expected_value # nosec
for key, value in expected_value.items():
assert return_value[key] == value


def test_get_user_id_missing_claim(microsoft_oauth2_client):
Expand Down
0