Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement forgot password feature #5534

Merged
merged 14 commits into from
Jul 5, 2024
Prev Previous commit
Next Next commit
Add TokenManager class to ensure only the latest reset password token…
… is usable
  • Loading branch information
xielong committed Jul 4, 2024
commit 62e11ef552447c30d876dba8f2bdb8de5d66c2de
2 changes: 1 addition & 1 deletion api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SecurityConfig(BaseModel):
default=None,
)

RESET_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description='Expiry time in hours for reset token',
default=24,
)
Expand Down
6 changes: 3 additions & 3 deletions api/controllers/console/auth/forgot_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def post(self):
args = parser.parse_args()
token = args['token']

reset_data = AccountService.get_reset_data(token)
reset_data = AccountService.get_reset_password_data(token)

if reset_data is None:
return {'is_valid': False, 'email': None}
Expand All @@ -73,12 +73,12 @@ def post(self):
raise PasswordMismatchError()

token = args['token']
reset_data = AccountService.get_reset_data(token)
reset_data = AccountService.get_reset_password_data(token)

if reset_data is None:
raise InvalidTokenError()

AccountService.revoke_reset_token(token)
AccountService.revoke_reset_password_token(token)

salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
Expand Down
110 changes: 78 additions & 32 deletions api/services/account_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def load_logged_in_account(*, account_id: str, token: str):

@classmethod
def send_reset_password_email(cls, account: Account):
token = cls.generate_reset_token(account)
token = TokenManager.generate_token(account, 'reset_password')
send_reset_password_mail_task.delay(
language=account.interface_language,
to=account.email,
Expand All @@ -234,40 +234,18 @@ def send_reset_password_email(cls, account: Account):
return token

@classmethod
def generate_reset_token(cls, account: Account) -> str:
token = str(uuid.uuid4())
reset_data = {
'account_id': account.id,
'email': account.email,
}
expiryHours = current_app.config['RESET_TOKEN_EXPIRY_HOURS']
redis_client.setex(
cls._get_reset_token_key(token),
expiryHours * 60 * 60,
json.dumps(reset_data)
)
return token
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, 'reset_password')

@classmethod
def _get_reset_token_key(cls, token: str) -> str:
return f'reset_password:token:{token}'
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, 'reset_password')

@classmethod
def revoke_reset_token(cls, token: str):
redis_client.delete(cls._get_reset_token_key(token))

@classmethod
def get_reset_data(cls, token: str) -> Optional[dict[str, Any]]:
key = cls._get_reset_token_key(token)
reset_data_json = redis_client.get(key)
if reset_data_json is None:
return None
reset_data = json.loads(reset_data_json)
return reset_data

def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"


class TenantService:

@staticmethod
Expand Down Expand Up @@ -343,16 +321,18 @@ def switch_tenant(account: Account, tenant_id: int = None) -> None:
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")

tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant,
TenantAccountJoin.tenant_id == Tenant.id).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
).first()
).first()

if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
tenant_account_join.current = True
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
Expand Down Expand Up @@ -564,7 +544,8 @@ def register(cls, email, name,
return account

@classmethod
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal',
inviter: Account = None) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()

Expand Down Expand Up @@ -684,3 +665,68 @@ def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) ->

invitation = json.loads(data)
return invitation


class TokenManager:

@classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode('utf-8')
cls.revoke_token(old_token, token_type)

token = str(uuid.uuid4())
token_data = {
'account_id': account.id,
'email': account.email,
'token_type': token_type
}
if additional_data:
token_data.update(additional_data)

expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
token_key = cls._get_token_key(token, token_type)
redis_client.setex(
token_key,
expiry_hours * 60 * 60,
json.dumps(token_data)
)

cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token

@classmethod
def _get_token_key(cls, token: str, token_type: str) -> str:
return f'{token_type}:token:{token}'

@classmethod
def revoke_token(cls, token: str, token_type: str):
token_key = cls._get_token_key(token, token_type)
redis_client.delete(token_key)

@classmethod
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key)
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
return None
token_data = json.loads(token_data_json)
return token_data

@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
key = cls._get_account_token_key(account_id, token_type)
current_token = redis_client.get(key)
return current_token

@classmethod
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int):
key = cls._get_account_token_key(account_id, token_type)
redis_client.setex(key, expiry_hours * 60 * 60, token)

@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
return f'{token_type}:account:{account_id}'