# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2015-2023 CERN.
# Copyright (C) 2022 KTH Royal Institute of Technology
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""Database models for accounts."""
import uuid
from datetime import datetime
from flask import current_app, session
from flask_babel import refresh
from flask_security import RoleMixin, UserMixin
from invenio_db import db
from sqlalchemy.dialects import postgresql
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import validates
from sqlalchemy_utils import IPAddressType, Timestamp
from sqlalchemy_utils.types import JSONType
from .errors import AlreadyLinkedError
from .profiles import UserPreferenceDict, UserProfileDict
from .utils import validate_username
json_field = (
db.JSON()
.with_variant(
postgresql.JSONB(none_as_null=True),
"postgresql",
)
.with_variant(
JSONType(),
"sqlite",
)
.with_variant(
JSONType(),
"mysql",
)
)
userrole = db.Table(
"accounts_userrole",
db.Column(
"user_id",
db.Integer(),
db.ForeignKey("accounts_user.id", name="fk_accounts_userrole_user_id"),
),
db.Column(
"role_id",
db.String(80),
db.ForeignKey("accounts_role.id", name="fk_accounts_userrole_role_id"),
),
)
"""Relationship between users and roles."""
[docs]class Role(db.Model, Timestamp, RoleMixin):
"""Role data model."""
__tablename__ = "accounts_role"
id = db.Column(db.String(80), primary_key=True, default=lambda x: str(uuid.uuid4()))
name = db.Column(db.String(80), unique=True)
"""Role name."""
description = db.Column(db.String(255))
"""Role description."""
is_managed = db.Column(db.Boolean(), default=True, nullable=False)
"""True when the role is managed by Invenio, and not externally provided."""
# Enables SQLAlchemy version counter
version_id = db.Column(db.Integer, nullable=False)
"""Used by SQLAlchemy for optimistic concurrency control."""
__mapper_args__ = {"version_id_col": version_id}
def __str__(self):
"""Return the name and description of the role."""
return "{0.name} - {0.description}".format(self)
[docs]class User(db.Model, Timestamp, UserMixin):
"""User data model."""
__tablename__ = "accounts_user"
id = db.Column(db.Integer, primary_key=True)
_username = db.Column("username", db.String(255), nullable=True, unique=True)
"""Lower-case version of the username, to assert uniqueness."""
_displayname = db.Column("displayname", db.String(255), nullable=True)
"""Case-preserving version of the username."""
email = db.Column(db.String(255), unique=True)
"""User email."""
password = db.Column(db.String(255))
"""User password."""
active = db.Column(db.Boolean(name="active"))
"""Flag to say if the user is active or not ."""
confirmed_at = db.Column(db.DateTime)
"""When the user confirmed the email address."""
roles = db.relationship(
"Role", secondary=userrole, backref=db.backref("users", lazy="dynamic")
)
"""List of the user's roles."""
# Enables SQLAlchemy version counter
version_id = db.Column(db.Integer, nullable=False)
"""Used by SQLAlchemy for optimistic concurrency control."""
_user_profile = db.Column(
"profile",
json_field,
default=lambda: dict(),
nullable=True,
)
"""The user profile as a JSON field."""
_preferences = db.Column(
"preferences",
json_field,
default=lambda: dict(),
nullable=True,
)
"""The user's preferences stored in a JSON field."""
__mapper_args__ = {"version_id_col": version_id}
login_info = db.relationship(
"LoginInformation", back_populates="user", uselist=False, lazy="joined"
)
blocked_at = db.Column(
db.DateTime,
nullable=True,
)
verified_at = db.Column(
db.DateTime,
nullable=True,
)
def __init__(self, *args, **kwargs):
"""Constructor."""
self.verified_at = (
datetime.utcnow()
if current_app.config.get("ACCOUNTS_DEFAULT_USERS_VERIFIED")
else None
)
user_profile = kwargs.pop("user_profile", {})
preferences = kwargs.pop("preferences", {})
preferences.setdefault(
"visibility",
current_app.config.get("ACCOUNTS_DEFAULT_USER_VISIBILITY", "restricted"),
)
preferences.setdefault(
"email_visibility",
current_app.config.get("ACCOUNTS_DEFAULT_EMAIL_VISIBILITY", "restricted"),
)
preferences.setdefault(
"locale",
current_app.config.get("BABEL_DEFAULT_LOCALE", "en"),
)
preferences.setdefault(
"timezone",
current_app.config.get("BABEL_DEFAULT_TIMEZONE", "Europe/Zurich"),
)
super().__init__(*args, **kwargs)
self.user_profile = user_profile
self.preferences = preferences
@hybrid_property
def username(self):
"""Get username."""
return self._displayname
@username.setter
def username(self, username):
"""Set username.
.. note:: The username will be converted to lowercase.
The display name will contain the original version.
"""
if username is None:
# if the username can't be validated, a ValueError will be raised
self._displayname = None
self._username = None
else:
validate_username(username)
self._displayname = username
self._username = username.lower()
@hybrid_property
def user_profile(self):
"""Get the user profile."""
# NOTE: accessing this property requires an initialized app for config
if self._user_profile is None:
return None
elif not isinstance(self._user_profile, UserProfileDict):
return UserProfileDict(**self._user_profile)
return self._user_profile
@user_profile.setter
def user_profile(self, value):
"""Set the user profile."""
if value is None:
self._user_profile = None
else:
self._user_profile = UserProfileDict(**value)
@hybrid_property
def preferences(self):
"""Get the user preferences."""
# NOTE: accessing this property requires an initialized app for config
if self._preferences is None:
return None
elif not isinstance(self._preferences, UserPreferenceDict):
self._preferences = UserPreferenceDict(**self._preferences)
return self._preferences
@preferences.setter
def preferences(self, value):
"""Set the user preferences."""
if value is None:
self._preferences = None
else:
self._preferences = UserPreferenceDict(**value)
refresh()
def _get_login_info_attr(self, attr_name):
if self.login_info is None:
return None
return getattr(self.login_info, attr_name)
def _set_login_info_attr(self, attr_name, value):
if self.login_info is None:
self.login_info = LoginInformation()
setattr(self.login_info, attr_name, value)
@property
def current_login_at(self):
"""When user logged into the current session."""
return self._get_login_info_attr("current_login_at")
@property
def current_login_ip(self):
"""Current user IP address."""
return self._get_login_info_attr("current_login_ip")
@property
def last_login_at(self):
"""When the user logged-in for the last time."""
return self._get_login_info_attr("last_login_at")
@property
def last_login_ip(self):
"""Last user IP address."""
return self._get_login_info_attr("last_login_ip")
@property
def login_count(self):
"""Count how many times the user logged in."""
return self._get_login_info_attr("login_count")
@current_login_at.setter
def current_login_at(self, value):
return self._set_login_info_attr("current_login_at", value)
@current_login_ip.setter
def current_login_ip(self, value):
return self._set_login_info_attr("current_login_ip", value)
@last_login_at.setter
def last_login_at(self, value):
return self._set_login_info_attr("last_login_at", value)
@last_login_ip.setter
def last_login_ip(self, value):
return self._set_login_info_attr("last_login_ip", value)
@login_count.setter
def login_count(self, value):
return self._set_login_info_attr("login_count", value)
def __str__(self):
"""Representation."""
return "User <id={0.id}, email={0.email}>".format(self)
[docs]class SessionActivity(db.Model, Timestamp):
"""User Session Activity model.
Instances of this model correspond to a session belonging to a user.
"""
__tablename__ = "accounts_user_session_activity"
sid_s = db.Column(db.String(255), primary_key=True)
"""Serialized Session ID. Used as the session's key in the kv-session
store employed by `flask-kvsession`.
Named here as it is in `flask-kvsession` to avoid confusion.
"""
user_id = db.Column(
db.Integer, db.ForeignKey(User.id, name="fk_accounts_session_activity_user_id")
)
"""ID of user to whom this session belongs."""
user = db.relationship(User, backref="active_sessions")
ip = db.Column(db.String(80), nullable=True)
"""IP address."""
country = db.Column(db.String(3), nullable=True)
"""Country name."""
browser = db.Column(db.String(80), nullable=True)
"""User browser."""
browser_version = db.Column(db.String(30), nullable=True)
"""Browser version."""
os = db.Column(db.String(80), nullable=True)
"""User operative system name."""
device = db.Column(db.String(80), nullable=True)
"""User device."""
[docs] @classmethod
def query_by_expired(cls):
"""Query to select all expired sessions."""
lifetime = current_app.permanent_session_lifetime
expired_moment = datetime.utcnow() - lifetime
return cls.query.filter(cls.created < expired_moment)
[docs] @classmethod
def query_by_user(cls, user_id):
"""Query to select user sessions."""
return cls.query.filter_by(user_id=user_id)
[docs] @classmethod
def is_current(cls, sid_s):
"""Check if the session is the current one."""
return session.sid_s == sid_s
[docs]class UserIdentity(db.Model, Timestamp):
"""Represent a UserIdentity record."""
__tablename__ = "accounts_useridentity"
id = db.Column(db.String(255), primary_key=True, nullable=False)
method = db.Column(db.String(255), primary_key=True, nullable=False)
id_user = db.Column(db.Integer(), db.ForeignKey(User.id), nullable=False)
user = db.relationship(User, backref="external_identifiers")
__table_args__ = (
db.Index("accounts_useridentity_id_user_method", id_user, method, unique=True),
)
[docs] @classmethod
def get_user(cls, method, external_id):
"""Get the user for a given identity."""
identity = cls.query.filter_by(id=external_id, method=method).one_or_none()
if identity is not None:
return identity.user
return None
[docs] @classmethod
def create(cls, user, method, external_id):
"""Link a user to an external id.
:param user: A :class:`invenio_accounts.models.User` instance.
:param method: The identity source (e.g. orcid, github)
:param method: The external identifier.
:raises AlreadyLinkedError: Raised if already exists a link.
"""
try:
with db.session.begin_nested():
db.session.add(cls(id=external_id, method=method, id_user=user.id))
except IntegrityError:
raise AlreadyLinkedError(
# dict used for backward compatibility (came from oauthclient)
user,
{"id": external_id, "method": method},
)
[docs] @classmethod
def delete_by_external_id(cls, method, external_id):
"""Unlink a user from an external id."""
with db.session.begin_nested():
cls.query.filter_by(id=external_id, method=method).delete()
[docs] @classmethod
def delete_by_user(cls, method, user):
"""Unlink a user from an external id."""
with db.session.begin_nested():
cls.query.filter_by(id_user=user.id, method=method).delete()