Use a single cursor for db operations when possible instead of reopening multiple times

This commit is contained in:
James Taylor 2019-06-10 15:27:17 -07:00
parent d637f5b29c
commit 9da073000a

View File

@ -63,27 +63,27 @@ def open_database():
# https://stackoverflow.com/questions/19522505/using-sqlite3-in-python-with-with-keyword # https://stackoverflow.com/questions/19522505/using-sqlite3-in-python-with-with-keyword
return contextlib.closing(connection) return contextlib.closing(connection)
def _subscribe(channels): def with_open_db(function, *args, **kwargs):
with open_database() as connection:
with connection as cursor:
return function(cursor, *args, **kwargs)
def _subscribe(cursor, channels):
''' channels is a list of (channel_id, channel_name) ''' ''' channels is a list of (channel_id, channel_name) '''
# set time_last_checked to 0 on all channels being subscribed to # set time_last_checked to 0 on all channels being subscribed to
channels = ( (channel_id, channel_name, 0) for channel_id, channel_name in channels) channels = ( (channel_id, channel_name, 0) for channel_id, channel_name in channels)
with open_database() as connection:
with connection as cursor:
cursor.executemany('''INSERT OR IGNORE INTO subscribed_channels (yt_channel_id, channel_name, time_last_checked) cursor.executemany('''INSERT OR IGNORE INTO subscribed_channels (yt_channel_id, channel_name, time_last_checked)
VALUES (?, ?, ?)''', channels) VALUES (?, ?, ?)''', channels)
# TODO: delete thumbnails # TODO: delete thumbnails
def _unsubscribe(channel_ids): def _unsubscribe(cursor, channel_ids):
''' channel_ids is a list of channel_ids ''' ''' channel_ids is a list of channel_ids '''
with open_database() as connection:
with connection as cursor:
cursor.executemany("DELETE FROM subscribed_channels WHERE yt_channel_id=?", ((channel_id, ) for channel_id in channel_ids)) cursor.executemany("DELETE FROM subscribed_channels WHERE yt_channel_id=?", ((channel_id, ) for channel_id in channel_ids))
def _get_videos(number, offset): def _get_videos(cursor, number, offset):
with open_database() as connection:
with connection as cursor:
db_videos = cursor.execute('''SELECT video_id, title, duration, channel_name db_videos = cursor.execute('''SELECT video_id, title, duration, channel_name
FROM videos FROM videos
INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id
@ -98,26 +98,20 @@ def _get_videos(number, offset):
'author': db_video[3], 'author': db_video[3],
} }
def _get_subscribed_channels(): def _get_subscribed_channels(cursor):
with open_database() as connection:
with connection as cursor:
for item in cursor.execute('''SELECT channel_name, yt_channel_id for item in cursor.execute('''SELECT channel_name, yt_channel_id
FROM subscribed_channels FROM subscribed_channels
ORDER BY channel_name'''): ORDER BY channel_name'''):
yield item yield item
def _add_tags(channel_ids, tags): def _add_tags(cursor, channel_ids, tags):
with open_database() as connection:
with connection as cursor:
pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids] pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids]
cursor.executemany('''INSERT OR IGNORE INTO tag_associations (tag, sql_channel_id) cursor.executemany('''INSERT OR IGNORE INTO tag_associations (tag, sql_channel_id)
SELECT ?, id FROM subscribed_channels WHERE yt_channel_id = ? ''', pairs) SELECT ?, id FROM subscribed_channels WHERE yt_channel_id = ? ''', pairs)
def _remove_tags(channel_ids, tags): def _remove_tags(cursor, channel_ids, tags):
with open_database() as connection:
with connection as cursor:
pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids] pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids]
cursor.executemany('''DELETE FROM tag_associations cursor.executemany('''DELETE FROM tag_associations
WHERE tag = ? AND sql_channel_id = ( WHERE tag = ? AND sql_channel_id = (
@ -133,15 +127,11 @@ def _get_tags(cursor, channel_id):
SELECT id FROM subscribed_channels WHERE yt_channel_id = ? SELECT id FROM subscribed_channels WHERE yt_channel_id = ?
)''', (channel_id,))] )''', (channel_id,))]
def _get_all_tags(): def _get_all_tags(cursor):
with open_database() as connection:
with connection as cursor:
return [row[0] for row in cursor.execute('''SELECT DISTINCT tag FROM tag_associations''')] return [row[0] for row in cursor.execute('''SELECT DISTINCT tag FROM tag_associations''')]
def _get_channel_names(channel_ids): def _get_channel_names(cursor, channel_ids):
''' returns list of (channel_id, channel_name) ''' ''' returns list of (channel_id, channel_name) '''
with open_database() as connection:
with connection as cursor:
result = [] result = []
for channel_id in channel_ids: for channel_id in channel_ids:
row = cursor.execute('''SELECT channel_name row = cursor.execute('''SELECT channel_name
@ -357,7 +347,7 @@ def import_subscriptions(env, start_response):
start_response('400 Bad Request', () ) start_response('400 Bad Request', () )
return b'400 Bad Request: Unsupported file format: ' + html.escape(content_type).encode('utf-8') + b'. Only subscription.json files (from Google Takeouts) and XML OPML files exported from Youtube\'s subscription manager page are supported' return b'400 Bad Request: Unsupported file format: ' + html.escape(content_type).encode('utf-8') + b'. Only subscription.json files (from Google Takeouts) and XML OPML files exported from Youtube\'s subscription manager page are supported'
_subscribe(channels) with_open_db(_subscribe, channels)
start_response('303 See Other', [('Location', util.URL_ORIGIN + '/subscription_manager'),] ) start_response('303 See Other', [('Location', util.URL_ORIGIN + '/subscription_manager'),] )
return b'' return b''
@ -388,7 +378,7 @@ def get_subscription_manager_page(env, start_response):
sort_link = util.URL_ORIGIN + '/subscription_manager' sort_link = util.URL_ORIGIN + '/subscription_manager'
main_list_html = '<ul class="tag-group-list">' main_list_html = '<ul class="tag-group-list">'
for tag in _get_all_tags(): for tag in _get_all_tags(cursor):
sub_list_html = '' sub_list_html = ''
for channel_id, channel_name in _channels_with_tag(cursor, tag, order=True): for channel_id, channel_name in _channels_with_tag(cursor, tag, order=True):
sub_list_html += sub_list_item_template.substitute( sub_list_html += sub_list_item_template.substitute(
@ -430,7 +420,7 @@ def get_subscription_manager_page(env, start_response):
sort_link = util.URL_ORIGIN + '/subscription_manager?group_by_tags=1' sort_link = util.URL_ORIGIN + '/subscription_manager?group_by_tags=1'
main_list_html = '<ol class="sub-list">' main_list_html = '<ol class="sub-list">'
for channel_name, channel_id in _get_subscribed_channels(): for channel_name, channel_id in _get_subscribed_channels(cursor):
main_list_html += sub_list_item_template.substitute( main_list_html += sub_list_item_template.substitute(
channel_url = util.URL_ORIGIN + '/channel/' + channel_id, channel_url = util.URL_ORIGIN + '/channel/' + channel_id,
channel_name = html.escape(channel_name), channel_name = html.escape(channel_name),
@ -461,11 +451,11 @@ def post_subscription_manager_page(env, start_response):
action = params['action'][0] action = params['action'][0]
if action == 'add_tags': if action == 'add_tags':
_add_tags(params['channel_ids'], [tag.lower() for tag in list_from_comma_separated_tags(params['tags'][0])]) with_open_db(_add_tags, params['channel_ids'], [tag.lower() for tag in list_from_comma_separated_tags(params['tags'][0])])
elif action == 'remove_tags': elif action == 'remove_tags':
_remove_tags(params['channel_ids'], [tag.lower() for tag in list_from_comma_separated_tags(params['tags'][0])]) with_open_db(_remove_tags, params['channel_ids'], [tag.lower() for tag in list_from_comma_separated_tags(params['tags'][0])])
elif action == 'unsubscribe': elif action == 'unsubscribe':
_unsubscribe(params['channel_ids']) with_open_db(_unsubscribe, params['channel_ids'])
elif action == 'unsubscribe_verify': elif action == 'unsubscribe_verify':
page = ''' page = '''
<span>Are you sure you want to unsubscribe from these channels?</span> <span>Are you sure you want to unsubscribe from these channels?</span>
@ -479,7 +469,7 @@ def post_subscription_manager_page(env, start_response):
<input type="submit" value="Yes, unsubscribe"> <input type="submit" value="Yes, unsubscribe">
</form> </form>
<ul>''' <ul>'''
for channel_id, channel_name in _get_channel_names(params['channel_ids']): for channel_id, channel_name in with_open_db(_get_channel_names, params['channel_ids']):
page += unsubscribe_list_item_template.substitute( page += unsubscribe_list_item_template.substitute(
channel_url = util.URL_ORIGIN + '/channel/' + channel_id, channel_url = util.URL_ORIGIN + '/channel/' + channel_id,
channel_name = html.escape(channel_name), channel_name = html.escape(channel_name),
@ -526,9 +516,11 @@ sidebar_channel_item_template = Template('''
</li>''') </li>''')
def get_subscriptions_page(env, start_response): def get_subscriptions_page(env, start_response):
with open_database() as connection:
with connection as cursor:
items_html = '''<nav class="item-grid">\n''' items_html = '''<nav class="item-grid">\n'''
for item in _get_videos(30, 0): for item in _get_videos(cursor, 30, 0):
if item['id'] in downloading_thumbnails: if item['id'] in downloading_thumbnails:
item['thumbnail'] = util.get_thumbnail_url(item['id']) item['thumbnail'] = util.get_thumbnail_url(item['id'])
else: else:
@ -538,12 +530,12 @@ def get_subscriptions_page(env, start_response):
tag_list_html = '' tag_list_html = ''
for tag_name in _get_all_tags(): for tag_name in _get_all_tags(cursor):
tag_list_html += sidebar_tag_item_template.substitute(tag_name = tag_name) tag_list_html += sidebar_tag_item_template.substitute(tag_name = tag_name)
sub_list_html = '' sub_list_html = ''
for channel_name, channel_id in _get_subscribed_channels(): for channel_name, channel_id in _get_subscribed_channels(cursor):
sub_list_html += sidebar_channel_item_template.substitute( sub_list_html += sidebar_channel_item_template.substitute(
channel_url = util.URL_ORIGIN + '/channel/' + channel_id, channel_url = util.URL_ORIGIN + '/channel/' + channel_id,
channel_name = html.escape(channel_name), channel_name = html.escape(channel_name),
@ -568,10 +560,10 @@ def post_subscriptions_page(env, start_response):
if len(params['channel_id']) != len(params['channel_name']): if len(params['channel_id']) != len(params['channel_name']):
start_response('400 Bad Request', ()) start_response('400 Bad Request', ())
return b'400 Bad Request, length of channel_id != length of channel_name' return b'400 Bad Request, length of channel_id != length of channel_name'
_subscribe(zip(params['channel_id'], params['channel_name'])) with_open_db(_subscribe, zip(params['channel_id'], params['channel_name']))
elif action == 'unsubscribe': elif action == 'unsubscribe':
_unsubscribe(params['channel_id']) with_open_db(_unsubscribe, params['channel_id'])
elif action == 'refresh': elif action == 'refresh':
type = params['type'][0] type = params['type'][0]