Merge branch 'master' of gitorious.org:mediagoblin/mediagoblin
This commit is contained in:
commit
0f6ab7da86
@ -102,6 +102,21 @@ class OAuthCode_v0(declarative_base()):
|
|||||||
client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
|
client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthRefreshToken_v0(declarative_base()):
|
||||||
|
__tablename__ = 'oauth__refresh_tokens'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
created = Column(DateTime, nullable=False,
|
||||||
|
default=datetime.now)
|
||||||
|
|
||||||
|
token = Column(Unicode, index=True)
|
||||||
|
|
||||||
|
user_id = Column(Integer, ForeignKey(User.id), nullable=False)
|
||||||
|
|
||||||
|
# XXX: Is it OK to use OAuthClient_v0.id in this way?
|
||||||
|
client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
@RegisterMigration(1, MIGRATIONS)
|
@RegisterMigration(1, MIGRATIONS)
|
||||||
def remove_and_replace_token_and_code(db):
|
def remove_and_replace_token_and_code(db):
|
||||||
metadata = MetaData(bind=db.bind)
|
metadata = MetaData(bind=db.bind)
|
||||||
@ -122,3 +137,22 @@ def remove_and_replace_token_and_code(db):
|
|||||||
OAuthCode_v0.__table__.create(db.bind)
|
OAuthCode_v0.__table__.create(db.bind)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@RegisterMigration(2, MIGRATIONS)
|
||||||
|
def remove_refresh_token_field(db):
|
||||||
|
metadata = MetaData(bind=db.bind)
|
||||||
|
|
||||||
|
token_table = Table('oauth__tokens', metadata, autoload=True,
|
||||||
|
autoload_with=db.bind)
|
||||||
|
|
||||||
|
refresh_token = token_table.columns['refresh_token']
|
||||||
|
|
||||||
|
refresh_token.drop()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
@RegisterMigration(3, MIGRATIONS)
|
||||||
|
def create_refresh_token_table(db):
|
||||||
|
OAuthRefreshToken_v0.__table__.create(db.bind)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
@ -14,17 +14,17 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import uuid
|
|
||||||
import bcrypt
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from mediagoblin.db.base import Base
|
|
||||||
from mediagoblin.db.models import User
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, Unicode, Integer, DateTime, ForeignKey, Enum)
|
Column, Unicode, Integer, DateTime, ForeignKey, Enum)
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship, backref
|
||||||
|
from mediagoblin.db.base import Base
|
||||||
|
from mediagoblin.db.models import User
|
||||||
|
from mediagoblin.plugins.oauth.tools import generate_identifier, \
|
||||||
|
generate_secret, generate_token, generate_code, generate_refresh_token
|
||||||
|
|
||||||
# Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto
|
# Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto
|
||||||
# the models.
|
# the models.
|
||||||
@ -41,11 +41,14 @@ class OAuthClient(Base):
|
|||||||
name = Column(Unicode)
|
name = Column(Unicode)
|
||||||
description = Column(Unicode)
|
description = Column(Unicode)
|
||||||
|
|
||||||
identifier = Column(Unicode, unique=True, index=True)
|
identifier = Column(Unicode, unique=True, index=True,
|
||||||
secret = Column(Unicode, index=True)
|
default=generate_identifier)
|
||||||
|
secret = Column(Unicode, index=True, default=generate_secret)
|
||||||
|
|
||||||
owner_id = Column(Integer, ForeignKey(User.id))
|
owner_id = Column(Integer, ForeignKey(User.id))
|
||||||
owner = relationship(User, backref='registered_clients')
|
owner = relationship(
|
||||||
|
User,
|
||||||
|
backref=backref('registered_clients', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
redirect_uri = Column(Unicode)
|
redirect_uri = Column(Unicode)
|
||||||
|
|
||||||
@ -54,14 +57,8 @@ class OAuthClient(Base):
|
|||||||
u'public',
|
u'public',
|
||||||
name=u'oauth__client_type'))
|
name=u'oauth__client_type'))
|
||||||
|
|
||||||
def generate_identifier(self):
|
def update_secret(self):
|
||||||
self.identifier = unicode(uuid.uuid4())
|
self.secret = generate_secret()
|
||||||
|
|
||||||
def generate_secret(self):
|
|
||||||
self.secret = unicode(
|
|
||||||
bcrypt.hashpw(
|
|
||||||
unicode(uuid.uuid4()),
|
|
||||||
bcrypt.gensalt()))
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<{0} {1}:{2} ({3})>'.format(
|
return '<{0} {1}:{2} ({3})>'.format(
|
||||||
@ -76,10 +73,15 @@ class OAuthUserClient(Base):
|
|||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
|
|
||||||
user_id = Column(Integer, ForeignKey(User.id))
|
user_id = Column(Integer, ForeignKey(User.id))
|
||||||
user = relationship(User, backref='oauth_clients')
|
user = relationship(
|
||||||
|
User,
|
||||||
|
backref=backref('oauth_client_relations',
|
||||||
|
cascade='all, delete-orphan'))
|
||||||
|
|
||||||
client_id = Column(Integer, ForeignKey(OAuthClient.id))
|
client_id = Column(Integer, ForeignKey(OAuthClient.id))
|
||||||
client = relationship(OAuthClient, backref='users')
|
client = relationship(
|
||||||
|
OAuthClient,
|
||||||
|
backref=backref('oauth_user_relations', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
state = Column(Enum(
|
state = Column(Enum(
|
||||||
u'approved',
|
u'approved',
|
||||||
@ -103,15 +105,18 @@ class OAuthToken(Base):
|
|||||||
default=datetime.now)
|
default=datetime.now)
|
||||||
expires = Column(DateTime, nullable=False,
|
expires = Column(DateTime, nullable=False,
|
||||||
default=lambda: datetime.now() + timedelta(days=30))
|
default=lambda: datetime.now() + timedelta(days=30))
|
||||||
token = Column(Unicode, index=True)
|
token = Column(Unicode, index=True, default=generate_token)
|
||||||
refresh_token = Column(Unicode, index=True)
|
|
||||||
|
|
||||||
user_id = Column(Integer, ForeignKey(User.id), nullable=False,
|
user_id = Column(Integer, ForeignKey(User.id), nullable=False,
|
||||||
index=True)
|
index=True)
|
||||||
user = relationship(User)
|
user = relationship(
|
||||||
|
User,
|
||||||
|
backref=backref('oauth_tokens', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
|
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
|
||||||
client = relationship(OAuthClient)
|
client = relationship(
|
||||||
|
OAuthClient,
|
||||||
|
backref=backref('oauth_tokens', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
|
return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
|
||||||
@ -121,6 +126,34 @@ class OAuthToken(Base):
|
|||||||
self.user,
|
self.user,
|
||||||
self.client)
|
self.client)
|
||||||
|
|
||||||
|
class OAuthRefreshToken(Base):
|
||||||
|
__tablename__ = 'oauth__refresh_tokens'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
created = Column(DateTime, nullable=False,
|
||||||
|
default=datetime.now)
|
||||||
|
|
||||||
|
token = Column(Unicode, index=True,
|
||||||
|
default=generate_refresh_token)
|
||||||
|
|
||||||
|
user_id = Column(Integer, ForeignKey(User.id), nullable=False)
|
||||||
|
|
||||||
|
user = relationship(User, backref=backref('oauth_refresh_tokens',
|
||||||
|
cascade='all, delete-orphan'))
|
||||||
|
|
||||||
|
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
|
||||||
|
client = relationship(OAuthClient,
|
||||||
|
backref=backref(
|
||||||
|
'oauth_refresh_tokens',
|
||||||
|
cascade='all, delete-orphan'))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<{0} #{1} [{3}, {4}]>'.format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.id,
|
||||||
|
self.user,
|
||||||
|
self.client)
|
||||||
|
|
||||||
|
|
||||||
class OAuthCode(Base):
|
class OAuthCode(Base):
|
||||||
__tablename__ = 'oauth__codes'
|
__tablename__ = 'oauth__codes'
|
||||||
@ -130,14 +163,17 @@ class OAuthCode(Base):
|
|||||||
default=datetime.now)
|
default=datetime.now)
|
||||||
expires = Column(DateTime, nullable=False,
|
expires = Column(DateTime, nullable=False,
|
||||||
default=lambda: datetime.now() + timedelta(minutes=5))
|
default=lambda: datetime.now() + timedelta(minutes=5))
|
||||||
code = Column(Unicode, index=True)
|
code = Column(Unicode, index=True, default=generate_code)
|
||||||
|
|
||||||
user_id = Column(Integer, ForeignKey(User.id), nullable=False,
|
user_id = Column(Integer, ForeignKey(User.id), nullable=False,
|
||||||
index=True)
|
index=True)
|
||||||
user = relationship(User)
|
user = relationship(User, backref=backref('oauth_codes',
|
||||||
|
cascade='all, delete-orphan'))
|
||||||
|
|
||||||
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
|
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
|
||||||
client = relationship(OAuthClient)
|
client = relationship(OAuthClient, backref=backref(
|
||||||
|
'oauth_codes',
|
||||||
|
cascade='all, delete-orphan'))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
|
return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
|
||||||
@ -150,6 +186,7 @@ class OAuthCode(Base):
|
|||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
OAuthToken,
|
OAuthToken,
|
||||||
|
OAuthRefreshToken,
|
||||||
OAuthCode,
|
OAuthCode,
|
||||||
OAuthClient,
|
OAuthClient,
|
||||||
OAuthUserClient]
|
OAuthUserClient]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
# GNU MediaGoblin -- federated, autonomous media hosting
|
# GNU MediaGoblin -- federated, autonomous media hosting
|
||||||
# Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
|
# Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
|
||||||
#
|
#
|
||||||
@ -14,13 +15,26 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from random import getrandbits
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from mediagoblin.plugins.oauth.models import OAuthClient
|
|
||||||
from mediagoblin.plugins.api.tools import json_response
|
from mediagoblin.plugins.api.tools import json_response
|
||||||
|
|
||||||
|
|
||||||
def require_client_auth(controller):
|
def require_client_auth(controller):
|
||||||
|
'''
|
||||||
|
View decorator
|
||||||
|
|
||||||
|
- Requires the presence of ``?client_id``
|
||||||
|
'''
|
||||||
|
# Avoid circular import
|
||||||
|
from mediagoblin.plugins.oauth.models import OAuthClient
|
||||||
|
|
||||||
@wraps(controller)
|
@wraps(controller)
|
||||||
def wrapper(request, *args, **kw):
|
def wrapper(request, *args, **kw):
|
||||||
if not request.GET.get('client_id'):
|
if not request.GET.get('client_id'):
|
||||||
@ -41,3 +55,60 @@ def require_client_auth(controller):
|
|||||||
return controller(request, client)
|
return controller(request, client)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def create_token(client, user):
|
||||||
|
'''
|
||||||
|
Create an OAuthToken and an OAuthRefreshToken entry in the database
|
||||||
|
|
||||||
|
Returns the data structure expected by the OAuth clients.
|
||||||
|
'''
|
||||||
|
from mediagoblin.plugins.oauth.models import OAuthToken, OAuthRefreshToken
|
||||||
|
|
||||||
|
token = OAuthToken()
|
||||||
|
token.user = user
|
||||||
|
token.client = client
|
||||||
|
token.save()
|
||||||
|
|
||||||
|
refresh_token = OAuthRefreshToken()
|
||||||
|
refresh_token.user = user
|
||||||
|
refresh_token.client = client
|
||||||
|
refresh_token.save()
|
||||||
|
|
||||||
|
# expire time of token in full seconds
|
||||||
|
# timedelta.total_seconds is python >= 2.7 or we would use that
|
||||||
|
td = token.expires - datetime.now()
|
||||||
|
exp_in = 86400*td.days + td.seconds # just ignore µsec
|
||||||
|
|
||||||
|
return {'access_token': token.token, 'token_type': 'bearer',
|
||||||
|
'refresh_token': refresh_token.token, 'expires_in': exp_in}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_identifier():
|
||||||
|
''' Generates a ``uuid.uuid4()`` '''
|
||||||
|
return unicode(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token():
|
||||||
|
''' Uses generate_identifier '''
|
||||||
|
return generate_identifier()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_refresh_token():
|
||||||
|
''' Uses generate_identifier '''
|
||||||
|
return generate_identifier()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_code():
|
||||||
|
''' Uses generate_identifier '''
|
||||||
|
return generate_identifier()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_secret():
|
||||||
|
'''
|
||||||
|
Generate a long string of pseudo-random characters
|
||||||
|
'''
|
||||||
|
# XXX: We might not want it to use bcrypt, since bcrypt takes its time to
|
||||||
|
# generate the result.
|
||||||
|
return unicode(getrandbits(192))
|
||||||
|
|
||||||
|
@ -16,21 +16,21 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
|
||||||
|
|
||||||
from urllib import urlencode
|
from urllib import urlencode
|
||||||
from uuid import uuid4
|
|
||||||
from datetime import datetime
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from mediagoblin.tools.response import render_to_response, redirect
|
from mediagoblin.tools.response import render_to_response, redirect
|
||||||
from mediagoblin.decorators import require_active_login
|
from mediagoblin.decorators import require_active_login
|
||||||
from mediagoblin.messages import add_message, SUCCESS, ERROR
|
from mediagoblin.messages import add_message, SUCCESS
|
||||||
from mediagoblin.tools.translate import pass_to_ugettext as _
|
from mediagoblin.tools.translate import pass_to_ugettext as _
|
||||||
from mediagoblin.plugins.oauth.models import OAuthCode, OAuthToken, \
|
from mediagoblin.plugins.oauth.models import OAuthCode, OAuthClient, \
|
||||||
OAuthClient, OAuthUserClient
|
OAuthUserClient, OAuthRefreshToken
|
||||||
from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \
|
from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \
|
||||||
AuthorizationForm
|
AuthorizationForm
|
||||||
from mediagoblin.plugins.oauth.tools import require_client_auth
|
from mediagoblin.plugins.oauth.tools import require_client_auth, \
|
||||||
|
create_token
|
||||||
from mediagoblin.plugins.api.tools import json_response
|
from mediagoblin.plugins.api.tools import json_response
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@ -51,9 +51,6 @@ def register_client(request):
|
|||||||
client.owner_id = request.user.id
|
client.owner_id = request.user.id
|
||||||
client.redirect_uri = unicode(form.redirect_uri.data)
|
client.redirect_uri = unicode(form.redirect_uri.data)
|
||||||
|
|
||||||
client.generate_identifier()
|
|
||||||
client.generate_secret()
|
|
||||||
|
|
||||||
client.save()
|
client.save()
|
||||||
|
|
||||||
add_message(request, SUCCESS, _('The client {0} has been registered!')\
|
add_message(request, SUCCESS, _('The client {0} has been registered!')\
|
||||||
@ -92,9 +89,9 @@ def authorize_client(request):
|
|||||||
form.client_id.data).first()
|
form.client_id.data).first()
|
||||||
|
|
||||||
if not client:
|
if not client:
|
||||||
_log.error('''No such client id as received from client authorization
|
_log.error('No such client id as received from client authorization \
|
||||||
form.''')
|
form.')
|
||||||
return BadRequest()
|
raise BadRequest()
|
||||||
|
|
||||||
if form.validate():
|
if form.validate():
|
||||||
relation = OAuthUserClient()
|
relation = OAuthUserClient()
|
||||||
@ -105,7 +102,7 @@ def authorize_client(request):
|
|||||||
elif form.deny.data:
|
elif form.deny.data:
|
||||||
relation.state = u'rejected'
|
relation.state = u'rejected'
|
||||||
else:
|
else:
|
||||||
return BadRequest
|
raise BadRequest()
|
||||||
|
|
||||||
relation.save()
|
relation.save()
|
||||||
|
|
||||||
@ -136,7 +133,7 @@ def authorize(request, client):
|
|||||||
return json_response({
|
return json_response({
|
||||||
'status': 400,
|
'status': 400,
|
||||||
'errors':
|
'errors':
|
||||||
[u'Public clients MUST have a redirect_uri pre-set']},
|
[u'Public clients should have a redirect_uri pre-set.']},
|
||||||
_disable_cors=True)
|
_disable_cors=True)
|
||||||
|
|
||||||
redirect_uri = client.redirect_uri
|
redirect_uri = client.redirect_uri
|
||||||
@ -146,11 +143,10 @@ def authorize(request, client):
|
|||||||
if not redirect_uri:
|
if not redirect_uri:
|
||||||
return json_response({
|
return json_response({
|
||||||
'status': 400,
|
'status': 400,
|
||||||
'errors': [u'Can not find a redirect_uri for client: {0}'\
|
'errors': [u'No redirect_uri supplied!']},
|
||||||
.format(client.name)]}, _disable_cors=True)
|
_disable_cors=True)
|
||||||
|
|
||||||
code = OAuthCode()
|
code = OAuthCode()
|
||||||
code.code = unicode(uuid4())
|
|
||||||
code.user = request.user
|
code.user = request.user
|
||||||
code.client = client
|
code.client = client
|
||||||
code.save()
|
code.save()
|
||||||
@ -180,59 +176,79 @@ def authorize(request, client):
|
|||||||
|
|
||||||
|
|
||||||
def access_token(request):
|
def access_token(request):
|
||||||
|
'''
|
||||||
|
Access token endpoint provides access tokens to any clients that have the
|
||||||
|
right grants/credentials
|
||||||
|
'''
|
||||||
|
|
||||||
|
client = None
|
||||||
|
user = None
|
||||||
|
|
||||||
if request.GET.get('code'):
|
if request.GET.get('code'):
|
||||||
|
# Validate the code arg, then get the client object from the db.
|
||||||
code = OAuthCode.query.filter(OAuthCode.code ==
|
code = OAuthCode.query.filter(OAuthCode.code ==
|
||||||
request.GET.get('code')).first()
|
request.GET.get('code')).first()
|
||||||
|
|
||||||
if code:
|
if not code:
|
||||||
if code.client.type == u'confidential':
|
return json_response({
|
||||||
|
'error': 'invalid_request',
|
||||||
|
'error_description':
|
||||||
|
'Invalid code.'})
|
||||||
|
|
||||||
|
client = code.client
|
||||||
|
user = code.user
|
||||||
|
|
||||||
|
elif request.args.get('refresh_token'):
|
||||||
|
# Validate a refresh token, then get the client object from the db.
|
||||||
|
refresh_token = OAuthRefreshToken.query.filter(
|
||||||
|
OAuthRefreshToken.token ==
|
||||||
|
request.args.get('refresh_token')).first()
|
||||||
|
|
||||||
|
if not refresh_token:
|
||||||
|
return json_response({
|
||||||
|
'error': 'invalid_request',
|
||||||
|
'error_description':
|
||||||
|
'Invalid refresh token.'})
|
||||||
|
|
||||||
|
client = refresh_token.client
|
||||||
|
user = refresh_token.user
|
||||||
|
|
||||||
|
if client:
|
||||||
client_identifier = request.GET.get('client_id')
|
client_identifier = request.GET.get('client_id')
|
||||||
|
|
||||||
if not client_identifier:
|
if not client_identifier:
|
||||||
return json_response({
|
return json_response({
|
||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
'error_description':
|
'error_description':
|
||||||
'Missing client_id in request'})
|
'Missing client_id in request.'})
|
||||||
|
|
||||||
|
if not client_identifier == client.identifier:
|
||||||
|
return json_response({
|
||||||
|
'error': 'invalid_client',
|
||||||
|
'error_description':
|
||||||
|
'Mismatching client credentials.'})
|
||||||
|
|
||||||
|
if client.type == u'confidential':
|
||||||
client_secret = request.GET.get('client_secret')
|
client_secret = request.GET.get('client_secret')
|
||||||
|
|
||||||
if not client_secret:
|
if not client_secret:
|
||||||
return json_response({
|
return json_response({
|
||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
'error_description':
|
'error_description':
|
||||||
'Missing client_secret in request'})
|
'Missing client_secret in request.'})
|
||||||
|
|
||||||
if not client_secret == code.client.secret or \
|
if not client_secret == client.secret:
|
||||||
not client_identifier == code.client.identifier:
|
|
||||||
return json_response({
|
return json_response({
|
||||||
'error': 'invalid_client',
|
'error': 'invalid_client',
|
||||||
'error_description':
|
'error_description':
|
||||||
'The client_id or client_secret does not match the'
|
'Mismatching client credentials.'})
|
||||||
' code'})
|
|
||||||
|
|
||||||
token = OAuthToken()
|
|
||||||
token.token = unicode(uuid4())
|
|
||||||
token.user = code.user
|
|
||||||
token.client = code.client
|
|
||||||
token.save()
|
|
||||||
|
|
||||||
# expire time of token in full seconds
|
access_token_data = create_token(client, user)
|
||||||
# timedelta.total_seconds is python >= 2.7 or we would use that
|
|
||||||
td = token.expires - datetime.now()
|
|
||||||
exp_in = 86400*td.days + td.seconds # just ignore µsec
|
|
||||||
|
|
||||||
access_token_data = {
|
|
||||||
'access_token': token.token,
|
|
||||||
'token_type': 'bearer',
|
|
||||||
'expires_in': exp_in}
|
|
||||||
return json_response(access_token_data, _disable_cors=True)
|
return json_response(access_token_data, _disable_cors=True)
|
||||||
else:
|
|
||||||
return json_response({
|
return json_response({
|
||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
'error_description':
|
'error_description':
|
||||||
'Invalid code'})
|
'Missing `code` or `refresh_token` parameter in request.'})
|
||||||
else:
|
|
||||||
return json_response({
|
|
||||||
'error': 'invalid_request',
|
|
||||||
'error_descriptin':
|
|
||||||
'Missing `code` parameter in request'})
|
|
||||||
|
@ -71,7 +71,7 @@ class TestOAuth(object):
|
|||||||
assert response.status_int == 200
|
assert response.status_int == 200
|
||||||
|
|
||||||
# Should display an error
|
# Should display an error
|
||||||
assert ctx['form'].redirect_uri.errors
|
assert len(ctx['form'].redirect_uri.errors)
|
||||||
|
|
||||||
# Should not pass through
|
# Should not pass through
|
||||||
assert not client
|
assert not client
|
||||||
@ -79,12 +79,16 @@ class TestOAuth(object):
|
|||||||
def test_2_successful_public_client_registration(self, test_app):
|
def test_2_successful_public_client_registration(self, test_app):
|
||||||
''' Successfully register a public client '''
|
''' Successfully register a public client '''
|
||||||
self._setup(test_app)
|
self._setup(test_app)
|
||||||
|
uri = 'http://foo.example'
|
||||||
self.register_client(test_app, u'OMGOMG', 'public', 'OMG!',
|
self.register_client(test_app, u'OMGOMG', 'public', 'OMG!',
|
||||||
'http://foo.example')
|
uri)
|
||||||
|
|
||||||
client = self.db.OAuthClient.query.filter(
|
client = self.db.OAuthClient.query.filter(
|
||||||
self.db.OAuthClient.name == u'OMGOMG').first()
|
self.db.OAuthClient.name == u'OMGOMG').first()
|
||||||
|
|
||||||
|
# redirect_uri should be set
|
||||||
|
assert client.redirect_uri == uri
|
||||||
|
|
||||||
# Client should have been registered
|
# Client should have been registered
|
||||||
assert client
|
assert client
|
||||||
|
|
||||||
@ -116,7 +120,7 @@ class TestOAuth(object):
|
|||||||
redirect_uri = 'https://foo.example'
|
redirect_uri = 'https://foo.example'
|
||||||
response = test_app.get('/oauth/authorize', {
|
response = test_app.get('/oauth/authorize', {
|
||||||
'client_id': client.identifier,
|
'client_id': client.identifier,
|
||||||
'scope': 'admin',
|
'scope': 'all',
|
||||||
'redirect_uri': redirect_uri})
|
'redirect_uri': redirect_uri})
|
||||||
|
|
||||||
# User-agent should NOT be redirected
|
# User-agent should NOT be redirected
|
||||||
@ -142,6 +146,7 @@ class TestOAuth(object):
|
|||||||
return authorization_response, client_identifier
|
return authorization_response, client_identifier
|
||||||
|
|
||||||
def get_code_from_redirect_uri(self, uri):
|
def get_code_from_redirect_uri(self, uri):
|
||||||
|
''' Get the value of ?code= from an URI '''
|
||||||
return parse_qs(urlparse(uri).query)['code'][0]
|
return parse_qs(urlparse(uri).query)['code'][0]
|
||||||
|
|
||||||
def test_token_endpoint_successful_confidential_request(self, test_app):
|
def test_token_endpoint_successful_confidential_request(self, test_app):
|
||||||
@ -170,6 +175,11 @@ code={1}&client_secret={2}'.format(client_id, code, client.secret))
|
|||||||
assert type(token_data['expires_in']) == int
|
assert type(token_data['expires_in']) == int
|
||||||
assert token_data['expires_in'] > 0
|
assert token_data['expires_in'] > 0
|
||||||
|
|
||||||
|
# There should be a refresh token provided in the token data
|
||||||
|
assert len(token_data['refresh_token'])
|
||||||
|
|
||||||
|
return client_id, token_data
|
||||||
|
|
||||||
def test_token_endpont_missing_id_confidential_request(self, test_app):
|
def test_token_endpont_missing_id_confidential_request(self, test_app):
|
||||||
''' Unsuccessful request against token endpoint, missing client_id '''
|
''' Unsuccessful request against token endpoint, missing client_id '''
|
||||||
self._setup(test_app)
|
self._setup(test_app)
|
||||||
@ -192,4 +202,30 @@ code={0}&client_secret={1}'.format(code, client.secret))
|
|||||||
assert 'error' in token_data
|
assert 'error' in token_data
|
||||||
assert not 'access_token' in token_data
|
assert not 'access_token' in token_data
|
||||||
assert token_data['error'] == 'invalid_request'
|
assert token_data['error'] == 'invalid_request'
|
||||||
assert token_data['error_description'] == 'Missing client_id in request'
|
assert len(token_data['error_description'])
|
||||||
|
|
||||||
|
def test_refresh_token(self, test_app):
|
||||||
|
''' Try to get a new access token using the refresh token '''
|
||||||
|
# Get an access token and a refresh token
|
||||||
|
client_id, token_data =\
|
||||||
|
self.test_token_endpoint_successful_confidential_request(test_app)
|
||||||
|
|
||||||
|
client = self.db.OAuthClient.query.filter(
|
||||||
|
self.db.OAuthClient.identifier == client_id).first()
|
||||||
|
|
||||||
|
token_res = test_app.get('/oauth/access_token',
|
||||||
|
{'refresh_token': token_data['refresh_token'],
|
||||||
|
'client_id': client_id,
|
||||||
|
'client_secret': client.secret
|
||||||
|
})
|
||||||
|
|
||||||
|
assert token_res.status_int == 200
|
||||||
|
|
||||||
|
new_token_data = json.loads(token_res.body)
|
||||||
|
|
||||||
|
assert not 'error' in new_token_data
|
||||||
|
assert 'access_token' in new_token_data
|
||||||
|
assert 'token_type' in new_token_data
|
||||||
|
assert 'expires_in' in new_token_data
|
||||||
|
assert type(new_token_data['expires_in']) == int
|
||||||
|
assert new_token_data['expires_in'] > 0
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import shutil
|
import shutil
|
||||||
@ -28,7 +29,6 @@ from mediagoblin import mg_globals
|
|||||||
from mediagoblin.db.models import User, MediaEntry, Collection
|
from mediagoblin.db.models import User, MediaEntry, Collection
|
||||||
from mediagoblin.tools import testing
|
from mediagoblin.tools import testing
|
||||||
from mediagoblin.init.config import read_mediagoblin_config
|
from mediagoblin.init.config import read_mediagoblin_config
|
||||||
from mediagoblin.db.open import setup_connection_and_db_from_config
|
|
||||||
from mediagoblin.db.base import Session
|
from mediagoblin.db.base import Session
|
||||||
from mediagoblin.meddleware import BaseMeddleware
|
from mediagoblin.meddleware import BaseMeddleware
|
||||||
from mediagoblin.auth.lib import bcrypt_gen_password_hash
|
from mediagoblin.auth.lib import bcrypt_gen_password_hash
|
||||||
@ -50,7 +50,9 @@ USER_DEV_DIRECTORIES_TO_SETUP = ['media/public', 'media/queue']
|
|||||||
BAD_CELERY_MESSAGE = """\
|
BAD_CELERY_MESSAGE = """\
|
||||||
Sorry, you *absolutely* must run tests with the
|
Sorry, you *absolutely* must run tests with the
|
||||||
mediagoblin.init.celery.from_tests module. Like so:
|
mediagoblin.init.celery.from_tests module. Like so:
|
||||||
$ CELERY_CONFIG_MODULE=mediagoblin.init.celery.from_tests ./bin/py.test"""
|
|
||||||
|
$ CELERY_CONFIG_MODULE=mediagoblin.init.celery.from_tests {0}
|
||||||
|
""".format(sys.argv[0])
|
||||||
|
|
||||||
|
|
||||||
class BadCeleryEnviron(Exception): pass
|
class BadCeleryEnviron(Exception): pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user