Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement forgot password feature #5534

Merged
merged 14 commits into from
Jul 5, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,5 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json
pyrightconfig.json

.idea/
4 changes: 4 additions & 0 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class SecurityConfig(BaseModel):
default=None,
)

RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description='Expiry time in hours for reset token',
default=24,
)

class AppExecutionConfig(BaseModel):
"""
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth

# Import billing controllers
from .billing import billing
Expand Down
25 changes: 25 additions & 0 deletions api/controllers/console/auth/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
description = "{message}"
code = 500


class InvalidEmailError(BaseHTTPException):
error_code = 'invalid_email'
description = "The email address is not valid."
code = 400


class PasswordMismatchError(BaseHTTPException):
error_code = 'password_mismatch'
description = "The passwords do not match."
code = 400


class InvalidTokenError(BaseHTTPException):
error_code = 'invalid_or_expired_token'
description = "The token is invalid or has expired."
code = 400


class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = 'password_reset_rate_limit_exceeded'
description = "Password reset rate limit exceeded. Try again later."
code = 429

107 changes: 107 additions & 0 deletions api/controllers/console/auth/forgot_password.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import base64
import logging
import secrets

from flask_restful import Resource, reqparse

from controllers.console import api
from controllers.console.auth.error import (
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
PasswordResetRateLimitExceededError,
)
from controllers.console.setup import setup_required
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
from services.errors.account import RateLimitExceededError


class ForgotPasswordSendEmailApi(Resource):

@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json')
args = parser.parse_args()

email = args['email']

if not email_validate(email):
raise InvalidEmailError()

account = Account.query.filter_by(email=email).first()

if account:
try:
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
else:
# Return success to avoid revealing email registration status
logging.warning(f"Attempt to reset password for unregistered email: {email}")

return {"result": "success"}


class ForgotPasswordCheckApi(Resource):

@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
token = args['token']

reset_data = AccountService.get_reset_password_data(token)

if reset_data is None:
return {'is_valid': False, 'email': None}
return {'is_valid': True, 'email': reset_data.get('email')}


class ForgotPasswordResetApi(Resource):

@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
args = parser.parse_args()

new_password = args['new_password']
password_confirm = args['password_confirm']

if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError()

token = args['token']
reset_data = AccountService.get_reset_password_data(token)

if reset_data is None:
raise InvalidTokenError()

AccountService.revoke_reset_password_token(token)

salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()

password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()

account = Account.query.filter_by(email=reset_data.get('email')).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()

return {'result': 'success'}


api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
2 changes: 2 additions & 0 deletions api/controllers/console/workspace/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def get(self):
return {'data': integrate_data}




# Register API resources
api.add_resource(AccountInitApi, '/account/init')
api.add_resource(AccountProfileApi, '/account/profile')
Expand Down
107 changes: 103 additions & 4 deletions api/libs/helper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import json
import logging
import random
import re
import string
import subprocess
import time
import uuid
from collections.abc import Generator
from datetime import datetime
from hashlib import sha256
from typing import Union
from typing import Any, Optional, Union
from zoneinfo import available_timezones

from flask import Response, stream_with_context
from flask import Response, current_app, stream_with_context
from flask_restful import fields

from extensions.ext_redis import redis_client
from models.account import Account


def run(script):
return subprocess.getstatusoutput('source /root/.bashrc && ' + script)
Expand Down Expand Up @@ -46,12 +51,12 @@ def uuid_value(value):
error = ('{value} is not a valid uuid.'
.format(value=value))
raise ValueError(error)

def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r'^[a-zA-Z0-9_]+$', value):
return value

raise ValueError(f'{value} is not a valid alphanumeric value')

def timestamp_value(timestamp):
Expand Down Expand Up @@ -163,3 +168,97 @@ def generate() -> Generator:

return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')


class TokenManager:

@classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode('utf-8')
cls.revoke_token(old_token, token_type)

token = str(uuid.uuid4())
token_data = {
'account_id': account.id,
'email': account.email,
'token_type': token_type
}
if additional_data:
token_data.update(additional_data)

expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
token_key = cls._get_token_key(token, token_type)
redis_client.setex(
token_key,
expiry_hours * 60 * 60,
json.dumps(token_data)
)

cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token

@classmethod
def _get_token_key(cls, token: str, token_type: str) -> str:
return f'{token_type}:token:{token}'

@classmethod
def revoke_token(cls, token: str, token_type: str):
token_key = cls._get_token_key(token, token_type)
redis_client.delete(token_key)

@classmethod
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key)
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
return None
token_data = json.loads(token_data_json)
return token_data

@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
key = cls._get_account_token_key(account_id, token_type)
current_token = redis_client.get(key)
return current_token

@classmethod
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int):
key = cls._get_account_token_key(account_id, token_type)
redis_client.setex(key, expiry_hours * 60 * 60, token)

@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
return f'{token_type}:account:{account_id}'


class RateLimiter:
def __init__(self, prefix: str, max_attempts: int, time_window: int):
self.prefix = prefix
self.max_attempts = max_attempts
self.time_window = time_window

def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"

def is_rate_limited(self, email: str) -> bool:
key = self._get_key(email)
current_time = int(time.time())
window_start_time = current_time - self.time_window

redis_client.zremrangebyscore(key, '-inf', window_start_time)
attempts = redis_client.zcard(key)

if attempts and int(attempts) >= self.max_attempts:
return True
return False

def increment_rate_limit(self, email: str):
key = self._get_key(email)
current_time = int(time.time())

redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2)
33 changes: 33 additions & 0 deletions api/services/account_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client
from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
Expand All @@ -29,14 +30,22 @@
LinkAccountIntegrateError,
MemberNotInTenantError,
NoPermissionError,
RateLimitExceededError,
RoleAlreadyAssignedError,
TenantNotFound,
)
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task


class AccountService:

reset_password_rate_limiter = RateLimiter(
prefix="reset_password_rate_limit",
max_attempts=5,
time_window=60 * 60
)

@staticmethod
def load_user(user_id: str) -> Account:
account = Account.query.filter_by(id=user_id).first()
Expand Down Expand Up @@ -222,9 +231,33 @@ def load_logged_in_account(*, account_id: str, token: str):
return None
return AccountService.load_user(account_id)

@classmethod
def send_reset_password_email(cls, account):
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")

token = TokenManager.generate_token(account, 'reset_password')
send_reset_password_mail_task.delay(
language=account.interface_language,
to=account.email,
token=token
)
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
return token

@classmethod
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, 'reset_password')

@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, 'reset_password')


def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"


class TenantService:

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions api/services/errors/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ class MemberNotInTenantError(BaseServiceError):

class RoleAlreadyAssignedError(BaseServiceError):
pass


class RateLimitExceededError(BaseServiceError):
pass

Loading