Merge subscriptions into master

This commit is contained in:
James Taylor 2019-09-06 15:45:01 -07:00
commit ac32b24b2a
65 changed files with 13566 additions and 48 deletions

12
python/atoma/__init__.py Normal file
View File

@ -0,0 +1,12 @@
from .atom import parse_atom_file, parse_atom_bytes
from .rss import parse_rss_file, parse_rss_bytes
from .json_feed import (
parse_json_feed, parse_json_feed_file, parse_json_feed_bytes
)
from .opml import parse_opml_file, parse_opml_bytes
from .exceptions import (
FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError
)
from .const import VERSION
__version__ = VERSION

284
python/atoma/atom.py Normal file
View File

@ -0,0 +1,284 @@
from datetime import datetime
import enum
from io import BytesIO
from typing import Optional, List
from xml.etree.ElementTree import Element
import attr
from .utils import (
parse_xml, get_child, get_text, get_datetime, FeedParseError, ns
)
class AtomTextType(enum.Enum):
text = "text"
html = "html"
xhtml = "xhtml"
@attr.s
class AtomTextConstruct:
text_type: str = attr.ib()
lang: Optional[str] = attr.ib()
value: str = attr.ib()
@attr.s
class AtomEntry:
title: AtomTextConstruct = attr.ib()
id_: str = attr.ib()
# Should be mandatory but many feeds use published instead
updated: Optional[datetime] = attr.ib()
authors: List['AtomPerson'] = attr.ib()
contributors: List['AtomPerson'] = attr.ib()
links: List['AtomLink'] = attr.ib()
categories: List['AtomCategory'] = attr.ib()
published: Optional[datetime] = attr.ib()
rights: Optional[AtomTextConstruct] = attr.ib()
summary: Optional[AtomTextConstruct] = attr.ib()
content: Optional[AtomTextConstruct] = attr.ib()
source: Optional['AtomFeed'] = attr.ib()
@attr.s
class AtomFeed:
title: Optional[AtomTextConstruct] = attr.ib()
id_: str = attr.ib()
# Should be mandatory but many feeds do not include it
updated: Optional[datetime] = attr.ib()
authors: List['AtomPerson'] = attr.ib()
contributors: List['AtomPerson'] = attr.ib()
links: List['AtomLink'] = attr.ib()
categories: List['AtomCategory'] = attr.ib()
generator: Optional['AtomGenerator'] = attr.ib()
subtitle: Optional[AtomTextConstruct] = attr.ib()
rights: Optional[AtomTextConstruct] = attr.ib()
icon: Optional[str] = attr.ib()
logo: Optional[str] = attr.ib()
entries: List[AtomEntry] = attr.ib()
@attr.s
class AtomPerson:
name: str = attr.ib()
uri: Optional[str] = attr.ib()
email: Optional[str] = attr.ib()
@attr.s
class AtomLink:
href: str = attr.ib()
rel: Optional[str] = attr.ib()
type_: Optional[str] = attr.ib()
hreflang: Optional[str] = attr.ib()
title: Optional[str] = attr.ib()
length: Optional[int] = attr.ib()
@attr.s
class AtomCategory:
term: str = attr.ib()
scheme: Optional[str] = attr.ib()
label: Optional[str] = attr.ib()
@attr.s
class AtomGenerator:
name: str = attr.ib()
uri: Optional[str] = attr.ib()
version: Optional[str] = attr.ib()
def _get_generator(element: Element, name,
optional: bool=True) -> Optional[AtomGenerator]:
child = get_child(element, name, optional)
if child is None:
return None
return AtomGenerator(
child.text.strip(),
child.attrib.get('uri'),
child.attrib.get('version'),
)
def _get_text_construct(element: Element, name,
optional: bool=True) -> Optional[AtomTextConstruct]:
child = get_child(element, name, optional)
if child is None:
return None
try:
text_type = AtomTextType(child.attrib['type'])
except KeyError:
text_type = AtomTextType.text
try:
lang = child.lang
except AttributeError:
lang = None
if child.text is None:
if optional:
return None
raise FeedParseError(
'Could not parse atom feed: "{}" text is required but is empty'
.format(name)
)
return AtomTextConstruct(
text_type,
lang,
child.text.strip()
)
def _get_person(element: Element) -> Optional[AtomPerson]:
try:
return AtomPerson(
get_text(element, 'feed:name', optional=False),
get_text(element, 'feed:uri'),
get_text(element, 'feed:email')
)
except FeedParseError:
return None
def _get_link(element: Element) -> AtomLink:
length = element.attrib.get('length')
length = int(length) if length else None
return AtomLink(
element.attrib['href'],
element.attrib.get('rel'),
element.attrib.get('type'),
element.attrib.get('hreflang'),
element.attrib.get('title'),
length
)
def _get_category(element: Element) -> AtomCategory:
return AtomCategory(
element.attrib['term'],
element.attrib.get('scheme'),
element.attrib.get('label'),
)
def _get_entry(element: Element,
default_authors: List[AtomPerson]) -> AtomEntry:
root = element
# Mandatory
title = _get_text_construct(root, 'feed:title')
id_ = get_text(root, 'feed:id')
# Optional
try:
source = _parse_atom(get_child(root, 'feed:source', optional=False),
parse_entries=False)
except FeedParseError:
source = None
source_authors = []
else:
source_authors = source.authors
authors = [_get_person(e)
for e in root.findall('feed:author', ns)] or default_authors
authors = [a for a in authors if a is not None]
authors = authors or default_authors or source_authors
contributors = [_get_person(e)
for e in root.findall('feed:contributor', ns) if e]
contributors = [c for c in contributors if c is not None]
links = [_get_link(e) for e in root.findall('feed:link', ns)]
categories = [_get_category(e) for e in root.findall('feed:category', ns)]
updated = get_datetime(root, 'feed:updated')
published = get_datetime(root, 'feed:published')
rights = _get_text_construct(root, 'feed:rights')
summary = _get_text_construct(root, 'feed:summary')
content = _get_text_construct(root, 'feed:content')
return AtomEntry(
title,
id_,
updated,
authors,
contributors,
links,
categories,
published,
rights,
summary,
content,
source
)
def _parse_atom(root: Element, parse_entries: bool=True) -> AtomFeed:
# Mandatory
id_ = get_text(root, 'feed:id', optional=False)
# Optional
title = _get_text_construct(root, 'feed:title')
updated = get_datetime(root, 'feed:updated')
authors = [_get_person(e)
for e in root.findall('feed:author', ns) if e]
authors = [a for a in authors if a is not None]
contributors = [_get_person(e)
for e in root.findall('feed:contributor', ns) if e]
contributors = [c for c in contributors if c is not None]
links = [_get_link(e)
for e in root.findall('feed:link', ns)]
categories = [_get_category(e)
for e in root.findall('feed:category', ns)]
generator = _get_generator(root, 'feed:generator')
subtitle = _get_text_construct(root, 'feed:subtitle')
rights = _get_text_construct(root, 'feed:rights')
icon = get_text(root, 'feed:icon')
logo = get_text(root, 'feed:logo')
if parse_entries:
entries = [_get_entry(e, authors)
for e in root.findall('feed:entry', ns)]
else:
entries = []
atom_feed = AtomFeed(
title,
id_,
updated,
authors,
contributors,
links,
categories,
generator,
subtitle,
rights,
icon,
logo,
entries
)
return atom_feed
def parse_atom_file(filename: str) -> AtomFeed:
"""Parse an Atom feed from a local XML file."""
root = parse_xml(filename).getroot()
return _parse_atom(root)
def parse_atom_bytes(data: bytes) -> AtomFeed:
"""Parse an Atom feed from a byte-string containing XML data."""
root = parse_xml(BytesIO(data)).getroot()
return _parse_atom(root)

1
python/atoma/const.py Normal file
View File

@ -0,0 +1 @@
VERSION = '0.0.13'

View File

@ -0,0 +1,14 @@
class FeedParseError(Exception):
"""Document is an invalid feed."""
class FeedDocumentError(Exception):
"""Document is not a supported file."""
class FeedXMLError(FeedDocumentError):
"""Document is not valid XML."""
class FeedJSONError(FeedDocumentError):
"""Document is not valid JSON."""

223
python/atoma/json_feed.py Normal file
View File

@ -0,0 +1,223 @@
from datetime import datetime, timedelta
import json
from typing import Optional, List
import attr
from .exceptions import FeedParseError, FeedJSONError
from .utils import try_parse_date
@attr.s
class JSONFeedAuthor:
name: Optional[str] = attr.ib()
url: Optional[str] = attr.ib()
avatar: Optional[str] = attr.ib()
@attr.s
class JSONFeedAttachment:
url: str = attr.ib()
mime_type: str = attr.ib()
title: Optional[str] = attr.ib()
size_in_bytes: Optional[int] = attr.ib()
duration: Optional[timedelta] = attr.ib()
@attr.s
class JSONFeedItem:
id_: str = attr.ib()
url: Optional[str] = attr.ib()
external_url: Optional[str] = attr.ib()
title: Optional[str] = attr.ib()
content_html: Optional[str] = attr.ib()
content_text: Optional[str] = attr.ib()
summary: Optional[str] = attr.ib()
image: Optional[str] = attr.ib()
banner_image: Optional[str] = attr.ib()
date_published: Optional[datetime] = attr.ib()
date_modified: Optional[datetime] = attr.ib()
author: Optional[JSONFeedAuthor] = attr.ib()
tags: List[str] = attr.ib()
attachments: List[JSONFeedAttachment] = attr.ib()
@attr.s
class JSONFeed:
version: str = attr.ib()
title: str = attr.ib()
home_page_url: Optional[str] = attr.ib()
feed_url: Optional[str] = attr.ib()
description: Optional[str] = attr.ib()
user_comment: Optional[str] = attr.ib()
next_url: Optional[str] = attr.ib()
icon: Optional[str] = attr.ib()
favicon: Optional[str] = attr.ib()
author: Optional[JSONFeedAuthor] = attr.ib()
expired: bool = attr.ib()
items: List[JSONFeedItem] = attr.ib()
def _get_items(root: dict) -> List[JSONFeedItem]:
rv = []
items = root.get('items', [])
if not items:
return rv
for item in items:
rv.append(_get_item(item))
return rv
def _get_item(item_dict: dict) -> JSONFeedItem:
return JSONFeedItem(
id_=_get_text(item_dict, 'id', optional=False),
url=_get_text(item_dict, 'url'),
external_url=_get_text(item_dict, 'external_url'),
title=_get_text(item_dict, 'title'),
content_html=_get_text(item_dict, 'content_html'),
content_text=_get_text(item_dict, 'content_text'),
summary=_get_text(item_dict, 'summary'),
image=_get_text(item_dict, 'image'),
banner_image=_get_text(item_dict, 'banner_image'),
date_published=_get_datetime(item_dict, 'date_published'),
date_modified=_get_datetime(item_dict, 'date_modified'),
author=_get_author(item_dict),
tags=_get_tags(item_dict, 'tags'),
attachments=_get_attachments(item_dict, 'attachments')
)
def _get_attachments(root, name) -> List[JSONFeedAttachment]:
rv = list()
for attachment_dict in root.get(name, []):
rv.append(JSONFeedAttachment(
_get_text(attachment_dict, 'url', optional=False),
_get_text(attachment_dict, 'mime_type', optional=False),
_get_text(attachment_dict, 'title'),
_get_int(attachment_dict, 'size_in_bytes'),
_get_duration(attachment_dict, 'duration_in_seconds')
))
return rv
def _get_tags(root, name) -> List[str]:
tags = root.get(name, [])
return [tag for tag in tags if isinstance(tag, str)]
def _get_datetime(root: dict, name, optional: bool=True) -> Optional[datetime]:
text = _get_text(root, name, optional)
if text is None:
return None
return try_parse_date(text)
def _get_expired(root: dict) -> bool:
if root.get('expired') is True:
return True
return False
def _get_author(root: dict) -> Optional[JSONFeedAuthor]:
author_dict = root.get('author')
if not author_dict:
return None
rv = JSONFeedAuthor(
name=_get_text(author_dict, 'name'),
url=_get_text(author_dict, 'url'),
avatar=_get_text(author_dict, 'avatar'),
)
if rv.name is None and rv.url is None and rv.avatar is None:
return None
return rv
def _get_int(root: dict, name: str, optional: bool=True) -> Optional[int]:
rv = root.get(name)
if not optional and rv is None:
raise FeedParseError('Could not parse feed: "{}" int is required but '
'is empty'.format(name))
if optional and rv is None:
return None
if not isinstance(rv, int):
raise FeedParseError('Could not parse feed: "{}" is not an int'
.format(name))
return rv
def _get_duration(root: dict, name: str,
optional: bool=True) -> Optional[timedelta]:
duration = _get_int(root, name, optional)
if duration is None:
return None
return timedelta(seconds=duration)
def _get_text(root: dict, name: str, optional: bool=True) -> Optional[str]:
rv = root.get(name)
if not optional and rv is None:
raise FeedParseError('Could not parse feed: "{}" text is required but '
'is empty'.format(name))
if optional and rv is None:
return None
if not isinstance(rv, str):
raise FeedParseError('Could not parse feed: "{}" is not a string'
.format(name))
return rv
def parse_json_feed(root: dict) -> JSONFeed:
return JSONFeed(
version=_get_text(root, 'version', optional=False),
title=_get_text(root, 'title', optional=False),
home_page_url=_get_text(root, 'home_page_url'),
feed_url=_get_text(root, 'feed_url'),
description=_get_text(root, 'description'),
user_comment=_get_text(root, 'user_comment'),
next_url=_get_text(root, 'next_url'),
icon=_get_text(root, 'icon'),
favicon=_get_text(root, 'favicon'),
author=_get_author(root),
expired=_get_expired(root),
items=_get_items(root)
)
def parse_json_feed_file(filename: str) -> JSONFeed:
"""Parse a JSON feed from a local json file."""
with open(filename) as f:
try:
root = json.load(f)
except json.decoder.JSONDecodeError:
raise FeedJSONError('Not a valid JSON document')
return parse_json_feed(root)
def parse_json_feed_bytes(data: bytes) -> JSONFeed:
"""Parse a JSON feed from a byte-string containing JSON data."""
try:
root = json.loads(data)
except json.decoder.JSONDecodeError:
raise FeedJSONError('Not a valid JSON document')
return parse_json_feed(root)

107
python/atoma/opml.py Normal file
View File

@ -0,0 +1,107 @@
from datetime import datetime
from io import BytesIO
from typing import Optional, List
from xml.etree.ElementTree import Element
import attr
from .utils import parse_xml, get_text, get_int, get_datetime
@attr.s
class OPMLOutline:
text: Optional[str] = attr.ib()
type: Optional[str] = attr.ib()
xml_url: Optional[str] = attr.ib()
description: Optional[str] = attr.ib()
html_url: Optional[str] = attr.ib()
language: Optional[str] = attr.ib()
title: Optional[str] = attr.ib()
version: Optional[str] = attr.ib()
outlines: List['OPMLOutline'] = attr.ib()
@attr.s
class OPML:
title: Optional[str] = attr.ib()
owner_name: Optional[str] = attr.ib()
owner_email: Optional[str] = attr.ib()
date_created: Optional[datetime] = attr.ib()
date_modified: Optional[datetime] = attr.ib()
expansion_state: Optional[str] = attr.ib()
vertical_scroll_state: Optional[int] = attr.ib()
window_top: Optional[int] = attr.ib()
window_left: Optional[int] = attr.ib()
window_bottom: Optional[int] = attr.ib()
window_right: Optional[int] = attr.ib()
outlines: List[OPMLOutline] = attr.ib()
def _get_outlines(element: Element) -> List[OPMLOutline]:
rv = list()
for outline in element.findall('outline'):
rv.append(OPMLOutline(
outline.attrib.get('text'),
outline.attrib.get('type'),
outline.attrib.get('xmlUrl'),
outline.attrib.get('description'),
outline.attrib.get('htmlUrl'),
outline.attrib.get('language'),
outline.attrib.get('title'),
outline.attrib.get('version'),
_get_outlines(outline)
))
return rv
def _parse_opml(root: Element) -> OPML:
head = root.find('head')
body = root.find('body')
return OPML(
get_text(head, 'title'),
get_text(head, 'ownerName'),
get_text(head, 'ownerEmail'),
get_datetime(head, 'dateCreated'),
get_datetime(head, 'dateModified'),
get_text(head, 'expansionState'),
get_int(head, 'vertScrollState'),
get_int(head, 'windowTop'),
get_int(head, 'windowLeft'),
get_int(head, 'windowBottom'),
get_int(head, 'windowRight'),
outlines=_get_outlines(body)
)
def parse_opml_file(filename: str) -> OPML:
"""Parse an OPML document from a local XML file."""
root = parse_xml(filename).getroot()
return _parse_opml(root)
def parse_opml_bytes(data: bytes) -> OPML:
"""Parse an OPML document from a byte-string containing XML data."""
root = parse_xml(BytesIO(data)).getroot()
return _parse_opml(root)
def get_feed_list(opml_obj: OPML) -> List[str]:
"""Walk an OPML document to extract the list of feed it contains."""
rv = list()
def collect(obj):
for outline in obj.outlines:
if outline.type == 'rss' and outline.xml_url:
rv.append(outline.xml_url)
if outline.outlines:
collect(outline)
collect(opml_obj)
return rv

221
python/atoma/rss.py Normal file
View File

@ -0,0 +1,221 @@
from datetime import datetime
from io import BytesIO
from typing import Optional, List
from xml.etree.ElementTree import Element
import attr
from .utils import (
parse_xml, get_child, get_text, get_int, get_datetime, FeedParseError
)
@attr.s
class RSSImage:
url: str = attr.ib()
title: Optional[str] = attr.ib()
link: str = attr.ib()
width: int = attr.ib()
height: int = attr.ib()
description: Optional[str] = attr.ib()
@attr.s
class RSSEnclosure:
url: str = attr.ib()
length: Optional[int] = attr.ib()
type: Optional[str] = attr.ib()
@attr.s
class RSSSource:
title: str = attr.ib()
url: Optional[str] = attr.ib()
@attr.s
class RSSItem:
title: Optional[str] = attr.ib()
link: Optional[str] = attr.ib()
description: Optional[str] = attr.ib()
author: Optional[str] = attr.ib()
categories: List[str] = attr.ib()
comments: Optional[str] = attr.ib()
enclosures: List[RSSEnclosure] = attr.ib()
guid: Optional[str] = attr.ib()
pub_date: Optional[datetime] = attr.ib()
source: Optional[RSSSource] = attr.ib()
# Extension
content_encoded: Optional[str] = attr.ib()
@attr.s
class RSSChannel:
title: Optional[str] = attr.ib()
link: Optional[str] = attr.ib()
description: Optional[str] = attr.ib()
language: Optional[str] = attr.ib()
copyright: Optional[str] = attr.ib()
managing_editor: Optional[str] = attr.ib()
web_master: Optional[str] = attr.ib()
pub_date: Optional[datetime] = attr.ib()
last_build_date: Optional[datetime] = attr.ib()
categories: List[str] = attr.ib()
generator: Optional[str] = attr.ib()
docs: Optional[str] = attr.ib()
ttl: Optional[int] = attr.ib()
image: Optional[RSSImage] = attr.ib()
items: List[RSSItem] = attr.ib()
# Extension
content_encoded: Optional[str] = attr.ib()
def _get_image(element: Element, name,
optional: bool=True) -> Optional[RSSImage]:
child = get_child(element, name, optional)
if child is None:
return None
return RSSImage(
get_text(child, 'url', optional=False),
get_text(child, 'title'),
get_text(child, 'link', optional=False),
get_int(child, 'width') or 88,
get_int(child, 'height') or 31,
get_text(child, 'description')
)
def _get_source(element: Element, name,
optional: bool=True) -> Optional[RSSSource]:
child = get_child(element, name, optional)
if child is None:
return None
return RSSSource(
child.text.strip(),
child.attrib.get('url'),
)
def _get_enclosure(element: Element) -> RSSEnclosure:
length = element.attrib.get('length')
try:
length = int(length)
except (TypeError, ValueError):
length = None
return RSSEnclosure(
element.attrib['url'],
length,
element.attrib.get('type'),
)
def _get_link(element: Element) -> Optional[str]:
"""Attempt to retrieve item link.
Use the GUID as a fallback if it is a permalink.
"""
link = get_text(element, 'link')
if link is not None:
return link
guid = get_child(element, 'guid')
if guid is not None and guid.attrib.get('isPermaLink') == 'true':
return get_text(element, 'guid')
return None
def _get_item(element: Element) -> RSSItem:
root = element
title = get_text(root, 'title')
link = _get_link(root)
description = get_text(root, 'description')
author = get_text(root, 'author')
categories = [e.text for e in root.findall('category')]
comments = get_text(root, 'comments')
enclosure = [_get_enclosure(e) for e in root.findall('enclosure')]
guid = get_text(root, 'guid')
pub_date = get_datetime(root, 'pubDate')
source = _get_source(root, 'source')
content_encoded = get_text(root, 'content:encoded')
return RSSItem(
title,
link,
description,
author,
categories,
comments,
enclosure,
guid,
pub_date,
source,
content_encoded
)
def _parse_rss(root: Element) -> RSSChannel:
rss_version = root.get('version')
if rss_version != '2.0':
raise FeedParseError('Cannot process RSS feed version "{}"'
.format(rss_version))
root = root.find('channel')
title = get_text(root, 'title')
link = get_text(root, 'link')
description = get_text(root, 'description')
language = get_text(root, 'language')
copyright = get_text(root, 'copyright')
managing_editor = get_text(root, 'managingEditor')
web_master = get_text(root, 'webMaster')
pub_date = get_datetime(root, 'pubDate')
last_build_date = get_datetime(root, 'lastBuildDate')
categories = [e.text for e in root.findall('category')]
generator = get_text(root, 'generator')
docs = get_text(root, 'docs')
ttl = get_int(root, 'ttl')
image = _get_image(root, 'image')
items = [_get_item(e) for e in root.findall('item')]
content_encoded = get_text(root, 'content:encoded')
return RSSChannel(
title,
link,
description,
language,
copyright,
managing_editor,
web_master,
pub_date,
last_build_date,
categories,
generator,
docs,
ttl,
image,
items,
content_encoded
)
def parse_rss_file(filename: str) -> RSSChannel:
"""Parse an RSS feed from a local XML file."""
root = parse_xml(filename).getroot()
return _parse_rss(root)
def parse_rss_bytes(data: bytes) -> RSSChannel:
"""Parse an RSS feed from a byte-string containing XML data."""
root = parse_xml(BytesIO(data)).getroot()
return _parse_rss(root)

224
python/atoma/simple.py Normal file
View File

