def _get_csrf_token(self): # find the ``csrf_token`` field in the subitted form # if the form had a prefix, the name will be # ``{prefix}-csrf_token`` field_name = current_app.config['WTF_CSRF_FIELD_NAME'] for key in request.form: if key.endswith(field_name): csrf_token = request.form[key] if csrf_token: return csrf_token for header_name in current_app.config['WTF_CSRF_HEADERS']: csrf_token = request.headers.get(header_name) if csrf_token: return csrf_token return None
def generate_csrf(secret_key=None, token_key=None): """Generate a CSRF token. The token is cached for a request, so multiple calls to this function will generate the same token. During testing, it might be useful to access the signed token in ``g.csrf_token`` and the raw token in ``session['csrf_token']``. :param secret_key: Used to securely sign the token. Default is ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``. :param token_key: Key where token is stored in session for comparision. Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``. """ secret_key = _get_config( secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key, message='A secret key is required to use CSRF.' ) field_name = _get_config( token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token', message='A field name is required to use CSRF.' ) if field_name not in g: if field_name not in session: session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest() s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token') setattr(g, field_name, s.dumps(session[field_name])) return g.get(field_name)
def validate_csrf(data, secret_key=None, time_limit=None, token_key=None): """Check if the given data is a valid CSRF token. This compares the given signed token to the one stored in the session. :param data: The signed CSRF token to be checked. :param secret_key: Used to securely sign the token. Default is ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``. :param time_limit: Number of seconds that the token is valid. Default is ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes). :param token_key: Key where token is stored in session for comparision. Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``. :raises ValidationError: Contains the reason that validation failed. .. versionchanged:: 0.14 Raises ``ValidationError`` with a specific error message rather than returning ``True`` or ``False``. """ secret_key = _get_config( secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key, message='A secret key is required to use CSRF.' ) field_name = _get_config( token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token', message='A field name is required to use CSRF.' ) time_limit = _get_config( time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False ) if not data: raise ValidationError('The CSRF token is missing.') if field_name not in session: raise ValidationError('The CSRF session token is missing.') s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token') try: token = s.loads(data, max_age=time_limit) except SignatureExpired: raise ValidationError('The CSRF token has expired.') except BadData: raise ValidationError('The CSRF token is invalid.') if not safe_str_cmp(session[field_name], token): raise ValidationError('The CSRF tokens do not match.')
def init_app(self, app): app.extensions['csrf'] = self app.config.setdefault('WTF_CSRF_ENABLED', True) app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True) app.config['WTF_CSRF_METHODS'] = set(app.config.get( 'WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH', 'DELETE'] )) app.config.setdefault('WTF_CSRF_FIELD_NAME', 'csrf_token') app.config.setdefault( 'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token'] ) app.config.setdefault('WTF_CSRF_TIME_LIMIT', 3600) app.config.setdefault('WTF_CSRF_SSL_STRICT', True) app.jinja_env.globals['csrf_token'] = generate_csrf app.context_processor(lambda: {'csrf_token': generate_csrf}) @app.before_request def csrf_protect(): if not app.config['WTF_CSRF_ENABLED']: return if not app.config['WTF_CSRF_CHECK_DEFAULT']: return if request.method not in app.config['WTF_CSRF_METHODS']: return if not request.endpoint: return view = app.view_functions.get(request.endpoint) if not view: return if request.blueprint in self._exempt_blueprints: return dest = '%s.%s' % (view.__module__, view.__name__) if dest in self._exempt_views: return self.protect()