@ -0,0 +1,224 @@
"""Simple API that abstracts away the differences between feed types."""
from datetime import datetime, timedelta
import html
import os
from typing import Optional, List, Tuple
import urllib.parse
import attr
from . import atom, rss, json_feed
from .exceptions import (
FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError
)
@attr.s
class Attachment:
link: str = attr.ib()
mime_type: Optional[str] = attr.ib()
title: Optional[str] = attr.ib()
size_in_bytes: Optional[int] = attr.ib()
duration: Optional[timedelta] = attr.ib()
@attr.s
class Article:
id: str = attr.ib()
title: Optional[str] = attr.ib()
link: Optional[str] = attr.ib()
content: str = attr.ib()
published_at: Optional[datetime] = attr.ib()
updated_at: Optional[datetime] = attr.ib()
attachments: List[Attachment] = attr.ib()
@attr.s
class Feed:
title: str = attr.ib()
subtitle: Optional[str] = attr.ib()
link: Optional[str] = attr.ib()
updated_at: Optional[datetime] = attr.ib()
articles: List[Article] = attr.ib()
def _adapt_atom_feed(atom_feed: atom.AtomFeed) -> Feed:
articles = list()
for entry in atom_feed.entries:
if entry.content is not None:
content = entry.content.value
elif entry.summary is not None:
content = entry.summary.value
else:
content = ''
published_at, updated_at = _get_article_dates(entry.published,
entry.updated)
# Find article link and attachments
article_link = None
attachments = list()
for candidate_link in entry.links:
if candidate_link.rel in ('alternate', None):
article_link = candidate_link.href
elif candidate_link.rel == 'enclosure':
attachments.append(Attachment(
title=_get_attachment_title(candidate_link.title,
candidate_link.href),
link=candidate_link.href,
mime_type=candidate_link.type_,
size_in_bytes=candidate_link.length,
duration=None
))
if entry.title is None:
entry_title = None
elif entry.title.text_type in (atom.AtomTextType.html,
atom.AtomTextType.xhtml):
entry_title = html.unescape(entry.title.value).strip()
else:
entry_title = entry.title.value
articles.append(Article(
entry.id_,
entry_title,
article_link,
content,
published_at,
updated_at,
attachments
))
# Find feed link
link = None
for candidate_link in atom_feed.links:
if candidate_link.rel == 'self':
link = candidate_link.href
break
return Feed(
atom_feed.title.value if atom_feed.title else atom_feed.id_,
atom_feed.subtitle.value if atom_feed.subtitle else None,
link,
atom_feed.updated,
articles
)
def _adapt_rss_channel(rss_channel: rss.RSSChannel) -> Feed:
articles = list()
for item in rss_channel.items:
attachments = [
Attachment(link=e.url, mime_type=e.type, size_in_bytes=e.length,
title=_get_attachment_title(None, e.url), duration=None)
for e in item.enclosures
]
articles.append(Article(
item.guid or item.link,
item.title,
item.link,
item.content_encoded or item.description or '',
item.pub_date,
None,
attachments
))
if rss_channel.title is None and rss_channel.link is None:
raise FeedParseError('RSS feed does not have a title nor a link')
return Feed(
rss_channel.title if rss_channel.title else rss_channel.link,
rss_channel.description,
rss_channel.link,
rss_channel.pub_date,
articles
)
def _adapt_json_feed(json_feed: json_feed.JSONFeed) -> Feed:
articles = list()
for item in json_feed.items:
attachments = [
Attachment(a.url, a.mime_type,
_get_attachment_title(a.title, a.url),
a.size_in_bytes, a.duration)
for a in item.attachments
]
articles.append(Article(
item.id_,
item.title,
item.url,
item.content_html or item.content_text or '',
item.date_published,
item.date_modified,
attachments
))
return Feed(
json_feed.title,
json_feed.description,
json_feed.feed_url,
None,
articles
)
def _get_article_dates(published_at: Optional[datetime],
updated_at: Optional[datetime]
) -> Tuple[Optional[datetime], Optional[datetime]]:
if published_at and updated_at:
return published_at, updated_at
if updated_at:
return updated_at, None
if published_at:
return published_at, None
raise FeedParseError('Article does not have proper dates')
def _get_attachment_title(attachment_title: Optional[str], link: str) -> str:
if attachment_title:
return attachment_title
parsed_link = urllib.parse.urlparse(link)
return os.path.basename(parsed_link.path)
def _simple_parse(pairs, content) -> Feed:
is_xml = True
is_json = True
for parser, adapter in pairs:
try:
return adapter(parser(content))
except FeedXMLError:
is_xml = False
except FeedJSONError:
is_json = False
except FeedParseError:
continue
if not is_xml and not is_json:
raise FeedDocumentError('File is not a supported feed type')
raise FeedParseError('File is not a valid supported feed')
def simple_parse_file(filename: str) -> Feed:
"""Parse an Atom, RSS or JSON feed from a local file."""
pairs = (
(rss.parse_rss_file, _adapt_rss_channel),
(atom.parse_atom_file, _adapt_atom_feed),
(json_feed.parse_json_feed_file, _adapt_json_feed)
)
return _simple_parse(pairs, filename)
def simple_parse_bytes(data: bytes) -> Feed:
"""Parse an Atom, RSS or JSON feed from a byte-string containing data."""
pairs = (
(rss.parse_rss_bytes, _adapt_rss_channel),
(atom.parse_atom_bytes, _adapt_atom_feed),
(json_feed.parse_json_feed_bytes, _adapt_json_feed)
)
return _simple_parse(pairs, data)

84
python/atoma/utils.py Normal file
View File

@ -0,0 +1,84 @@
from datetime import datetime, timezone
from xml.etree.ElementTree import Element
from typing import Optional
import dateutil.parser
from defusedxml.ElementTree import parse as defused_xml_parse, ParseError
from .exceptions import FeedXMLError, FeedParseError
ns = {
'content': 'http://purl.org/rss/1.0/modules/content/',
'feed': 'http://www.w3.org/2005/Atom'
}
def parse_xml(xml_content):
try:
return defused_xml_parse(xml_content)
except ParseError:
raise FeedXMLError('Not a valid XML document')
def get_child(element: Element, name,
optional: bool=True) -> Optional[Element]:
child = element.find(name, namespaces=ns)
if child is None and not optional:
raise FeedParseError(
'Could not parse feed: "{}" does not have a "{}"'
.format(element.tag, name)
)
elif child is None:
return None
return child
def get_text(element: Element, name, optional: bool=True) -> Optional[str]:
child = get_child(element, name, optional)
if child is None:
return None
if child.text is None:
if optional:
return None
raise FeedParseError(
'Could not parse feed: "{}" text is required but is empty'
.format(name)
)
return child.text.strip()
def get_int(element: Element, name, optional: bool=True) -> Optional[int]:
text = get_text(element, name, optional)
if text is None:
return None
return int(text)
def get_datetime(element: Element, name,
optional: bool=True) -> Optional[datetime]:
text = get_text(element, name, optional)
if text is None:
return None
return try_parse_date(text)
def try_parse_date(date_str: str) -> Optional[datetime]:
try:
date = dateutil.parser.parse(date_str, fuzzy=True)
except (ValueError, OverflowError):
return None
if date.tzinfo is None:
# TZ naive datetime, make it a TZ aware datetime by assuming it
# contains UTC time
date = date.replace(tzinfo=timezone.utc)
return date

65
python/attr/__init__.py Normal file
View File

@ -0,0 +1,65 @@
from __future__ import absolute_import, division, print_function
from functools import partial
from . import converters, exceptions, filters, validators
from ._config import get_run_validators, set_run_validators
from ._funcs import asdict, assoc, astuple, evolve, has
from ._make import (
NOTHING,
Attribute,
Factory,
attrib,
attrs,
fields,
fields_dict,
make_class,
validate,
)
__version__ = "18.2.0"
__title__ = "attrs"
__description__ = "Classes Without Boilerplate"
__url__ = "https://www.attrs.org/"
__uri__ = __url__
__doc__ = __description__ + " <" + __uri__ + ">"
__author__ = "Hynek Schlawack"
__email__ = "hs@ox.cx"
__license__ = "MIT"
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
s = attributes = attrs
ib = attr = attrib
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
__all__ = [
"Attribute",
"Factory",
"NOTHING",
"asdict",
"assoc",
"astuple",
"attr",
"attrib",
"attributes",
"attrs",
"converters",
"evolve",
"exceptions",
"fields",
"fields_dict",
"filters",
"get_run_validators",
"has",
"ib",
"make_class",
"s",
"set_run_validators",
"validate",
"validators",
]

252
python/attr/__init__.pyi Normal file
View File

@ -0,0 +1,252 @@
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Sequence,
Mapping,
Tuple,
Type,
TypeVar,
Union,
overload,
)
# `import X as X` is required to make these public
from . import exceptions as exceptions
from . import filters as filters
from . import converters as converters
from . import validators as validators
_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)
_ValidatorType = Callable[[Any, Attribute, _T], Any]
_ConverterType = Callable[[Any], _T]
_FilterType = Callable[[Attribute, Any], bool]
# FIXME: in reality, if multiple validators are passed they must be in a list or tuple,
# but those are invariant and so would prevent subtypes of _ValidatorType from working
# when passed in a list or tuple.
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
# _make --
NOTHING: object
# NOTE: Factory lies about its return type to make this possible: `x: List[int] = Factory(list)`
# Work around mypy issue #4554 in the common case by using an overload.
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Union[Callable[[Any], _T], Callable[[], _T]],
takes_self: bool = ...,
) -> _T: ...
class Attribute(Generic[_T]):
name: str
default: Optional[_T]
validator: Optional[_ValidatorType[_T]]
repr: bool
cmp: bool
hash: Optional[bool]
init: bool
converter: Optional[_ConverterType[_T]]
metadata: Dict[Any, Any]
type: Optional[Type[_T]]
kw_only: bool
def __lt__(self, x: Attribute) -> bool: ...
def __le__(self, x: Attribute) -> bool: ...
def __gt__(self, x: Attribute) -> bool: ...
def __ge__(self, x: Attribute) -> bool: ...
# NOTE: We had several choices for the annotation to use for type arg:
# 1) Type[_T]
# - Pros: Handles simple cases correctly
# - Cons: Might produce less informative errors in the case of conflicting TypeVars
# e.g. `attr.ib(default='bad', type=int)`
# 2) Callable[..., _T]
# - Pros: Better error messages than #1 for conflicting TypeVars
# - Cons: Terrible error messages for validator checks.
# e.g. attr.ib(type=int, validator=validate_str)
# -> error: Cannot infer function type argument
# 3) type (and do all of the work in the mypy plugin)
# - Pros: Simple here, and we could customize the plugin with our own errors.
# - Cons: Would need to write mypy plugin code to handle all the cases.
# We chose option #1.
# `attr` lies about its return type to make the following possible:
# attr() -> Any
# attr(8) -> int
# attr(validator=<some callable>) -> Whatever the callable expects.
# This makes this type of assignments possible:
# x: int = attr(8)
#
# This form catches explicit None or no default but with no other arguments returns Any.
@overload
def attrib(
default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the other arguments.
@overload
def attrib(
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: Optional[_ConverterType[_T]] = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def attrib(
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: Optional[_ConverterType[_T]] = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def attrib(
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: Optional[_ConverterType[_T]] = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: object = ...,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
) -> Any: ...
@overload
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
) -> _C: ...
@overload
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
) -> Callable[[_C], _C]: ...
# TODO: add support for returning NamedTuple from the mypy plugin
class _Fields(Tuple[Attribute, ...]):
def __getattr__(self, name: str) -> Attribute: ...
def fields(cls: type) -> _Fields: ...
def fields_dict(cls: type) -> Dict[str, Attribute]: ...
def validate(inst: Any) -> None: ...
# TODO: add support for returning a proper attrs class from the mypy plugin
# we use Any instead of _CountingAttr so that e.g. `make_class('Foo', [attr.ib()])` is valid
def make_class(
name: str,
attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]],
bases: Tuple[type, ...] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
) -> type: ...
# _funcs --
# TODO: add support for returning TypedDict from the mypy plugin
# FIXME: asdict/astuple do not honor their factory args. waiting on one of these:
# https://github.com/python/mypy/issues/4236
# https://github.com/python/typing/issues/253
def asdict(
inst: Any,
recurse: bool = ...,
filter: Optional[_FilterType] = ...,
dict_factory: Type[Mapping[Any, Any]] = ...,
retain_collection_types: bool = ...,
) -> Dict[str, Any]: ...
# TODO: add support for returning NamedTuple from the mypy plugin
def astuple(
inst: Any,
recurse: bool = ...,
filter: Optional[_FilterType] = ...,
tuple_factory: Type[Sequence] = ...,
retain_collection_types: bool = ...,
) -> Tuple[Any, ...]: ...
def has(cls: type) -> bool: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
def evolve(inst: _T, **changes: Any) -> _T: ...
# _config --
def set_run_validators(run: bool) -> None: ...
def get_run_validators() -> bool: ...
# aliases --
s = attributes = attrs
ib = attr = attrib
dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)

163
python/attr/_compat.py Normal file
View File

@ -0,0 +1,163 @@
from __future__ import absolute_import, division, print_function
import platform
import sys
import types
import warnings
PY2 = sys.version_info[0] == 2
PYPY = platform.python_implementation() == "PyPy"
if PYPY or sys.version_info[:2] >= (3, 6):
ordered_dict = dict
else:
from collections import OrderedDict
ordered_dict = OrderedDict
if PY2:
from UserDict import IterableUserDict
# We 'bundle' isclass instead of using inspect as importing inspect is
# fairly expensive (order of 10-15 ms for a modern machine in 2016)
def isclass(klass):
return isinstance(klass, (type, types.ClassType))
# TYPE is used in exceptions, repr(int) is different on Python 2 and 3.
TYPE = "type"
def iteritems(d):
return d.iteritems()
# Python 2 is bereft of a read-only dict proxy, so we make one!
class ReadOnlyDict(IterableUserDict):
"""
Best-effort read-only dict wrapper.
"""
def __setitem__(self, key, val):
# We gently pretend we're a Python 3 mappingproxy.
raise TypeError(
"'mappingproxy' object does not support item assignment"
)
def update(self, _):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'update'"
)
def __delitem__(self, _):
# We gently pretend we're a Python 3 mappingproxy.
raise TypeError(
"'mappingproxy' object does not support item deletion"
)
def clear(self):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'clear'"
)
def pop(self, key, default=None):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'pop'"
)
def popitem(self):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'popitem'"
)
def setdefault(self, key, default=None):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'setdefault'"
)
def __repr__(self):
# Override to be identical to the Python 3 version.
return "mappingproxy(" + repr(self.data) + ")"
def metadata_proxy(d):
res = ReadOnlyDict()
res.data.update(d) # We blocked update, so we have to do it like this.
return res
else:
def isclass(klass):
return isinstance(klass, type)
TYPE = "class"
def iteritems(d):
return d.items()
def metadata_proxy(d):
return types.MappingProxyType(dict(d))
def import_ctypes():
"""
Moved into a function for testability.
"""
import ctypes
return ctypes
if not PY2:
def just_warn(*args, **kw):
"""
We only warn on Python 3 because we are not aware of any concrete
consequences of not setting the cell on Python 2.
"""
warnings.warn(
"Missing ctypes. Some features like bare super() or accessing "
"__class__ will not work with slots classes.",
RuntimeWarning,
stacklevel=2,
)
else:
def just_warn(*args, **kw): # pragma: nocover
"""
We only warn on Python 3 because we are not aware of any concrete
consequences of not setting the cell on Python 2.
"""
def make_set_closure_cell():
"""
Moved into a function for testability.
"""
if PYPY: # pragma: no cover
def set_closure_cell(cell, value):
cell.__setstate__((value,))
else:
try:
ctypes = import_ctypes()
set_closure_cell = ctypes.pythonapi.PyCell_Set
set_closure_cell.argtypes = (ctypes.py_object, ctypes.py_object)
set_closure_cell.restype = ctypes.c_int
except Exception:
# We try best effort to set the cell, but sometimes it's not
# possible. For example on Jython or on GAE.
set_closure_cell = just_warn
return set_closure_cell
set_closure_cell = make_set_closure_cell()

23
python/attr/_config.py Normal file
View File

@ -0,0 +1,23 @@
from __future__ import absolute_import, division, print_function
__all__ = ["set_run_validators", "get_run_validators"]
_run_validators = True
def set_run_validators(run):
"""
Set whether or not validators are run. By default, they are run.
"""
if not isinstance(run, bool):
raise TypeError("'run' must be bool.")
global _run_validators
_run_validators = run
def get_run_validators():
"""
Return whether or not validators are run.
"""
return _run_validators

290
python/attr/_funcs.py Normal file
View File

@ -0,0 +1,290 @@
from __future__ import absolute_import, division, print_function
import copy
from ._compat import iteritems
from ._make import NOTHING, _obj_setattr, fields
from .exceptions import AttrsAttributeNotFoundError
def asdict(
inst,
recurse=True,
filter=None,
dict_factory=dict,
retain_collection_types=False,
):
"""
Return the ``attrs`` attribute values of *inst* as a dict.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the :class:`attr.Attribute` as the first argument and the
value as the second argument.
:param callable dict_factory: A callable to produce dictionaries from. For
example, to produce ordered dictionaries instead of normal Python
dictionaries, pass in ``collections.OrderedDict``.
:param bool retain_collection_types: Do not convert to ``list`` when
encountering an attribute whose type is ``tuple`` or ``set``. Only
meaningful if ``recurse`` is ``True``.
:rtype: return type of *dict_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.0.0 *dict_factory*
.. versionadded:: 16.1.0 *retain_collection_types*
"""
attrs = fields(inst.__class__)
rv = dict_factory()
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if recurse is True:
if has(v.__class__):
rv[a.name] = asdict(
v, True, filter, dict_factory, retain_collection_types
)
elif isinstance(v, (tuple, list, set)):
cf = v.__class__ if retain_collection_types is True else list
rv[a.name] = cf(
[
_asdict_anything(
i, filter, dict_factory, retain_collection_types
)
for i in v
]
)
elif isinstance(v, dict):
df = dict_factory
rv[a.name] = df(
(
_asdict_anything(
kk, filter, df, retain_collection_types
),
_asdict_anything(
vv, filter, df, retain_collection_types
),
)
for kk, vv in iteritems(v)
)
else:
rv[a.name] = v
else:
rv[a.name] = v
return rv
def _asdict_anything(val, filter, dict_factory, retain_collection_types):
"""
``asdict`` only works on attrs instances, this works on anything.
"""
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
# Attrs class.
rv = asdict(val, True, filter, dict_factory, retain_collection_types)
elif isinstance(val, (tuple, list, set)):
cf = val.__class__ if retain_collection_types is True else list
rv = cf(
[
_asdict_anything(
i, filter, dict_factory, retain_collection_types
)
for i in val
]
)
elif isinstance(val, dict):
df = dict_factory
rv = df(
(
_asdict_anything(kk, filter, df, retain_collection_types),
_asdict_anything(vv, filter, df, retain_collection_types),
)
for kk, vv in iteritems(val)
)
else:
rv = val
return rv
def astuple(
inst,
recurse=True,
filter=None,
tuple_factory=tuple,
retain_collection_types=False,
):
"""
Return the ``attrs`` attribute values of *inst* as a tuple.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the :class:`attr.Attribute` as the first argument and the
value as the second argument.
:param callable tuple_factory: A callable to produce tuples from. For
example, to produce lists instead of tuples.
:param bool retain_collection_types: Do not convert to ``list``
or ``dict`` when encountering an attribute which type is
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
``True``.
:rtype: return type of *tuple_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.2.0
"""
attrs = fields(inst.__class__)
rv = []
retain = retain_collection_types # Very long. :/
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if recurse is True:
if has(v.__class__):
rv.append(
astuple(
v,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
)
elif isinstance(v, (tuple, list, set)):
cf = v.__class__ if retain is True else list
rv.append(
cf(
[
astuple(
j,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(j.__class__)
else j
for j in v
]
)
)
elif isinstance(v, dict):
df = v.__class__ if retain is True else dict
rv.append(
df(
(
astuple(
kk,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(kk.__class__)
else kk,
astuple(
vv,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(vv.__class__)
else vv,
)
for kk, vv in iteritems(v)
)
)
else:
rv.append(v)
else:
rv.append(v)
return rv if tuple_factory is list else tuple_factory(rv)
def has(cls):
"""
Check whether *cls* is a class with ``attrs`` attributes.
:param type cls: Class to introspect.
:raise TypeError: If *cls* is not a class.
:rtype: :class:`bool`
"""
return getattr(cls, "__attrs_attrs__", None) is not None
def assoc(inst, **changes):
"""
Copy *inst* and apply *changes*.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
be found on *cls*.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. deprecated:: 17.1.0
Use :func:`evolve` instead.
"""
import warnings
warnings.warn(
"assoc is deprecated and will be removed after 2018/01.",
DeprecationWarning,
stacklevel=2,
)
new = copy.copy(inst)
attrs = fields(inst.__class__)
for k, v in iteritems(changes):
a = getattr(attrs, k, NOTHING)
if a is NOTHING:
raise AttrsAttributeNotFoundError(
"{k} is not an attrs attribute on {cl}.".format(
k=k, cl=new.__class__
)
)
_obj_setattr(new, k, v)
return new
def evolve(inst, **changes):
"""
Create a new instance, based on *inst* with *changes* applied.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise TypeError: If *attr_name* couldn't be found in the class
``__init__``.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 17.1.0
"""
cls = inst.__class__
attrs = fields(cls)
for a in attrs:
if not a.init:
continue
attr_name = a.name # To deal with private attributes.
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
if init_name not in changes:
changes[init_name] = getattr(inst, attr_name)
return cls(**changes)

2034
python/attr/_make.py Normal file

File diff suppressed because it is too large Load Diff

78
python/attr/converters.py Normal file
View File

@ -0,0 +1,78 @@
"""
Commonly useful converters.
"""
from __future__ import absolute_import, division, print_function
from ._make import NOTHING, Factory
def optional(converter):
"""
A converter that allows an attribute to be optional. An optional attribute
is one which can be set to ``None``.
:param callable converter: the converter that is used for non-``None``
values.
.. versionadded:: 17.1.0
"""
def optional_converter(val):
if val is None:
return None
return converter(val)
return optional_converter
def default_if_none(default=NOTHING, factory=None):
"""
A converter that allows to replace ``None`` values by *default* or the
result of *factory*.
:param default: Value to be used if ``None`` is passed. Passing an instance
of :class:`attr.Factory` is supported, however the ``takes_self`` option
is *not*.
:param callable factory: A callable that takes not parameters whose result
is used if ``None`` is passed.
:raises TypeError: If **neither** *default* or *factory* is passed.
:raises TypeError: If **both** *default* and *factory* are passed.
:raises ValueError: If an instance of :class:`attr.Factory` is passed with
``takes_self=True``.
.. versionadded:: 18.2.0
"""
if default is NOTHING and factory is None:
raise TypeError("Must pass either `default` or `factory`.")
if default is not NOTHING and factory is not None:
raise TypeError(
"Must pass either `default` or `factory` but not both."
)
if factory is not None:
default = Factory(factory)
if isinstance(default, Factory):
if default.takes_self:
raise ValueError(
"`takes_self` is not supported by default_if_none."
)
def default_if_none_converter(val):
if val is not None:
return val
return default.factory()
else:
def default_if_none_converter(val):
if val is not None:
return val
return default
return default_if_none_converter

View File

@ -0,0 +1,12 @@
from typing import TypeVar, Optional, Callable, overload
from . import _ConverterType
_T = TypeVar("_T")
def optional(
converter: _ConverterType[_T]
) -> _ConverterType[Optional[_T]]: ...
@overload
def default_if_none(default: _T) -> _ConverterType[_T]: ...
@overload
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType[_T]: ...

57
python/attr/exceptions.py Normal file
View File

@ -0,0 +1,57 @@
from __future__ import absolute_import, division, print_function
class FrozenInstanceError(AttributeError):
"""
A frozen/immutable instance has been attempted to be modified.
It mirrors the behavior of ``namedtuples`` by using the same error message
and subclassing :exc:`AttributeError`.
.. versionadded:: 16.1.0
"""
msg = "can't set attribute"
args = [msg]
class AttrsAttributeNotFoundError(ValueError):
"""
An ``attrs`` function couldn't find an attribute that the user asked for.
.. versionadded:: 16.2.0
"""
class NotAnAttrsClassError(ValueError):
"""
A non-``attrs`` class has been passed into an ``attrs`` function.
.. versionadded:: 16.2.0
"""
class DefaultAlreadySetError(RuntimeError):
"""
A default has been set using ``attr.ib()`` and is attempted to be reset
using the decorator.
.. versionadded:: 17.1.0
"""
class UnannotatedAttributeError(RuntimeError):
"""
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
annotation.
.. versionadded:: 17.3.0
"""
class PythonTooOldError(RuntimeError):
"""
An ``attrs`` feature requiring a more recent python version has been used.
.. versionadded:: 18.2.0
"""

View File

@ -0,0 +1,7 @@
class FrozenInstanceError(AttributeError):
msg: str = ...
class AttrsAttributeNotFoundError(ValueError): ...
class NotAnAttrsClassError(ValueError): ...
class DefaultAlreadySetError(RuntimeError): ...
class UnannotatedAttributeError(RuntimeError): ...

52
python/attr/filters.py Normal file
View File

@ -0,0 +1,52 @@
"""
Commonly useful filters for :func:`attr.asdict`.
"""
from __future__ import absolute_import, division, print_function
from ._compat import isclass
from ._make import Attribute
def _split_what(what):
"""
Returns a tuple of `frozenset`s of classes and attributes.
"""
return (
frozenset(cls for cls in what if isclass(cls)),
frozenset(cls for cls in what if isinstance(cls, Attribute)),
)
def include(*what):
"""
Whitelist *what*.
:param what: What to whitelist.
:type what: :class:`list` of :class:`type` or :class:`attr.Attribute`\\ s
:rtype: :class:`callable`
"""
cls, attrs = _split_what(what)
def include_(attribute, value):
return value.__class__ in cls or attribute in attrs
return include_
def exclude(*what):
"""
Blacklist *what*.
:param what: What to blacklist.
:type what: :class:`list` of classes or :class:`attr.Attribute`\\ s.
:rtype: :class:`callable`
"""
cls, attrs = _split_what(what)
def exclude_(attribute, value):
return value.__class__ not in cls and attribute not in attrs
return exclude_

5
python/attr/filters.pyi Normal file
View File

@ -0,0 +1,5 @@
from typing import Union
from . import Attribute, _FilterType
def include(*what: Union[type, Attribute]) -> _FilterType: ...
def exclude(*what: Union[type, Attribute]) -> _FilterType: ...

0
python/attr/py.typed Normal file
View File

170
python/attr/validators.py Normal file
View File

@ -0,0 +1,170 @@
"""
Commonly useful validators.
"""
from __future__ import absolute_import, division, print_function
from ._make import _AndValidator, and_, attrib, attrs
__all__ = ["and_", "in_", "instance_of", "optional", "provides"]
@attrs(repr=False, slots=True, hash=True)
class _InstanceOfValidator(object):
type = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not isinstance(value, self.type):
raise TypeError(
"'{name}' must be {type!r} (got {value!r} that is a "
"{actual!r}).".format(
name=attr.name,
type=self.type,
actual=value.__class__,
value=value,
),
attr,
self.type,
value,
)
def __repr__(self):
return "<instance_of validator for type {type!r}>".format(
type=self.type
)
def instance_of(type):
"""
A validator that raises a :exc:`TypeError` if the initializer is called
with a wrong type for this particular attribute (checks are performed using
:func:`isinstance` therefore it's also valid to pass a tuple of types).
:param type: The type to check for.
:type type: type or tuple of types
:raises TypeError: With a human readable error message, the attribute
(of type :class:`attr.Attribute`), the expected type, and the value it
got.
"""
return _InstanceOfValidator(type)
@attrs(repr=False, slots=True, hash=True)
class _ProvidesValidator(object):
interface = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.interface.providedBy(value):
raise TypeError(
"'{name}' must provide {interface!r} which {value!r} "
"doesn't.".format(
name=attr.name, interface=self.interface, value=value
),
attr,
self.interface,
value,
)
def __repr__(self):
return "<provides validator for interface {interface!r}>".format(
interface=self.interface
)
def provides(interface):
"""
A validator that raises a :exc:`TypeError` if the initializer is called
with an object that does not provide the requested *interface* (checks are
performed using ``interface.providedBy(value)`` (see `zope.interface
<https://zopeinterface.readthedocs.io/en/latest/>`_).
:param zope.interface.Interface interface: The interface to check for.
:raises TypeError: With a human readable error message, the attribute
(of type :class:`attr.Attribute`), the expected interface, and the
value it got.
"""
return _ProvidesValidator(interface)
@attrs(repr=False, slots=True, hash=True)
class _OptionalValidator(object):
validator = attrib()
def __call__(self, inst, attr, value):
if value is None:
return
self.validator(inst, attr, value)
def __repr__(self):
return "<optional validator for {what} or None>".format(
what=repr(self.validator)
)
def optional(validator):
"""
A validator that makes an attribute optional. An optional attribute is one
which can be set to ``None`` in addition to satisfying the requirements of
the sub-validator.
:param validator: A validator (or a list of validators) that is used for
non-``None`` values.
:type validator: callable or :class:`list` of callables.
.. versionadded:: 15.1.0
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
"""
if isinstance(validator, list):
return _OptionalValidator(_AndValidator(validator))
return _OptionalValidator(validator)
@attrs(repr=False, slots=True, hash=True)
class _InValidator(object):
options = attrib()
def __call__(self, inst, attr, value):
try:
in_options = value in self.options
except TypeError as e: # e.g. `1 in "abc"`
in_options = False
if not in_options:
raise ValueError(
"'{name}' must be in {options!r} (got {value!r})".format(
name=attr.name, options=self.options, value=value
)
)
def __repr__(self):
return "<in_ validator with options {options!r}>".format(
options=self.options
)
def in_(options):
"""
A validator that raises a :exc:`ValueError` if the initializer is called
with a value that does not belong in the options provided. The check is
performed using ``value in options``.
:param options: Allowed options.
:type options: list, tuple, :class:`enum.Enum`, ...
:raises ValueError: With a human readable error message, the attribute (of
type :class:`attr.Attribute`), the expected options, and the value it
got.
.. versionadded:: 17.1.0
"""
return _InValidator(options)

View File

@ -0,0 +1,14 @@
from typing import Container, List, Union, TypeVar, Type, Any, Optional, Tuple
from . import _ValidatorType
_T = TypeVar("_T")
def instance_of(
type: Union[Tuple[Type[_T], ...], Type[_T]]
) -> _ValidatorType[_T]: ...
def provides(interface: Any) -> _ValidatorType[Any]: ...
def optional(
validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]]
) -> _ValidatorType[Optional[_T]]: ...
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...

View File

@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
from ._version import VERSION as __version__

View File

@ -0,0 +1,34 @@
"""
Common code used in multiple modules.
"""
class weekday(object):
__slots__ = ["weekday", "n"]
def __init__(self, weekday, n=None):
self.weekday = weekday
self.n = n
def __call__(self, n):
if n == self.n:
return self
else:
return self.__class__(self.weekday, n)
def __eq__(self, other):
try:
if self.weekday != other.weekday or self.n != other.n:
return False
except AttributeError:
return False
return True
__hash__ = None
def __repr__(self):
s = ("MO", "TU", "WE", "TH", "FR", "SA", "SU")[self.weekday]
if not self.n:
return s
else:
return "%s(%+d)" % (s, self.n)

View File

@ -0,0 +1,10 @@
"""
Contains information about the dateutil version.
"""
VERSION_MAJOR = 2
VERSION_MINOR = 6
VERSION_PATCH = 1
VERSION_TUPLE = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH)
VERSION = '.'.join(map(str, VERSION_TUPLE))

89
python/dateutil/easter.py Normal file
View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
"""
This module offers a generic easter computing method for any given year, using
Western, Orthodox or Julian algorithms.
"""
import datetime
__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"]
EASTER_JULIAN = 1
EASTER_ORTHODOX = 2
EASTER_WESTERN = 3
def easter(year, method=EASTER_WESTERN):
"""
This method was ported from the work done by GM Arts,
on top of the algorithm by Claus Tondering, which was
based in part on the algorithm of Ouding (1940), as
quoted in "Explanatory Supplement to the Astronomical
Almanac", P. Kenneth Seidelmann, editor.
This algorithm implements three different easter
calculation methods:
1 - Original calculation in Julian calendar, valid in
dates after 326 AD
2 - Original method, with date converted to Gregorian
calendar, valid in years 1583 to 4099
3 - Revised method, in Gregorian calendar, valid in
years 1583 to 4099 as well
These methods are represented by the constants:
* ``EASTER_JULIAN = 1``
* ``EASTER_ORTHODOX = 2``
* ``EASTER_WESTERN = 3``
The default method is method 3.
More about the algorithm may be found at:
http://users.chariot.net.au/~gmarts/eastalg.htm
and
http://www.tondering.dk/claus/calendar.html
"""
if not (1 <= method <= 3):
raise ValueError("invalid method")
# g - Golden year - 1
# c - Century
# h - (23 - Epact) mod 30
# i - Number of days from March 21 to Paschal Full Moon
# j - Weekday for PFM (0=Sunday, etc)
# p - Number of days from March 21 to Sunday on or before PFM
# (-6 to 28 methods 1 & 3, to 56 for method 2)
# e - Extra days to add for method 2 (converting Julian
# date to Gregorian date)
y = year
g = y % 19
e = 0
if method < 3:
# Old method
i = (19*g + 15) % 30
j = (y + y//4 + i) % 7
if method == 2:
# Extra dates to convert Julian to Gregorian date
e = 10
if y > 1600:
e = e + y//100 - 16 - (y//100 - 16)//4
else:
# New method
c = y//100
h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30
i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11))
j = (y + y//4 + i + 2 - c + c//4) % 7
# p can be from -6 to 56 corresponding to dates 22 March to 23 May
# (later dates apply to method 2, although 23 May never actually occurs)
p = i - j + e
d = 1 + (p + 27 + (p + 6)//40) % 31
m = 3 + (p + 26)//30
return datetime.date(int(y), int(m), int(d))

1374
python/dateutil/parser.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,549 @@
# -*- coding: utf-8 -*-
import datetime
import calendar
import operator
from math import copysign
from six import integer_types
from warnings import warn
from ._common import weekday
MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7))
__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"]
class relativedelta(object):
"""
The relativedelta type is based on the specification of the excellent
work done by M.-A. Lemburg in his
`mx.DateTime <http://www.egenix.com/files/python/mxDateTime.html>`_ extension.
However, notice that this type does *NOT* implement the same algorithm as
his work. Do *NOT* expect it to behave like mx.DateTime's counterpart.
There are two different ways to build a relativedelta instance. The
first one is passing it two date/datetime classes::
relativedelta(datetime1, datetime2)
The second one is passing it any number of the following keyword arguments::
relativedelta(arg1=x,arg2=y,arg3=z...)
year, month, day, hour, minute, second, microsecond:
Absolute information (argument is singular); adding or subtracting a
relativedelta with absolute information does not perform an aritmetic
operation, but rather REPLACES the corresponding value in the
original datetime with the value(s) in relativedelta.
years, months, weeks, days, hours, minutes, seconds, microseconds:
Relative information, may be negative (argument is plural); adding
or subtracting a relativedelta with relative information performs
the corresponding aritmetic operation on the original datetime value
with the information in the relativedelta.
weekday:
One of the weekday instances (MO, TU, etc). These instances may
receive a parameter N, specifying the Nth weekday, which could
be positive or negative (like MO(+1) or MO(-2). Not specifying
it is the same as specifying +1. You can also use an integer,
where 0=MO.
leapdays:
Will add given days to the date found, if year is a leap
year, and the date found is post 28 of february.
yearday, nlyearday:
Set the yearday or the non-leap year day (jump leap days).
These are converted to day/month/leapdays information.
Here is the behavior of operations with relativedelta:
1. Calculate the absolute year, using the 'year' argument, or the
original datetime year, if the argument is not present.
2. Add the relative 'years' argument to the absolute year.
3. Do steps 1 and 2 for month/months.
4. Calculate the absolute day, using the 'day' argument, or the
original datetime day, if the argument is not present. Then,
subtract from the day until it fits in the year and month
found after their operations.
5. Add the relative 'days' argument to the absolute day. Notice
that the 'weeks' argument is multiplied by 7 and added to
'days'.
6. Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds,
microsecond/microseconds.
7. If the 'weekday' argument is present, calculate the weekday,
with the given (wday, nth) tuple. wday is the index of the
weekday (0-6, 0=Mon), and nth is the number of weeks to add
forward or backward, depending on its signal. Notice that if
the calculated date is already Monday, for example, using
(0, 1) or (0, -1) won't change the day.
"""
def __init__(self, dt1=None, dt2=None,
years=0, months=0, days=0, leapdays=0, weeks=0,
hours=0, minutes=0, seconds=0, microseconds=0,
year=None, month=None, day=None, weekday=None,
yearday=None, nlyearday=None,
hour=None, minute=None, second=None, microsecond=None):
# Check for non-integer values in integer-only quantities
if any(x is not None and x != int(x) for x in (years, months)):
raise ValueError("Non-integer years and months are "
"ambiguous and not currently supported.")
if dt1 and dt2:
# datetime is a subclass of date. So both must be date
if not (isinstance(dt1, datetime.date) and
isinstance(dt2, datetime.date)):
raise TypeError("relativedelta only diffs datetime/date")
# We allow two dates, or two datetimes, so we coerce them to be
# of the same type
if (isinstance(dt1, datetime.datetime) !=
isinstance(dt2, datetime.datetime)):
if not isinstance(dt1, datetime.datetime):
dt1 = datetime.datetime.fromordinal(dt1.toordinal())
elif not isinstance(dt2, datetime.datetime):
dt2 = datetime.datetime.fromordinal(dt2.toordinal())
self.years = 0
self.months = 0
self.days = 0
self.leapdays = 0
self.hours = 0
self.minutes = 0
self.seconds = 0
self.microseconds = 0
self.year = None
self.month = None
self.day = None
self.weekday = None
self.hour = None
self.minute = None
self.second = None
self.microsecond = None
self._has_time = 0
# Get year / month delta between the two
months = (dt1.year - dt2.year) * 12 + (dt1.month - dt2.month)
self._set_months(months)
# Remove the year/month delta so the timedelta is just well-defined
# time units (seconds, days and microseconds)
dtm = self.__radd__(dt2)
# If we've overshot our target, make an adjustment
if dt1 < dt2:
compare = operator.gt
increment = 1
else:
compare = operator.lt
increment = -1
while compare(dt1, dtm):
months += increment
self._set_months(months)
dtm = self.__radd__(dt2)
# Get the timedelta between the "months-adjusted" date and dt1
delta = dt1 - dtm
self.seconds = delta.seconds + delta.days * 86400
self.microseconds = delta.microseconds
else:
# Relative information
self.years = years
self.months = months
self.days = days + weeks * 7
self.leapdays = leapdays
self.hours = hours
self.minutes = minutes
self.seconds = seconds
self.microseconds = microseconds
# Absolute information
self.year = year
self.month = month
self.day = day
self.hour = hour
self.minute = minute
self.second = second
self.microsecond = microsecond
if any(x is not None and int(x) != x
for x in (year, month, day, hour,
minute, second, microsecond)):
# For now we'll deprecate floats - later it'll be an error.
warn("Non-integer value passed as absolute information. " +
"This is not a well-defined condition and will raise " +
"errors in future versions.", DeprecationWarning)
if isinstance(weekday, integer_types):
self.weekday = weekdays[weekday]
else:
self.weekday = weekday
yday = 0
if nlyearday:
yday = nlyearday
elif yearday:
yday = yearday
if yearday > 59:
self.leapdays = -1
if yday:
ydayidx = [31, 59, 90, 120, 151, 181, 212,
243, 273, 304, 334, 366]
for idx, ydays in enumerate(ydayidx):
if yday <= ydays:
self.month = idx+1
if idx == 0:
self.day = yday
else:
self.day = yday-ydayidx[idx-1]
break
else:
raise ValueError("invalid year day (%d)" % yday)
self._fix()
def _fix(self):
if abs(self.microseconds) > 999999:
s = _sign(self.microseconds)
div, mod = divmod(self.microseconds * s, 1000000)
self.microseconds = mod * s
self.seconds += div * s
if abs(self.seconds) > 59:
s = _sign(self.seconds)
div, mod = divmod(self.seconds * s, 60)
self.seconds = mod * s
self.minutes += div * s
if abs(self.minutes) > 59:
s = _sign(self.minutes)
div, mod = divmod(self.minutes * s, 60)
self.minutes = mod * s
self.hours += div * s
if abs(self.hours) > 23:
s = _sign(self.hours)
div, mod = divmod(self.hours * s, 24)
self.hours = mod * s
self.days += div * s
if abs(self.months) > 11:
s = _sign(self.months)
div, mod = divmod(self.months * s, 12)
self.months = mod * s
self.years += div * s
if (self.hours or self.minutes or self.seconds or self.microseconds
or self.hour is not None or self.minute is not None or
self.second is not None or self.microsecond is not None):
self._has_time = 1
else:
self._has_time = 0
@property
def weeks(self):
return self.days // 7
@weeks.setter
def weeks(self, value):
self.days = self.days - (self.weeks * 7) + value * 7
def _set_months(self, months):
self.months = months
if abs(self.months) > 11:
s = _sign(self.months)
div, mod = divmod(self.months * s, 12)
self.months = mod * s
self.years = div * s
else:
self.years = 0
def normalized(self):
"""
Return a version of this object represented entirely using integer
values for the relative attributes.
>>> relativedelta(days=1.5, hours=2).normalized()
relativedelta(days=1, hours=14)
:return:
Returns a :class:`dateutil.relativedelta.relativedelta` object.
"""
# Cascade remainders down (rounding each to roughly nearest microsecond)
days = int(self.days)
hours_f = round(self.hours + 24 * (self.days - days), 11)
hours = int(hours_f)
minutes_f = round(self.minutes + 60 * (hours_f - hours), 10)
minutes = int(minutes_f)
seconds_f = round(self.seconds + 60 * (minutes_f - minutes), 8)
seconds = int(seconds_f)
microseconds = round(self.microseconds + 1e6 * (seconds_f - seconds))
# Constructor carries overflow back up with call to _fix()
return self.__class__(years=self.years, months=self.months,
days=days, hours=hours, minutes=minutes,
seconds=seconds, microseconds=microseconds,
leapdays=self.leapdays, year=self.year,
month=self.month, day=self.day,
weekday=self.weekday, hour=self.hour,
minute=self.minute, second=self.second,
microsecond=self.microsecond)
def __add__(self, other):
if isinstance(other, relativedelta):
return self.__class__(years=other.years + self.years,
months=other.months + self.months,
days=other.days + self.days,
hours=other.hours + self.hours,
minutes=other.minutes + self.minutes,
seconds=other.seconds + self.seconds,
microseconds=(other.microseconds +
self.microseconds),
leapdays=other.leapdays or self.leapdays,
year=(other.year if other.year is not None
else self.year),
month=(other.month if other.month is not None
else self.month),
day=(other.day if other.day is not None
else self.day),
weekday=(other.weekday if other.weekday is not None
else self.weekday),
hour=(other.hour if other.hour is not None
else self.hour),
minute=(other.minute if other.minute is not None
else self.minute),
second=(other.second if other.second is not None
else self.second),
microsecond=(other.microsecond if other.microsecond
is not None else
self.microsecond))
if isinstance(other, datetime.timedelta):
return self.__class__(years=self.years,
months=self.months,
days=self.days + other.days,
hours=self.hours,
minutes=self.minutes,
seconds=self.seconds + other.seconds,
microseconds=self.microseconds + other.microseconds,
leapdays=self.leapdays,
year=self.year,
month=self.month,
day=self.day,
weekday=self.weekday,
hour=self.hour,
minute=self.minute,
second=self.second,
microsecond=self.microsecond)
if not isinstance(other, datetime.date):
return NotImplemented
elif self._has_time and not isinstance(other, datetime.datetime):
other = datetime.datetime.fromordinal(other.toordinal())
year = (self.year or other.year)+self.years
month = self.month or other.month
if self.months:
assert 1 <= abs(self.months) <= 12
month += self.months
if month > 12:
year += 1
month -= 12
elif month < 1:
year -= 1
month += 12
day = min(calendar.monthrange(year, month)[1],
self.day or other.day)
repl = {"year": year, "month": month, "day": day}
for attr in ["hour", "minute", "second", "microsecond"]:
value = getattr(self, attr)
if value is not None:
repl[attr] = value
days = self.days
if self.leapdays and month > 2 and calendar.isleap(year):
days += self.leapdays
ret = (other.replace(**repl)
+ datetime.timedelta(days=days,
hours=self.hours,
minutes=self.minutes,
seconds=self.seconds,
microseconds=self.microseconds))
if self.weekday:
weekday, nth = self.weekday.weekday, self.weekday.n or 1
jumpdays = (abs(nth) - 1) * 7
if nth > 0:
jumpdays += (7 - ret.weekday() + weekday) % 7
else:
jumpdays += (ret.weekday() - weekday) % 7
jumpdays *= -1
ret += datetime.timedelta(days=jumpdays)
return ret
def __radd__(self, other):
return self.__add__(other)
def __rsub__(self, other):
return self.__neg__().__radd__(other)
def __sub__(self, other):
if not isinstance(other, relativedelta):
return NotImplemented # In case the other object defines __rsub__
return self.__class__(years=self.years - other.years,
months=self.months - other.months,
days=self.days - other.days,
hours=self.hours - other.hours,
minutes=self.minutes - other.minutes,
seconds=self.seconds - other.seconds,
microseconds=self.microseconds - other.microseconds,
leapdays=self.leapdays or other.leapdays,
year=(self.year if self.year is not None
else other.year),
month=(self.month if self.month is not None else
other.month),
day=(self.day if self.day is not None else
other.day),
weekday=(self.weekday if self.weekday is not None else
other.weekday),
hour=(self.hour if self.hour is not None else
other.hour),
minute=(self.minute if self.minute is not None else
other.minute),
second=(self.second if self.second is not None else
other.second),
microsecond=(self.microsecond if self.microsecond
is not None else
other.microsecond))
def __neg__(self):
return self.__class__(years=-self.years,
months=-self.months,
days=-self.days,
hours=-self.hours,
minutes=-self.minutes,
seconds=-self.seconds,
microseconds=-self.microseconds,
leapdays=self.leapdays,
year=self.year,
month=self.month,
day=self.day,
weekday=self.weekday,
hour=self.hour,
minute=self.minute,
second=self.second,
microsecond=self.microsecond)
def __bool__(self):
return not (not self.years and
not self.months and
not self.days and
not self.hours and
not self.minutes and
not self.seconds and
not self.microseconds and
not self.leapdays and
self.year is None and
self.month is None and
self.day is None and
self.weekday is None and
self.hour is None and
self.minute is None and
self.second is None and
self.microsecond is None)
# Compatibility with Python 2.x
__nonzero__ = __bool__
def __mul__(self, other):
try:
f = float(other)
except TypeError:
return NotImplemented
return self.__class__(years=int(self.years * f),
months=int(self.months * f),
days=int(self.days * f),
hours=int(self.hours * f),
minutes=int(self.minutes * f),
seconds=int(self.seconds * f),
microseconds=int(self.microseconds * f),
leapdays=self.leapdays,
year=self.year,
month=self.month,
day=self.day,
weekday=self.weekday,
hour=self.hour,
minute=self.minute,
second=self.second,
microsecond=self.microsecond)
__rmul__ = __mul__
def __eq__(self, other):
if not isinstance(other, relativedelta):
return NotImplemented
if self.weekday or other.weekday:
if not self.weekday or not other.weekday:
return False
if self.weekday.weekday != other.weekday.weekday:
return False
n1, n2 = self.weekday.n, other.weekday.n
if n1 != n2 and not ((not n1 or n1 == 1) and (not n2 or n2 == 1)):
return False
return (self.years == other.years and
self.months == other.months and
self.days == other.days and
self.hours == other.hours and
self.minutes == other.minutes and
self.seconds == other.seconds and
self.microseconds == other.microseconds and
self.leapdays == other.leapdays and
self.year == other.year and
self.month == other.month and
self.day == other.day and
self.hour == other.hour and
self.minute == other.minute and
self.second == other.second and
self.microsecond == other.microsecond)
__hash__ = None
def __ne__(self, other):
return not self.__eq__(other)
def __div__(self, other):
try:
reciprocal = 1 / float(other)
except TypeError:
return NotImplemented
return self.__mul__(reciprocal)
__truediv__ = __div__
def __repr__(self):
l = []
for attr in ["years", "months", "days", "leapdays",
"hours", "minutes", "seconds", "microseconds"]:
value = getattr(self, attr)
if value:
l.append("{attr}={value:+g}".format(attr=attr, value=value))
for attr in ["year", "month", "day", "weekday",
"hour", "minute", "second", "microsecond"]:
value = getattr(self, attr)
if value is not None:
l.append("{attr}={value}".format(attr=attr, value=repr(value)))
return "{classname}({attrs})".format(classname=self.__class__.__name__,
attrs=", ".join(l))
def _sign(x):
return int(copysign(1, x))
# vim:ts=4:sw=4:et

1610
python/dateutil/rrule.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,5 @@
from .tz import *
__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange",
"tzstr", "tzical", "tzwin", "tzwinlocal", "gettz",
"enfold", "datetime_ambiguous", "datetime_exists"]

View File

@ -0,0 +1,394 @@
from six import PY3
from functools import wraps
from datetime import datetime, timedelta, tzinfo
ZERO = timedelta(0)
__all__ = ['tzname_in_python2', 'enfold']
def tzname_in_python2(namefunc):
"""Change unicode output into bytestrings in Python 2
tzname() API changed in Python 3. It used to return bytes, but was changed
to unicode strings
"""
def adjust_encoding(*args, **kwargs):
name = namefunc(*args, **kwargs)
if name is not None and not PY3:
name = name.encode()
return name
return adjust_encoding
# The following is adapted from Alexander Belopolsky's tz library
# https://github.com/abalkin/tz
if hasattr(datetime, 'fold'):
# This is the pre-python 3.6 fold situation
def enfold(dt, fold=1):
"""
Provides a unified interface for assigning the ``fold`` attribute to
datetimes both before and after the implementation of PEP-495.
:param fold:
The value for the ``fold`` attribute in the returned datetime. This
should be either 0 or 1.
:return:
Returns an object for which ``getattr(dt, 'fold', 0)`` returns
``fold`` for all versions of Python. In versions prior to
Python 3.6, this is a ``_DatetimeWithFold`` object, which is a
subclass of :py:class:`datetime.datetime` with the ``fold``
attribute added, if ``fold`` is 1.
.. versionadded:: 2.6.0
"""
return dt.replace(fold=fold)
else:
class _DatetimeWithFold(datetime):
"""
This is a class designed to provide a PEP 495-compliant interface for
Python versions before 3.6. It is used only for dates in a fold, so
the ``fold`` attribute is fixed at ``1``.
.. versionadded:: 2.6.0
"""
__slots__ = ()
@property
def fold(self):
return 1
def enfold(dt, fold=1):
"""
Provides a unified interface for assigning the ``fold`` attribute to
datetimes both before and after the implementation of PEP-495.
:param fold:
The value for the ``fold`` attribute in the returned datetime. This
should be either 0 or 1.
:return:
Returns an object for which ``getattr(dt, 'fold', 0)`` returns
``fold`` for all versions of Python. In versions prior to
Python 3.6, this is a ``_DatetimeWithFold`` object, which is a
subclass of :py:class:`datetime.datetime` with the ``fold``
attribute added, if ``fold`` is 1.
.. versionadded:: 2.6.0
"""
if getattr(dt, 'fold', 0) == fold:
return dt
args = dt.timetuple()[:6]
args += (dt.microsecond, dt.tzinfo)
if fold:
return _DatetimeWithFold(*args)
else:
return datetime(*args)
def _validate_fromutc_inputs(f):
"""
The CPython version of ``fromutc`` checks that the input is a ``datetime``
object and that ``self`` is attached as its ``tzinfo``.
"""
@wraps(f)
def fromutc(self, dt):
if not isinstance(dt, datetime):
raise TypeError("fromutc() requires a datetime argument")
if dt.tzinfo is not self:
raise ValueError("dt.tzinfo is not self")
return f(self, dt)
return fromutc
class _tzinfo(tzinfo):
"""
Base class for all ``dateutil`` ``tzinfo`` objects.
"""
def is_ambiguous(self, dt):
"""
Whether or not the "wall time" of a given datetime is ambiguous in this
zone.
:param dt:
A :py:class:`datetime.datetime`, naive or time zone aware.
:return:
Returns ``True`` if ambiguous, ``False`` otherwise.
.. versionadded:: 2.6.0
"""
dt = dt.replace(tzinfo=self)
wall_0 = enfold(dt, fold=0)
wall_1 = enfold(dt, fold=1)
same_offset = wall_0.utcoffset() == wall_1.utcoffset()
same_dt = wall_0.replace(tzinfo=None) == wall_1.replace(tzinfo=None)
return same_dt and not same_offset
def _fold_status(self, dt_utc, dt_wall):
"""
Determine the fold status of a "wall" datetime, given a representation
of the same datetime as a (naive) UTC datetime. This is calculated based
on the assumption that ``dt.utcoffset() - dt.dst()`` is constant for all
datetimes, and that this offset is the actual number of hours separating
``dt_utc`` and ``dt_wall``.
:param dt_utc:
Representation of the datetime as UTC
:param dt_wall:
Representation of the datetime as "wall time". This parameter must
either have a `fold` attribute or have a fold-naive
:class:`datetime.tzinfo` attached, otherwise the calculation may
fail.
"""
if self.is_ambiguous(dt_wall):
delta_wall = dt_wall - dt_utc
_fold = int(delta_wall == (dt_utc.utcoffset() - dt_utc.dst()))
else:
_fold = 0
return _fold
def _fold(self, dt):
return getattr(dt, 'fold', 0)
def _fromutc(self, dt):
"""
Given a timezone-aware datetime in a given timezone, calculates a
timezone-aware datetime in a new timezone.
Since this is the one time that we *know* we have an unambiguous
datetime object, we take this opportunity to determine whether the
datetime is ambiguous and in a "fold" state (e.g. if it's the first
occurence, chronologically, of the ambiguous datetime).
:param dt:
A timezone-aware :class:`datetime.datetime` object.
"""
# Re-implement the algorithm from Python's datetime.py
dtoff = dt.utcoffset()
if dtoff is None:
raise ValueError("fromutc() requires a non-None utcoffset() "
"result")
# The original datetime.py code assumes that `dst()` defaults to
# zero during ambiguous times. PEP 495 inverts this presumption, so
# for pre-PEP 495 versions of python, we need to tweak the algorithm.
dtdst = dt.dst()
if dtdst is None:
raise ValueError("fromutc() requires a non-None dst() result")
delta = dtoff - dtdst
dt += delta
# Set fold=1 so we can default to being in the fold for
# ambiguous dates.
dtdst = enfold(dt, fold=1).dst()
if dtdst is None:
raise ValueError("fromutc(): dt.dst gave inconsistent "
"results; cannot convert")
return dt + dtdst
@_validate_fromutc_inputs
def fromutc(self, dt):
"""
Given a timezone-aware datetime in a given timezone, calculates a
timezone-aware datetime in a new timezone.
Since this is the one time that we *know* we have an unambiguous
datetime object, we take this opportunity to determine whether the
datetime is ambiguous and in a "fold" state (e.g. if it's the first
occurance, chronologically, of the ambiguous datetime).
:param dt:
A timezone-aware :class:`datetime.datetime` object.
"""
dt_wall = self._fromutc(dt)
# Calculate the fold status given the two datetimes.
_fold = self._fold_status(dt, dt_wall)
# Set the default fold value for ambiguous dates
return enfold(dt_wall, fold=_fold)
class tzrangebase(_tzinfo):
"""
This is an abstract base class for time zones represented by an annual
transition into and out of DST. Child classes should implement the following
methods:
* ``__init__(self, *args, **kwargs)``
* ``transitions(self, year)`` - this is expected to return a tuple of
datetimes representing the DST on and off transitions in standard
time.
A fully initialized ``tzrangebase`` subclass should also provide the
following attributes:
* ``hasdst``: Boolean whether or not the zone uses DST.
* ``_dst_offset`` / ``_std_offset``: :class:`datetime.timedelta` objects
representing the respective UTC offsets.
* ``_dst_abbr`` / ``_std_abbr``: Strings representing the timezone short
abbreviations in DST and STD, respectively.
* ``_hasdst``: Whether or not the zone has DST.
.. versionadded:: 2.6.0
"""
def __init__(self):
raise NotImplementedError('tzrangebase is an abstract base class')
def utcoffset(self, dt):
isdst = self._isdst(dt)
if isdst is None:
return None
elif isdst:
return self._dst_offset
else:
return self._std_offset
def dst(self, dt):
isdst = self._isdst(dt)
if isdst is None:
return None
elif isdst:
return self._dst_base_offset
else:
return ZERO
@tzname_in_python2
def tzname(self, dt):
if self._isdst(dt):
return self._dst_abbr
else:
return self._std_abbr
def fromutc(self, dt):
""" Given a datetime in UTC, return local time """
if not isinstance(dt, datetime):
raise TypeError("fromutc() requires a datetime argument")
if dt.tzinfo is not self:
raise ValueError("dt.tzinfo is not self")
# Get transitions - if there are none, fixed offset
transitions = self.transitions(dt.year)
if transitions is None:
return dt + self.utcoffset(dt)
# Get the transition times in UTC
dston, dstoff = transitions
dston -= self._std_offset
dstoff -= self._std_offset
utc_transitions = (dston, dstoff)
dt_utc = dt.replace(tzinfo=None)
isdst = self._naive_isdst(dt_utc, utc_transitions)
if isdst:
dt_wall = dt + self._dst_offset
else:
dt_wall = dt + self._std_offset
_fold = int(not isdst and self.is_ambiguous(dt_wall))
return enfold(dt_wall, fold=_fold)
def is_ambiguous(self, dt):
"""
Whether or not the "wall time" of a given datetime is ambiguous in this
zone.
:param dt:
A :py:class:`datetime.datetime`, naive or time zone aware.
:return:
Returns ``True`` if ambiguous, ``False`` otherwise.
.. versionadded:: 2.6.0
"""
if not self.hasdst:
return False
start, end = self.transitions(dt.year)
dt = dt.replace(tzinfo=None)
return (end <= dt < end + self._dst_base_offset)
def _isdst(self, dt):
if not self.hasdst:
return False
elif dt is None:
return None
transitions = self.transitions(dt.year)
if transitions is None:
return False
dt = dt.replace(tzinfo=None)
isdst = self._naive_isdst(dt, transitions)
# Handle ambiguous dates
if not isdst and self.is_ambiguous(dt):
return not self._fold(dt)
else:
return isdst
def _naive_isdst(self, dt, transitions):
dston, dstoff = transitions
dt = dt.replace(tzinfo=None)
if dston < dstoff:
isdst = dston <= dt < dstoff
else:
isdst = not dstoff <= dt < dston
return isdst
@property
def _dst_base_offset(self):
return self._dst_offset - self._std_offset
__hash__ = None
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return "%s(...)" % self.__class__.__name__
__reduce__ = object.__reduce__
def _total_seconds(td):
# Python 2.6 doesn't have a total_seconds() method on timedelta objects
return ((td.seconds + td.days * 86400) * 1000000 +
td.microseconds) // 1000000
_total_seconds = getattr(timedelta, 'total_seconds', _total_seconds)

1511
python/dateutil/tz/tz.py Normal file

File diff suppressed because it is too large Load Diff

332
python/dateutil/tz/win.py Normal file
View File

@ -0,0 +1,332 @@
# This code was originally contributed by Jeffrey Harris.
import datetime
import struct
from six.moves import winreg
from six import text_type
try:
import ctypes
from ctypes import wintypes
except ValueError:
# ValueError is raised on non-Windows systems for some horrible reason.
raise ImportError("Running tzwin on non-Windows system")
from ._common import tzrangebase
__all__ = ["tzwin", "tzwinlocal", "tzres"]
ONEWEEK = datetime.timedelta(7)
TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones"
TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones"
TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation"
def _settzkeyname():
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try:
winreg.OpenKey(handle, TZKEYNAMENT).Close()
TZKEYNAME = TZKEYNAMENT
except WindowsError:
TZKEYNAME = TZKEYNAME9X
handle.Close()
return TZKEYNAME
TZKEYNAME = _settzkeyname()
class tzres(object):
"""
Class for accessing `tzres.dll`, which contains timezone name related
resources.
.. versionadded:: 2.5.0
"""
p_wchar = ctypes.POINTER(wintypes.WCHAR) # Pointer to a wide char
def __init__(self, tzres_loc='tzres.dll'):
# Load the user32 DLL so we can load strings from tzres
user32 = ctypes.WinDLL('user32')
# Specify the LoadStringW function
user32.LoadStringW.argtypes = (wintypes.HINSTANCE,
wintypes.UINT,
wintypes.LPWSTR,
ctypes.c_int)
self.LoadStringW = user32.LoadStringW
self._tzres = ctypes.WinDLL(tzres_loc)
self.tzres_loc = tzres_loc
def load_name(self, offset):
"""
Load a timezone name from a DLL offset (integer).
>>> from dateutil.tzwin import tzres
>>> tzr = tzres()
>>> print(tzr.load_name(112))
'Eastern Standard Time'
:param offset:
A positive integer value referring to a string from the tzres dll.
..note:
Offsets found in the registry are generally of the form
`@tzres.dll,-114`. The offset in this case if 114, not -114.
"""
resource = self.p_wchar()
lpBuffer = ctypes.cast(ctypes.byref(resource), wintypes.LPWSTR)
nchar = self.LoadStringW(self._tzres._handle, offset, lpBuffer, 0)
return resource[:nchar]
def name_from_string(self, tzname_str):
"""
Parse strings as returned from the Windows registry into the time zone
name as defined in the registry.
>>> from dateutil.tzwin import tzres
>>> tzr = tzres()
>>> print(tzr.name_from_string('@tzres.dll,-251'))
'Dateline Daylight Time'
>>> print(tzr.name_from_string('Eastern Standard Time'))
'Eastern Standard Time'
:param tzname_str:
A timezone name string as returned from a Windows registry key.
:return:
Returns the localized timezone string from tzres.dll if the string
is of the form `@tzres.dll,-offset`, else returns the input string.
"""
if not tzname_str.startswith('@'):
return tzname_str
name_splt = tzname_str.split(',-')
try:
offset = int(name_splt[1])
except:
raise ValueError("Malformed timezone string.")
return self.load_name(offset)
class tzwinbase(tzrangebase):
"""tzinfo class based on win32's timezones available in the registry."""
def __init__(self):
raise NotImplementedError('tzwinbase is an abstract base class')
def __eq__(self, other):
# Compare on all relevant dimensions, including name.
if not isinstance(other, tzwinbase):
return NotImplemented
return (self._std_offset == other._std_offset and
self._dst_offset == other._dst_offset and
self._stddayofweek == other._stddayofweek and
self._dstdayofweek == other._dstdayofweek and
self._stdweeknumber == other._stdweeknumber and
self._dstweeknumber == other._dstweeknumber and
self._stdhour == other._stdhour and
self._dsthour == other._dsthour and
self._stdminute == other._stdminute and
self._dstminute == other._dstminute and
self._std_abbr == other._std_abbr and
self._dst_abbr == other._dst_abbr)
@staticmethod
def list():
"""Return a list of all time zones known to the system."""
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
with winreg.OpenKey(handle, TZKEYNAME) as tzkey:
result = [winreg.EnumKey(tzkey, i)
for i in range(winreg.QueryInfoKey(tzkey)[0])]
return result
def display(self):
return self._display
def transitions(self, year):
"""
For a given year, get the DST on and off transition times, expressed
always on the standard time side. For zones with no transitions, this
function returns ``None``.
:param year:
The year whose transitions you would like to query.
:return:
Returns a :class:`tuple` of :class:`datetime.datetime` objects,
``(dston, dstoff)`` for zones with an annual DST transition, or
``None`` for fixed offset zones.
"""
if not self.hasdst:
return None
dston = picknthweekday(year, self._dstmonth, self._dstdayofweek,
self._dsthour, self._dstminute,
self._dstweeknumber)
dstoff = picknthweekday(year, self._stdmonth, self._stddayofweek,
self._stdhour, self._stdminute,
self._stdweeknumber)
# Ambiguous dates default to the STD side
dstoff -= self._dst_base_offset
return dston, dstoff
def _get_hasdst(self):
return self._dstmonth != 0
@property
def _dst_base_offset(self):
return self._dst_base_offset_
class tzwin(tzwinbase):
def __init__(self, name):
self._name = name
# multiple contexts only possible in 2.7 and 3.1, we still support 2.6
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
tzkeyname = text_type("{kn}\\{name}").format(kn=TZKEYNAME, name=name)
with winreg.OpenKey(handle, tzkeyname) as tzkey:
keydict = valuestodict(tzkey)
self._std_abbr = keydict["Std"]
self._dst_abbr = keydict["Dlt"]
self._display = keydict["Display"]
# See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm
tup = struct.unpack("=3l16h", keydict["TZI"])
stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1
dstoffset = stdoffset-tup[2] # + DaylightBias * -1
self._std_offset = datetime.timedelta(minutes=stdoffset)
self._dst_offset = datetime.timedelta(minutes=dstoffset)
# for the meaning see the win32 TIME_ZONE_INFORMATION structure docs
# http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx
(self._stdmonth,
self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5
self._stdhour,
self._stdminute) = tup[4:9]
(self._dstmonth,
self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5
self._dsthour,
self._dstminute) = tup[12:17]
self._dst_base_offset_ = self._dst_offset - self._std_offset
self.hasdst = self._get_hasdst()
def __repr__(self):
return "tzwin(%s)" % repr(self._name)
def __reduce__(self):
return (self.__class__, (self._name,))
class tzwinlocal(tzwinbase):
def __init__(self):
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey:
keydict = valuestodict(tzlocalkey)
self._std_abbr = keydict["StandardName"]
self._dst_abbr = keydict["DaylightName"]
try:
tzkeyname = text_type('{kn}\\{sn}').format(kn=TZKEYNAME,
sn=self._std_abbr)
with winreg.OpenKey(handle, tzkeyname) as tzkey:
_keydict = valuestodict(tzkey)
self._display = _keydict["Display"]
except OSError:
self._display = None
stdoffset = -keydict["Bias"]-keydict["StandardBias"]
dstoffset = stdoffset-keydict["DaylightBias"]
self._std_offset = datetime.timedelta(minutes=stdoffset)
self._dst_offset = datetime.timedelta(minutes=dstoffset)
# For reasons unclear, in this particular key, the day of week has been
# moved to the END of the SYSTEMTIME structure.
tup = struct.unpack("=8h", keydict["StandardStart"])
(self._stdmonth,
self._stdweeknumber, # Last = 5
self._stdhour,
self._stdminute) = tup[1:5]
self._stddayofweek = tup[7]
tup = struct.unpack("=8h", keydict["DaylightStart"])
(self._dstmonth,
self._dstweeknumber, # Last = 5
self._dsthour,
self._dstminute) = tup[1:5]
self._dstdayofweek = tup[7]
self._dst_base_offset_ = self._dst_offset - self._std_offset
self.hasdst = self._get_hasdst()
def __repr__(self):
return "tzwinlocal()"
def __str__(self):
# str will return the standard name, not the daylight name.
return "tzwinlocal(%s)" % repr(self._std_abbr)
def __reduce__(self):
return (self.__class__, ())
def picknthweekday(year, month, dayofweek, hour, minute, whichweek):
""" dayofweek == 0 means Sunday, whichweek 5 means last instance """
first = datetime.datetime(year, month, 1, hour, minute)
# This will work if dayofweek is ISO weekday (1-7) or Microsoft-style (0-6),
# Because 7 % 7 = 0
weekdayone = first.replace(day=((dayofweek - first.isoweekday()) % 7) + 1)
wd = weekdayone + ((whichweek - 1) * ONEWEEK)
if (wd.month != month):
wd -= ONEWEEK
return wd
def valuestodict(key):
"""Convert a registry key's values to a dictionary."""
dout = {}
size = winreg.QueryInfoKey(key)[1]
tz_res = None
for i in range(size):
key_name, value, dtype = winreg.EnumValue(key, i)
if dtype == winreg.REG_DWORD or dtype == winreg.REG_DWORD_LITTLE_ENDIAN:
# If it's a DWORD (32-bit integer), it's stored as unsigned - convert
# that to a proper signed integer
if value & (1 << 31):
value = value - (1 << 32)
elif dtype == winreg.REG_SZ:
# If it's a reference to the tzres DLL, load the actual string
if value.startswith('@tzres'):
tz_res = tz_res or tzres()
value = tz_res.name_from_string(value)
value = value.rstrip('\x00') # Remove trailing nulls
dout[key_name] = value
return dout

2
python/dateutil/tzwin.py Normal file
View File

@ -0,0 +1,2 @@
# tzwin has moved to dateutil.tz.win
from .tz.win import *

View File

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
import warnings
import json
from tarfile import TarFile
from pkgutil import get_data
from io import BytesIO
from contextlib import closing
from dateutil.tz import tzfile
__all__ = ["get_zonefile_instance", "gettz", "gettz_db_metadata", "rebuild"]
ZONEFILENAME = "dateutil-zoneinfo.tar.gz"
METADATA_FN = 'METADATA'
# python2.6 compatability. Note that TarFile.__exit__ != TarFile.close, but
# it's close enough for python2.6
tar_open = TarFile.open
if not hasattr(TarFile, '__exit__'):
def tar_open(*args, **kwargs):
return closing(TarFile.open(*args, **kwargs))
class tzfile(tzfile):
def __reduce__(self):
return (gettz, (self._filename,))
def getzoneinfofile_stream():
try:
return BytesIO(get_data(__name__, ZONEFILENAME))
except IOError as e: # TODO switch to FileNotFoundError?
warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror))
return None
class ZoneInfoFile(object):
def __init__(self, zonefile_stream=None):
if zonefile_stream is not None:
with tar_open(fileobj=zonefile_stream, mode='r') as tf:
# dict comprehension does not work on python2.6
# TODO: get back to the nicer syntax when we ditch python2.6
# self.zones = {zf.name: tzfile(tf.extractfile(zf),
# filename = zf.name)
# for zf in tf.getmembers() if zf.isfile()}
self.zones = dict((zf.name, tzfile(tf.extractfile(zf),
filename=zf.name))
for zf in tf.getmembers()
if zf.isfile() and zf.name != METADATA_FN)
# deal with links: They'll point to their parent object. Less
# waste of memory
# links = {zl.name: self.zones[zl.linkname]
# for zl in tf.getmembers() if zl.islnk() or zl.issym()}
links = dict((zl.name, self.zones[zl.linkname])
for zl in tf.getmembers() if
zl.islnk() or zl.issym())
self.zones.update(links)
try:
metadata_json = tf.extractfile(tf.getmember(METADATA_FN))
metadata_str = metadata_json.read().decode('UTF-8')
self.metadata = json.loads(metadata_str)
except KeyError:
# no metadata in tar file
self.metadata = None
else:
self.zones = dict()
self.metadata = None
def get(self, name, default=None):
"""
Wrapper for :func:`ZoneInfoFile.zones.get`. This is a convenience method
for retrieving zones from the zone dictionary.
:param name:
The name of the zone to retrieve. (Generally IANA zone names)
:param default:
The value to return in the event of a missing key.
.. versionadded:: 2.6.0
"""
return self.zones.get(name, default)
# The current API has gettz as a module function, although in fact it taps into
# a stateful class. So as a workaround for now, without changing the API, we
# will create a new "global" class instance the first time a user requests a
# timezone. Ugly, but adheres to the api.
#
# TODO: Remove after deprecation period.
_CLASS_ZONE_INSTANCE = list()
def get_zonefile_instance(new_instance=False):
"""
This is a convenience function which provides a :class:`ZoneInfoFile`
instance using the data provided by the ``dateutil`` package. By default, it
caches a single instance of the ZoneInfoFile object and returns that.
:param new_instance:
If ``True``, a new instance of :class:`ZoneInfoFile` is instantiated and
used as the cached instance for the next call. Otherwise, new instances
are created only as necessary.
:return:
Returns a :class:`ZoneInfoFile` object.
.. versionadded:: 2.6
"""
if new_instance:
zif = None
else:
zif = getattr(get_zonefile_instance, '_cached_instance', None)
if zif is None:
zif = ZoneInfoFile(getzoneinfofile_stream())
get_zonefile_instance._cached_instance = zif
return zif
def gettz(name):
"""
This retrieves a time zone from the local zoneinfo tarball that is packaged
with dateutil.
:param name:
An IANA-style time zone name, as found in the zoneinfo file.
:return:
Returns a :class:`dateutil.tz.tzfile` time zone object.
.. warning::
It is generally inadvisable to use this function, and it is only
provided for API compatibility with earlier versions. This is *not*
equivalent to ``dateutil.tz.gettz()``, which selects an appropriate
time zone based on the inputs, favoring system zoneinfo. This is ONLY
for accessing the dateutil-specific zoneinfo (which may be out of
date compared to the system zoneinfo).
.. deprecated:: 2.6
If you need to use a specific zoneinfofile over the system zoneinfo,
instantiate a :class:`dateutil.zoneinfo.ZoneInfoFile` object and call
:func:`dateutil.zoneinfo.ZoneInfoFile.get(name)` instead.
Use :func:`get_zonefile_instance` to retrieve an instance of the
dateutil-provided zoneinfo.
"""
warnings.warn("zoneinfo.gettz() will be removed in future versions, "
"to use the dateutil-provided zoneinfo files, instantiate a "
"ZoneInfoFile object and use ZoneInfoFile.zones.get() "
"instead. See the documentation for details.",
DeprecationWarning)
if len(_CLASS_ZONE_INSTANCE) == 0:
_CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream()))
return _CLASS_ZONE_INSTANCE[0].zones.get(name)
def gettz_db_metadata():
""" Get the zonefile metadata
See `zonefile_metadata`_
:returns:
A dictionary with the database metadata
.. deprecated:: 2.6
See deprecation warning in :func:`zoneinfo.gettz`. To get metadata,
query the attribute ``zoneinfo.ZoneInfoFile.metadata``.
"""
warnings.warn("zoneinfo.gettz_db_metadata() will be removed in future "
"versions, to use the dateutil-provided zoneinfo files, "
"ZoneInfoFile object and query the 'metadata' attribute "
"instead. See the documentation for details.",
DeprecationWarning)
if len(_CLASS_ZONE_INSTANCE) == 0:
_CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream()))
return _CLASS_ZONE_INSTANCE[0].metadata

Binary file not shown.

View File

@ -0,0 +1,52 @@
import logging
import os
import tempfile
import shutil
import json
from subprocess import check_call
from dateutil.zoneinfo import tar_open, METADATA_FN, ZONEFILENAME
def rebuild(filename, tag=None, format="gz", zonegroups=[], metadata=None):
"""Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar*
filename is the timezone tarball from ftp.iana.org/tz.
"""
tmpdir = tempfile.mkdtemp()
zonedir = os.path.join(tmpdir, "zoneinfo")
moduledir = os.path.dirname(__file__)
try:
with tar_open(filename) as tf:
for name in zonegroups:
tf.extract(name, tmpdir)
filepaths = [os.path.join(tmpdir, n) for n in zonegroups]
try:
check_call(["zic", "-d", zonedir] + filepaths)
except OSError as e:
_print_on_nosuchfile(e)
raise
# write metadata file
with open(os.path.join(zonedir, METADATA_FN), 'w') as f:
json.dump(metadata, f, indent=4, sort_keys=True)
target = os.path.join(moduledir, ZONEFILENAME)
with tar_open(target, "w:%s" % format) as tf:
for entry in os.listdir(zonedir):
entrypath = os.path.join(zonedir, entry)
tf.add(entrypath, entry)
finally:
shutil.rmtree(tmpdir)
def _print_on_nosuchfile(e):
"""Print helpful troubleshooting message
e is an exception raised by subprocess.check_call()
"""
if e.errno == 2:
logging.error(
"Could not find zic. Perhaps you need to install "
"libc-bin or some other package that provides it, "
"or it's not in your PATH?")

View File

@ -0,0 +1,112 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.etree.ElementTree facade
"""
from __future__ import print_function, absolute_import
import sys
from xml.etree.ElementTree import TreeBuilder as _TreeBuilder
from xml.etree.ElementTree import parse as _parse
from xml.etree.ElementTree import tostring
from .common import PY3
if PY3:
import importlib
else:
from xml.etree.ElementTree import XMLParser as _XMLParser
from xml.etree.ElementTree import iterparse as _iterparse
from xml.etree.ElementTree import ParseError
from .common import (DTDForbidden, EntitiesForbidden,
ExternalReferenceForbidden, _generate_etree_functions)
__origin__ = "xml.etree.ElementTree"
def _get_py3_cls():
"""Python 3.3 hides the pure Python code but defusedxml requires it.
The code is based on test.support.import_fresh_module().
"""
pymodname = "xml.etree.ElementTree"
cmodname = "_elementtree"
pymod = sys.modules.pop(pymodname, None)
cmod = sys.modules.pop(cmodname, None)
sys.modules[cmodname] = None
pure_pymod = importlib.import_module(pymodname)
if cmod is not None:
sys.modules[cmodname] = cmod
else:
sys.modules.pop(cmodname)
sys.modules[pymodname] = pymod
_XMLParser = pure_pymod.XMLParser
_iterparse = pure_pymod.iterparse
ParseError = pure_pymod.ParseError
return _XMLParser, _iterparse, ParseError
if PY3:
_XMLParser, _iterparse, ParseError = _get_py3_cls()
class DefusedXMLParser(_XMLParser):
def __init__(self, html=0, target=None, encoding=None,
forbid_dtd=False, forbid_entities=True,
forbid_external=True):
# Python 2.x old style class
_XMLParser.__init__(self, html, target, encoding)
self.forbid_dtd = forbid_dtd
self.forbid_entities = forbid_entities
self.forbid_external = forbid_external
if PY3:
parser = self.parser
else:
parser = self._parser
if self.forbid_dtd:
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
if self.forbid_entities:
parser.EntityDeclHandler = self.defused_entity_decl
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
if self.forbid_external:
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
def defused_start_doctype_decl(self, name, sysid, pubid,
has_internal_subset):
raise DTDForbidden(name, sysid, pubid)
def defused_entity_decl(self, name, is_parameter_entity, value, base,
sysid, pubid, notation_name):
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
notation_name):
# expat 1.2
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
def defused_external_entity_ref_handler(self, context, base, sysid,
pubid):
raise ExternalReferenceForbidden(context, base, sysid, pubid)
# aliases
XMLTreeBuilder = XMLParse = DefusedXMLParser
parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser,
_TreeBuilder, _parse,
_iterparse)
XML = fromstring
__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse',
'parse', 'tostring']

View File

@ -0,0 +1,45 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defuse XML bomb denial of service vulnerabilities
"""
from __future__ import print_function, absolute_import
from .common import (DefusedXmlException, DTDForbidden, EntitiesForbidden,
ExternalReferenceForbidden, NotSupportedError,
_apply_defusing)
def defuse_stdlib():
"""Monkey patch and defuse all stdlib packages
:warning: The monkey patch is an EXPERIMETNAL feature.
"""
defused = {}
from . import cElementTree
from . import ElementTree
from . import minidom
from . import pulldom
from . import sax
from . import expatbuilder
from . import expatreader
from . import xmlrpc
xmlrpc.monkey_patch()
defused[xmlrpc] = None
for defused_mod in [cElementTree, ElementTree, minidom, pulldom, sax,
expatbuilder, expatreader]:
stdlib_mod = _apply_defusing(defused_mod)
defused[defused_mod] = stdlib_mod
return defused
__version__ = "0.5.0"
__all__ = ['DefusedXmlException', 'DTDForbidden', 'EntitiesForbidden',
'ExternalReferenceForbidden', 'NotSupportedError']

View File

@ -0,0 +1,30 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.etree.cElementTree
"""
from __future__ import absolute_import
from xml.etree.cElementTree import TreeBuilder as _TreeBuilder
from xml.etree.cElementTree import parse as _parse
from xml.etree.cElementTree import tostring
# iterparse from ElementTree!
from xml.etree.ElementTree import iterparse as _iterparse
from .ElementTree import DefusedXMLParser
from .common import _generate_etree_functions
__origin__ = "xml.etree.cElementTree"
XMLTreeBuilder = XMLParse = DefusedXMLParser
parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser,
_TreeBuilder, _parse,
_iterparse)
XML = fromstring
__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse',
'parse', 'tostring']

120
python/defusedxml/common.py Normal file
View File

@ -0,0 +1,120 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Common constants, exceptions and helpe functions
"""
import sys
PY3 = sys.version_info[0] == 3
class DefusedXmlException(ValueError):
"""Base exception
"""
def __repr__(self):
return str(self)
class DTDForbidden(DefusedXmlException):
"""Document type definition is forbidden
"""
def __init__(self, name, sysid, pubid):
super(DTDForbidden, self).__init__()
self.name = name
self.sysid = sysid
self.pubid = pubid
def __str__(self):
tpl = "DTDForbidden(name='{}', system_id={!r}, public_id={!r})"
return tpl.format(self.name, self.sysid, self.pubid)
class EntitiesForbidden(DefusedXmlException):
"""Entity definition is forbidden
"""
def __init__(self, name, value, base, sysid, pubid, notation_name):
super(EntitiesForbidden, self).__init__()
self.name = name
self.value = value
self.base = base
self.sysid = sysid
self.pubid = pubid
self.notation_name = notation_name
def __str__(self):
tpl = "EntitiesForbidden(name='{}', system_id={!r}, public_id={!r})"
return tpl.format(self.name, self.sysid, self.pubid)
class ExternalReferenceForbidden(DefusedXmlException):
"""Resolving an external reference is forbidden
"""
def __init__(self, context, base, sysid, pubid):
super(ExternalReferenceForbidden, self).__init__()
self.context = context
self.base = base
self.sysid = sysid
self.pubid = pubid
def __str__(self):
tpl = "ExternalReferenceForbidden(system_id='{}', public_id={})"
return tpl.format(self.sysid, self.pubid)
class NotSupportedError(DefusedXmlException):
"""The operation is not supported
"""
def _apply_defusing(defused_mod):
assert defused_mod is sys.modules[defused_mod.__name__]
stdlib_name = defused_mod.__origin__
__import__(stdlib_name, {}, {}, ["*"])
stdlib_mod = sys.modules[stdlib_name]
stdlib_names = set(dir(stdlib_mod))
for name, obj in vars(defused_mod).items():
if name.startswith("_") or name not in stdlib_names:
continue
setattr(stdlib_mod, name, obj)
return stdlib_mod
def _generate_etree_functions(DefusedXMLParser, _TreeBuilder,
_parse, _iterparse):
"""Factory for functions needed by etree, dependent on whether
cElementTree or ElementTree is used."""
def parse(source, parser=None, forbid_dtd=False, forbid_entities=True,
forbid_external=True):
if parser is None:
parser = DefusedXMLParser(target=_TreeBuilder(),
forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
return _parse(source, parser)
def iterparse(source, events=None, parser=None, forbid_dtd=False,
forbid_entities=True, forbid_external=True):
if parser is None:
parser = DefusedXMLParser(target=_TreeBuilder(),
forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
return _iterparse(source, events, parser)
def fromstring(text, forbid_dtd=False, forbid_entities=True,
forbid_external=True):
parser = DefusedXMLParser(target=_TreeBuilder(),
forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
parser.feed(text)
return parser.close()
return parse, iterparse, fromstring

View File

@ -0,0 +1,110 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.dom.expatbuilder
"""
from __future__ import print_function, absolute_import
from xml.dom.expatbuilder import ExpatBuilder as _ExpatBuilder
from xml.dom.expatbuilder import Namespaces as _Namespaces
from .common import (DTDForbidden, EntitiesForbidden,
ExternalReferenceForbidden)
__origin__ = "xml.dom.expatbuilder"
class DefusedExpatBuilder(_ExpatBuilder):
"""Defused document builder"""
def __init__(self, options=None, forbid_dtd=False, forbid_entities=True,
forbid_external=True):
_ExpatBuilder.__init__(self, options)
self.forbid_dtd = forbid_dtd
self.forbid_entities = forbid_entities
self.forbid_external = forbid_external
def defused_start_doctype_decl(self, name, sysid, pubid,
has_internal_subset):
raise DTDForbidden(name, sysid, pubid)
def defused_entity_decl(self, name, is_parameter_entity, value, base,
sysid, pubid, notation_name):
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
notation_name):
# expat 1.2
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
def defused_external_entity_ref_handler(self, context, base, sysid,
pubid):
raise ExternalReferenceForbidden(context, base, sysid, pubid)
def install(self, parser):
_ExpatBuilder.install(self, parser)
if self.forbid_dtd:
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
if self.forbid_entities:
# if self._options.entities:
parser.EntityDeclHandler = self.defused_entity_decl
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
if self.forbid_external:
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
class DefusedExpatBuilderNS(_Namespaces, DefusedExpatBuilder):
"""Defused document builder that supports namespaces."""
def install(self, parser):
DefusedExpatBuilder.install(self, parser)
if self._options.namespace_declarations:
parser.StartNamespaceDeclHandler = (
self.start_namespace_decl_handler)
def reset(self):
DefusedExpatBuilder.reset(self)
self._initNamespaces()
def parse(file, namespaces=True, forbid_dtd=False, forbid_entities=True,
forbid_external=True):
"""Parse a document, returning the resulting Document node.
'file' may be either a file name or an open file object.
"""
if namespaces:
build_builder = DefusedExpatBuilderNS
else:
build_builder = DefusedExpatBuilder
builder = build_builder(forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
if isinstance(file, str):
fp = open(file, 'rb')
try:
result = builder.parseFile(fp)
finally:
fp.close()
else:
result = builder.parseFile(file)
return result
def parseString(string, namespaces=True, forbid_dtd=False,
forbid_entities=True, forbid_external=True):
"""Parse a document from a string, returning the resulting
Document node.
"""
if namespaces:
build_builder = DefusedExpatBuilderNS
else:
build_builder = DefusedExpatBuilder
builder = build_builder(forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
return builder.parseString(string)

View File

@ -0,0 +1,59 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.sax.expatreader
"""
from __future__ import print_function, absolute_import
from xml.sax.expatreader import ExpatParser as _ExpatParser
from .common import (DTDForbidden, EntitiesForbidden,
ExternalReferenceForbidden)
__origin__ = "xml.sax.expatreader"
class DefusedExpatParser(_ExpatParser):
"""Defused SAX driver for the pyexpat C module."""
def __init__(self, namespaceHandling=0, bufsize=2 ** 16 - 20,
forbid_dtd=False, forbid_entities=True,
forbid_external=True):
_ExpatParser.__init__(self, namespaceHandling, bufsize)
self.forbid_dtd = forbid_dtd
self.forbid_entities = forbid_entities
self.forbid_external = forbid_external
def defused_start_doctype_decl(self, name, sysid, pubid,
has_internal_subset):
raise DTDForbidden(name, sysid, pubid)
def defused_entity_decl(self, name, is_parameter_entity, value, base,
sysid, pubid, notation_name):
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
notation_name):
# expat 1.2
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
def defused_external_entity_ref_handler(self, context, base, sysid,
pubid):
raise ExternalReferenceForbidden(context, base, sysid, pubid)
def reset(self):
_ExpatParser.reset(self)
parser = self._parser
if self.forbid_dtd:
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
if self.forbid_entities:
parser.EntityDeclHandler = self.defused_entity_decl
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
if self.forbid_external:
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
def create_parser(*args, **kwargs):
return DefusedExpatParser(*args, **kwargs)

153
python/defusedxml/lxml.py Normal file
View File

@ -0,0 +1,153 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Example code for lxml.etree protection
The code has NO protection against decompression bombs.
"""
from __future__ import print_function, absolute_import
import threading
from lxml import etree as _etree
from .common import DTDForbidden, EntitiesForbidden, NotSupportedError
LXML3 = _etree.LXML_VERSION[0] >= 3
__origin__ = "lxml.etree"
tostring = _etree.tostring
class RestrictedElement(_etree.ElementBase):
"""A restricted Element class that filters out instances of some classes
"""
__slots__ = ()
# blacklist = (etree._Entity, etree._ProcessingInstruction, etree._Comment)
blacklist = _etree._Entity
def _filter(self, iterator):
blacklist = self.blacklist
for child in iterator:
if isinstance(child, blacklist):
continue
yield child
def __iter__(self):
iterator = super(RestrictedElement, self).__iter__()
return self._filter(iterator)
def iterchildren(self, tag=None, reversed=False):
iterator = super(RestrictedElement, self).iterchildren(
tag=tag, reversed=reversed)
return self._filter(iterator)
def iter(self, tag=None, *tags):
iterator = super(RestrictedElement, self).iter(tag=tag, *tags)
return self._filter(iterator)
def iterdescendants(self, tag=None, *tags):
iterator = super(RestrictedElement,
self).iterdescendants(tag=tag, *tags)
return self._filter(iterator)
def itersiblings(self, tag=None, preceding=False):
iterator = super(RestrictedElement, self).itersiblings(
tag=tag, preceding=preceding)
return self._filter(iterator)
def getchildren(self):
iterator = super(RestrictedElement, self).__iter__()
return list(self._filter(iterator))
def getiterator(self, tag=None):
iterator = super(RestrictedElement, self).getiterator(tag)
return self._filter(iterator)
class GlobalParserTLS(threading.local):
"""Thread local context for custom parser instances
"""
parser_config = {
'resolve_entities': False,
# 'remove_comments': True,
# 'remove_pis': True,
}
element_class = RestrictedElement
def createDefaultParser(self):
parser = _etree.XMLParser(**self.parser_config)
element_class = self.element_class
if self.element_class is not None:
lookup = _etree.ElementDefaultClassLookup(element=element_class)
parser.set_element_class_lookup(lookup)
return parser
def setDefaultParser(self, parser):
self._default_parser = parser
def getDefaultParser(self):
parser = getattr(self, "_default_parser", None)
if parser is None:
parser = self.createDefaultParser()
self.setDefaultParser(parser)
return parser
_parser_tls = GlobalParserTLS()
getDefaultParser = _parser_tls.getDefaultParser
def check_docinfo(elementtree, forbid_dtd=False, forbid_entities=True):
"""Check docinfo of an element tree for DTD and entity declarations
The check for entity declarations needs lxml 3 or newer. lxml 2.x does
not support dtd.iterentities().
"""
docinfo = elementtree.docinfo
if docinfo.doctype:
if forbid_dtd:
raise DTDForbidden(docinfo.doctype,
docinfo.system_url,
docinfo.public_id)
if forbid_entities and not LXML3:
# lxml < 3 has no iterentities()
raise NotSupportedError("Unable to check for entity declarations "
"in lxml 2.x")
if forbid_entities:
for dtd in docinfo.internalDTD, docinfo.externalDTD:
if dtd is None:
continue
for entity in dtd.iterentities():
raise EntitiesForbidden(entity.name, entity.content, None,
None, None, None)
def parse(source, parser=None, base_url=None, forbid_dtd=False,
forbid_entities=True):
if parser is None:
parser = getDefaultParser()
elementtree = _etree.parse(source, parser, base_url=base_url)
check_docinfo(elementtree, forbid_dtd, forbid_entities)
return elementtree
def fromstring(text, parser=None, base_url=None, forbid_dtd=False,
forbid_entities=True):
if parser is None:
parser = getDefaultParser()
rootelement = _etree.fromstring(text, parser, base_url=base_url)
elementtree = rootelement.getroottree()
check_docinfo(elementtree, forbid_dtd, forbid_entities)
return rootelement
XML = fromstring
def iterparse(*args, **kwargs):
raise NotSupportedError("defused lxml.etree.iterparse not available")

View File

@ -0,0 +1,42 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.dom.minidom
"""
from __future__ import print_function, absolute_import
from xml.dom.minidom import _do_pulldom_parse
from . import expatbuilder as _expatbuilder
from . import pulldom as _pulldom
__origin__ = "xml.dom.minidom"
def parse(file, parser=None, bufsize=None, forbid_dtd=False,
forbid_entities=True, forbid_external=True):
"""Parse a file into a DOM by filename or file object."""
if parser is None and not bufsize:
return _expatbuilder.parse(file, forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
else:
return _do_pulldom_parse(_pulldom.parse, (file,),
{'parser': parser, 'bufsize': bufsize,
'forbid_dtd': forbid_dtd, 'forbid_entities': forbid_entities,
'forbid_external': forbid_external})
def parseString(string, parser=None, forbid_dtd=False,
forbid_entities=True, forbid_external=True):
"""Parse a file into a DOM from a string."""
if parser is None:
return _expatbuilder.parseString(string, forbid_dtd=forbid_dtd,
forbid_entities=forbid_entities,
forbid_external=forbid_external)
else:
return _do_pulldom_parse(_pulldom.parseString, (string,),
{'parser': parser, 'forbid_dtd': forbid_dtd,
'forbid_entities': forbid_entities,
'forbid_external': forbid_external})

View File

@ -0,0 +1,34 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.dom.pulldom
"""
from __future__ import print_function, absolute_import
from xml.dom.pulldom import parse as _parse
from xml.dom.pulldom import parseString as _parseString
from .sax import make_parser
__origin__ = "xml.dom.pulldom"
def parse(stream_or_string, parser=None, bufsize=None, forbid_dtd=False,
forbid_entities=True, forbid_external=True):
if parser is None:
parser = make_parser()
parser.forbid_dtd = forbid_dtd
parser.forbid_entities = forbid_entities
parser.forbid_external = forbid_external
return _parse(stream_or_string, parser, bufsize)
def parseString(string, parser=None, forbid_dtd=False,
forbid_entities=True, forbid_external=True):
if parser is None:
parser = make_parser()
parser.forbid_dtd = forbid_dtd
parser.forbid_entities = forbid_entities
parser.forbid_external = forbid_external
return _parseString(string, parser)

49
python/defusedxml/sax.py Normal file
View File

@ -0,0 +1,49 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xml.sax
"""
from __future__ import print_function, absolute_import
from xml.sax import InputSource as _InputSource
from xml.sax import ErrorHandler as _ErrorHandler
from . import expatreader
__origin__ = "xml.sax"
def parse(source, handler, errorHandler=_ErrorHandler(), forbid_dtd=False,
forbid_entities=True, forbid_external=True):
parser = make_parser()
parser.setContentHandler(handler)
parser.setErrorHandler(errorHandler)
parser.forbid_dtd = forbid_dtd
parser.forbid_entities = forbid_entities
parser.forbid_external = forbid_external
parser.parse(source)
def parseString(string, handler, errorHandler=_ErrorHandler(),
forbid_dtd=False, forbid_entities=True,
forbid_external=True):
from io import BytesIO
if errorHandler is None:
errorHandler = _ErrorHandler()
parser = make_parser()
parser.setContentHandler(handler)
parser.setErrorHandler(errorHandler)
parser.forbid_dtd = forbid_dtd
parser.forbid_entities = forbid_entities
parser.forbid_external = forbid_external
inpsrc = _InputSource()
inpsrc.setByteStream(BytesIO(string))
parser.parse(inpsrc)
def make_parser(parser_list=[]):
return expatreader.create_parser()

157
python/defusedxml/xmlrpc.py Normal file
View File

@ -0,0 +1,157 @@
# defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
"""Defused xmlrpclib
Also defuses gzip bomb
"""
from __future__ import print_function, absolute_import
import io
from .common import (
DTDForbidden, EntitiesForbidden, ExternalReferenceForbidden, PY3)
if PY3:
__origin__ = "xmlrpc.client"
from xmlrpc.client import ExpatParser
from xmlrpc import client as xmlrpc_client
from xmlrpc import server as xmlrpc_server
from xmlrpc.client import gzip_decode as _orig_gzip_decode
from xmlrpc.client import GzipDecodedResponse as _OrigGzipDecodedResponse
else:
__origin__ = "xmlrpclib"
from xmlrpclib import ExpatParser
import xmlrpclib as xmlrpc_client
xmlrpc_server = None
from xmlrpclib import gzip_decode as _orig_gzip_decode
from xmlrpclib import GzipDecodedResponse as _OrigGzipDecodedResponse
try:
import gzip
except ImportError:
gzip = None
# Limit maximum request size to prevent resource exhaustion DoS
# Also used to limit maximum amount of gzip decoded data in order to prevent
# decompression bombs
# A value of -1 or smaller disables the limit
MAX_DATA = 30 * 1024 * 1024 # 30 MB
def defused_gzip_decode(data, limit=None):
"""gzip encoded data -> unencoded data
Decode data using the gzip content encoding as described in RFC 1952
"""
if not gzip:
raise NotImplementedError
if limit is None:
limit = MAX_DATA
f = io.BytesIO(data)
gzf = gzip.GzipFile(mode="rb", fileobj=f)
try:
if limit < 0: # no limit
decoded = gzf.read()
else:
decoded = gzf.read(limit + 1)
except IOError:
raise ValueError("invalid data")
f.close()
gzf.close()
if limit >= 0 and len(decoded) > limit:
raise ValueError("max gzipped payload length exceeded")
return decoded
class DefusedGzipDecodedResponse(gzip.GzipFile if gzip else object):
"""a file-like object to decode a response encoded with the gzip
method, as described in RFC 1952.
"""
def __init__(self, response, limit=None):
# response doesn't support tell() and read(), required by
# GzipFile
if not gzip:
raise NotImplementedError
self.limit = limit = limit if limit is not None else MAX_DATA
if limit < 0: # no limit
data = response.read()
self.readlength = None
else:
data = response.read(limit + 1)
self.readlength = 0
if limit >= 0 and len(data) > limit:
raise ValueError("max payload length exceeded")
self.stringio = io.BytesIO(data)
gzip.GzipFile.__init__(self, mode="rb", fileobj=self.stringio)
def read(self, n):
if self.limit >= 0:
left = self.limit - self.readlength
n = min(n, left + 1)
data = gzip.GzipFile.read(self, n)
self.readlength += len(data)
if self.readlength > self.limit:
raise ValueError("max payload length exceeded")
return data
else:
return gzip.GzipFile.read(self, n)
def close(self):
gzip.GzipFile.close(self)
self.stringio.close()
class DefusedExpatParser(ExpatParser):
def __init__(self, target, forbid_dtd=False, forbid_entities=True,
forbid_external=True):
ExpatParser.__init__(self, target)
self.forbid_dtd = forbid_dtd
self.forbid_entities = forbid_entities
self.forbid_external = forbid_external
parser = self._parser
if self.forbid_dtd:
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
if self.forbid_entities:
parser.EntityDeclHandler = self.defused_entity_decl
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
if self.forbid_external:
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
def defused_start_doctype_decl(self, name, sysid, pubid,
has_internal_subset):
raise DTDForbidden(name, sysid, pubid)
def defused_entity_decl(self, name, is_parameter_entity, value, base,
sysid, pubid, notation_name):
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
notation_name):
# expat 1.2
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
def defused_external_entity_ref_handler(self, context, base, sysid,
pubid):
raise ExternalReferenceForbidden(context, base, sysid, pubid)
def monkey_patch():
xmlrpc_client.FastParser = DefusedExpatParser
xmlrpc_client.GzipDecodedResponse = DefusedGzipDecodedResponse
xmlrpc_client.gzip_decode = defused_gzip_decode
if xmlrpc_server:
xmlrpc_server.gzip_decode = defused_gzip_decode
def unmonkey_patch():
xmlrpc_client.FastParser = None
xmlrpc_client.GzipDecodedResponse = _OrigGzipDecodedResponse
xmlrpc_client.gzip_decode = _orig_gzip_decode
if xmlrpc_server:
xmlrpc_server.gzip_decode = _orig_gzip_decode

891
python/six.py Normal file
View File

@ -0,0 +1,891 @@
# Copyright (c) 2010-2017 Benjamin Peterson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Utilities for writing code that runs on Python 2 and 3"""
from __future__ import absolute_import
import functools
import itertools
import operator
import sys
import types
__author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.11.0"
# Useful for very coarse version differentiation.
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
PY34 = sys.version_info[0:2] >= (3, 4)
if PY3:
string_types = str,
integer_types = int,
class_types = type,
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
else:
string_types = basestring,
integer_types = (int, long)
class_types = (type, types.ClassType)
text_type = unicode
binary_type = str
if sys.platform.startswith("java"):
# Jython always uses 32 bits.
MAXSIZE = int((1 << 31) - 1)
else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1)
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1)
del X
def _add_doc(func, doc):
"""Add documentation to a function."""
func.__doc__ = doc
def _import_module(name):
"""Import module, returning the module after the last dot."""
__import__(name)
return sys.modules[name]
class _LazyDescr(object):
def __init__(self, name):
self.name = name
def __get__(self, obj, tp):
result = self._resolve()
setattr(obj, self.name, result) # Invokes __set__.
try:
# This is a bit ugly, but it avoids running this again by
# removing this descriptor.
delattr(obj.__class__, self.name)
except AttributeError:
pass
return result
class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name)
if PY3:
if new is None:
new = name
self.mod = new
else:
self.mod = old
def _resolve(self):
return _import_module(self.mod)
def __getattr__(self, attr):
_module = self._resolve()
value = getattr(_module, attr)
setattr(self, attr, value)
return value
class _LazyModule(types.ModuleType):
def __init__(self, name):
super(_LazyModule, self).__init__(name)
self.__doc__ = self.__class__.__doc__
def __dir__(self):
attrs = ["__doc__", "__name__"]
attrs += [attr.name for attr in self._moved_attributes]
return attrs
# Subclasses should override this
_moved_attributes = []
class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name)
if PY3:
if new_mod is None:
new_mod = name
self.mod = new_mod
if new_attr is None:
if old_attr is None:
new_attr = name
else:
new_attr = old_attr
self.attr = new_attr
else:
self.mod = old_mod
if old_attr is None:
old_attr = name
self.attr = old_attr
def _resolve(self):
module = _import_module(self.mod)
return getattr(module, self.attr)
class _SixMetaPathImporter(object):
"""
A meta path importer to import six.moves and its submodules.
This class implements a PEP302 finder and loader. It should be compatible
with Python 2.5 and all existing versions of Python3
"""
def __init__(self, six_module_name):
self.name = six_module_name
self.known_modules = {}
def _add_module(self, mod, *fullnames):
for fullname in fullnames:
self.known_modules[self.name + "." + fullname] = mod
def _get_module(self, fullname):
return self.known_modules[self.name + "." + fullname]
def find_module(self, fullname, path=None):
if fullname in self.known_modules:
return self
return None
def __get_module(self, fullname):
try:
return self.known_modules[fullname]
except KeyError:
raise ImportError("This loader does not know module " + fullname)
def load_module(self, fullname):
try:
# in case of a reload
return sys.modules[fullname]
except KeyError:
pass
mod = self.__get_module(fullname)
if isinstance(mod, MovedModule):
mod = mod._resolve()
else:
mod.__loader__ = self
sys.modules[fullname] = mod
return mod
def is_package(self, fullname):
"""
Return true, if the named module is a package.
We need this method to get correct spec objects with
Python 3.4 (see PEP451)
"""
return hasattr(self.__get_module(fullname), "__path__")
def get_code(self, fullname):
"""Return None
Required, if is_package is implemented"""
self.__get_module(fullname) # eventually raises ImportError
return None
get_source = get_code # same as get_code
_importer = _SixMetaPathImporter(__name__)
class _MovedItems(_LazyModule):
"""Lazy loading of moved objects"""
__path__ = [] # mark as package
_moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("intern", "__builtin__", "sys"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
MovedAttribute("getoutput", "commands", "subprocess"),
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
MovedAttribute("StringIO", "StringIO", "io"),
MovedAttribute("UserDict", "UserDict", "collections"),
MovedAttribute("UserList", "UserList", "collections"),
MovedAttribute("UserString", "UserString", "collections"),
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
MovedModule("builtins", "__builtin__"),
MovedModule("configparser", "ConfigParser"),
MovedModule("copyreg", "copy_reg"),
MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"),
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
MovedModule("http_cookies", "Cookie", "http.cookies"),
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"),
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
MovedModule("cPickle", "cPickle", "pickle"),
MovedModule("queue", "Queue"),
MovedModule("reprlib", "repr"),
MovedModule("socketserver", "SocketServer"),
MovedModule("_thread", "thread", "_thread"),
MovedModule("tkinter", "Tkinter"),
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule("tkinter_colorchooser", "tkColorChooser",
"tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog",
"tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
"tkinter.simpledialog"),
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
]
# Add windows specific modules.
if sys.platform == "win32":
_moved_attributes += [
MovedModule("winreg", "_winreg"),
]
for attr in _moved_attributes:
setattr(_MovedItems, attr.name, attr)
if isinstance(attr, MovedModule):
_importer._add_module(attr, "moves." + attr.name)
del attr
_MovedItems._moved_attributes = _moved_attributes
moves = _MovedItems(__name__ + ".moves")
_importer._add_module(moves, "moves")
class Module_six_moves_urllib_parse(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_parse"""
_urllib_parse_moved_attributes = [
MovedAttribute("ParseResult", "urlparse", "urllib.parse"),
MovedAttribute("SplitResult", "urlparse", "urllib.parse"),
MovedAttribute("parse_qs", "urlparse", "urllib.parse"),
MovedAttribute("parse_qsl", "urlparse", "urllib.parse"),
MovedAttribute("urldefrag", "urlparse", "urllib.parse"),
MovedAttribute("urljoin", "urlparse", "urllib.parse"),
MovedAttribute("urlparse", "urlparse", "urllib.parse"),
MovedAttribute("urlsplit", "urlparse", "urllib.parse"),
MovedAttribute("urlunparse", "urlparse", "urllib.parse"),
MovedAttribute("urlunsplit", "urlparse", "urllib.parse"),
MovedAttribute("quote", "urllib", "urllib.parse"),
MovedAttribute("quote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote", "urllib", "urllib.parse"),
MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"),
MovedAttribute("urlencode", "urllib", "urllib.parse"),
MovedAttribute("splitquery", "urllib", "urllib.parse"),
MovedAttribute("splittag", "urllib", "urllib.parse"),
MovedAttribute("splituser", "urllib", "urllib.parse"),
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
MovedAttribute("uses_params", "urlparse", "urllib.parse"),
MovedAttribute("uses_query", "urlparse", "urllib.parse"),
MovedAttribute("uses_relative", "urlparse", "urllib.parse"),
]
for attr in _urllib_parse_moved_attributes:
setattr(Module_six_moves_urllib_parse, attr.name, attr)
del attr
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
"moves.urllib_parse", "moves.urllib.parse")
class Module_six_moves_urllib_error(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_error"""
_urllib_error_moved_attributes = [
MovedAttribute("URLError", "urllib2", "urllib.error"),
MovedAttribute("HTTPError", "urllib2", "urllib.error"),
MovedAttribute("ContentTooShortError", "urllib", "urllib.error"),
]
for attr in _urllib_error_moved_attributes:
setattr(Module_six_moves_urllib_error, attr.name, attr)
del attr
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
"moves.urllib_error", "moves.urllib.error")
class Module_six_moves_urllib_request(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_request"""
_urllib_request_moved_attributes = [
MovedAttribute("urlopen", "urllib2", "urllib.request"),
MovedAttribute("install_opener", "urllib2", "urllib.request"),
MovedAttribute("build_opener", "urllib2", "urllib.request"),
MovedAttribute("pathname2url", "urllib", "urllib.request"),
MovedAttribute("url2pathname", "urllib", "urllib.request"),
MovedAttribute("getproxies", "urllib", "urllib.request"),
MovedAttribute("Request", "urllib2", "urllib.request"),
MovedAttribute("OpenerDirector", "urllib2", "urllib.request"),
MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"),
MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"),
MovedAttribute("FileHandler", "urllib2", "urllib.request"),
MovedAttribute("FTPHandler", "urllib2", "urllib.request"),
MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"),
MovedAttribute("UnknownHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"),
MovedAttribute("urlretrieve", "urllib", "urllib.request"),
MovedAttribute("urlcleanup", "urllib", "urllib.request"),
MovedAttribute("URLopener", "urllib", "urllib.request"),
MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
]
for attr in _urllib_request_moved_attributes:
setattr(Module_six_moves_urllib_request, attr.name, attr)
del attr
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
"moves.urllib_request", "moves.urllib.request")
class Module_six_moves_urllib_response(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_response"""
_urllib_response_moved_attributes = [
MovedAttribute("addbase", "urllib", "urllib.response"),
MovedAttribute("addclosehook", "urllib", "urllib.response"),
MovedAttribute("addinfo", "urllib", "urllib.response"),
MovedAttribute("addinfourl", "urllib", "urllib.response"),
]
for attr in _urllib_response_moved_attributes:
setattr(Module_six_moves_urllib_response, attr.name, attr)
del attr
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
"moves.urllib_response", "moves.urllib.response")
class Module_six_moves_urllib_robotparser(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_robotparser"""
_urllib_robotparser_moved_attributes = [
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
]
for attr in _urllib_robotparser_moved_attributes:
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
del attr
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser", "moves.urllib.robotparser")
class Module_six_moves_urllib(types.ModuleType):
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
__path__ = [] # mark as package
parse = _importer._get_module("moves.urllib_parse")
error = _importer._get_module("moves.urllib_error")
request = _importer._get_module("moves.urllib_request")
response = _importer._get_module("moves.urllib_response")
robotparser = _importer._get_module("moves.urllib_robotparser")
def __dir__(self):
return ['parse', 'error', 'request', 'response', 'robotparser']
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
"moves.urllib")
def add_move(move):
"""Add an item to six.moves."""
setattr(_MovedItems, move.name, move)
def remove_move(name):
"""Remove item from six.moves."""
try:
delattr(_MovedItems, name)
except AttributeError:
try:
del moves.__dict__[name]
except KeyError:
raise AttributeError("no such move, %r" % (name,))
if PY3:
_meth_func = "__func__"
_meth_self = "__self__"
_func_closure = "__closure__"
_func_code = "__code__"
_func_defaults = "__defaults__"
_func_globals = "__globals__"
else:
_meth_func = "im_func"
_meth_self = "im_self"
_func_closure = "func_closure"
_func_code = "func_code"
_func_defaults = "func_defaults"
_func_globals = "func_globals"
try:
advance_iterator = next
except NameError:
def advance_iterator(it):
return it.next()
next = advance_iterator
try:
callable = callable
except NameError:
def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
if PY3:
def get_unbound_function(unbound):
return unbound
create_bound_method = types.MethodType
def create_unbound_method(func, cls):
return func
Iterator = object
else:
def get_unbound_function(unbound):
return unbound.im_func
def create_bound_method(func, obj):
return types.MethodType(func, obj, obj.__class__)
def create_unbound_method(func, cls):
return types.MethodType(func, None, cls)
class Iterator(object):
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(get_unbound_function,
"""Get the function out of a possibly unbound function""")
get_method_function = operator.attrgetter(_meth_func)
get_method_self = operator.attrgetter(_meth_self)
get_function_closure = operator.attrgetter(_func_closure)
get_function_code = operator.attrgetter(_func_code)
get_function_defaults = operator.attrgetter(_func_defaults)
get_function_globals = operator.attrgetter(_func_globals)
if PY3:
def iterkeys(d, **kw):
return iter(d.keys(**kw))
def itervalues(d, **kw):
return iter(d.values(**kw))
def iteritems(d, **kw):
return iter(d.items(**kw))
def iterlists(d, **kw):
return iter(d.lists(**kw))
viewkeys = operator.methodcaller("keys")
viewvalues = operator.methodcaller("values")
viewitems = operator.methodcaller("items")
else:
def iterkeys(d, **kw):
return d.iterkeys(**kw)
def itervalues(d, **kw):
return d.itervalues(**kw)
def iteritems(d, **kw):
return d.iteritems(**kw)
def iterlists(d, **kw):
return d.iterlists(**kw)
viewkeys = operator.methodcaller("viewkeys")
viewvalues = operator.methodcaller("viewvalues")
viewitems = operator.methodcaller("viewitems")
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
_add_doc(iteritems,
"Return an iterator over the (key, value) pairs of a dictionary.")
_add_doc(iterlists,
"Return an iterator over the (key, [values]) pairs of a dictionary.")
if PY3:
def b(s):
return s.encode("latin-1")
def u(s):
return s
unichr = chr
import struct
int2byte = struct.Struct(">B").pack
del struct
byte2int = operator.itemgetter(0)
indexbytes = operator.getitem
iterbytes = iter
import io
StringIO = io.StringIO
BytesIO = io.BytesIO
_assertCountEqual = "assertCountEqual"
if sys.version_info[1] <= 1:
_assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches"
else:
_assertRaisesRegex = "assertRaisesRegex"
_assertRegex = "assertRegex"
else:
def b(s):
return s
# Workaround for standalone backslash
def u(s):
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
unichr = unichr
int2byte = chr
def byte2int(bs):
return ord(bs[0])
def indexbytes(buf, i):
return ord(buf[i])
iterbytes = functools.partial(itertools.imap, ord)
import StringIO
StringIO = BytesIO = StringIO.StringIO
_assertCountEqual = "assertItemsEqual"
_assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches"
_add_doc(b, """Byte literal""")
_add_doc(u, """Text literal""")
def assertCountEqual(self, *args, **kwargs):
return getattr(self, _assertCountEqual)(*args, **kwargs)
def assertRaisesRegex(self, *args, **kwargs):
return getattr(self, _assertRaisesRegex)(*args, **kwargs)
def assertRegex(self, *args, **kwargs):
return getattr(self, _assertRegex)(*args, **kwargs)
if PY3:
exec_ = getattr(moves.builtins, "exec")
def reraise(tp, value, tb=None):
try:
if value is None:
value = tp()
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
finally:
value = None
tb = None
else:
def exec_(_code_, _globs_=None, _locs_=None):
"""Execute code in a namespace."""
if _globs_ is None:
frame = sys._getframe(1)
_globs_ = frame.f_globals
if _locs_ is None:
_locs_ = frame.f_locals
del frame
elif _locs_ is None:
_locs_ = _globs_
exec("""exec _code_ in _globs_, _locs_""")
exec_("""def reraise(tp, value, tb=None):
try:
raise tp, value, tb
finally:
tb = None
""")
if sys.version_info[:2] == (3, 2):
exec_("""def raise_from(value, from_value):
try:
if from_value is None:
raise value
raise value from from_value
finally:
value = None
""")
elif sys.version_info[:2] > (3, 2):
exec_("""def raise_from(value, from_value):
try:
raise value from from_value
finally:
value = None
""")
else:
def raise_from(value, from_value):
raise value
print_ = getattr(moves.builtins, "print", None)
if print_ is None:
def print_(*args, **kwargs):
"""The new-style print function for Python 2.4 and 2.5."""
fp = kwargs.pop("file", sys.stdout)
if fp is None:
return
def write(data):
if not isinstance(data, basestring):
data = str(data)
# If the file has an encoding, encode unicode with it.
if (isinstance(fp, file) and
isinstance(data, unicode) and
fp.encoding is not None):
errors = getattr(fp, "errors", None)
if errors is None:
errors = "strict"
data = data.encode(fp.encoding, errors)
fp.write(data)
want_unicode = False
sep = kwargs.pop("sep", None)
if sep is not None:
if isinstance(sep, unicode):
want_unicode = True
elif not isinstance(sep, str):
raise TypeError("sep must be None or a string")
end = kwargs.pop("end", None)
if end is not None:
if isinstance(end, unicode):
want_unicode = True
elif not isinstance(end, str):
raise TypeError("end must be None or a string")
if kwargs:
raise TypeError("invalid keyword arguments to print()")
if not want_unicode:
for arg in args:
if isinstance(arg, unicode):
want_unicode = True
break
if want_unicode:
newline = unicode("\n")
space = unicode(" ")
else:
newline = "\n"
space = " "
if sep is None:
sep = space
if end is None:
end = newline
for i, arg in enumerate(args):
if i:
write(sep)
write(arg)
write(end)
if sys.version_info[:2] < (3, 3):
_print = print_
def print_(*args, **kwargs):
fp = kwargs.get("file", sys.stdout)
flush = kwargs.pop("flush", False)
_print(*args, **kwargs)
if flush and fp is not None:
fp.flush()
_add_doc(reraise, """Reraise an exception.""")
if sys.version_info[0:2] < (3, 4):
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES):
def wrapper(f):
f = functools.wraps(wrapped, assigned, updated)(f)
f.__wrapped__ = wrapped
return f
return wrapper
else:
wraps = functools.wraps
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass."""
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
def python_2_unicode_compatible(klass):
"""
A decorator that defines __unicode__ and __str__ methods under Python 2.
Under Python 3 it does nothing.
To support Python 2 and 3 with a single code base, define a __str__ method
returning text and apply this decorator to the class.
"""
if PY2:
if '__str__' not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." %
klass.__name__)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
return klass
# Complete the moves implementation.
# This code is at the end of this module to speed up module loading.
# Turn this module into a package.
__path__ = [] # required for PEP 302 and PEP 451
__package__ = __name__ # see PEP 366 @ReservedAssignment
if globals().get("__spec__") is not None:
__spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable
# Remove other six meta path importers, since they cause problems. This can
# happen if six is removed from sys.modules and then reloaded. (Setuptools does
# this for some reason.)
if sys.meta_path:
for i, importer in enumerate(sys.meta_path):
# Here's some real nastiness: Another "instance" of the six module might
# be floating around. Therefore, we can't use isinstance() to check for
# the six meta path importer, since the other six instance will have
# inserted an importer with different class.
if (type(importer).__name__ == "_SixMetaPathImporter" and
importer.name == __name__):
del sys.meta_path[i]
break
del i, importer
# Finally, add the importer to the meta path import hook.
sys.meta_path.append(_importer)

View File

@ -6,7 +6,7 @@ from youtube import yt_app
from youtube import util from youtube import util
# these are just so the files get run - they import yt_app and add routes to it # these are just so the files get run - they import yt_app and add routes to it
from youtube import watch, search, playlist, channel, local_playlist, comments, post_comment from youtube import watch, search, playlist, channel, local_playlist, comments, post_comment, subscriptions
import settings import settings

View File

@ -119,6 +119,12 @@ For security reasons, enabling this is not recommended.''',
], ],
}), }),
('autocheck_subscriptions', {
'type': bool,
'default': 0,
'comment': '',
}),
('gather_googlevideo_domains', { ('gather_googlevideo_domains', {
'type': bool, 'type': bool,
'default': False, 'default': False,

View File

@ -1,5 +1,5 @@
import base64 import base64
from youtube import util, yt_data_extract, local_playlist from youtube import util, yt_data_extract, local_playlist, subscriptions
from youtube import yt_app from youtube import yt_app
import urllib import urllib
@ -83,13 +83,15 @@ def channel_ctoken(channel_id, page, sort, tab, view=1):
return base64.urlsafe_b64encode(pointless_nest).decode('ascii') return base64.urlsafe_b64encode(pointless_nest).decode('ascii')
def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1): def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1, print_status=True):
ctoken = channel_ctoken(channel_id, page, sort, tab, view).replace('=', '%3D') ctoken = channel_ctoken(channel_id, page, sort, tab, view).replace('=', '%3D')
url = "https://www.youtube.com/browse_ajax?ctoken=" + ctoken url = "https://www.youtube.com/browse_ajax?ctoken=" + ctoken
print("Sending channel tab ajax request") if print_status:
print("Sending channel tab ajax request")
content = util.fetch_url(url, util.desktop_ua + headers_1, debug_name='channel_tab') content = util.fetch_url(url, util.desktop_ua + headers_1, debug_name='channel_tab')
print("Finished recieving channel tab response") if print_status:
print("Finished recieving channel tab response")
return content return content
@ -312,7 +314,7 @@ def get_channel_page(channel_id, tab='videos'):
info['current_sort'] = sort info['current_sort'] = sort
elif tab == 'search': elif tab == 'search':
info['search_box_value'] = query info['search_box_value'] = query
info['subscribed'] = subscriptions.is_subscribed(info['channel_id'])
return flask.render_template('channel.html', return flask.render_template('channel.html',
parameters_dictionary = request.args, parameters_dictionary = request.args,
@ -352,7 +354,7 @@ def get_channel_page_general_url(base_url, tab, request):
info['current_sort'] = sort info['current_sort'] = sort
elif tab == 'search': elif tab == 'search':
info['search_box_value'] = query info['search_box_value'] = query
info['subscribed'] = subscriptions.is_subscribed(info['channel_id'])
return flask.render_template('channel.html', return flask.render_template('channel.html',
parameters_dictionary = request.args, parameters_dictionary = request.args,

View File

@ -34,33 +34,7 @@ def add_to_playlist(name, video_info_list):
if id not in ids: if id not in ids:
file.write(info + "\n") file.write(info + "\n")
missing_thumbnails.append(id) missing_thumbnails.append(id)
gevent.spawn(download_thumbnails, name, missing_thumbnails) gevent.spawn(util.download_thumbnails, os.path.join(thumbnails_directory, name), missing_thumbnails)
def download_thumbnail(playlist_name, video_id):
url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"
save_location = os.path.join(thumbnails_directory, playlist_name, video_id + ".jpg")
try:
thumbnail = util.fetch_url(url, report_text="Saved local playlist thumbnail: " + video_id)
except urllib.error.HTTPError as e:
print("Failed to download thumbnail for " + video_id + ": " + str(e))
return
try:
f = open(save_location, 'wb')
except FileNotFoundError:
os.makedirs(os.path.join(thumbnails_directory, playlist_name))
f = open(save_location, 'wb')
f.write(thumbnail)
f.close()
def download_thumbnails(playlist_name, ids):
# only do 5 at a time
# do the n where n is divisible by 5
i = -1
for i in range(0, int(len(ids)/5) - 1 ):
gevent.joinall([gevent.spawn(download_thumbnail, playlist_name, ids[j]) for j in range(i*5, i*5 + 5)])
# do the remainders (< 5)
gevent.joinall([gevent.spawn(download_thumbnail, playlist_name, ids[j]) for j in range(i*5 + 5, len(ids))])
def get_local_playlist_videos(name, offset=0, amount=50): def get_local_playlist_videos(name, offset=0, amount=50):
@ -88,7 +62,7 @@ def get_local_playlist_videos(name, offset=0, amount=50):
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
if not video_json.strip() == '': if not video_json.strip() == '':
print('Corrupt playlist video entry: ' + video_json) print('Corrupt playlist video entry: ' + video_json)
gevent.spawn(download_thumbnails, name, missing_thumbnails) gevent.spawn(util.download_thumbnails, os.path.join(thumbnails_directory, name), missing_thumbnails)
return videos[offset:offset+amount], len(videos) return videos[offset:offset+amount], len(videos)
def get_playlist_names(): def get_playlist_names():

View File

@ -9,7 +9,7 @@ a:link {
} }
a:visited { a:visited {
color: ##7755ff; color: #7755ff;
} }
a:not([href]){ a:not([href]){
@ -23,3 +23,16 @@ a:not([href]){
.setting-item{ .setting-item{
background-color: #444444; background-color: #444444;
} }
.muted{
background-color: #111111;
color: gray;
}
.muted a:link {
color: #10547f;
}

View File

@ -11,3 +11,7 @@ body{
.setting-item{ .setting-item{
background-color: #eeeeee; background-color: #eeeeee;
} }
.muted{
background-color: #888888;
}

View File

@ -8,7 +8,11 @@ body{
color: #000000; color: #000000;
} }
.setting-item{ .setting-item{
background-color: #f8f8f8; background-color: #f8f8f8;
} }
.muted{
background-color: #888888;
}

View File

@ -1,7 +1,10 @@
* {
box-sizing: border-box;
}
h1, h2, h3, h4, h5, h6, div, button{ h1, h2, h3, h4, h5, h6, div, button{
margin:0; margin:0;
padding:0; padding:0;
} }
address{ address{

View File

@ -1,18 +1,828 @@
from youtube import util, yt_data_extract, channel
from youtube import yt_app
import settings
import sqlite3
import os
import time
import gevent
import json
import traceback
import contextlib
import defusedxml.ElementTree
import urllib import urllib
import math
import secrets
import collections
import calendar # bullshit! https://bugs.python.org/issue6280
with open("subscriptions.txt", 'r', encoding='utf-8') as file: import flask
subscriptions = file.read() from flask import request
# Line format: "channel_id channel_name"
# Example:
# UCYO_jab_esuFRV4b17AJtAw 3Blue1Brown
subscriptions = ((line[0:24], line[25: ]) for line in subscriptions.splitlines())
def get_new_videos(): thumbnails_directory = os.path.join(settings.data_dir, "subscription_thumbnails")
for channel_id, channel_name in subscriptions:
# https://stackabuse.com/a-sqlite-tutorial-with-python/
database_path = os.path.join(settings.data_dir, "subscriptions.sqlite")
def open_database():
if not os.path.exists(settings.data_dir):
os.makedirs(settings.data_dir)
connection = sqlite3.connect(database_path, check_same_thread=False)
try:
cursor = connection.cursor()
cursor.execute('''PRAGMA foreign_keys = 1''')
# Create tables if they don't exist
cursor.execute('''CREATE TABLE IF NOT EXISTS subscribed_channels (
id integer PRIMARY KEY,
yt_channel_id text UNIQUE NOT NULL,
channel_name text NOT NULL,
time_last_checked integer DEFAULT 0,
next_check_time integer DEFAULT 0,
muted integer DEFAULT 0
)''')
cursor.execute('''CREATE TABLE IF NOT EXISTS videos (
id integer PRIMARY KEY,
sql_channel_id integer NOT NULL REFERENCES subscribed_channels(id) ON UPDATE CASCADE ON DELETE CASCADE,
video_id text UNIQUE NOT NULL,
title text NOT NULL,
duration text,
time_published integer NOT NULL,
is_time_published_exact integer DEFAULT 0,
time_noticed integer NOT NULL,
description text,
watched integer default 0
)''')
cursor.execute('''CREATE TABLE IF NOT EXISTS tag_associations (
id integer PRIMARY KEY,
tag text NOT NULL,
sql_channel_id integer NOT NULL REFERENCES subscribed_channels(id) ON UPDATE CASCADE ON DELETE CASCADE,
UNIQUE(tag, sql_channel_id)
)''')
cursor.execute('''CREATE TABLE IF NOT EXISTS db_info (
version integer DEFAULT 1
)''')
connection.commit()
except:
connection.rollback()
connection.close()
raise
# https://stackoverflow.com/questions/19522505/using-sqlite3-in-python-with-with-keyword
return contextlib.closing(connection)
def with_open_db(function, *args, **kwargs):
with open_database() as connection:
with connection as cursor:
return function(cursor, *args, **kwargs)
def _is_subscribed(cursor, channel_id):
result = cursor.execute('''SELECT EXISTS(
SELECT 1
FROM subscribed_channels
WHERE yt_channel_id=?
LIMIT 1
)''', [channel_id]).fetchone()
return bool(result[0])
def is_subscribed(channel_id):
if not os.path.exists(database_path):
return False
return with_open_db(_is_subscribed, channel_id)
def _subscribe(channels):
''' channels is a list of (channel_id, channel_name) '''
channels = list(channels)
with open_database() as connection:
with connection as cursor:
channel_ids_to_check = [channel[0] for channel in channels if not _is_subscribed(cursor, channel[0])]
rows = ( (channel_id, channel_name, 0, 0) for channel_id, channel_name in channels)
cursor.executemany('''INSERT OR IGNORE INTO subscribed_channels (yt_channel_id, channel_name, time_last_checked, next_check_time)
VALUES (?, ?, ?, ?)''', rows)
if settings.autocheck_subscriptions:
# important that this is after the changes have been committed to database
# otherwise the autochecker (other thread) tries checking the channel before it's in the database
channel_names.update(channels)
check_channels_if_necessary(channel_ids_to_check)
def delete_thumbnails(to_delete):
for thumbnail in to_delete:
try:
video_id = thumbnail[0:-4]
if video_id in existing_thumbnails:
os.remove(os.path.join(thumbnails_directory, thumbnail))
existing_thumbnails.remove(video_id)
except Exception:
print('Failed to delete thumbnail: ' + thumbnail)
traceback.print_exc()
def _unsubscribe(cursor, channel_ids):
''' channel_ids is a list of channel_ids '''
to_delete = []
for channel_id in channel_ids:
rows = cursor.execute('''SELECT video_id
FROM videos
WHERE sql_channel_id = (
SELECT id
FROM subscribed_channels
WHERE yt_channel_id=?
)''', (channel_id,)).fetchall()
to_delete += [row[0] + '.jpg' for row in rows]
gevent.spawn(delete_thumbnails, to_delete)
cursor.executemany("DELETE FROM subscribed_channels WHERE yt_channel_id=?", ((channel_id, ) for channel_id in channel_ids))
def _get_videos(cursor, number_per_page, offset, tag = None):
'''Returns a full page of videos with an offset, and a value good enough to be used as the total number of videos'''
# We ask for the next 9 pages from the database
# Then the actual length of the results tell us if there are more than 9 pages left, and if not, how many there actually are
# This is done since there are only 9 page buttons on display at a time
# If there are more than 9 pages left, we give a fake value in place of the real number of results if the entire database was queried without limit
# This fake value is sufficient to get the page button generation macro to display 9 page buttons
# If we wish to display more buttons this logic must change
# We cannot use tricks with the sql id for the video since we frequently have filters and other restrictions in place on the results anyway
# TODO: This is probably not the ideal solution
if tag is not None:
db_videos = cursor.execute('''SELECT video_id, title, duration, time_published, is_time_published_exact, channel_name
FROM videos
INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id
INNER JOIN tag_associations on videos.sql_channel_id = tag_associations.sql_channel_id
WHERE tag = ? AND muted = 0
ORDER BY time_noticed DESC, time_published DESC
LIMIT ? OFFSET ?''', (tag, number_per_page*9, offset)).fetchall()
else:
db_videos = cursor.execute('''SELECT video_id, title, duration, time_published, is_time_published_exact, channel_name
FROM videos
INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id
WHERE muted = 0
ORDER BY time_noticed DESC, time_published DESC
LIMIT ? OFFSET ?''', (number_per_page*9, offset)).fetchall()
pseudo_number_of_videos = offset + len(db_videos)
videos = []
for db_video in db_videos[0:number_per_page]:
videos.append({
'id': db_video[0],
'title': db_video[1],
'duration': db_video[2],
'published': exact_timestamp(db_video[3]) if db_video[4] else posix_to_dumbed_down(db_video[3]),
'author': db_video[5],
})
return videos, pseudo_number_of_videos
def _get_subscribed_channels(cursor):
for item in cursor.execute('''SELECT channel_name, yt_channel_id, muted
FROM subscribed_channels
ORDER BY channel_name COLLATE NOCASE'''):
yield item
def _add_tags(cursor, channel_ids, tags):
pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids]
cursor.executemany('''INSERT OR IGNORE INTO tag_associations (tag, sql_channel_id)
SELECT ?, id FROM subscribed_channels WHERE yt_channel_id = ? ''', pairs)
def _remove_tags(cursor, channel_ids, tags):
pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids]
cursor.executemany('''DELETE FROM tag_associations
WHERE tag = ? AND sql_channel_id = (
SELECT id FROM subscribed_channels WHERE yt_channel_id = ?
)''', pairs)
def _get_tags(cursor, channel_id):
return [row[0] for row in cursor.execute('''SELECT tag
FROM tag_associations
WHERE sql_channel_id = (
SELECT id FROM subscribed_channels WHERE yt_channel_id = ?
)''', (channel_id,))]
def _get_all_tags(cursor):
return [row[0] for row in cursor.execute('''SELECT DISTINCT tag FROM tag_associations''')]
def _get_channel_names(cursor, channel_ids):
''' returns list of (channel_id, channel_name) '''
result = []
for channel_id in channel_ids:
row = cursor.execute('''SELECT channel_name
FROM subscribed_channels
WHERE yt_channel_id = ?''', (channel_id,)).fetchone()
result.append( (channel_id, row[0]) )
return result
def _channels_with_tag(cursor, tag, order=False, exclude_muted=False, include_muted_status=False):
''' returns list of (channel_id, channel_name) '''
statement = '''SELECT yt_channel_id, channel_name'''
if include_muted_status:
statement += ''', muted'''
statement += '''
FROM subscribed_channels
WHERE subscribed_channels.id IN (
SELECT tag_associations.sql_channel_id FROM tag_associations WHERE tag=?
)
'''
if exclude_muted:
statement += '''AND muted != 1\n'''
if order:
statement += '''ORDER BY channel_name COLLATE NOCASE'''
return cursor.execute(statement, [tag]).fetchall()
def _schedule_checking(cursor, channel_id, next_check_time):
cursor.execute('''UPDATE subscribed_channels SET next_check_time = ? WHERE yt_channel_id = ?''', [int(next_check_time), channel_id])
def _is_muted(cursor, channel_id):
return bool(cursor.execute('''SELECT muted FROM subscribed_channels WHERE yt_channel_id=?''', [channel_id]).fetchone()[0])
units = collections.OrderedDict([
('year', 31536000), # 365*24*3600
('month', 2592000), # 30*24*3600
('week', 604800), # 7*24*3600
('day', 86400), # 24*3600
('hour', 3600),
('minute', 60),
('second', 1),
])
def youtube_timestamp_to_posix(dumb_timestamp):
''' Given a dumbed down timestamp such as 1 year ago, 3 hours ago,
approximates the unix time (seconds since 1/1/1970) '''
dumb_timestamp = dumb_timestamp.lower()
now = time.time()
if dumb_timestamp == "just now":
return now
split = dumb_timestamp.split(' ')
quantifier, unit = int(split[0]), split[1]
if quantifier > 1:
unit = unit[:-1] # remove s from end
return now - quantifier*units[unit]
def posix_to_dumbed_down(posix_time):
'''Inverse of youtube_timestamp_to_posix.'''
delta = int(time.time() - posix_time)
assert delta >= 0
if delta == 0:
return '0 seconds ago'
for unit_name, unit_time in units.items():
if delta >= unit_time:
quantifier = round(delta/unit_time)
if quantifier == 1:
return '1 ' + unit_name + ' ago'
else:
return str(quantifier) + ' ' + unit_name + 's ago'
else:
raise Exception()
def exact_timestamp(posix_time):
result = time.strftime('%I:%M %p %m/%d/%y', time.localtime(posix_time))
if result[0] == '0': # remove 0 infront of hour (like 01:00 PM)
return result[1:]
return result
try:
existing_thumbnails = set(os.path.splitext(name)[0] for name in os.listdir(thumbnails_directory))
except FileNotFoundError:
existing_thumbnails = set()
# --- Manual checking system. Rate limited in order to support very large numbers of channels to be checked ---
# Auto checking system plugs into this for convenience, though it doesn't really need the rate limiting
check_channels_queue = util.RateLimitedQueue()
checking_channels = set()
# Just to use for printing channel checking status to console without opening database
channel_names = dict()
def check_channel_worker():
while True:
channel_id = check_channels_queue.get()
try:
_get_upstream_videos(channel_id)
finally:
checking_channels.remove(channel_id)
for i in range(0,5):
gevent.spawn(check_channel_worker)
# ----------------------------
# --- Auto checking system - Spaghetti code ---
if settings.autocheck_subscriptions:
# job application format: dict with keys (channel_id, channel_name, next_check_time)
autocheck_job_application = gevent.queue.Queue() # only really meant to hold 1 item, just reusing gevent's wait and timeout machinery
autocheck_jobs = [] # list of dicts with the keys (channel_id, channel_name, next_check_time). Stores all the channels that need to be autochecked and when to check them
with open_database() as connection:
with connection as cursor:
now = time.time()
for row in cursor.execute('''SELECT yt_channel_id, channel_name, next_check_time FROM subscribed_channels WHERE muted != 1''').fetchall():
if row[2] is None:
next_check_time = 0
else:
next_check_time = row[2]
# expired, check randomly within the next hour
# note: even if it isn't scheduled in the past right now, it might end up being if it's due soon and we dont start dispatching by then, see below where time_until_earliest_job is negative
if next_check_time < now:
next_check_time = now + 3600*secrets.randbelow(60)/60
row = (row[0], row[1], next_check_time)
_schedule_checking(cursor, row[0], next_check_time)
autocheck_jobs.append({'channel_id': row[0], 'channel_name': row[1], 'next_check_time': next_check_time})
def autocheck_dispatcher():
'''Scans the auto_check_list. Sleeps until the earliest job is due, then adds that channel to the checking queue above. Can be sent a new job through autocheck_job_application'''
while True:
if len(autocheck_jobs) == 0:
new_job = autocheck_job_application.get()
autocheck_jobs.append(new_job)
else:
earliest_job_index = min(range(0, len(autocheck_jobs)), key=lambda index: autocheck_jobs[index]['next_check_time']) # https://stackoverflow.com/a/11825864
earliest_job = autocheck_jobs[earliest_job_index]
time_until_earliest_job = earliest_job['next_check_time'] - time.time()
if time_until_earliest_job <= -5: # should not happen unless we're running extremely slow
print('ERROR: autocheck_dispatcher got job scheduled in the past, skipping and rescheduling: ' + earliest_job['channel_id'] + ', ' + earliest_job['channel_name'] + ', ' + str(earliest_job['next_check_time']))
next_check_time = time.time() + 3600*secrets.randbelow(60)/60
with_open_db(_schedule_checking, earliest_job['channel_id'], next_check_time)
autocheck_jobs[earliest_job_index]['next_check_time'] = next_check_time
continue
# make sure it's not muted
if with_open_db(_is_muted, earliest_job['channel_id']):
del autocheck_jobs[earliest_job_index]
continue
if time_until_earliest_job > 0: # it can become less than zero (in the past) when it's set to go off while the dispatcher is doing something else at that moment
try:
new_job = autocheck_job_application.get(timeout = time_until_earliest_job) # sleep for time_until_earliest_job time, but allow to be interrupted by new jobs
except gevent.queue.Empty: # no new jobs
pass
else: # new job, add it to the list
autocheck_jobs.append(new_job)
continue
# no new jobs, time to execute the earliest job
channel_names[earliest_job['channel_id']] = earliest_job['channel_name']
checking_channels.add(earliest_job['channel_id'])
check_channels_queue.put(earliest_job['channel_id'])
del autocheck_jobs[earliest_job_index]
gevent.spawn(autocheck_dispatcher)
# ----------------------------
def check_channels_if_necessary(channel_ids):
for channel_id in channel_ids:
if channel_id not in checking_channels:
checking_channels.add(channel_id)
check_channels_queue.put(channel_id)
def _get_upstream_videos(channel_id):
try:
channel_status_name = channel_names[channel_id]
except KeyError:
channel_status_name = channel_id
print("Checking channel: " + channel_status_name)
tasks = (
gevent.spawn(channel.get_channel_tab, channel_id, print_status=False), # channel page, need for video duration
gevent.spawn(util.fetch_url, 'https://www.youtube.com/feeds/videos.xml?channel_id=' + channel_id) # atoma feed, need for exact published time
)
gevent.joinall(tasks)
channel_tab, feed = tasks[0].value, tasks[1].value
# extract published times from atoma feed
times_published = {}
try:
def remove_bullshit(tag):
'''Remove XML namespace bullshit from tagname. https://bugs.python.org/issue18304'''
if '}' in tag:
return tag[tag.rfind('}')+1:]
return tag
def find_element(base, tag_name):
for element in base:
if remove_bullshit(element.tag) == tag_name:
return element
return None
root = defusedxml.ElementTree.fromstring(feed.decode('utf-8'))
assert remove_bullshit(root.tag) == 'feed'
for entry in root:
if (remove_bullshit(entry.tag) != 'entry'):
continue
# it's yt:videoId in the xml but the yt: is turned into a namespace which is removed by remove_bullshit
video_id_element = find_element(entry, 'videoId')
time_published_element = find_element(entry, 'published')
assert video_id_element is not None
assert time_published_element is not None
time_published = int(calendar.timegm(time.strptime(time_published_element.text, '%Y-%m-%dT%H:%M:%S+00:00')))
times_published[video_id_element.text] = time_published
except (AssertionError, defusedxml.ElementTree.ParseError) as e:
print('Failed to read atoma feed for ' + channel_status_name)
traceback.print_exc()
with open_database() as connection:
with connection as cursor:
is_first_check = cursor.execute('''SELECT time_last_checked FROM subscribed_channels WHERE yt_channel_id=?''', [channel_id]).fetchone()[0] in (None, 0)
video_add_time = int(time.time())
videos = []
channel_videos = channel.extract_info(json.loads(channel_tab), 'videos')['items']
for i, video_item in enumerate(channel_videos):
if 'description' not in video_item:
video_item['description'] = ''
if video_item['id'] in times_published:
time_published = times_published[video_item['id']]
is_time_published_exact = True
else:
is_time_published_exact = False
try:
time_published = youtube_timestamp_to_posix(video_item['published']) - i # subtract a few seconds off the videos so they will be in the right order
except KeyError:
print(video_item)
if is_first_check:
time_noticed = time_published # don't want a crazy ordering on first check, since we're ordering by time_noticed
else:
time_noticed = video_add_time
videos.append((channel_id, video_item['id'], video_item['title'], video_item['duration'], time_published, is_time_published_exact, time_noticed, video_item['description']))
if len(videos) == 0:
average_upload_period = 4*7*24*3600 # assume 1 month for channel with no videos
elif len(videos) < 5:
average_upload_period = int((time.time() - videos[len(videos)-1][4])/len(videos))
else:
average_upload_period = int((time.time() - videos[4][4])/5) # equivalent to averaging the time between videos for the last 5 videos
# calculate when to check next for auto checking
# add some quantization and randomness to make pattern analysis by Youtube slightly harder
quantized_upload_period = average_upload_period - (average_upload_period % (4*3600)) + 4*3600 # round up to nearest 4 hours
randomized_upload_period = quantized_upload_period*(1 + secrets.randbelow(50)/50*0.5) # randomly between 1x and 1.5x
next_check_delay = randomized_upload_period/10 # check at 10x the channel posting rate. might want to fine tune this number
next_check_time = int(time.time() + next_check_delay)
# calculate how many new videos there are
row = cursor.execute('''SELECT video_id
FROM videos
INNER JOIN subscribed_channels ON videos.sql_channel_id = subscribed_channels.id
WHERE yt_channel_id=?
ORDER BY time_published DESC
LIMIT 1''', [channel_id]).fetchone()
if row is None:
number_of_new_videos = len(videos)
else:
latest_video_id = row[0]
index = 0
for video in videos:
if video[1] == latest_video_id:
break
index += 1
number_of_new_videos = index
cursor.executemany('''INSERT OR IGNORE INTO videos (sql_channel_id, video_id, title, duration, time_published, is_time_published_exact, time_noticed, description)
VALUES ((SELECT id FROM subscribed_channels WHERE yt_channel_id=?), ?, ?, ?, ?, ?, ?, ?)''', videos)
cursor.execute('''UPDATE subscribed_channels
SET time_last_checked = ?, next_check_time = ?
WHERE yt_channel_id=?''', [int(time.time()), next_check_time, channel_id])
if settings.autocheck_subscriptions:
if not _is_muted(cursor, channel_id):
autocheck_job_application.put({'channel_id': channel_id, 'channel_name': channel_names[channel_id], 'next_check_time': next_check_time})
if number_of_new_videos == 0:
print('No new videos from ' + channel_status_name)
elif number_of_new_videos == 1:
print('1 new video from ' + channel_status_name)
else:
print(str(number_of_new_videos) + ' new videos from ' + channel_status_name)
def check_all_channels():
with open_database() as connection:
with connection as cursor:
channel_id_name_list = cursor.execute('''SELECT yt_channel_id, channel_name
FROM subscribed_channels
WHERE muted != 1''').fetchall()
channel_names.update(channel_id_name_list)
check_channels_if_necessary([item[0] for item in channel_id_name_list])
def check_tags(tags):
channel_id_name_list = []
with open_database() as connection:
with connection as cursor:
for tag in tags:
channel_id_name_list += _channels_with_tag(cursor, tag, exclude_muted=True)
channel_names.update(channel_id_name_list)
check_channels_if_necessary([item[0] for item in channel_id_name_list])
def check_specific_channels(channel_ids):
with open_database() as connection:
with connection as cursor:
channel_id_name_list = []
for channel_id in channel_ids:
channel_id_name_list += cursor.execute('''SELECT yt_channel_id, channel_name
FROM subscribed_channels
WHERE yt_channel_id=?''', [channel_id]).fetchall()
channel_names.update(channel_id_name_list)
check_channels_if_necessary(channel_ids)
@yt_app.route('/import_subscriptions', methods=['POST'])
def import_subscriptions():
# check if the post request has the file part
if 'subscriptions_file' not in request.files:
#flash('No file part')
return flask.redirect(util.URL_ORIGIN + request.full_path)
file = request.files['subscriptions_file']
# if user does not select file, browser also
# submit an empty part without filename
if file.filename == '':
#flash('No selected file')
return flask.redirect(util.URL_ORIGIN + request.full_path)
mime_type = file.mimetype
if mime_type == 'application/json':
file = file.read().decode('utf-8')
try:
file = json.loads(file)
except json.decoder.JSONDecodeError:
traceback.print_exc()
return '400 Bad Request: Invalid json file', 400
try:
channels = ( (item['snippet']['resourceId']['channelId'], item['snippet']['title']) for item in file)
except (KeyError, IndexError):
traceback.print_exc()
return '400 Bad Request: Unknown json structure', 400
elif mime_type in ('application/xml', 'text/xml', 'text/x-opml'):
file = file.read().decode('utf-8')
try:
root = defusedxml.ElementTree.fromstring(file)
assert root.tag == 'opml'
channels = []
for outline_element in root[0][0]:
if (outline_element.tag != 'outline') or ('xmlUrl' not in outline_element.attrib):
continue
channel_name = outline_element.attrib['text']
channel_rss_url = outline_element.attrib['xmlUrl']
channel_id = channel_rss_url[channel_rss_url.find('channel_id=')+11:].strip()
channels.append( (channel_id, channel_name) )
except (AssertionError, IndexError, defusedxml.ElementTree.ParseError) as e:
return '400 Bad Request: Unable to read opml xml file, or the file is not the expected format', 400
else:
return '400 Bad Request: Unsupported file format: ' + mime_type + '. Only subscription.json files (from Google Takeouts) and XML OPML files exported from Youtube\'s subscription manager page are supported', 400
_subscribe(channels)
return flask.redirect(util.URL_ORIGIN + '/subscription_manager', 303)
@yt_app.route('/subscription_manager', methods=['GET'])
def get_subscription_manager_page():
group_by_tags = request.args.get('group_by_tags', '0') == '1'
with open_database() as connection:
with connection as cursor:
if group_by_tags:
tag_groups = []
for tag in _get_all_tags(cursor):
sub_list = []
for channel_id, channel_name, muted in _channels_with_tag(cursor, tag, order=True, include_muted_status=True):
sub_list.append({
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
'channel_name': channel_name,
'channel_id': channel_id,
'muted': muted,
'tags': [t for t in _get_tags(cursor, channel_id) if t != tag],
})
tag_groups.append( (tag, sub_list) )
# Channels with no tags
channel_list = cursor.execute('''SELECT yt_channel_id, channel_name, muted
FROM subscribed_channels
WHERE id NOT IN (
SELECT sql_channel_id FROM tag_associations
)
ORDER BY channel_name COLLATE NOCASE''').fetchall()
if channel_list:
sub_list = []
for channel_id, channel_name, muted in channel_list:
sub_list.append({
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
'channel_name': channel_name,
'channel_id': channel_id,
'muted': muted,
'tags': [],
})
tag_groups.append( ('No tags', sub_list) )
else:
sub_list = []
for channel_name, channel_id, muted in _get_subscribed_channels(cursor):
sub_list.append({
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
'channel_name': channel_name,
'channel_id': channel_id,
'muted': muted,
'tags': _get_tags(cursor, channel_id),
})
if group_by_tags:
return flask.render_template('subscription_manager.html',
group_by_tags = True,
tag_groups = tag_groups,
)
else:
return flask.render_template('subscription_manager.html',
group_by_tags = False,
sub_list = sub_list,
)
def list_from_comma_separated_tags(string):
return [tag.strip() for tag in string.split(',') if tag.strip()]
@yt_app.route('/subscription_manager', methods=['POST'])
def post_subscription_manager_page():
action = request.values['action']
with open_database() as connection:
with connection as cursor:
if action == 'add_tags':
_add_tags(cursor, request.values.getlist('channel_ids'), [tag.lower() for tag in list_from_comma_separated_tags(request.values['tags'])])
elif action == 'remove_tags':
_remove_tags(cursor, request.values.getlist('channel_ids'), [tag.lower() for tag in list_from_comma_separated_tags(request.values['tags'])])
elif action == 'unsubscribe':
_unsubscribe(cursor, request.values.getlist('channel_ids'))
elif action == 'unsubscribe_verify':
unsubscribe_list = _get_channel_names(cursor, request.values.getlist('channel_ids'))
return flask.render_template('unsubscribe_verify.html', unsubscribe_list = unsubscribe_list)
elif action == 'mute':
cursor.executemany('''UPDATE subscribed_channels
SET muted = 1
WHERE yt_channel_id = ?''', [(ci,) for ci in request.values.getlist('channel_ids')])
elif action == 'unmute':
cursor.executemany('''UPDATE subscribed_channels
SET muted = 0
WHERE yt_channel_id = ?''', [(ci,) for ci in request.values.getlist('channel_ids')])
else:
flask.abort(400)
return flask.redirect(util.URL_ORIGIN + request.full_path, 303)
@yt_app.route('/subscriptions', methods=['GET'])
@yt_app.route('/feed/subscriptions', methods=['GET'])
def get_subscriptions_page(): def get_subscriptions_page():
page = int(request.args.get('page', 1))
with open_database() as connection:
with connection as cursor:
tag = request.args.get('tag', None)
videos, number_of_videos_in_db = _get_videos(cursor, 60, (page - 1)*60, tag)
for video in videos:
video['thumbnail'] = util.URL_ORIGIN + '/data/subscription_thumbnails/' + video['id'] + '.jpg'
video['type'] = 'video'
video['item_size'] = 'small'
yt_data_extract.add_extra_html_info(video)
tags = _get_all_tags(cursor)
subscription_list = []
for channel_name, channel_id, muted in _get_subscribed_channels(cursor):
subscription_list.append({
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
'channel_name': channel_name,
'channel_id': channel_id,
'muted': muted,
})
return flask.render_template('subscriptions.html',
videos = videos,
num_pages = math.ceil(number_of_videos_in_db/60),
parameters_dictionary = request.args,
tags = tags,
current_tag = tag,
subscription_list = subscription_list,
)
@yt_app.route('/subscriptions', methods=['POST'])
@yt_app.route('/feed/subscriptions', methods=['POST'])
def post_subscriptions_page():
action = request.values['action']
if action == 'subscribe':
if len(request.values.getlist('channel_id')) != len(request.values.getlist('channel_name')):
return '400 Bad Request, length of channel_id != length of channel_name', 400
_subscribe(zip(request.values.getlist('channel_id'), request.values.getlist('channel_name')))
elif action == 'unsubscribe':
with_open_db(_unsubscribe, request.values.getlist('channel_id'))
elif action == 'refresh':
type = request.values['type']
if type == 'all':
check_all_channels()
elif type == 'tag':
check_tags(request.values.getlist('tag_name'))
elif type == 'channel':
check_specific_channels(request.values.getlist('channel_id'))
else:
flask.abort(400)
else:
flask.abort(400)
return '', 204
@yt_app.route('/data/subscription_thumbnails/<thumbnail>')
def serve_subscription_thumbnail(thumbnail):
'''Serves thumbnail from disk if it's been saved already. If not, downloads the thumbnail, saves to disk, and serves it.'''
assert thumbnail[-4:] == '.jpg'
video_id = thumbnail[0:-4]
thumbnail_path = os.path.join(thumbnails_directory, thumbnail)
if video_id in existing_thumbnails:
try:
f = open(thumbnail_path, 'rb')
except FileNotFoundError:
existing_thumbnails.remove(video_id)
else:
image = f.read()
f.close()
return flask.Response(image, mimetype='image/jpeg')
url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"
try:
image = util.fetch_url(url, report_text="Saved thumbnail: " + video_id)
except urllib.error.HTTPError as e:
print("Failed to download thumbnail for " + video_id + ": " + str(e))
abort(e.code)
try:
f = open(thumbnail_path, 'wb')
except FileNotFoundError:
os.makedirs(thumbnails_directory, exist_ok = True)
f = open(thumbnail_path, 'wb')
f.write(image)
f.close()
existing_thumbnails.add(video_id)
return flask.Response(image, mimetype='image/jpeg')

View File

@ -23,6 +23,9 @@
grid-column:2; grid-column:2;
margin-left: 5px; margin-left: 5px;
} }
.summary subscribe-unsubscribe, .summary short-description{
margin-top: 10px;
}
main .channel-tabs{ main .channel-tabs{
grid-row:2; grid-row:2;
grid-column: 1 / span 2; grid-column: 1 / span 2;
@ -89,6 +92,12 @@
<div class="summary"> <div class="summary">
<h2 class="title">{{ channel_name }}</h2> <h2 class="title">{{ channel_name }}</h2>
<p class="short-description">{{ short_description }}</p> <p class="short-description">{{ short_description }}</p>
<form method="POST" action="/youtube.com/subscriptions" class="subscribe-unsubscribe">
<input type="submit" value="{{ 'Unsubscribe' if subscribed else 'Subscribe' }}">
<input type="hidden" name="channel_id" value="{{ channel_id }}">
<input type="hidden" name="channel_name" value="{{ channel_name }}">
<input type="hidden" name="action" value="{{ 'unsubscribe' if subscribed else 'subscribe' }}">
</form>
</div> </div>
<nav class="channel-tabs"> <nav class="channel-tabs">
{% for tab_name in ('Videos', 'Playlists', 'About') %} {% for tab_name in ('Videos', 'Playlists', 'About') %}

View File

@ -0,0 +1,139 @@
{% set page_title = 'Subscription Manager' %}
{% extends "base.html" %}
{% block style %}
.import-export{
display: flex;
flex-direction: row;
}
.subscriptions-import-form{
background-color: var(--interface-color);
display: flex;
flex-direction: column;
align-items: flex-start;
max-width: 300px;
padding:10px;
}
.subscriptions-import-form h2{
font-size: 20px;
margin-bottom: 10px;
}
.import-submit-button{
margin-top:15px;
align-self: flex-end;
}
.subscriptions-export-links{
margin: 0px 0px 0px 20px;
background-color: var(--interface-color);
list-style: none;
max-width: 300px;
padding:10px;
}
.sub-list-controls{
background-color: var(--interface-color);
padding:10px;
}
.tag-group-list{
list-style: none;
margin-left: 10px;
margin-right: 10px;
padding: 0px;
}
.tag-group{
border-style: solid;
margin-bottom: 10px;
}
.sub-list{
list-style: none;
padding:10px;
column-width: 300px;
column-gap: 40px;
}
.sub-list-item{
display:flex;
margin-bottom: 10px;
break-inside:avoid;
background-color: var(--interface-color);
}
.tag-list{
margin-left:15px;
font-weight:bold;
}
.sub-list-item-name{
margin-left:15px;
}
.sub-list-checkbox{
height: 1.5em;
min-width: 1.5em; // need min-width otherwise browser doesn't respect the width and squishes the checkbox down when there's too many tags
}
{% endblock style %}
{% macro subscription_list(sub_list) %}
{% for subscription in sub_list %}
<li class="sub-list-item {{ 'muted' if subscription['muted'] else '' }}">
<input class="sub-list-checkbox" name="channel_ids" value="{{ subscription['channel_id'] }}" form="subscription-manager-form" type="checkbox">
<a href="{{ subscription['channel_url'] }}" class="sub-list-item-name" title="{{ subscription['channel_name'] }}">{{ subscription['channel_name'] }}</a>
<span class="tag-list">{{ ', '.join(subscription['tags']) }}</span>
</li>
{% endfor %}
{% endmacro %}
{% block main %}
<div class="import-export">
<form class="subscriptions-import-form" enctype="multipart/form-data" action="/youtube.com/import_subscriptions" method="POST">
<h2>Import subscriptions</h2>
<input type="file" id="subscriptions-import" accept="application/json, application/xml, text/x-opml" name="subscriptions_file">
<input type="submit" value="Import" class="import-submit-button">
</form>
<ul class="subscriptions-export-links">
<li><a href="/youtube.com/subscriptions.opml">Export subscriptions (OPML)</a></li>
<li><a href="/youtube.com/subscriptions.xml">Export subscriptions (RSS)</a></li>
</ul>
</div>
<hr>
<form id="subscription-manager-form" class="sub-list-controls" method="POST">
{% if group_by_tags %}
<a class="sort-button" href="/https://www.youtube.com/subscription_manager?group_by_tags=0">Don't group</a>
{% else %}
<a class="sort-button" href="/https://www.youtube.com/subscription_manager?group_by_tags=1">Group by tags</a>
{% endif %}
<input type="text" name="tags">
<button type="submit" name="action" value="add_tags">Add tags</button>
<button type="submit" name="action" value="remove_tags">Remove tags</button>
<button type="submit" name="action" value="unsubscribe_verify">Unsubscribe</button>
<button type="submit" name="action" value="mute">Mute</button>
<button type="submit" name="action" value="unmute">Unmute</button>
<input type="reset" value="Clear Selection">
</form>
{% if group_by_tags %}
<ul class="tag-group-list">
{% for tag_name, sub_list in tag_groups %}
<li class="tag-group">
<h2 class="tag-group-name">{{ tag_name }}</h2>
<ol class="sub-list">
{{ subscription_list(sub_list) }}
</ol>
</li>
{% endfor %}
</ul>
{% else %}
<ol class="sub-list">
{{ subscription_list(sub_list) }}
</ol>
{% endif %}
{% endblock main %}

View File

@ -0,0 +1,113 @@
{% set page_title = 'Subscriptions' %}
{% extends "base.html" %}
{% import "common_elements.html" as common_elements %}
{% block style %}
main{
display:flex;
flex-direction: row;
}
.video-section{
flex-grow: 1;
}
.video-section .page-button-row{
justify-content: center;
}
.subscriptions-sidebar{
flex-basis: 300px;
background-color: var(--interface-color);
border-left: 2px;
}
.sidebar-links{
display:flex;
justify-content: space-between;
padding-left:10px;
padding-right: 10px;
}
.sidebar-list{
list-style: none;
padding-left:10px;
padding-right: 10px;
}
.sidebar-list-item{
display:flex;
justify-content: space-between;
margin-bottom: 5px;
}
.sub-refresh-list .sidebar-item-name{
text-overflow: clip;
white-space: nowrap;
overflow: hidden;
max-width: 200px;
}
{% endblock style %}
{% block main %}
<div class="video-section">
<nav class="item-grid">
{% for video_info in videos %}
{{ common_elements.item(video_info, include_author=false) }}
{% endfor %}
</nav>
<nav class="page-button-row">
{{ common_elements.page_buttons(num_pages, '/youtube.com/subscriptions', parameters_dictionary) }}
</nav>
</div>
<div class="subscriptions-sidebar">
<div class="sidebar-links">
<a href="/youtube.com/subscription_manager" class="sub-manager-link">Subscription Manager</a>
<form method="POST" class="refresh-all">
<input type="submit" value="Check All">
<input type="hidden" name="action" value="refresh">
<input type="hidden" name="type" value="all">
</form>
</div>
<hr>
<ol class="sidebar-list tags">
{% if current_tag %}
<li class="sidebar-list-item">
<a href="/youtube.com/subscriptions" class="sidebar-item-name">Any tag</a>
</li>
{% endif %}
{% for tag in tags %}
<li class="sidebar-list-item">
{% if tag == current_tag %}
<span class="sidebar-item-name">{{ tag }}</span>
{% else %}
<a href="?tag={{ tag|urlencode }}" class="sidebar-item-name">{{ tag }}</a>
{% endif %}
<form method="POST" class="sidebar-item-refresh">
<input type="submit" value="Check">
<input type="hidden" name="action" value="refresh">
<input type="hidden" name="type" value="tag">
<input type="hidden" name="tag_name" value="{{ tag }}">
</form>
</li>
{% endfor %}
</ol>
<hr>
<ol class="sidebar-list sub-refresh-list">
{% for subscription in subscription_list %}
<li class="sidebar-list-item {{ 'muted' if subscription['muted'] else '' }}">
<a href="{{ subscription['channel_url'] }}" class="sidebar-item-name" title="{{ subscription['channel_name'] }}">{{ subscription['channel_name'] }}</a>
<form method="POST" class="sidebar-item-refresh">
<input type="submit" value="Check">
<input type="hidden" name="action" value="refresh">
<input type="hidden" name="type" value="channel">
<input type="hidden" name="channel_id" value="{{ subscription['channel_id'] }}">
</form>
</li>
{% endfor %}
</ol>
</div>
{% endblock main %}

View File

@ -0,0 +1,19 @@
{% set page_title = 'Unsubscribe?' %}
{% extends "base.html" %}
{% block main %}
<span>Are you sure you want to unsubscribe from these channels?</span>
<form class="subscriptions-import-form" action="/youtube.com/subscription_manager" method="POST">
{% for channel_id, channel_name in unsubscribe_list %}
<input type="hidden" name="channel_ids" value="{{ channel_id }}">
{% endfor %}
<input type="hidden" name="action" value="unsubscribe">
<input type="submit" value="Yes, unsubscribe">
</form>
<ul>
{% for channel_id, channel_name in unsubscribe_list %}
<li><a href="{{ '/https://www.youtube.com/channel/' + channel_id }}" title="{{ channel_name }}">{{ channel_name }}</a></li>
{% endfor %}
</ul>
{% endblock main %}

View File

@ -6,6 +6,9 @@ import urllib.parse
import re import re
import time import time
import os import os
import gevent
import gevent.queue
import gevent.lock
# The trouble with the requests library: It ships its own certificate bundle via certifi # The trouble with the requests library: It ships its own certificate bundle via certifi
# instead of using the system certificate store, meaning self-signed certificates # instead of using the system certificate store, meaning self-signed certificates
@ -183,6 +186,84 @@ desktop_ua = (('User-Agent', desktop_user_agent),)
class RateLimitedQueue(gevent.queue.Queue):
''' Does initial_burst (def. 30) at first, then alternates between waiting waiting_period (def. 5) seconds and doing subsequent_bursts (def. 10) queries. After 5 seconds with nothing left in the queue, resets rate limiting. '''
def __init__(self, initial_burst=30, waiting_period=5, subsequent_bursts=10):
self.initial_burst = initial_burst
self.waiting_period = waiting_period
self.subsequent_bursts = subsequent_bursts
self.count_since_last_wait = 0
self.surpassed_initial = False
self.lock = gevent.lock.BoundedSemaphore(1)
self.currently_empty = False
self.empty_start = 0
gevent.queue.Queue.__init__(self)
def get(self):
self.lock.acquire() # blocks if another greenlet currently has the lock
if self.count_since_last_wait >= self.subsequent_bursts and self.surpassed_initial:
gevent.sleep(self.waiting_period)
self.count_since_last_wait = 0
elif self.count_since_last_wait >= self.initial_burst and not self.surpassed_initial:
self.surpassed_initial = True
gevent.sleep(self.waiting_period)
self.count_since_last_wait = 0
self.count_since_last_wait += 1
if not self.currently_empty and self.empty():
self.currently_empty = True
self.empty_start = time.monotonic()
item = gevent.queue.Queue.get(self) # blocks when nothing left
if self.currently_empty:
if time.monotonic() - self.empty_start >= self.waiting_period:
self.count_since_last_wait = 0
self.surpassed_initial = False
self.currently_empty = False
self.lock.release()
return item
def download_thumbnail(save_directory, video_id):
url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"
save_location = os.path.join(save_directory, video_id + ".jpg")
try:
thumbnail = fetch_url(url, report_text="Saved thumbnail: " + video_id)
except urllib.error.HTTPError as e:
print("Failed to download thumbnail for " + video_id + ": " + str(e))
return False
try:
f = open(save_location, 'wb')
except FileNotFoundError:
os.makedirs(save_directory, exist_ok = True)
f = open(save_location, 'wb')
f.write(thumbnail)
f.close()
return True
def download_thumbnails(save_directory, ids):
if not isinstance(ids, (list, tuple)):
ids = list(ids)
# only do 5 at a time
# do the n where n is divisible by 5
i = -1
for i in range(0, int(len(ids)/5) - 1 ):
gevent.joinall([gevent.spawn(download_thumbnail, save_directory, ids[j]) for j in range(i*5, i*5 + 5)])
# do the remainders (< 5)
gevent.joinall([gevent.spawn(download_thumbnail, save_directory, ids[j]) for j in range(i*5 + 5, len(ids))])