Merge subscriptions into master
This commit is contained in:
commit
ac32b24b2a
12
python/atoma/__init__.py
Normal file
12
python/atoma/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .atom import parse_atom_file, parse_atom_bytes
|
||||
from .rss import parse_rss_file, parse_rss_bytes
|
||||
from .json_feed import (
|
||||
parse_json_feed, parse_json_feed_file, parse_json_feed_bytes
|
||||
)
|
||||
from .opml import parse_opml_file, parse_opml_bytes
|
||||
from .exceptions import (
|
||||
FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError
|
||||
)
|
||||
from .const import VERSION
|
||||
|
||||
__version__ = VERSION
|
284
python/atoma/atom.py
Normal file
284
python/atoma/atom.py
Normal file
@ -0,0 +1,284 @@
|
||||
from datetime import datetime
|
||||
import enum
|
||||
from io import BytesIO
|
||||
from typing import Optional, List
|
||||
from xml.etree.ElementTree import Element
|
||||
|
||||
import attr
|
||||
|
||||
from .utils import (
|
||||
parse_xml, get_child, get_text, get_datetime, FeedParseError, ns
|
||||
)
|
||||
|
||||
|
||||
class AtomTextType(enum.Enum):
|
||||
text = "text"
|
||||
html = "html"
|
||||
xhtml = "xhtml"
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomTextConstruct:
|
||||
text_type: str = attr.ib()
|
||||
lang: Optional[str] = attr.ib()
|
||||
value: str = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomEntry:
|
||||
title: AtomTextConstruct = attr.ib()
|
||||
id_: str = attr.ib()
|
||||
|
||||
# Should be mandatory but many feeds use published instead
|
||||
updated: Optional[datetime] = attr.ib()
|
||||
|
||||
authors: List['AtomPerson'] = attr.ib()
|
||||
contributors: List['AtomPerson'] = attr.ib()
|
||||
links: List['AtomLink'] = attr.ib()
|
||||
categories: List['AtomCategory'] = attr.ib()
|
||||
published: Optional[datetime] = attr.ib()
|
||||
rights: Optional[AtomTextConstruct] = attr.ib()
|
||||
summary: Optional[AtomTextConstruct] = attr.ib()
|
||||
content: Optional[AtomTextConstruct] = attr.ib()
|
||||
source: Optional['AtomFeed'] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomFeed:
|
||||
title: Optional[AtomTextConstruct] = attr.ib()
|
||||
id_: str = attr.ib()
|
||||
|
||||
# Should be mandatory but many feeds do not include it
|
||||
updated: Optional[datetime] = attr.ib()
|
||||
|
||||
authors: List['AtomPerson'] = attr.ib()
|
||||
contributors: List['AtomPerson'] = attr.ib()
|
||||
links: List['AtomLink'] = attr.ib()
|
||||
categories: List['AtomCategory'] = attr.ib()
|
||||
generator: Optional['AtomGenerator'] = attr.ib()
|
||||
subtitle: Optional[AtomTextConstruct] = attr.ib()
|
||||
rights: Optional[AtomTextConstruct] = attr.ib()
|
||||
icon: Optional[str] = attr.ib()
|
||||
logo: Optional[str] = attr.ib()
|
||||
|
||||
entries: List[AtomEntry] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomPerson:
|
||||
name: str = attr.ib()
|
||||
uri: Optional[str] = attr.ib()
|
||||
email: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomLink:
|
||||
href: str = attr.ib()
|
||||
rel: Optional[str] = attr.ib()
|
||||
type_: Optional[str] = attr.ib()
|
||||
hreflang: Optional[str] = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
length: Optional[int] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomCategory:
|
||||
term: str = attr.ib()
|
||||
scheme: Optional[str] = attr.ib()
|
||||
label: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class AtomGenerator:
|
||||
name: str = attr.ib()
|
||||
uri: Optional[str] = attr.ib()
|
||||
version: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
def _get_generator(element: Element, name,
|
||||
optional: bool=True) -> Optional[AtomGenerator]:
|
||||
child = get_child(element, name, optional)
|
||||
if child is None:
|
||||
return None
|
||||
|
||||
return AtomGenerator(
|
||||
child.text.strip(),
|
||||
child.attrib.get('uri'),
|
||||
child.attrib.get('version'),
|
||||
)
|
||||
|
||||
|
||||
def _get_text_construct(element: Element, name,
|
||||
optional: bool=True) -> Optional[AtomTextConstruct]:
|
||||
child = get_child(element, name, optional)
|
||||
if child is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
text_type = AtomTextType(child.attrib['type'])
|
||||
except KeyError:
|
||||
text_type = AtomTextType.text
|
||||
|
||||
try:
|
||||
lang = child.lang
|
||||
except AttributeError:
|
||||
lang = None
|
||||
|
||||
if child.text is None:
|
||||
if optional:
|
||||
return None
|
||||
|
||||
raise FeedParseError(
|
||||
'Could not parse atom feed: "{}" text is required but is empty'
|
||||
.format(name)
|
||||
)
|
||||
|
||||
return AtomTextConstruct(
|
||||
text_type,
|
||||
lang,
|
||||
child.text.strip()
|
||||
)
|
||||
|
||||
|
||||
def _get_person(element: Element) -> Optional[AtomPerson]:
|
||||
try:
|
||||
return AtomPerson(
|
||||
get_text(element, 'feed:name', optional=False),
|
||||
get_text(element, 'feed:uri'),
|
||||
get_text(element, 'feed:email')
|
||||
)
|
||||
except FeedParseError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_link(element: Element) -> AtomLink:
|
||||
length = element.attrib.get('length')
|
||||
length = int(length) if length else None
|
||||
return AtomLink(
|
||||
element.attrib['href'],
|
||||
element.attrib.get('rel'),
|
||||
element.attrib.get('type'),
|
||||
element.attrib.get('hreflang'),
|
||||
element.attrib.get('title'),
|
||||
length
|
||||
)
|
||||
|
||||
|
||||
def _get_category(element: Element) -> AtomCategory:
|
||||
return AtomCategory(
|
||||
element.attrib['term'],
|
||||
element.attrib.get('scheme'),
|
||||
element.attrib.get('label'),
|
||||
)
|
||||
|
||||
|
||||
def _get_entry(element: Element,
|
||||
default_authors: List[AtomPerson]) -> AtomEntry:
|
||||
root = element
|
||||
|
||||
# Mandatory
|
||||
title = _get_text_construct(root, 'feed:title')
|
||||
id_ = get_text(root, 'feed:id')
|
||||
|
||||
# Optional
|
||||
try:
|
||||
source = _parse_atom(get_child(root, 'feed:source', optional=False),
|
||||
parse_entries=False)
|
||||
except FeedParseError:
|
||||
source = None
|
||||
source_authors = []
|
||||
else:
|
||||
source_authors = source.authors
|
||||
|
||||
authors = [_get_person(e)
|
||||
for e in root.findall('feed:author', ns)] or default_authors
|
||||
authors = [a for a in authors if a is not None]
|
||||
authors = authors or default_authors or source_authors
|
||||
|
||||
contributors = [_get_person(e)
|
||||
for e in root.findall('feed:contributor', ns) if e]
|
||||
contributors = [c for c in contributors if c is not None]
|
||||
|
||||
links = [_get_link(e) for e in root.findall('feed:link', ns)]
|
||||
categories = [_get_category(e) for e in root.findall('feed:category', ns)]
|
||||
|
||||
updated = get_datetime(root, 'feed:updated')
|
||||
published = get_datetime(root, 'feed:published')
|
||||
rights = _get_text_construct(root, 'feed:rights')
|
||||
summary = _get_text_construct(root, 'feed:summary')
|
||||
content = _get_text_construct(root, 'feed:content')
|
||||
|
||||
return AtomEntry(
|
||||
title,
|
||||
id_,
|
||||
updated,
|
||||
authors,
|
||||
contributors,
|
||||
links,
|
||||
categories,
|
||||
published,
|
||||
rights,
|
||||
summary,
|
||||
content,
|
||||
source
|
||||
)
|
||||
|
||||
|
||||
def _parse_atom(root: Element, parse_entries: bool=True) -> AtomFeed:
|
||||
# Mandatory
|
||||
id_ = get_text(root, 'feed:id', optional=False)
|
||||
|
||||
# Optional
|
||||
title = _get_text_construct(root, 'feed:title')
|
||||
updated = get_datetime(root, 'feed:updated')
|
||||
authors = [_get_person(e)
|
||||
for e in root.findall('feed:author', ns) if e]
|
||||
authors = [a for a in authors if a is not None]
|
||||
contributors = [_get_person(e)
|
||||
for e in root.findall('feed:contributor', ns) if e]
|
||||
contributors = [c for c in contributors if c is not None]
|
||||
links = [_get_link(e)
|
||||
for e in root.findall('feed:link', ns)]
|
||||
categories = [_get_category(e)
|
||||
for e in root.findall('feed:category', ns)]
|
||||
|
||||
generator = _get_generator(root, 'feed:generator')
|
||||
subtitle = _get_text_construct(root, 'feed:subtitle')
|
||||
rights = _get_text_construct(root, 'feed:rights')
|
||||
icon = get_text(root, 'feed:icon')
|
||||
logo = get_text(root, 'feed:logo')
|
||||
|
||||
if parse_entries:
|
||||
entries = [_get_entry(e, authors)
|
||||
for e in root.findall('feed:entry', ns)]
|
||||
else:
|
||||
entries = []
|
||||
|
||||
atom_feed = AtomFeed(
|
||||
title,
|
||||
id_,
|
||||
updated,
|
||||
authors,
|
||||
contributors,
|
||||
links,
|
||||
categories,
|
||||
generator,
|
||||
subtitle,
|
||||
rights,
|
||||
icon,
|
||||
logo,
|
||||
entries
|
||||
)
|
||||
return atom_feed
|
||||
|
||||
|
||||
def parse_atom_file(filename: str) -> AtomFeed:
|
||||
"""Parse an Atom feed from a local XML file."""
|
||||
root = parse_xml(filename).getroot()
|
||||
return _parse_atom(root)
|
||||
|
||||
|
||||
def parse_atom_bytes(data: bytes) -> AtomFeed:
|
||||
"""Parse an Atom feed from a byte-string containing XML data."""
|
||||
root = parse_xml(BytesIO(data)).getroot()
|
||||
return _parse_atom(root)
|
1
python/atoma/const.py
Normal file
1
python/atoma/const.py
Normal file
@ -0,0 +1 @@
|
||||
VERSION = '0.0.13'
|
14
python/atoma/exceptions.py
Normal file
14
python/atoma/exceptions.py
Normal file
@ -0,0 +1,14 @@
|
||||
class FeedParseError(Exception):
|
||||
"""Document is an invalid feed."""
|
||||
|
||||
|
||||
class FeedDocumentError(Exception):
|
||||
"""Document is not a supported file."""
|
||||
|
||||
|
||||
class FeedXMLError(FeedDocumentError):
|
||||
"""Document is not valid XML."""
|
||||
|
||||
|
||||
class FeedJSONError(FeedDocumentError):
|
||||
"""Document is not valid JSON."""
|
223
python/atoma/json_feed.py
Normal file
223
python/atoma/json_feed.py
Normal file
@ -0,0 +1,223 @@
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import Optional, List
|
||||
|
||||
import attr
|
||||
|
||||
from .exceptions import FeedParseError, FeedJSONError
|
||||
from .utils import try_parse_date
|
||||
|
||||
|
||||
@attr.s
|
||||
class JSONFeedAuthor:
|
||||
|
||||
name: Optional[str] = attr.ib()
|
||||
url: Optional[str] = attr.ib()
|
||||
avatar: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class JSONFeedAttachment:
|
||||
|
||||
url: str = attr.ib()
|
||||
mime_type: str = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
size_in_bytes: Optional[int] = attr.ib()
|
||||
duration: Optional[timedelta] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class JSONFeedItem:
|
||||
|
||||
id_: str = attr.ib()
|
||||
url: Optional[str] = attr.ib()
|
||||
external_url: Optional[str] = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
content_html: Optional[str] = attr.ib()
|
||||
content_text: Optional[str] = attr.ib()
|
||||
summary: Optional[str] = attr.ib()
|
||||
image: Optional[str] = attr.ib()
|
||||
banner_image: Optional[str] = attr.ib()
|
||||
date_published: Optional[datetime] = attr.ib()
|
||||
date_modified: Optional[datetime] = attr.ib()
|
||||
author: Optional[JSONFeedAuthor] = attr.ib()
|
||||
|
||||
tags: List[str] = attr.ib()
|
||||
attachments: List[JSONFeedAttachment] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class JSONFeed:
|
||||
|
||||
version: str = attr.ib()
|
||||
title: str = attr.ib()
|
||||
home_page_url: Optional[str] = attr.ib()
|
||||
feed_url: Optional[str] = attr.ib()
|
||||
description: Optional[str] = attr.ib()
|
||||
user_comment: Optional[str] = attr.ib()
|
||||
next_url: Optional[str] = attr.ib()
|
||||
icon: Optional[str] = attr.ib()
|
||||
favicon: Optional[str] = attr.ib()
|
||||
author: Optional[JSONFeedAuthor] = attr.ib()
|
||||
expired: bool = attr.ib()
|
||||
|
||||
items: List[JSONFeedItem] = attr.ib()
|
||||
|
||||
|
||||
def _get_items(root: dict) -> List[JSONFeedItem]:
|
||||
rv = []
|
||||
items = root.get('items', [])
|
||||
if not items:
|
||||
return rv
|
||||
|
||||
for item in items:
|
||||
rv.append(_get_item(item))
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def _get_item(item_dict: dict) -> JSONFeedItem:
|
||||
return JSONFeedItem(
|
||||
id_=_get_text(item_dict, 'id', optional=False),
|
||||
url=_get_text(item_dict, 'url'),
|
||||
external_url=_get_text(item_dict, 'external_url'),
|
||||
title=_get_text(item_dict, 'title'),
|
||||
content_html=_get_text(item_dict, 'content_html'),
|
||||
content_text=_get_text(item_dict, 'content_text'),
|
||||
summary=_get_text(item_dict, 'summary'),
|
||||
image=_get_text(item_dict, 'image'),
|
||||
banner_image=_get_text(item_dict, 'banner_image'),
|
||||
date_published=_get_datetime(item_dict, 'date_published'),
|
||||
date_modified=_get_datetime(item_dict, 'date_modified'),
|
||||
author=_get_author(item_dict),
|
||||
tags=_get_tags(item_dict, 'tags'),
|
||||
attachments=_get_attachments(item_dict, 'attachments')
|
||||
)
|
||||
|
||||
|
||||
def _get_attachments(root, name) -> List[JSONFeedAttachment]:
|
||||
rv = list()
|
||||
for attachment_dict in root.get(name, []):
|
||||
rv.append(JSONFeedAttachment(
|
||||
_get_text(attachment_dict, 'url', optional=False),
|
||||
_get_text(attachment_dict, 'mime_type', optional=False),
|
||||
_get_text(attachment_dict, 'title'),
|
||||
_get_int(attachment_dict, 'size_in_bytes'),
|
||||
_get_duration(attachment_dict, 'duration_in_seconds')
|
||||
))
|
||||
return rv
|
||||
|
||||
|
||||
def _get_tags(root, name) -> List[str]:
|
||||
tags = root.get(name, [])
|
||||
return [tag for tag in tags if isinstance(tag, str)]
|
||||
|
||||
|
||||
def _get_datetime(root: dict, name, optional: bool=True) -> Optional[datetime]:
|
||||
text = _get_text(root, name, optional)
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
return try_parse_date(text)
|
||||
|
||||
|
||||
def _get_expired(root: dict) -> bool:
|
||||
if root.get('expired') is True:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_author(root: dict) -> Optional[JSONFeedAuthor]:
|
||||
author_dict = root.get('author')
|
||||
if not author_dict:
|
||||
return None
|
||||
|
||||
rv = JSONFeedAuthor(
|
||||
name=_get_text(author_dict, 'name'),
|
||||
url=_get_text(author_dict, 'url'),
|
||||
avatar=_get_text(author_dict, 'avatar'),
|
||||
)
|
||||
if rv.name is None and rv.url is None and rv.avatar is None:
|
||||
return None
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def _get_int(root: dict, name: str, optional: bool=True) -> Optional[int]:
|
||||
rv = root.get(name)
|
||||
if not optional and rv is None:
|
||||
raise FeedParseError('Could not parse feed: "{}" int is required but '
|
||||
'is empty'.format(name))
|
||||
|
||||
if optional and rv is None:
|
||||
return None
|
||||
|
||||
if not isinstance(rv, int):
|
||||
raise FeedParseError('Could not parse feed: "{}" is not an int'
|
||||
.format(name))
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def _get_duration(root: dict, name: str,
|
||||
optional: bool=True) -> Optional[timedelta]:
|
||||
duration = _get_int(root, name, optional)
|
||||
if duration is None:
|
||||
return None
|
||||
|
||||
return timedelta(seconds=duration)
|
||||
|
||||
|
||||
def _get_text(root: dict, name: str, optional: bool=True) -> Optional[str]:
|
||||
rv = root.get(name)
|
||||
if not optional and rv is None:
|
||||
raise FeedParseError('Could not parse feed: "{}" text is required but '
|
||||
'is empty'.format(name))
|
||||
|
||||
if optional and rv is None:
|
||||
return None
|
||||
|
||||
if not isinstance(rv, str):
|
||||
raise FeedParseError('Could not parse feed: "{}" is not a string'
|
||||
.format(name))
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def parse_json_feed(root: dict) -> JSONFeed:
|
||||
return JSONFeed(
|
||||
version=_get_text(root, 'version', optional=False),
|
||||
title=_get_text(root, 'title', optional=False),
|
||||
home_page_url=_get_text(root, 'home_page_url'),
|
||||
feed_url=_get_text(root, 'feed_url'),
|
||||
description=_get_text(root, 'description'),
|
||||
user_comment=_get_text(root, 'user_comment'),
|
||||
next_url=_get_text(root, 'next_url'),
|
||||
icon=_get_text(root, 'icon'),
|
||||
favicon=_get_text(root, 'favicon'),
|
||||
author=_get_author(root),
|
||||
expired=_get_expired(root),
|
||||
items=_get_items(root)
|
||||
)
|
||||
|
||||
|
||||
def parse_json_feed_file(filename: str) -> JSONFeed:
|
||||
"""Parse a JSON feed from a local json file."""
|
||||
with open(filename) as f:
|
||||
try:
|
||||
root = json.load(f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise FeedJSONError('Not a valid JSON document')
|
||||
|
||||
return parse_json_feed(root)
|
||||
|
||||
|
||||
def parse_json_feed_bytes(data: bytes) -> JSONFeed:
|
||||
"""Parse a JSON feed from a byte-string containing JSON data."""
|
||||
try:
|
||||
root = json.loads(data)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise FeedJSONError('Not a valid JSON document')
|
||||
|
||||
return parse_json_feed(root)
|
107
python/atoma/opml.py
Normal file
107
python/atoma/opml.py
Normal file
@ -0,0 +1,107 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional, List
|
||||
from xml.etree.ElementTree import Element
|
||||
|
||||
import attr
|
||||
|
||||
from .utils import parse_xml, get_text, get_int, get_datetime
|
||||
|
||||
|
||||
@attr.s
|
||||
class OPMLOutline:
|
||||
text: Optional[str] = attr.ib()
|
||||
type: Optional[str] = attr.ib()
|
||||
xml_url: Optional[str] = attr.ib()
|
||||
description: Optional[str] = attr.ib()
|
||||
html_url: Optional[str] = attr.ib()
|
||||
language: Optional[str] = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
version: Optional[str] = attr.ib()
|
||||
|
||||
outlines: List['OPMLOutline'] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class OPML:
|
||||
title: Optional[str] = attr.ib()
|
||||
owner_name: Optional[str] = attr.ib()
|
||||
owner_email: Optional[str] = attr.ib()
|
||||
date_created: Optional[datetime] = attr.ib()
|
||||
date_modified: Optional[datetime] = attr.ib()
|
||||
expansion_state: Optional[str] = attr.ib()
|
||||
|
||||
vertical_scroll_state: Optional[int] = attr.ib()
|
||||
window_top: Optional[int] = attr.ib()
|
||||
window_left: Optional[int] = attr.ib()
|
||||
window_bottom: Optional[int] = attr.ib()
|
||||
window_right: Optional[int] = attr.ib()
|
||||
|
||||
outlines: List[OPMLOutline] = attr.ib()
|
||||
|
||||
|
||||
def _get_outlines(element: Element) -> List[OPMLOutline]:
|
||||
rv = list()
|
||||
|
||||
for outline in element.findall('outline'):
|
||||
rv.append(OPMLOutline(
|
||||
outline.attrib.get('text'),
|
||||
outline.attrib.get('type'),
|
||||
outline.attrib.get('xmlUrl'),
|
||||
outline.attrib.get('description'),
|
||||
outline.attrib.get('htmlUrl'),
|
||||
outline.attrib.get('language'),
|
||||
outline.attrib.get('title'),
|
||||
outline.attrib.get('version'),
|
||||
_get_outlines(outline)
|
||||
))
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def _parse_opml(root: Element) -> OPML:
|
||||
head = root.find('head')
|
||||
body = root.find('body')
|
||||
|
||||
return OPML(
|
||||
get_text(head, 'title'),
|
||||
get_text(head, 'ownerName'),
|
||||
get_text(head, 'ownerEmail'),
|
||||
get_datetime(head, 'dateCreated'),
|
||||
get_datetime(head, 'dateModified'),
|
||||
get_text(head, 'expansionState'),
|
||||
get_int(head, 'vertScrollState'),
|
||||
get_int(head, 'windowTop'),
|
||||
get_int(head, 'windowLeft'),
|
||||
get_int(head, 'windowBottom'),
|
||||
get_int(head, 'windowRight'),
|
||||
outlines=_get_outlines(body)
|
||||
)
|
||||
|
||||
|
||||
def parse_opml_file(filename: str) -> OPML:
|
||||
"""Parse an OPML document from a local XML file."""
|
||||
root = parse_xml(filename).getroot()
|
||||
return _parse_opml(root)
|
||||
|
||||
|
||||
def parse_opml_bytes(data: bytes) -> OPML:
|
||||
"""Parse an OPML document from a byte-string containing XML data."""
|
||||
root = parse_xml(BytesIO(data)).getroot()
|
||||
return _parse_opml(root)
|
||||
|
||||
|
||||
def get_feed_list(opml_obj: OPML) -> List[str]:
|
||||
"""Walk an OPML document to extract the list of feed it contains."""
|
||||
rv = list()
|
||||
|
||||
def collect(obj):
|
||||
for outline in obj.outlines:
|
||||
if outline.type == 'rss' and outline.xml_url:
|
||||
rv.append(outline.xml_url)
|
||||
|
||||
if outline.outlines:
|
||||
collect(outline)
|
||||
|
||||
collect(opml_obj)
|
||||
return rv
|
221
python/atoma/rss.py
Normal file
221
python/atoma/rss.py
Normal file
@ -0,0 +1,221 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional, List
|
||||
from xml.etree.ElementTree import Element
|
||||
|
||||
import attr
|
||||
|
||||
from .utils import (
|
||||
parse_xml, get_child, get_text, get_int, get_datetime, FeedParseError
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class RSSImage:
|
||||
url: str = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
link: str = attr.ib()
|
||||
width: int = attr.ib()
|
||||
height: int = attr.ib()
|
||||
description: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class RSSEnclosure:
|
||||
url: str = attr.ib()
|
||||
length: Optional[int] = attr.ib()
|
||||
type: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class RSSSource:
|
||||
title: str = attr.ib()
|
||||
url: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class RSSItem:
|
||||
title: Optional[str] = attr.ib()
|
||||
link: Optional[str] = attr.ib()
|
||||
description: Optional[str] = attr.ib()
|
||||
author: Optional[str] = attr.ib()
|
||||
categories: List[str] = attr.ib()
|
||||
comments: Optional[str] = attr.ib()
|
||||
enclosures: List[RSSEnclosure] = attr.ib()
|
||||
guid: Optional[str] = attr.ib()
|
||||
pub_date: Optional[datetime] = attr.ib()
|
||||
source: Optional[RSSSource] = attr.ib()
|
||||
|
||||
# Extension
|
||||
content_encoded: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class RSSChannel:
|
||||
title: Optional[str] = attr.ib()
|
||||
link: Optional[str] = attr.ib()
|
||||
description: Optional[str] = attr.ib()
|
||||
language: Optional[str] = attr.ib()
|
||||
copyright: Optional[str] = attr.ib()
|
||||
managing_editor: Optional[str] = attr.ib()
|
||||
web_master: Optional[str] = attr.ib()
|
||||
pub_date: Optional[datetime] = attr.ib()
|
||||
last_build_date: Optional[datetime] = attr.ib()
|
||||
categories: List[str] = attr.ib()
|
||||
generator: Optional[str] = attr.ib()
|
||||
docs: Optional[str] = attr.ib()
|
||||
ttl: Optional[int] = attr.ib()
|
||||
image: Optional[RSSImage] = attr.ib()
|
||||
|
||||
items: List[RSSItem] = attr.ib()
|
||||
|
||||
# Extension
|
||||
content_encoded: Optional[str] = attr.ib()
|
||||
|
||||
|
||||
def _get_image(element: Element, name,
|
||||
optional: bool=True) -> Optional[RSSImage]:
|
||||
child = get_child(element, name, optional)
|
||||
if child is None:
|
||||
return None
|
||||
|
||||
return RSSImage(
|
||||
get_text(child, 'url', optional=False),
|
||||
get_text(child, 'title'),
|
||||
get_text(child, 'link', optional=False),
|
||||
get_int(child, 'width') or 88,
|
||||
get_int(child, 'height') or 31,
|
||||
get_text(child, 'description')
|
||||
)
|
||||
|
||||
|
||||
def _get_source(element: Element, name,
|
||||
optional: bool=True) -> Optional[RSSSource]:
|
||||
child = get_child(element, name, optional)
|
||||
if child is None:
|
||||
return None
|
||||
|
||||
return RSSSource(
|
||||
child.text.strip(),
|
||||
child.attrib.get('url'),
|
||||
)
|
||||
|
||||
|
||||
def _get_enclosure(element: Element) -> RSSEnclosure:
|
||||
length = element.attrib.get('length')
|
||||
try:
|
||||
length = int(length)
|
||||
except (TypeError, ValueError):
|
||||
length = None
|
||||
|
||||
return RSSEnclosure(
|
||||
element.attrib['url'],
|
||||
length,
|
||||
element.attrib.get('type'),
|
||||
)
|
||||
|
||||
|
||||
def _get_link(element: Element) -> Optional[str]:
|
||||
"""Attempt to retrieve item link.
|
||||
|
||||
Use the GUID as a fallback if it is a permalink.
|
||||
"""
|
||||
link = get_text(element, 'link')
|
||||
if link is not None:
|
||||
return link
|
||||
|
||||
guid = get_child(element, 'guid')
|
||||
if guid is not None and guid.attrib.get('isPermaLink') == 'true':
|
||||
return get_text(element, 'guid')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_item(element: Element) -> RSSItem:
|
||||
root = element
|
||||
|
||||
title = get_text(root, 'title')
|
||||
link = _get_link(root)
|
||||
description = get_text(root, 'description')
|
||||
author = get_text(root, 'author')
|
||||
categories = [e.text for e in root.findall('category')]
|
||||
comments = get_text(root, 'comments')
|
||||
enclosure = [_get_enclosure(e) for e in root.findall('enclosure')]
|
||||
guid = get_text(root, 'guid')
|
||||
pub_date = get_datetime(root, 'pubDate')
|
||||
source = _get_source(root, 'source')
|
||||
|
||||
content_encoded = get_text(root, 'content:encoded')
|
||||
|
||||
return RSSItem(
|
||||
title,
|
||||
link,
|
||||
description,
|
||||
author,
|
||||
categories,
|
||||
comments,
|
||||
enclosure,
|
||||
guid,
|
||||
pub_date,
|
||||
source,
|
||||
content_encoded
|
||||
)
|
||||
|
||||
|
||||
def _parse_rss(root: Element) -> RSSChannel:
|
||||
rss_version = root.get('version')
|
||||
if rss_version != '2.0':
|
||||
raise FeedParseError('Cannot process RSS feed version "{}"'
|
||||
.format(rss_version))
|
||||
|
||||
root = root.find('channel')
|
||||
|
||||
title = get_text(root, 'title')
|
||||
link = get_text(root, 'link')
|
||||
description = get_text(root, 'description')
|
||||
language = get_text(root, 'language')
|
||||
copyright = get_text(root, 'copyright')
|
||||
managing_editor = get_text(root, 'managingEditor')
|
||||
web_master = get_text(root, 'webMaster')
|
||||
pub_date = get_datetime(root, 'pubDate')
|
||||
last_build_date = get_datetime(root, 'lastBuildDate')
|
||||
categories = [e.text for e in root.findall('category')]
|
||||
generator = get_text(root, 'generator')
|
||||
docs = get_text(root, 'docs')
|
||||
ttl = get_int(root, 'ttl')
|
||||
|
||||
image = _get_image(root, 'image')
|
||||
items = [_get_item(e) for e in root.findall('item')]
|
||||
|
||||
content_encoded = get_text(root, 'content:encoded')
|
||||
|
||||
return RSSChannel(
|
||||
title,
|
||||
link,
|
||||
description,
|
||||
language,
|
||||
copyright,
|
||||
managing_editor,
|
||||
web_master,
|
||||
pub_date,
|
||||
last_build_date,
|
||||
categories,
|
||||
generator,
|
||||
docs,
|
||||
ttl,
|
||||
image,
|
||||
items,
|
||||
content_encoded
|
||||
)
|
||||
|
||||
|
||||
def parse_rss_file(filename: str) -> RSSChannel:
|
||||
"""Parse an RSS feed from a local XML file."""
|
||||
root = parse_xml(filename).getroot()
|
||||
return _parse_rss(root)
|
||||
|
||||
|
||||
def parse_rss_bytes(data: bytes) -> RSSChannel:
|
||||
"""Parse an RSS feed from a byte-string containing XML data."""
|
||||
root = parse_xml(BytesIO(data)).getroot()
|
||||
return _parse_rss(root)
|
224
python/atoma/simple.py
Normal file
224
python/atoma/simple.py
Normal file
@ -0,0 +1,224 @@
|
||||
"""Simple API that abstracts away the differences between feed types."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import html
|
||||
import os
|
||||
from typing import Optional, List, Tuple
|
||||
import urllib.parse
|
||||
|
||||
import attr
|
||||
|
||||
from . import atom, rss, json_feed
|
||||
from .exceptions import (
|
||||
FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Attachment:
|
||||
link: str = attr.ib()
|
||||
mime_type: Optional[str] = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
size_in_bytes: Optional[int] = attr.ib()
|
||||
duration: Optional[timedelta] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class Article:
|
||||
id: str = attr.ib()
|
||||
title: Optional[str] = attr.ib()
|
||||
link: Optional[str] = attr.ib()
|
||||
content: str = attr.ib()
|
||||
published_at: Optional[datetime] = attr.ib()
|
||||
updated_at: Optional[datetime] = attr.ib()
|
||||
attachments: List[Attachment] = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class Feed:
|
||||
title: str = attr.ib()
|
||||
subtitle: Optional[str] = attr.ib()
|
||||
link: Optional[str] = attr.ib()
|
||||
updated_at: Optional[datetime] = attr.ib()
|
||||
articles: List[Article] = attr.ib()
|
||||
|
||||
|
||||
def _adapt_atom_feed(atom_feed: atom.AtomFeed) -> Feed:
|
||||
articles = list()
|
||||
for entry in atom_feed.entries:
|
||||
if entry.content is not None:
|
||||
content = entry.content.value
|
||||
elif entry.summary is not None:
|
||||
content = entry.summary.value
|
||||
else:
|
||||
content = ''
|
||||
published_at, updated_at = _get_article_dates(entry.published,
|
||||
entry.updated)
|
||||
# Find article link and attachments
|
||||
article_link = None
|
||||
attachments = list()
|
||||
for candidate_link in entry.links:
|
||||
if candidate_link.rel in ('alternate', None):
|
||||
article_link = candidate_link.href
|
||||
elif candidate_link.rel == 'enclosure':
|
||||
attachments.append(Attachment(
|
||||
title=_get_attachment_title(candidate_link.title,
|
||||
candidate_link.href),
|
||||
link=candidate_link.href,
|
||||
mime_type=candidate_link.type_,
|
||||
size_in_bytes=candidate_link.length,
|
||||
duration=None
|
||||
))
|
||||
|
||||
if entry.title is None:
|
||||
entry_title = None
|
||||
elif entry.title.text_type in (atom.AtomTextType.html,
|
||||
atom.AtomTextType.xhtml):
|
||||
entry_title = html.unescape(entry.title.value).strip()
|
||||
else:
|
||||
entry_title = entry.title.value
|
||||
|
||||
articles.append(Article(
|
||||
entry.id_,
|
||||
entry_title,
|
||||
article_link,
|
||||
content,
|
||||
published_at,
|
||||
updated_at,
|
||||
attachments
|
||||
))
|
||||
|
||||
# Find feed link
|
||||
link = None
|
||||
for candidate_link in atom_feed.links:
|
||||
if candidate_link.rel == 'self':
|
||||
link = candidate_link.href
|
||||
break
|
||||
|
||||
return Feed(
|
||||
atom_feed.title.value if atom_feed.title else atom_feed.id_,
|
||||
atom_feed.subtitle.value if atom_feed.subtitle else None,
|
||||
link,
|
||||
atom_feed.updated,
|
||||
articles
|
||||
)
|
||||
|
||||
|
||||
def _adapt_rss_channel(rss_channel: rss.RSSChannel) -> Feed:
|
||||
articles = list()
|
||||
for item in rss_channel.items:
|
||||
attachments = [
|
||||
Attachment(link=e.url, mime_type=e.type, size_in_bytes=e.length,
|
||||
title=_get_attachment_title(None, e.url), duration=None)
|
||||
for e in item.enclosures
|
||||
]
|
||||
articles.append(Article(
|
||||
item.guid or item.link,
|
||||
item.title,
|
||||
item.link,
|
||||
item.content_encoded or item.description or '',
|
||||
item.pub_date,
|
||||
None,
|
||||
attachments
|
||||
))
|
||||
|
||||
if rss_channel.title is None and rss_channel.link is None:
|
||||
raise FeedParseError('RSS feed does not have a title nor a link')
|
||||
|
||||
return Feed(
|
||||
rss_channel.title if rss_channel.title else rss_channel.link,
|
||||
rss_channel.description,
|
||||
rss_channel.link,
|
||||
rss_channel.pub_date,
|
||||
articles
|
||||
)
|
||||
|
||||
|
||||
def _adapt_json_feed(json_feed: json_feed.JSONFeed) -> Feed:
|
||||
articles = list()
|
||||
for item in json_feed.items:
|
||||
attachments = [
|
||||
Attachment(a.url, a.mime_type,
|
||||
_get_attachment_title(a.title, a.url),
|
||||
a.size_in_bytes, a.duration)
|
||||
for a in item.attachments
|
||||
]
|
||||
articles.append(Article(
|
||||
item.id_,
|
||||
item.title,
|
||||
item.url,
|
||||
item.content_html or item.content_text or '',
|
||||
item.date_published,
|
||||
item.date_modified,
|
||||
attachments
|
||||
))
|
||||
|
||||
return Feed(
|
||||
json_feed.title,
|
||||
json_feed.description,
|
||||
json_feed.feed_url,
|
||||
None,
|
||||
articles
|
||||
)
|
||||
|
||||
|
||||
def _get_article_dates(published_at: Optional[datetime],
|
||||
updated_at: Optional[datetime]
|
||||
) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
if published_at and updated_at:
|
||||
return published_at, updated_at
|
||||
|
||||
if updated_at:
|
||||
return updated_at, None
|
||||
|
||||
if published_at:
|
||||
return published_at, None
|
||||
|
||||
raise FeedParseError('Article does not have proper dates')
|
||||
|
||||
|
||||
def _get_attachment_title(attachment_title: Optional[str], link: str) -> str:
|
||||
if attachment_title:
|
||||
return attachment_title
|
||||
|
||||
parsed_link = urllib.parse.urlparse(link)
|
||||
return os.path.basename(parsed_link.path)
|
||||
|
||||
|
||||
def _simple_parse(pairs, content) -> Feed:
|
||||
is_xml = True
|
||||
is_json = True
|
||||
for parser, adapter in pairs:
|
||||
try:
|
||||
return adapter(parser(content))
|
||||
except FeedXMLError:
|
||||
is_xml = False
|
||||
except FeedJSONError:
|
||||
is_json = False
|
||||
except FeedParseError:
|
||||
continue
|
||||
|
||||
if not is_xml and not is_json:
|
||||
raise FeedDocumentError('File is not a supported feed type')
|
||||
|
||||
raise FeedParseError('File is not a valid supported feed')
|
||||
|
||||
|
||||
def simple_parse_file(filename: str) -> Feed:
|
||||
"""Parse an Atom, RSS or JSON feed from a local file."""
|
||||
pairs = (
|
||||
(rss.parse_rss_file, _adapt_rss_channel),
|
||||
(atom.parse_atom_file, _adapt_atom_feed),
|
||||
(json_feed.parse_json_feed_file, _adapt_json_feed)
|
||||
)
|
||||
return _simple_parse(pairs, filename)
|
||||
|
||||
|
||||
def simple_parse_bytes(data: bytes) -> Feed:
|
||||
"""Parse an Atom, RSS or JSON feed from a byte-string containing data."""
|
||||
pairs = (
|
||||
(rss.parse_rss_bytes, _adapt_rss_channel),
|
||||
(atom.parse_atom_bytes, _adapt_atom_feed),
|
||||
(json_feed.parse_json_feed_bytes, _adapt_json_feed)
|
||||
)
|
||||
return _simple_parse(pairs, data)
|
84
python/atoma/utils.py
Normal file
84
python/atoma/utils.py
Normal file
@ -0,0 +1,84 @@
|
||||
from datetime import datetime, timezone
|
||||
from xml.etree.ElementTree import Element
|
||||
from typing import Optional
|
||||
|
||||
import dateutil.parser
|
||||
from defusedxml.ElementTree import parse as defused_xml_parse, ParseError
|
||||
|
||||
from .exceptions import FeedXMLError, FeedParseError
|
||||
|
||||
ns = {
|
||||
'content': 'http://purl.org/rss/1.0/modules/content/',
|
||||
'feed': 'http://www.w3.org/2005/Atom'
|
||||
}
|
||||
|
||||
|
||||
def parse_xml(xml_content):
|
||||
try:
|
||||
return defused_xml_parse(xml_content)
|
||||
except ParseError:
|
||||
raise FeedXMLError('Not a valid XML document')
|
||||
|
||||
|
||||
def get_child(element: Element, name,
|
||||
optional: bool=True) -> Optional[Element]:
|
||||
child = element.find(name, namespaces=ns)
|
||||
|
||||
if child is None and not optional:
|
||||
raise FeedParseError(
|
||||
'Could not parse feed: "{}" does not have a "{}"'
|
||||
.format(element.tag, name)
|
||||
)
|
||||
|
||||
elif child is None:
|
||||
return None
|
||||
|
||||
return child
|
||||
|
||||
|
||||
def get_text(element: Element, name, optional: bool=True) -> Optional[str]:
|
||||
child = get_child(element, name, optional)
|
||||
if child is None:
|
||||
return None
|
||||
|
||||
if child.text is None:
|
||||
if optional:
|
||||
return None
|
||||
|
||||
raise FeedParseError(
|
||||
'Could not parse feed: "{}" text is required but is empty'
|
||||
.format(name)
|
||||
)
|
||||
|
||||
return child.text.strip()
|
||||
|
||||
|
||||
def get_int(element: Element, name, optional: bool=True) -> Optional[int]:
|
||||
text = get_text(element, name, optional)
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
return int(text)
|
||||
|
||||
|
||||
def get_datetime(element: Element, name,
|
||||
optional: bool=True) -> Optional[datetime]:
|
||||
text = get_text(element, name, optional)
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
return try_parse_date(text)
|
||||
|
||||
|
||||
def try_parse_date(date_str: str) -> Optional[datetime]:
|
||||
try:
|
||||
date = dateutil.parser.parse(date_str, fuzzy=True)
|
||||
except (ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
if date.tzinfo is None:
|
||||
# TZ naive datetime, make it a TZ aware datetime by assuming it
|
||||
# contains UTC time
|
||||
date = date.replace(tzinfo=timezone.utc)
|
||||
|
||||
return date
|
65
python/attr/__init__.py
Normal file
65
python/attr/__init__.py
Normal file
@ -0,0 +1,65 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from functools import partial
|
||||
|
||||
from . import converters, exceptions, filters, validators
|
||||
from ._config import get_run_validators, set_run_validators
|
||||
from ._funcs import asdict, assoc, astuple, evolve, has
|
||||
from ._make import (
|
||||
NOTHING,
|
||||
Attribute,
|
||||
Factory,
|
||||
attrib,
|
||||
attrs,
|
||||
fields,
|
||||
fields_dict,
|
||||
make_class,
|
||||
validate,
|
||||
)
|
||||
|
||||
|
||||
__version__ = "18.2.0"
|
||||
|
||||
__title__ = "attrs"
|
||||
__description__ = "Classes Without Boilerplate"
|
||||
__url__ = "https://www.attrs.org/"
|
||||
__uri__ = __url__
|
||||
__doc__ = __description__ + " <" + __uri__ + ">"
|
||||
|
||||
__author__ = "Hynek Schlawack"
|
||||
__email__ = "hs@ox.cx"
|
||||
|
||||
__license__ = "MIT"
|
||||
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
|
||||
|
||||
|
||||
s = attributes = attrs
|
||||
ib = attr = attrib
|
||||
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
|
||||
|
||||
__all__ = [
|
||||
"Attribute",
|
||||
"Factory",
|
||||
"NOTHING",
|
||||
"asdict",
|
||||
"assoc",
|
||||
"astuple",
|
||||
"attr",
|
||||
"attrib",
|
||||
"attributes",
|
||||
"attrs",
|
||||
"converters",
|
||||
"evolve",
|
||||
"exceptions",
|
||||
"fields",
|
||||
"fields_dict",
|
||||
"filters",
|
||||
"get_run_validators",
|
||||
"has",
|
||||
"ib",
|
||||
"make_class",
|
||||
"s",
|
||||
"set_run_validators",
|
||||
"validate",
|
||||
"validators",
|
||||
]
|
252
python/attr/__init__.pyi
Normal file
252
python/attr/__init__.pyi
Normal file
@ -0,0 +1,252 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
# `import X as X` is required to make these public
|
||||
from . import exceptions as exceptions
|
||||
from . import filters as filters
|
||||
from . import converters as converters
|
||||
from . import validators as validators
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=type)
|
||||
|
||||
_ValidatorType = Callable[[Any, Attribute, _T], Any]
|
||||
_ConverterType = Callable[[Any], _T]
|
||||
_FilterType = Callable[[Attribute, Any], bool]
|
||||
# FIXME: in reality, if multiple validators are passed they must be in a list or tuple,
|
||||
# but those are invariant and so would prevent subtypes of _ValidatorType from working
|
||||
# when passed in a list or tuple.
|
||||
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
|
||||
|
||||
# _make --
|
||||
|
||||
NOTHING: object
|
||||
|
||||
# NOTE: Factory lies about its return type to make this possible: `x: List[int] = Factory(list)`
|
||||
# Work around mypy issue #4554 in the common case by using an overload.
|
||||
@overload
|
||||
def Factory(factory: Callable[[], _T]) -> _T: ...
|
||||
@overload
|
||||
def Factory(
|
||||
factory: Union[Callable[[Any], _T], Callable[[], _T]],
|
||||
takes_self: bool = ...,
|
||||
) -> _T: ...
|
||||
|
||||
class Attribute(Generic[_T]):
|
||||
name: str
|
||||
default: Optional[_T]
|
||||
validator: Optional[_ValidatorType[_T]]
|
||||
repr: bool
|
||||
cmp: bool
|
||||
hash: Optional[bool]
|
||||
init: bool
|
||||
converter: Optional[_ConverterType[_T]]
|
||||
metadata: Dict[Any, Any]
|
||||
type: Optional[Type[_T]]
|
||||
kw_only: bool
|
||||
def __lt__(self, x: Attribute) -> bool: ...
|
||||
def __le__(self, x: Attribute) -> bool: ...
|
||||
def __gt__(self, x: Attribute) -> bool: ...
|
||||
def __ge__(self, x: Attribute) -> bool: ...
|
||||
|
||||
# NOTE: We had several choices for the annotation to use for type arg:
|
||||
# 1) Type[_T]
|
||||
# - Pros: Handles simple cases correctly
|
||||
# - Cons: Might produce less informative errors in the case of conflicting TypeVars
|
||||
# e.g. `attr.ib(default='bad', type=int)`
|
||||
# 2) Callable[..., _T]
|
||||
# - Pros: Better error messages than #1 for conflicting TypeVars
|
||||
# - Cons: Terrible error messages for validator checks.
|
||||
# e.g. attr.ib(type=int, validator=validate_str)
|
||||
# -> error: Cannot infer function type argument
|
||||
# 3) type (and do all of the work in the mypy plugin)
|
||||
# - Pros: Simple here, and we could customize the plugin with our own errors.
|
||||
# - Cons: Would need to write mypy plugin code to handle all the cases.
|
||||
# We chose option #1.
|
||||
|
||||
# `attr` lies about its return type to make the following possible:
|
||||
# attr() -> Any
|
||||
# attr(8) -> int
|
||||
# attr(validator=<some callable>) -> Whatever the callable expects.
|
||||
# This makes this type of assignments possible:
|
||||
# x: int = attr(8)
|
||||
#
|
||||
# This form catches explicit None or no default but with no other arguments returns Any.
|
||||
@overload
|
||||
def attrib(
|
||||
default: None = ...,
|
||||
validator: None = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
convert: None = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: None = ...,
|
||||
converter: None = ...,
|
||||
factory: None = ...,
|
||||
kw_only: bool = ...,
|
||||
) -> Any: ...
|
||||
|
||||
# This form catches an explicit None or no default and infers the type from the other arguments.
|
||||
@overload
|
||||
def attrib(
|
||||
default: None = ...,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
convert: Optional[_ConverterType[_T]] = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: Optional[Type[_T]] = ...,
|
||||
converter: Optional[_ConverterType[_T]] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# This form catches an explicit default argument.
|
||||
@overload
|
||||
def attrib(
|
||||
default: _T,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
convert: Optional[_ConverterType[_T]] = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: Optional[Type[_T]] = ...,
|
||||
converter: Optional[_ConverterType[_T]] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# This form covers type=non-Type: e.g. forward references (str), Any
|
||||
@overload
|
||||
def attrib(
|
||||
default: Optional[_T] = ...,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
convert: Optional[_ConverterType[_T]] = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: object = ...,
|
||||
converter: Optional[_ConverterType[_T]] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
) -> Any: ...
|
||||
@overload
|
||||
def attrs(
|
||||
maybe_cls: _C,
|
||||
these: Optional[Dict[str, Any]] = ...,
|
||||
repr_ns: Optional[str] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
) -> _C: ...
|
||||
@overload
|
||||
def attrs(
|
||||
maybe_cls: None = ...,
|
||||
these: Optional[Dict[str, Any]] = ...,
|
||||
repr_ns: Optional[str] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
) -> Callable[[_C], _C]: ...
|
||||
|
||||
# TODO: add support for returning NamedTuple from the mypy plugin
|
||||
class _Fields(Tuple[Attribute, ...]):
|
||||
def __getattr__(self, name: str) -> Attribute: ...
|
||||
|
||||
def fields(cls: type) -> _Fields: ...
|
||||
def fields_dict(cls: type) -> Dict[str, Attribute]: ...
|
||||
def validate(inst: Any) -> None: ...
|
||||
|
||||
# TODO: add support for returning a proper attrs class from the mypy plugin
|
||||
# we use Any instead of _CountingAttr so that e.g. `make_class('Foo', [attr.ib()])` is valid
|
||||
def make_class(
|
||||
name: str,
|
||||
attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]],
|
||||
bases: Tuple[type, ...] = ...,
|
||||
repr_ns: Optional[str] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
) -> type: ...
|
||||
|
||||
# _funcs --
|
||||
|
||||
# TODO: add support for returning TypedDict from the mypy plugin
|
||||
# FIXME: asdict/astuple do not honor their factory args. waiting on one of these:
|
||||
# https://github.com/python/mypy/issues/4236
|
||||
# https://github.com/python/typing/issues/253
|
||||
def asdict(
|
||||
inst: Any,
|
||||
recurse: bool = ...,
|
||||
filter: Optional[_FilterType] = ...,
|
||||
dict_factory: Type[Mapping[Any, Any]] = ...,
|
||||
retain_collection_types: bool = ...,
|
||||
) -> Dict[str, Any]: ...
|
||||
|
||||
# TODO: add support for returning NamedTuple from the mypy plugin
|
||||
def astuple(
|
||||
inst: Any,
|
||||
recurse: bool = ...,
|
||||
filter: Optional[_FilterType] = ...,
|
||||
tuple_factory: Type[Sequence] = ...,
|
||||
retain_collection_types: bool = ...,
|
||||
) -> Tuple[Any, ...]: ...
|
||||
def has(cls: type) -> bool: ...
|
||||
def assoc(inst: _T, **changes: Any) -> _T: ...
|
||||
def evolve(inst: _T, **changes: Any) -> _T: ...
|
||||
|
||||
# _config --
|
||||
|
||||
def set_run_validators(run: bool) -> None: ...
|
||||
def get_run_validators() -> bool: ...
|
||||
|
||||
# aliases --
|
||||
|
||||
s = attributes = attrs
|
||||
ib = attr = attrib
|
||||
dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)
|
163
python/attr/_compat.py
Normal file
163
python/attr/_compat.py
Normal file
@ -0,0 +1,163 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import platform
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PYPY = platform.python_implementation() == "PyPy"
|
||||
|
||||
|
||||
if PYPY or sys.version_info[:2] >= (3, 6):
|
||||
ordered_dict = dict
|
||||
else:
|
||||
from collections import OrderedDict
|
||||
|
||||
ordered_dict = OrderedDict
|
||||
|
||||
|
||||
if PY2:
|
||||
from UserDict import IterableUserDict
|
||||
|
||||
# We 'bundle' isclass instead of using inspect as importing inspect is
|
||||
# fairly expensive (order of 10-15 ms for a modern machine in 2016)
|
||||
def isclass(klass):
|
||||
return isinstance(klass, (type, types.ClassType))
|
||||
|
||||
# TYPE is used in exceptions, repr(int) is different on Python 2 and 3.
|
||||
TYPE = "type"
|
||||
|
||||
def iteritems(d):
|
||||
return d.iteritems()
|
||||
|
||||
# Python 2 is bereft of a read-only dict proxy, so we make one!
|
||||
class ReadOnlyDict(IterableUserDict):
|
||||
"""
|
||||
Best-effort read-only dict wrapper.
|
||||
"""
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise TypeError(
|
||||
"'mappingproxy' object does not support item assignment"
|
||||
)
|
||||
|
||||
def update(self, _):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'update'"
|
||||
)
|
||||
|
||||
def __delitem__(self, _):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise TypeError(
|
||||
"'mappingproxy' object does not support item deletion"
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'clear'"
|
||||
)
|
||||
|
||||
def pop(self, key, default=None):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'pop'"
|
||||
)
|
||||
|
||||
def popitem(self):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'popitem'"
|
||||
)
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'setdefault'"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
# Override to be identical to the Python 3 version.
|
||||
return "mappingproxy(" + repr(self.data) + ")"
|
||||
|
||||
def metadata_proxy(d):
|
||||
res = ReadOnlyDict()
|
||||
res.data.update(d) # We blocked update, so we have to do it like this.
|
||||
return res
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def isclass(klass):
|
||||
return isinstance(klass, type)
|
||||
|
||||
TYPE = "class"
|
||||
|
||||
def iteritems(d):
|
||||
return d.items()
|
||||
|
||||
def metadata_proxy(d):
|
||||
return types.MappingProxyType(dict(d))
|
||||
|
||||
|
||||
def import_ctypes():
|
||||
"""
|
||||
Moved into a function for testability.
|
||||
"""
|
||||
import ctypes
|
||||
|
||||
return ctypes
|
||||
|
||||
|
||||
if not PY2:
|
||||
|
||||
def just_warn(*args, **kw):
|
||||
"""
|
||||
We only warn on Python 3 because we are not aware of any concrete
|
||||
consequences of not setting the cell on Python 2.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Missing ctypes. Some features like bare super() or accessing "
|
||||
"__class__ will not work with slots classes.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def just_warn(*args, **kw): # pragma: nocover
|
||||
"""
|
||||
We only warn on Python 3 because we are not aware of any concrete
|
||||
consequences of not setting the cell on Python 2.
|
||||
"""
|
||||
|
||||
|
||||
def make_set_closure_cell():
|
||||
"""
|
||||
Moved into a function for testability.
|
||||
"""
|
||||
if PYPY: # pragma: no cover
|
||||
|
||||
def set_closure_cell(cell, value):
|
||||
cell.__setstate__((value,))
|
||||
|
||||
else:
|
||||
try:
|
||||
ctypes = import_ctypes()
|
||||
|
||||
set_closure_cell = ctypes.pythonapi.PyCell_Set
|
||||
set_closure_cell.argtypes = (ctypes.py_object, ctypes.py_object)
|
||||
set_closure_cell.restype = ctypes.c_int
|
||||
except Exception:
|
||||
# We try best effort to set the cell, but sometimes it's not
|
||||
# possible. For example on Jython or on GAE.
|
||||
set_closure_cell = just_warn
|
||||
return set_closure_cell
|
||||
|
||||
|
||||
set_closure_cell = make_set_closure_cell()
|
23
python/attr/_config.py
Normal file
23
python/attr/_config.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
__all__ = ["set_run_validators", "get_run_validators"]
|
||||
|
||||
_run_validators = True
|
||||
|
||||
|
||||
def set_run_validators(run):
|
||||
"""
|
||||
Set whether or not validators are run. By default, they are run.
|
||||
"""
|
||||
if not isinstance(run, bool):
|
||||
raise TypeError("'run' must be bool.")
|
||||
global _run_validators
|
||||
_run_validators = run
|
||||
|
||||
|
||||
def get_run_validators():
|
||||
"""
|
||||
Return whether or not validators are run.
|
||||
"""
|
||||
return _run_validators
|
290
python/attr/_funcs.py
Normal file
290
python/attr/_funcs.py
Normal file
@ -0,0 +1,290 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import copy
|
||||
|
||||
from ._compat import iteritems
|
||||
from ._make import NOTHING, _obj_setattr, fields
|
||||
from .exceptions import AttrsAttributeNotFoundError
|
||||
|
||||
|
||||
def asdict(
|
||||
inst,
|
||||
recurse=True,
|
||||
filter=None,
|
||||
dict_factory=dict,
|
||||
retain_collection_types=False,
|
||||
):
|
||||
"""
|
||||
Return the ``attrs`` attribute values of *inst* as a dict.
|
||||
|
||||
Optionally recurse into other ``attrs``-decorated classes.
|
||||
|
||||
:param inst: Instance of an ``attrs``-decorated class.
|
||||
:param bool recurse: Recurse into classes that are also
|
||||
``attrs``-decorated.
|
||||
:param callable filter: A callable whose return code determines whether an
|
||||
attribute or element is included (``True``) or dropped (``False``). Is
|
||||
called with the :class:`attr.Attribute` as the first argument and the
|
||||
value as the second argument.
|
||||
:param callable dict_factory: A callable to produce dictionaries from. For
|
||||
example, to produce ordered dictionaries instead of normal Python
|
||||
dictionaries, pass in ``collections.OrderedDict``.
|
||||
:param bool retain_collection_types: Do not convert to ``list`` when
|
||||
encountering an attribute whose type is ``tuple`` or ``set``. Only
|
||||
meaningful if ``recurse`` is ``True``.
|
||||
|
||||
:rtype: return type of *dict_factory*
|
||||
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. versionadded:: 16.0.0 *dict_factory*
|
||||
.. versionadded:: 16.1.0 *retain_collection_types*
|
||||
"""
|
||||
attrs = fields(inst.__class__)
|
||||
rv = dict_factory()
|
||||
for a in attrs:
|
||||
v = getattr(inst, a.name)
|
||||
if filter is not None and not filter(a, v):
|
||||
continue
|
||||
if recurse is True:
|
||||
if has(v.__class__):
|
||||
rv[a.name] = asdict(
|
||||
v, True, filter, dict_factory, retain_collection_types
|
||||
)
|
||||
elif isinstance(v, (tuple, list, set)):
|
||||
cf = v.__class__ if retain_collection_types is True else list
|
||||
rv[a.name] = cf(
|
||||
[
|
||||
_asdict_anything(
|
||||
i, filter, dict_factory, retain_collection_types
|
||||
)
|
||||
for i in v
|
||||
]
|
||||
)
|
||||
elif isinstance(v, dict):
|
||||
df = dict_factory
|
||||
rv[a.name] = df(
|
||||
(
|
||||
_asdict_anything(
|
||||
kk, filter, df, retain_collection_types
|
||||
),
|
||||
_asdict_anything(
|
||||
vv, filter, df, retain_collection_types
|
||||
),
|
||||
)
|
||||
for kk, vv in iteritems(v)
|
||||
)
|
||||
else:
|
||||
rv[a.name] = v
|
||||
else:
|
||||
rv[a.name] = v
|
||||
return rv
|
||||
|
||||
|
||||
def _asdict_anything(val, filter, dict_factory, retain_collection_types):
|
||||
"""
|
||||
``asdict`` only works on attrs instances, this works on anything.
|
||||
"""
|
||||
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
|
||||
# Attrs class.
|
||||
rv = asdict(val, True, filter, dict_factory, retain_collection_types)
|
||||
elif isinstance(val, (tuple, list, set)):
|
||||
cf = val.__class__ if retain_collection_types is True else list
|
||||
rv = cf(
|
||||
[
|
||||
_asdict_anything(
|
||||
i, filter, dict_factory, retain_collection_types
|
||||
)
|
||||
for i in val
|
||||
]
|
||||
)
|
||||
elif isinstance(val, dict):
|
||||
df = dict_factory
|
||||
rv = df(
|
||||
(
|
||||
_asdict_anything(kk, filter, df, retain_collection_types),
|
||||
_asdict_anything(vv, filter, df, retain_collection_types),
|
||||
)
|
||||
for kk, vv in iteritems(val)
|
||||
)
|
||||
else:
|
||||
rv = val
|
||||
return rv
|
||||
|
||||
|
||||
def astuple(
|
||||
inst,
|
||||
recurse=True,
|
||||
filter=None,
|
||||
tuple_factory=tuple,
|
||||
retain_collection_types=False,
|
||||
):
|
||||
"""
|
||||
Return the ``attrs`` attribute values of *inst* as a tuple.
|
||||
|
||||
Optionally recurse into other ``attrs``-decorated classes.
|
||||
|
||||
:param inst: Instance of an ``attrs``-decorated class.
|
||||
:param bool recurse: Recurse into classes that are also
|
||||
``attrs``-decorated.
|
||||
:param callable filter: A callable whose return code determines whether an
|
||||
attribute or element is included (``True``) or dropped (``False``). Is
|
||||
called with the :class:`attr.Attribute` as the first argument and the
|
||||
value as the second argument.
|
||||
:param callable tuple_factory: A callable to produce tuples from. For
|
||||
example, to produce lists instead of tuples.
|
||||
:param bool retain_collection_types: Do not convert to ``list``
|
||||
or ``dict`` when encountering an attribute which type is
|
||||
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
|
||||
``True``.
|
||||
|
||||
:rtype: return type of *tuple_factory*
|
||||
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. versionadded:: 16.2.0
|
||||
"""
|
||||
attrs = fields(inst.__class__)
|
||||
rv = []
|
||||
retain = retain_collection_types # Very long. :/
|
||||
for a in attrs:
|
||||
v = getattr(inst, a.name)
|
||||
if filter is not None and not filter(a, v):
|
||||
continue
|
||||
if recurse is True:
|
||||
if has(v.__class__):
|
||||
rv.append(
|
||||
astuple(
|
||||
v,
|
||||
recurse=True,
|
||||
filter=filter,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
)
|
||||
elif isinstance(v, (tuple, list, set)):
|
||||
cf = v.__class__ if retain is True else list
|
||||
rv.append(
|
||||
cf(
|
||||
[
|
||||
astuple(
|
||||
j,
|
||||
recurse=True,
|
||||
filter=filter,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
if has(j.__class__)
|
||||
else j
|
||||
for j in v
|
||||
]
|
||||
)
|
||||
)
|
||||
elif isinstance(v, dict):
|
||||
df = v.__class__ if retain is True else dict
|
||||
rv.append(
|
||||
df(
|
||||
(
|
||||
astuple(
|
||||
kk,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
if has(kk.__class__)
|
||||
else kk,
|
||||
astuple(
|
||||
vv,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
if has(vv.__class__)
|
||||
else vv,
|
||||
)
|
||||
for kk, vv in iteritems(v)
|
||||
)
|
||||
)
|
||||
else:
|
||||
rv.append(v)
|
||||
else:
|
||||
rv.append(v)
|
||||
return rv if tuple_factory is list else tuple_factory(rv)
|
||||
|
||||
|
||||
def has(cls):
|
||||
"""
|
||||
Check whether *cls* is a class with ``attrs`` attributes.
|
||||
|
||||
:param type cls: Class to introspect.
|
||||
:raise TypeError: If *cls* is not a class.
|
||||
|
||||
:rtype: :class:`bool`
|
||||
"""
|
||||
return getattr(cls, "__attrs_attrs__", None) is not None
|
||||
|
||||
|
||||
def assoc(inst, **changes):
|
||||
"""
|
||||
Copy *inst* and apply *changes*.
|
||||
|
||||
:param inst: Instance of a class with ``attrs`` attributes.
|
||||
:param changes: Keyword changes in the new copy.
|
||||
|
||||
:return: A copy of inst with *changes* incorporated.
|
||||
|
||||
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
|
||||
be found on *cls*.
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. deprecated:: 17.1.0
|
||||
Use :func:`evolve` instead.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"assoc is deprecated and will be removed after 2018/01.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
new = copy.copy(inst)
|
||||
attrs = fields(inst.__class__)
|
||||
for k, v in iteritems(changes):
|
||||
a = getattr(attrs, k, NOTHING)
|
||||
if a is NOTHING:
|
||||
raise AttrsAttributeNotFoundError(
|
||||
"{k} is not an attrs attribute on {cl}.".format(
|
||||
k=k, cl=new.__class__
|
||||
)
|
||||
)
|
||||
_obj_setattr(new, k, v)
|
||||
return new
|
||||
|
||||
|
||||
def evolve(inst, **changes):
|
||||
"""
|
||||
Create a new instance, based on *inst* with *changes* applied.
|
||||
|
||||
:param inst: Instance of a class with ``attrs`` attributes.
|
||||
:param changes: Keyword changes in the new copy.
|
||||
|
||||
:return: A copy of inst with *changes* incorporated.
|
||||
|
||||
:raise TypeError: If *attr_name* couldn't be found in the class
|
||||
``__init__``.
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
cls = inst.__class__
|
||||
attrs = fields(cls)
|
||||
for a in attrs:
|
||||
if not a.init:
|
||||
continue
|
||||
attr_name = a.name # To deal with private attributes.
|
||||
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
|
||||
if init_name not in changes:
|
||||
changes[init_name] = getattr(inst, attr_name)
|
||||
return cls(**changes)
|
2034
python/attr/_make.py
Normal file
2034
python/attr/_make.py
Normal file
File diff suppressed because it is too large
Load Diff
78
python/attr/converters.py
Normal file
78
python/attr/converters.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Commonly useful converters.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from ._make import NOTHING, Factory
|
||||
|
||||
|
||||
def optional(converter):
|
||||
"""
|
||||
A converter that allows an attribute to be optional. An optional attribute
|
||||
is one which can be set to ``None``.
|
||||
|
||||
:param callable converter: the converter that is used for non-``None``
|
||||
values.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
|
||||
def optional_converter(val):
|
||||
if val is None:
|
||||
return None
|
||||
return converter(val)
|
||||
|
||||
return optional_converter
|
||||
|
||||
|
||||
def default_if_none(default=NOTHING, factory=None):
|
||||
"""
|
||||
A converter that allows to replace ``None`` values by *default* or the
|
||||
result of *factory*.
|
||||
|
||||
:param default: Value to be used if ``None`` is passed. Passing an instance
|
||||
of :class:`attr.Factory` is supported, however the ``takes_self`` option
|
||||
is *not*.
|
||||
:param callable factory: A callable that takes not parameters whose result
|
||||
is used if ``None`` is passed.
|
||||
|
||||
:raises TypeError: If **neither** *default* or *factory* is passed.
|
||||
:raises TypeError: If **both** *default* and *factory* are passed.
|
||||
:raises ValueError: If an instance of :class:`attr.Factory` is passed with
|
||||
``takes_self=True``.
|
||||
|
||||
.. versionadded:: 18.2.0
|
||||
"""
|
||||
if default is NOTHING and factory is None:
|
||||
raise TypeError("Must pass either `default` or `factory`.")
|
||||
|
||||
if default is not NOTHING and factory is not None:
|
||||
raise TypeError(
|
||||
"Must pass either `default` or `factory` but not both."
|
||||
)
|
||||
|
||||
if factory is not None:
|
||||
default = Factory(factory)
|
||||
|
||||
if isinstance(default, Factory):
|
||||
if default.takes_self:
|
||||
raise ValueError(
|
||||
"`takes_self` is not supported by default_if_none."
|
||||
)
|
||||
|
||||
def default_if_none_converter(val):
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
return default.factory()
|
||||
|
||||
else:
|
||||
|
||||
def default_if_none_converter(val):
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
return default
|
||||
|
||||
return default_if_none_converter
|
12
python/attr/converters.pyi
Normal file
12
python/attr/converters.pyi
Normal file
@ -0,0 +1,12 @@
|
||||
from typing import TypeVar, Optional, Callable, overload
|
||||
from . import _ConverterType
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def optional(
|
||||
converter: _ConverterType[_T]
|
||||
) -> _ConverterType[Optional[_T]]: ...
|
||||
@overload
|
||||
def default_if_none(default: _T) -> _ConverterType[_T]: ...
|
||||
@overload
|
||||
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType[_T]: ...
|
57
python/attr/exceptions.py
Normal file
57
python/attr/exceptions.py
Normal file
@ -0,0 +1,57 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
class FrozenInstanceError(AttributeError):
|
||||
"""
|
||||
A frozen/immutable instance has been attempted to be modified.
|
||||
|
||||
It mirrors the behavior of ``namedtuples`` by using the same error message
|
||||
and subclassing :exc:`AttributeError`.
|
||||
|
||||
.. versionadded:: 16.1.0
|
||||
"""
|
||||
|
||||
msg = "can't set attribute"
|
||||
args = [msg]
|
||||
|
||||
|
||||
class AttrsAttributeNotFoundError(ValueError):
|
||||
"""
|
||||
An ``attrs`` function couldn't find an attribute that the user asked for.
|
||||
|
||||
.. versionadded:: 16.2.0
|
||||
"""
|
||||
|
||||
|
||||
class NotAnAttrsClassError(ValueError):
|
||||
"""
|
||||
A non-``attrs`` class has been passed into an ``attrs`` function.
|
||||
|
||||
.. versionadded:: 16.2.0
|
||||
"""
|
||||
|
||||
|
||||
class DefaultAlreadySetError(RuntimeError):
|
||||
"""
|
||||
A default has been set using ``attr.ib()`` and is attempted to be reset
|
||||
using the decorator.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
|
||||
|
||||
class UnannotatedAttributeError(RuntimeError):
|
||||
"""
|
||||
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
|
||||
annotation.
|
||||
|
||||
.. versionadded:: 17.3.0
|
||||
"""
|
||||
|
||||
|
||||
class PythonTooOldError(RuntimeError):
|
||||
"""
|
||||
An ``attrs`` feature requiring a more recent python version has been used.
|
||||
|
||||
.. versionadded:: 18.2.0
|
||||
"""
|
7
python/attr/exceptions.pyi
Normal file
7
python/attr/exceptions.pyi
Normal file
@ -0,0 +1,7 @@
|
||||
class FrozenInstanceError(AttributeError):
|
||||
msg: str = ...
|
||||
|
||||
class AttrsAttributeNotFoundError(ValueError): ...
|
||||
class NotAnAttrsClassError(ValueError): ...
|
||||
class DefaultAlreadySetError(RuntimeError): ...
|
||||
class UnannotatedAttributeError(RuntimeError): ...
|
52
python/attr/filters.py
Normal file
52
python/attr/filters.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Commonly useful filters for :func:`attr.asdict`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from ._compat import isclass
|
||||
from ._make import Attribute
|
||||
|
||||
|
||||
def _split_what(what):
|
||||
"""
|
||||
Returns a tuple of `frozenset`s of classes and attributes.
|
||||
"""
|
||||
return (
|
||||
frozenset(cls for cls in what if isclass(cls)),
|
||||
frozenset(cls for cls in what if isinstance(cls, Attribute)),
|
||||
)
|
||||
|
||||
|
||||
def include(*what):
|
||||
"""
|
||||
Whitelist *what*.
|
||||
|
||||
:param what: What to whitelist.
|
||||
:type what: :class:`list` of :class:`type` or :class:`attr.Attribute`\\ s
|
||||
|
||||
:rtype: :class:`callable`
|
||||
"""
|
||||
cls, attrs = _split_what(what)
|
||||
|
||||
def include_(attribute, value):
|
||||
return value.__class__ in cls or attribute in attrs
|
||||
|
||||
return include_
|
||||
|
||||
|
||||
def exclude(*what):
|
||||
"""
|
||||
Blacklist *what*.
|
||||
|
||||
:param what: What to blacklist.
|
||||
:type what: :class:`list` of classes or :class:`attr.Attribute`\\ s.
|
||||
|
||||
:rtype: :class:`callable`
|
||||
"""
|
||||
cls, attrs = _split_what(what)
|
||||
|
||||
def exclude_(attribute, value):
|
||||
return value.__class__ not in cls and attribute not in attrs
|
||||
|
||||
return exclude_
|
5
python/attr/filters.pyi
Normal file
5
python/attr/filters.pyi
Normal file
@ -0,0 +1,5 @@
|
||||
from typing import Union
|
||||
from . import Attribute, _FilterType
|
||||
|
||||
def include(*what: Union[type, Attribute]) -> _FilterType: ...
|
||||
def exclude(*what: Union[type, Attribute]) -> _FilterType: ...
|
0
python/attr/py.typed
Normal file
0
python/attr/py.typed
Normal file
170
python/attr/validators.py
Normal file
170
python/attr/validators.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""
|
||||
Commonly useful validators.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from ._make import _AndValidator, and_, attrib, attrs
|
||||
|
||||
|
||||
__all__ = ["and_", "in_", "instance_of", "optional", "provides"]
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _InstanceOfValidator(object):
|
||||
type = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if not isinstance(value, self.type):
|
||||
raise TypeError(
|
||||
"'{name}' must be {type!r} (got {value!r} that is a "
|
||||
"{actual!r}).".format(
|
||||
name=attr.name,
|
||||
type=self.type,
|
||||
actual=value.__class__,
|
||||
value=value,
|
||||
),
|
||||
attr,
|
||||
self.type,
|
||||
value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<instance_of validator for type {type!r}>".format(
|
||||
type=self.type
|
||||
)
|
||||
|
||||
|
||||
def instance_of(type):
|
||||
"""
|
||||
A validator that raises a :exc:`TypeError` if the initializer is called
|
||||
with a wrong type for this particular attribute (checks are performed using
|
||||
:func:`isinstance` therefore it's also valid to pass a tuple of types).
|
||||
|
||||
:param type: The type to check for.
|
||||
:type type: type or tuple of types
|
||||
|
||||
:raises TypeError: With a human readable error message, the attribute
|
||||
(of type :class:`attr.Attribute`), the expected type, and the value it
|
||||
got.
|
||||
"""
|
||||
return _InstanceOfValidator(type)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _ProvidesValidator(object):
|
||||
interface = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if not self.interface.providedBy(value):
|
||||
raise TypeError(
|
||||
"'{name}' must provide {interface!r} which {value!r} "
|
||||
"doesn't.".format(
|
||||
name=attr.name, interface=self.interface, value=value
|
||||
),
|
||||
attr,
|
||||
self.interface,
|
||||
value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<provides validator for interface {interface!r}>".format(
|
||||
interface=self.interface
|
||||
)
|
||||
|
||||
|
||||
def provides(interface):
|
||||
"""
|
||||
A validator that raises a :exc:`TypeError` if the initializer is called
|
||||
with an object that does not provide the requested *interface* (checks are
|
||||
performed using ``interface.providedBy(value)`` (see `zope.interface
|
||||
<https://zopeinterface.readthedocs.io/en/latest/>`_).
|
||||
|
||||
:param zope.interface.Interface interface: The interface to check for.
|
||||
|
||||
:raises TypeError: With a human readable error message, the attribute
|
||||
(of type :class:`attr.Attribute`), the expected interface, and the
|
||||
value it got.
|
||||
"""
|
||||
return _ProvidesValidator(interface)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _OptionalValidator(object):
|
||||
validator = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
self.validator(inst, attr, value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<optional validator for {what} or None>".format(
|
||||
what=repr(self.validator)
|
||||
)
|
||||
|
||||
|
||||
def optional(validator):
|
||||
"""
|
||||
A validator that makes an attribute optional. An optional attribute is one
|
||||
which can be set to ``None`` in addition to satisfying the requirements of
|
||||
the sub-validator.
|
||||
|
||||
:param validator: A validator (or a list of validators) that is used for
|
||||
non-``None`` values.
|
||||
:type validator: callable or :class:`list` of callables.
|
||||
|
||||
.. versionadded:: 15.1.0
|
||||
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
|
||||
"""
|
||||
if isinstance(validator, list):
|
||||
return _OptionalValidator(_AndValidator(validator))
|
||||
return _OptionalValidator(validator)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _InValidator(object):
|
||||
options = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
try:
|
||||
in_options = value in self.options
|
||||
except TypeError as e: # e.g. `1 in "abc"`
|
||||
in_options = False
|
||||
|
||||
if not in_options:
|
||||
raise ValueError(
|
||||
"'{name}' must be in {options!r} (got {value!r})".format(
|
||||
name=attr.name, options=self.options, value=value
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<in_ validator with options {options!r}>".format(
|
||||
options=self.options
|
||||
)
|
||||
|
||||
|
||||
def in_(options):
|
||||
"""
|
||||
A validator that raises a :exc:`ValueError` if the initializer is called
|
||||
with a value that does not belong in the options provided. The check is
|
||||
performed using ``value in options``.
|
||||
|
||||
:param options: Allowed options.
|
||||
:type options: list, tuple, :class:`enum.Enum`, ...
|
||||
|
||||
:raises ValueError: With a human readable error message, the attribute (of
|
||||
type :class:`attr.Attribute`), the expected options, and the value it
|
||||
got.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
return _InValidator(options)
|
14
python/attr/validators.pyi
Normal file
14
python/attr/validators.pyi
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Container, List, Union, TypeVar, Type, Any, Optional, Tuple
|
||||
from . import _ValidatorType
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def instance_of(
|
||||
type: Union[Tuple[Type[_T], ...], Type[_T]]
|
||||
) -> _ValidatorType[_T]: ...
|
||||
def provides(interface: Any) -> _ValidatorType[Any]: ...
|
||||
def optional(
|
||||
validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]]
|
||||
) -> _ValidatorType[Optional[_T]]: ...
|
||||
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
|
||||
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
|
2
python/dateutil/__init__.py
Normal file
2
python/dateutil/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from ._version import VERSION as __version__
|
34
python/dateutil/_common.py
Normal file
34
python/dateutil/_common.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""
|
||||
Common code used in multiple modules.
|
||||
"""
|
||||
|
||||
|
||||
class weekday(object):
|
||||
__slots__ = ["weekday", "n"]
|
||||
|
||||
def __init__(self, weekday, n=None):
|
||||
self.weekday = weekday
|
||||
self.n = n
|
||||
|
||||
def __call__(self, n):
|
||||
if n == self.n:
|
||||
return self
|
||||
else:
|
||||
return self.__class__(self.weekday, n)
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
if self.weekday != other.weekday or self.n != other.n:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __repr__(self):
|
||||
s = ("MO", "TU", "WE", "TH", "FR", "SA", "SU")[self.weekday]
|
||||
if not self.n:
|
||||
return s
|
||||
else:
|
||||
return "%s(%+d)" % (s, self.n)
|
10
python/dateutil/_version.py
Normal file
10
python/dateutil/_version.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""
|
||||
Contains information about the dateutil version.
|
||||
"""
|
||||
|
||||
VERSION_MAJOR = 2
|
||||
VERSION_MINOR = 6
|
||||
VERSION_PATCH = 1
|
||||
|
||||
VERSION_TUPLE = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH)
|
||||
VERSION = '.'.join(map(str, VERSION_TUPLE))
|
89
python/dateutil/easter.py
Normal file
89
python/dateutil/easter.py
Normal file
@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This module offers a generic easter computing method for any given year, using
|
||||
Western, Orthodox or Julian algorithms.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"]
|
||||
|
||||
EASTER_JULIAN = 1
|
||||
EASTER_ORTHODOX = 2
|
||||
EASTER_WESTERN = 3
|
||||
|
||||
|
||||
def easter(year, method=EASTER_WESTERN):
|
||||
"""
|
||||
This method was ported from the work done by GM Arts,
|
||||
on top of the algorithm by Claus Tondering, which was
|
||||
based in part on the algorithm of Ouding (1940), as
|
||||
quoted in "Explanatory Supplement to the Astronomical
|
||||
Almanac", P. Kenneth Seidelmann, editor.
|
||||
|
||||
This algorithm implements three different easter
|
||||
calculation methods:
|
||||
|
||||
1 - Original calculation in Julian calendar, valid in
|
||||
dates after 326 AD
|
||||
2 - Original method, with date converted to Gregorian
|
||||
calendar, valid in years 1583 to 4099
|
||||
3 - Revised method, in Gregorian calendar, valid in
|
||||
years 1583 to 4099 as well
|
||||
|
||||
These methods are represented by the constants:
|
||||
|
||||
* ``EASTER_JULIAN = 1``
|
||||
* ``EASTER_ORTHODOX = 2``
|
||||
* ``EASTER_WESTERN = 3``
|
||||
|
||||
The default method is method 3.
|
||||
|
||||
More about the algorithm may be found at:
|
||||
|
||||
http://users.chariot.net.au/~gmarts/eastalg.htm
|
||||
|
||||
and
|
||||
|
||||
http://www.tondering.dk/claus/calendar.html
|
||||
|
||||
"""
|
||||
|
||||
if not (1 <= method <= 3):
|
||||
raise ValueError("invalid method")
|
||||
|
||||
# g - Golden year - 1
|
||||
# c - Century
|
||||
# h - (23 - Epact) mod 30
|
||||
# i - Number of days from March 21 to Paschal Full Moon
|
||||
# j - Weekday for PFM (0=Sunday, etc)
|
||||
# p - Number of days from March 21 to Sunday on or before PFM
|
||||
# (-6 to 28 methods 1 & 3, to 56 for method 2)
|
||||
# e - Extra days to add for method 2 (converting Julian
|
||||
# date to Gregorian date)
|
||||
|
||||
y = year
|
||||
g = y % 19
|
||||
e = 0
|
||||
if method < 3:
|
||||
# Old method
|
||||
i = (19*g + 15) % 30
|
||||
j = (y + y//4 + i) % 7
|
||||
if method == 2:
|
||||
# Extra dates to convert Julian to Gregorian date
|
||||
e = 10
|
||||
if y > 1600:
|
||||
e = e + y//100 - 16 - (y//100 - 16)//4
|
||||
else:
|
||||
# New method
|
||||
c = y//100
|
||||
h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30
|
||||
i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11))
|
||||
j = (y + y//4 + i + 2 - c + c//4) % 7
|
||||
|
||||
# p can be from -6 to 56 corresponding to dates 22 March to 23 May
|
||||
# (later dates apply to method 2, although 23 May never actually occurs)
|
||||
p = i - j + e
|
||||
d = 1 + (p + 27 + (p + 6)//40) % 31
|
||||
m = 3 + (p + 26)//30
|
||||
return datetime.date(int(y), int(m), int(d))
|
1374
python/dateutil/parser.py
Normal file
1374
python/dateutil/parser.py
Normal file
File diff suppressed because it is too large
Load Diff
549
python/dateutil/relativedelta.py
Normal file
549
python/dateutil/relativedelta.py
Normal file
@ -0,0 +1,549 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import calendar
|
||||
|
||||
import operator
|
||||
from math import copysign
|
||||
|
||||
from six import integer_types
|
||||
from warnings import warn
|
||||
|
||||
from ._common import weekday
|
||||
|
||||
MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7))
|
||||
|
||||
__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"]
|
||||
|
||||
|
||||
class relativedelta(object):
|
||||
"""
|
||||
The relativedelta type is based on the specification of the excellent
|
||||
work done by M.-A. Lemburg in his
|
||||
`mx.DateTime <http://www.egenix.com/files/python/mxDateTime.html>`_ extension.
|
||||
However, notice that this type does *NOT* implement the same algorithm as
|
||||
his work. Do *NOT* expect it to behave like mx.DateTime's counterpart.
|
||||
|
||||
There are two different ways to build a relativedelta instance. The
|
||||
first one is passing it two date/datetime classes::
|
||||
|
||||
relativedelta(datetime1, datetime2)
|
||||
|
||||
The second one is passing it any number of the following keyword arguments::
|
||||
|
||||
relativedelta(arg1=x,arg2=y,arg3=z...)
|
||||
|
||||
year, month, day, hour, minute, second, microsecond:
|
||||
Absolute information (argument is singular); adding or subtracting a
|
||||
relativedelta with absolute information does not perform an aritmetic
|
||||
operation, but rather REPLACES the corresponding value in the
|
||||
original datetime with the value(s) in relativedelta.
|
||||
|
||||
years, months, weeks, days, hours, minutes, seconds, microseconds:
|
||||
Relative information, may be negative (argument is plural); adding
|
||||
or subtracting a relativedelta with relative information performs
|
||||
the corresponding aritmetic operation on the original datetime value
|
||||
with the information in the relativedelta.
|
||||
|
||||
weekday:
|
||||
One of the weekday instances (MO, TU, etc). These instances may
|
||||
receive a parameter N, specifying the Nth weekday, which could
|
||||
be positive or negative (like MO(+1) or MO(-2). Not specifying
|
||||
it is the same as specifying +1. You can also use an integer,
|
||||
where 0=MO.
|
||||
|
||||
leapdays:
|
||||
Will add given days to the date found, if year is a leap
|
||||
year, and the date found is post 28 of february.
|
||||
|
||||
yearday, nlyearday:
|
||||
Set the yearday or the non-leap year day (jump leap days).
|
||||
These are converted to day/month/leapdays information.
|
||||
|
||||
Here is the behavior of operations with relativedelta:
|
||||
|
||||
1. Calculate the absolute year, using the 'year' argument, or the
|
||||
original datetime year, if the argument is not present.
|
||||
|
||||
2. Add the relative 'years' argument to the absolute year.
|
||||
|
||||
3. Do steps 1 and 2 for month/months.
|
||||
|
||||
4. Calculate the absolute day, using the 'day' argument, or the
|
||||
original datetime day, if the argument is not present. Then,
|
||||
subtract from the day until it fits in the year and month
|
||||
found after their operations.
|
||||
|
||||
5. Add the relative 'days' argument to the absolute day. Notice
|
||||
that the 'weeks' argument is multiplied by 7 and added to
|
||||
'days'.
|
||||
|
||||
6. Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds,
|
||||
microsecond/microseconds.
|
||||
|
||||
7. If the 'weekday' argument is present, calculate the weekday,
|
||||
with the given (wday, nth) tuple. wday is the index of the
|
||||
weekday (0-6, 0=Mon), and nth is the number of weeks to add
|
||||
forward or backward, depending on its signal. Notice that if
|
||||
the calculated date is already Monday, for example, using
|
||||
(0, 1) or (0, -1) won't change the day.
|
||||
"""
|
||||
|
||||
def __init__(self, dt1=None, dt2=None,
|
||||
years=0, months=0, days=0, leapdays=0, weeks=0,
|
||||
hours=0, minutes=0, seconds=0, microseconds=0,
|
||||
year=None, month=None, day=None, weekday=None,
|
||||
yearday=None, nlyearday=None,
|
||||
hour=None, minute=None, second=None, microsecond=None):
|
||||
|
||||
# Check for non-integer values in integer-only quantities
|
||||
if any(x is not None and x != int(x) for x in (years, months)):
|
||||
raise ValueError("Non-integer years and months are "
|
||||
"ambiguous and not currently supported.")
|
||||
|
||||
if dt1 and dt2:
|
||||
# datetime is a subclass of date. So both must be date
|
||||
if not (isinstance(dt1, datetime.date) and
|
||||
isinstance(dt2, datetime.date)):
|
||||
raise TypeError("relativedelta only diffs datetime/date")
|
||||
|
||||
# We allow two dates, or two datetimes, so we coerce them to be
|
||||
# of the same type
|
||||
if (isinstance(dt1, datetime.datetime) !=
|
||||
isinstance(dt2, datetime.datetime)):
|
||||
if not isinstance(dt1, datetime.datetime):
|
||||
dt1 = datetime.datetime.fromordinal(dt1.toordinal())
|
||||
elif not isinstance(dt2, datetime.datetime):
|
||||
dt2 = datetime.datetime.fromordinal(dt2.toordinal())
|
||||
|
||||
self.years = 0
|
||||
self.months = 0
|
||||
self.days = 0
|
||||
self.leapdays = 0
|
||||
self.hours = 0
|
||||
self.minutes = 0
|
||||
self.seconds = 0
|
||||
self.microseconds = 0
|
||||
self.year = None
|
||||
self.month = None
|
||||
self.day = None
|
||||
self.weekday = None
|
||||
self.hour = None
|
||||
self.minute = None
|
||||
self.second = None
|
||||
self.microsecond = None
|
||||
self._has_time = 0
|
||||
|
||||
# Get year / month delta between the two
|
||||
months = (dt1.year - dt2.year) * 12 + (dt1.month - dt2.month)
|
||||
self._set_months(months)
|
||||
|
||||
# Remove the year/month delta so the timedelta is just well-defined
|
||||
# time units (seconds, days and microseconds)
|
||||
dtm = self.__radd__(dt2)
|
||||
|
||||
# If we've overshot our target, make an adjustment
|
||||
if dt1 < dt2:
|
||||
compare = operator.gt
|
||||
increment = 1
|
||||
else:
|
||||
compare = operator.lt
|
||||
increment = -1
|
||||
|
||||
while compare(dt1, dtm):
|
||||
months += increment
|
||||
self._set_months(months)
|
||||
dtm = self.__radd__(dt2)
|
||||
|
||||
# Get the timedelta between the "months-adjusted" date and dt1
|
||||
delta = dt1 - dtm
|
||||
self.seconds = delta.seconds + delta.days * 86400
|
||||
self.microseconds = delta.microseconds
|
||||
else:
|
||||
# Relative information
|
||||
self.years = years
|
||||
self.months = months
|
||||
self.days = days + weeks * 7
|
||||
self.leapdays = leapdays
|
||||
self.hours = hours
|
||||
self.minutes = minutes
|
||||
self.seconds = seconds
|
||||
self.microseconds = microseconds
|
||||
|
||||
# Absolute information
|
||||
self.year = year
|
||||
self.month = month
|
||||
self.day = day
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
self.second = second
|
||||
self.microsecond = microsecond
|
||||
|
||||
if any(x is not None and int(x) != x
|
||||
for x in (year, month, day, hour,
|
||||
minute, second, microsecond)):
|
||||
# For now we'll deprecate floats - later it'll be an error.
|
||||
warn("Non-integer value passed as absolute information. " +
|
||||
"This is not a well-defined condition and will raise " +
|
||||
"errors in future versions.", DeprecationWarning)
|
||||
|
||||
if isinstance(weekday, integer_types):
|
||||
self.weekday = weekdays[weekday]
|
||||
else:
|
||||
self.weekday = weekday
|
||||
|
||||
yday = 0
|
||||
if nlyearday:
|
||||
yday = nlyearday
|
||||
elif yearday:
|
||||
yday = yearday
|
||||
if yearday > 59:
|
||||
self.leapdays = -1
|
||||
if yday:
|
||||
ydayidx = [31, 59, 90, 120, 151, 181, 212,
|
||||
243, 273, 304, 334, 366]
|
||||
for idx, ydays in enumerate(ydayidx):
|
||||
if yday <= ydays:
|
||||
self.month = idx+1
|
||||
if idx == 0:
|
||||
self.day = yday
|
||||
else:
|
||||
self.day = yday-ydayidx[idx-1]
|
||||
break
|
||||
else:
|
||||
raise ValueError("invalid year day (%d)" % yday)
|
||||
|
||||
self._fix()
|
||||
|
||||
def _fix(self):
|
||||
if abs(self.microseconds) > 999999:
|
||||
s = _sign(self.microseconds)
|
||||
div, mod = divmod(self.microseconds * s, 1000000)
|
||||
self.microseconds = mod * s
|
||||
self.seconds += div * s
|
||||
if abs(self.seconds) > 59:
|
||||
s = _sign(self.seconds)
|
||||
div, mod = divmod(self.seconds * s, 60)
|
||||
self.seconds = mod * s
|
||||
self.minutes += div * s
|
||||
if abs(self.minutes) > 59:
|
||||
s = _sign(self.minutes)
|
||||
div, mod = divmod(self.minutes * s, 60)
|
||||
self.minutes = mod * s
|
||||
self.hours += div * s
|
||||
if abs(self.hours) > 23:
|
||||
s = _sign(self.hours)
|
||||
div, mod = divmod(self.hours * s, 24)
|
||||
self.hours = mod * s
|
||||
self.days += div * s
|
||||
if abs(self.months) > 11:
|
||||
s = _sign(self.months)
|
||||
div, mod = divmod(self.months * s, 12)
|
||||
self.months = mod * s
|
||||
self.years += div * s
|
||||
if (self.hours or self.minutes or self.seconds or self.microseconds
|
||||
or self.hour is not None or self.minute is not None or
|
||||
self.second is not None or self.microsecond is not None):
|
||||
self._has_time = 1
|
||||
else:
|
||||
self._has_time = 0
|
||||
|
||||
@property
|
||||
def weeks(self):
|
||||
return self.days // 7
|
||||
|
||||
@weeks.setter
|
||||
def weeks(self, value):
|
||||
self.days = self.days - (self.weeks * 7) + value * 7
|
||||
|
||||
def _set_months(self, months):
|
||||
self.months = months
|
||||
if abs(self.months) > 11:
|
||||
s = _sign(self.months)
|
||||
div, mod = divmod(self.months * s, 12)
|
||||
self.months = mod * s
|
||||
self.years = div * s
|
||||
else:
|
||||
self.years = 0
|
||||
|
||||
def normalized(self):
|
||||
"""
|
||||
Return a version of this object represented entirely using integer
|
||||
values for the relative attributes.
|
||||
|
||||
>>> relativedelta(days=1.5, hours=2).normalized()
|
||||
relativedelta(days=1, hours=14)
|
||||
|
||||
:return:
|
||||
Returns a :class:`dateutil.relativedelta.relativedelta` object.
|
||||
"""
|
||||
# Cascade remainders down (rounding each to roughly nearest microsecond)
|
||||
days = int(self.days)
|
||||
|
||||
hours_f = round(self.hours + 24 * (self.days - days), 11)
|
||||
hours = int(hours_f)
|
||||
|
||||
minutes_f = round(self.minutes + 60 * (hours_f - hours), 10)
|
||||
minutes = int(minutes_f)
|
||||
|
||||
seconds_f = round(self.seconds + 60 * (minutes_f - minutes), 8)
|
||||
seconds = int(seconds_f)
|
||||
|
||||
microseconds = round(self.microseconds + 1e6 * (seconds_f - seconds))
|
||||
|
||||
# Constructor carries overflow back up with call to _fix()
|
||||
return self.__class__(years=self.years, months=self.months,
|
||||
days=days, hours=hours, minutes=minutes,
|
||||
seconds=seconds, microseconds=microseconds,
|
||||
leapdays=self.leapdays, year=self.year,
|
||||
month=self.month, day=self.day,
|
||||
weekday=self.weekday, hour=self.hour,
|
||||
minute=self.minute, second=self.second,
|
||||
microsecond=self.microsecond)
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, relativedelta):
|
||||
return self.__class__(years=other.years + self.years,
|
||||
months=other.months + self.months,
|
||||
days=other.days + self.days,
|
||||
hours=other.hours + self.hours,
|
||||
minutes=other.minutes + self.minutes,
|
||||
seconds=other.seconds + self.seconds,
|
||||
microseconds=(other.microseconds +
|
||||
self.microseconds),
|
||||
leapdays=other.leapdays or self.leapdays,
|
||||
year=(other.year if other.year is not None
|
||||
else self.year),
|
||||
month=(other.month if other.month is not None
|
||||
else self.month),
|
||||
day=(other.day if other.day is not None
|
||||
else self.day),
|
||||
weekday=(other.weekday if other.weekday is not None
|
||||
else self.weekday),
|
||||
hour=(other.hour if other.hour is not None
|
||||
else self.hour),
|
||||
minute=(other.minute if other.minute is not None
|
||||
else self.minute),
|
||||
second=(other.second if other.second is not None
|
||||
else self.second),
|
||||
microsecond=(other.microsecond if other.microsecond
|
||||
is not None else
|
||||
self.microsecond))
|
||||
if isinstance(other, datetime.timedelta):
|
||||
return self.__class__(years=self.years,
|
||||
months=self.months,
|
||||
days=self.days + other.days,
|
||||
hours=self.hours,
|
||||
minutes=self.minutes,
|
||||
seconds=self.seconds + other.seconds,
|
||||
microseconds=self.microseconds + other.microseconds,
|
||||
leapdays=self.leapdays,
|
||||
year=self.year,
|
||||
month=self.month,
|
||||
day=self.day,
|
||||
weekday=self.weekday,
|
||||
hour=self.hour,
|
||||
minute=self.minute,
|
||||
second=self.second,
|
||||
microsecond=self.microsecond)
|
||||
if not isinstance(other, datetime.date):
|
||||
return NotImplemented
|
||||
elif self._has_time and not isinstance(other, datetime.datetime):
|
||||
other = datetime.datetime.fromordinal(other.toordinal())
|
||||
year = (self.year or other.year)+self.years
|
||||
month = self.month or other.month
|
||||
if self.months:
|
||||
assert 1 <= abs(self.months) <= 12
|
||||
month += self.months
|
||||
if month > 12:
|
||||
year += 1
|
||||
month -= 12
|
||||
elif month < 1:
|
||||
year -= 1
|
||||
month += 12
|
||||
day = min(calendar.monthrange(year, month)[1],
|
||||
self.day or other.day)
|
||||
repl = {"year": year, "month": month, "day": day}
|
||||
for attr in ["hour", "minute", "second", "microsecond"]:
|
||||
value = getattr(self, attr)
|
||||
if value is not None:
|
||||
repl[attr] = value
|
||||
days = self.days
|
||||
if self.leapdays and month > 2 and calendar.isleap(year):
|
||||
days += self.leapdays
|
||||
ret = (other.replace(**repl)
|
||||
+ datetime.timedelta(days=days,
|
||||
hours=self.hours,
|
||||
minutes=self.minutes,
|
||||
seconds=self.seconds,
|
||||
microseconds=self.microseconds))
|
||||
if self.weekday:
|
||||
weekday, nth = self.weekday.weekday, self.weekday.n or 1
|
||||
jumpdays = (abs(nth) - 1) * 7
|
||||
if nth > 0:
|
||||
jumpdays += (7 - ret.weekday() + weekday) % 7
|
||||
else:
|
||||
jumpdays += (ret.weekday() - weekday) % 7
|
||||
jumpdays *= -1
|
||||
ret += datetime.timedelta(days=jumpdays)
|
||||
return ret
|
||||
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return self.__neg__().__radd__(other)
|
||||
|
||||
def __sub__(self, other):
|
||||
if not isinstance(other, relativedelta):
|
||||
return NotImplemented # In case the other object defines __rsub__
|
||||
return self.__class__(years=self.years - other.years,
|
||||
months=self.months - other.months,
|
||||
days=self.days - other.days,
|
||||
hours=self.hours - other.hours,
|
||||
minutes=self.minutes - other.minutes,
|
||||
seconds=self.seconds - other.seconds,
|
||||
microseconds=self.microseconds - other.microseconds,
|
||||
leapdays=self.leapdays or other.leapdays,
|
||||
year=(self.year if self.year is not None
|
||||
else other.year),
|
||||
month=(self.month if self.month is not None else
|
||||
other.month),
|
||||
day=(self.day if self.day is not None else
|
||||
other.day),
|
||||
weekday=(self.weekday if self.weekday is not None else
|
||||
other.weekday),
|
||||
hour=(self.hour if self.hour is not None else
|
||||
other.hour),
|
||||
minute=(self.minute if self.minute is not None else
|
||||
other.minute),
|
||||
second=(self.second if self.second is not None else
|
||||
other.second),
|
||||
microsecond=(self.microsecond if self.microsecond
|
||||
is not None else
|
||||
other.microsecond))
|
||||
|
||||
def __neg__(self):
|
||||
return self.__class__(years=-self.years,
|
||||
months=-self.months,
|
||||
days=-self.days,
|
||||
hours=-self.hours,
|
||||
minutes=-self.minutes,
|
||||
seconds=-self.seconds,
|
||||
microseconds=-self.microseconds,
|
||||
leapdays=self.leapdays,
|
||||
year=self.year,
|
||||
month=self.month,
|
||||
day=self.day,
|
||||
weekday=self.weekday,
|
||||
hour=self.hour,
|
||||
minute=self.minute,
|
||||
second=self.second,
|
||||
microsecond=self.microsecond)
|
||||
|
||||
def __bool__(self):
|
||||
return not (not self.years and
|
||||
not self.months and
|
||||
not self.days and
|
||||
not self.hours and
|
||||
not self.minutes and
|
||||
not self.seconds and
|
||||
not self.microseconds and
|
||||
not self.leapdays and
|
||||
self.year is None and
|
||||
self.month is None and
|
||||
self.day is None and
|
||||
self.weekday is None and
|
||||
self.hour is None and
|
||||
self.minute is None and
|
||||
self.second is None and
|
||||
self.microsecond is None)
|
||||
# Compatibility with Python 2.x
|
||||
__nonzero__ = __bool__
|
||||
|
||||
def __mul__(self, other):
|
||||
try:
|
||||
f = float(other)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
|
||||
return self.__class__(years=int(self.years * f),
|
||||
months=int(self.months * f),
|
||||
days=int(self.days * f),
|
||||
hours=int(self.hours * f),
|
||||
minutes=int(self.minutes * f),
|
||||
seconds=int(self.seconds * f),
|
||||
microseconds=int(self.microseconds * f),
|
||||
leapdays=self.leapdays,
|
||||
year=self.year,
|
||||
month=self.month,
|
||||
day=self.day,
|
||||
weekday=self.weekday,
|
||||
hour=self.hour,
|
||||
minute=self.minute,
|
||||
second=self.second,
|
||||
microsecond=self.microsecond)
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, relativedelta):
|
||||
return NotImplemented
|
||||
if self.weekday or other.weekday:
|
||||
if not self.weekday or not other.weekday:
|
||||
return False
|
||||
if self.weekday.weekday != other.weekday.weekday:
|
||||
return False
|
||||
n1, n2 = self.weekday.n, other.weekday.n
|
||||
if n1 != n2 and not ((not n1 or n1 == 1) and (not n2 or n2 == 1)):
|
||||
return False
|
||||
return (self.years == other.years and
|
||||
self.months == other.months and
|
||||
self.days == other.days and
|
||||
self.hours == other.hours and
|
||||
self.minutes == other.minutes and
|
||||
self.seconds == other.seconds and
|
||||
self.microseconds == other.microseconds and
|
||||
self.leapdays == other.leapdays and
|
||||
self.year == other.year and
|
||||
self.month == other.month and
|
||||
self.day == other.day and
|
||||
self.hour == other.hour and
|
||||
self.minute == other.minute and
|
||||
self.second == other.second and
|
||||
self.microsecond == other.microsecond)
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __div__(self, other):
|
||||
try:
|
||||
reciprocal = 1 / float(other)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
|
||||
return self.__mul__(reciprocal)
|
||||
|
||||
__truediv__ = __div__
|
||||
|
||||
def __repr__(self):
|
||||
l = []
|
||||
for attr in ["years", "months", "days", "leapdays",
|
||||
"hours", "minutes", "seconds", "microseconds"]:
|
||||
value = getattr(self, attr)
|
||||
if value:
|
||||
l.append("{attr}={value:+g}".format(attr=attr, value=value))
|
||||
for attr in ["year", "month", "day", "weekday",
|
||||
"hour", "minute", "second", "microsecond"]:
|
||||
value = getattr(self, attr)
|
||||
if value is not None:
|
||||
l.append("{attr}={value}".format(attr=attr, value=repr(value)))
|
||||
return "{classname}({attrs})".format(classname=self.__class__.__name__,
|
||||
attrs=", ".join(l))
|
||||
|
||||
|
||||
def _sign(x):
|
||||
return int(copysign(1, x))
|
||||
|
||||
# vim:ts=4:sw=4:et
|
1610
python/dateutil/rrule.py
Normal file
1610
python/dateutil/rrule.py
Normal file
File diff suppressed because it is too large
Load Diff
5
python/dateutil/tz/__init__.py
Normal file
5
python/dateutil/tz/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .tz import *
|
||||
|
||||
__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange",
|
||||
"tzstr", "tzical", "tzwin", "tzwinlocal", "gettz",
|
||||
"enfold", "datetime_ambiguous", "datetime_exists"]
|
394
python/dateutil/tz/_common.py
Normal file
394
python/dateutil/tz/_common.py
Normal file
@ -0,0 +1,394 @@
|
||||
from six import PY3
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from datetime import datetime, timedelta, tzinfo
|
||||
|
||||
|
||||
ZERO = timedelta(0)
|
||||
|
||||
__all__ = ['tzname_in_python2', 'enfold']
|
||||
|
||||
|
||||
def tzname_in_python2(namefunc):
|
||||
"""Change unicode output into bytestrings in Python 2
|
||||
|
||||
tzname() API changed in Python 3. It used to return bytes, but was changed
|
||||
to unicode strings
|
||||
"""
|
||||
def adjust_encoding(*args, **kwargs):
|
||||
name = namefunc(*args, **kwargs)
|
||||
if name is not None and not PY3:
|
||||
name = name.encode()
|
||||
|
||||
return name
|
||||
|
||||
return adjust_encoding
|
||||
|
||||
|
||||
# The following is adapted from Alexander Belopolsky's tz library
|
||||
# https://github.com/abalkin/tz
|
||||
if hasattr(datetime, 'fold'):
|
||||
# This is the pre-python 3.6 fold situation
|
||||
def enfold(dt, fold=1):
|
||||
"""
|
||||
Provides a unified interface for assigning the ``fold`` attribute to
|
||||
datetimes both before and after the implementation of PEP-495.
|
||||
|
||||
:param fold:
|
||||
The value for the ``fold`` attribute in the returned datetime. This
|
||||
should be either 0 or 1.
|
||||
|
||||
:return:
|
||||
Returns an object for which ``getattr(dt, 'fold', 0)`` returns
|
||||
``fold`` for all versions of Python. In versions prior to
|
||||
Python 3.6, this is a ``_DatetimeWithFold`` object, which is a
|
||||
subclass of :py:class:`datetime.datetime` with the ``fold``
|
||||
attribute added, if ``fold`` is 1.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
"""
|
||||
return dt.replace(fold=fold)
|
||||
|
||||
else:
|
||||
class _DatetimeWithFold(datetime):
|
||||
"""
|
||||
This is a class designed to provide a PEP 495-compliant interface for
|
||||
Python versions before 3.6. It is used only for dates in a fold, so
|
||||
the ``fold`` attribute is fixed at ``1``.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def fold(self):
|
||||
return 1
|
||||
|
||||
def enfold(dt, fold=1):
|
||||
"""
|
||||
Provides a unified interface for assigning the ``fold`` attribute to
|
||||
datetimes both before and after the implementation of PEP-495.
|
||||
|
||||
:param fold:
|
||||
The value for the ``fold`` attribute in the returned datetime. This
|
||||
should be either 0 or 1.
|
||||
|
||||
:return:
|
||||
Returns an object for which ``getattr(dt, 'fold', 0)`` returns
|
||||
``fold`` for all versions of Python. In versions prior to
|
||||
Python 3.6, this is a ``_DatetimeWithFold`` object, which is a
|
||||
subclass of :py:class:`datetime.datetime` with the ``fold``
|
||||
attribute added, if ``fold`` is 1.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
"""
|
||||
if getattr(dt, 'fold', 0) == fold:
|
||||
return dt
|
||||
|
||||
args = dt.timetuple()[:6]
|
||||
args += (dt.microsecond, dt.tzinfo)
|
||||
|
||||
if fold:
|
||||
return _DatetimeWithFold(*args)
|
||||
else:
|
||||
return datetime(*args)
|
||||
|
||||
|
||||
def _validate_fromutc_inputs(f):
|
||||
"""
|
||||
The CPython version of ``fromutc`` checks that the input is a ``datetime``
|
||||
object and that ``self`` is attached as its ``tzinfo``.
|
||||
"""
|
||||
@wraps(f)
|
||||
def fromutc(self, dt):
|
||||
if not isinstance(dt, datetime):
|
||||
raise TypeError("fromutc() requires a datetime argument")
|
||||
if dt.tzinfo is not self:
|
||||
raise ValueError("dt.tzinfo is not self")
|
||||
|
||||
return f(self, dt)
|
||||
|
||||
return fromutc
|
||||
|
||||
|
||||
class _tzinfo(tzinfo):
|
||||
"""
|
||||
Base class for all ``dateutil`` ``tzinfo`` objects.
|
||||
"""
|
||||
|
||||
def is_ambiguous(self, dt):
|
||||
"""
|
||||
Whether or not the "wall time" of a given datetime is ambiguous in this
|
||||
zone.
|
||||
|
||||
:param dt:
|
||||
A :py:class:`datetime.datetime`, naive or time zone aware.
|
||||
|
||||
|
||||
:return:
|
||||
Returns ``True`` if ambiguous, ``False`` otherwise.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
"""
|
||||
|
||||
dt = dt.replace(tzinfo=self)
|
||||
|
||||
wall_0 = enfold(dt, fold=0)
|
||||
wall_1 = enfold(dt, fold=1)
|
||||
|
||||
same_offset = wall_0.utcoffset() == wall_1.utcoffset()
|
||||
same_dt = wall_0.replace(tzinfo=None) == wall_1.replace(tzinfo=None)
|
||||
|
||||
return same_dt and not same_offset
|
||||
|
||||
def _fold_status(self, dt_utc, dt_wall):
|
||||
"""
|
||||
Determine the fold status of a "wall" datetime, given a representation
|
||||
of the same datetime as a (naive) UTC datetime. This is calculated based
|
||||
on the assumption that ``dt.utcoffset() - dt.dst()`` is constant for all
|
||||
datetimes, and that this offset is the actual number of hours separating
|
||||
``dt_utc`` and ``dt_wall``.
|
||||
|
||||
:param dt_utc:
|
||||
Representation of the datetime as UTC
|
||||
|
||||
:param dt_wall:
|
||||
Representation of the datetime as "wall time". This parameter must
|
||||
either have a `fold` attribute or have a fold-naive
|
||||
:class:`datetime.tzinfo` attached, otherwise the calculation may
|
||||
fail.
|
||||
"""
|
||||
if self.is_ambiguous(dt_wall):
|
||||
delta_wall = dt_wall - dt_utc
|
||||
_fold = int(delta_wall == (dt_utc.utcoffset() - dt_utc.dst()))
|
||||
else:
|
||||
_fold = 0
|
||||
|
||||
return _fold
|
||||
|
||||
def _fold(self, dt):
|
||||
return getattr(dt, 'fold', 0)
|
||||
|
||||
def _fromutc(self, dt):
|
||||
"""
|
||||
Given a timezone-aware datetime in a given timezone, calculates a
|
||||
timezone-aware datetime in a new timezone.
|
||||
|
||||
Since this is the one time that we *know* we have an unambiguous
|
||||
datetime object, we take this opportunity to determine whether the
|
||||
datetime is ambiguous and in a "fold" state (e.g. if it's the first
|
||||
occurence, chronologically, of the ambiguous datetime).
|
||||
|
||||
:param dt:
|
||||
A timezone-aware :class:`datetime.datetime` object.
|
||||
"""
|
||||
|
||||
# Re-implement the algorithm from Python's datetime.py
|
||||
dtoff = dt.utcoffset()
|
||||
if dtoff is None:
|
||||
raise ValueError("fromutc() requires a non-None utcoffset() "
|
||||
"result")
|
||||
|
||||
# The original datetime.py code assumes that `dst()` defaults to
|
||||
# zero during ambiguous times. PEP 495 inverts this presumption, so
|
||||
# for pre-PEP 495 versions of python, we need to tweak the algorithm.
|
||||
dtdst = dt.dst()
|
||||
if dtdst is None:
|
||||
raise ValueError("fromutc() requires a non-None dst() result")
|
||||
delta = dtoff - dtdst
|
||||
|
||||
dt += delta
|
||||
# Set fold=1 so we can default to being in the fold for
|
||||
# ambiguous dates.
|
||||
dtdst = enfold(dt, fold=1).dst()
|
||||
if dtdst is None:
|
||||
raise ValueError("fromutc(): dt.dst gave inconsistent "
|
||||
"results; cannot convert")
|
||||
return dt + dtdst
|
||||
|
||||
@_validate_fromutc_inputs
|
||||
def fromutc(self, dt):
|
||||
"""
|
||||
Given a timezone-aware datetime in a given timezone, calculates a
|
||||
timezone-aware datetime in a new timezone.
|
||||
|
||||
Since this is the one time that we *know* we have an unambiguous
|
||||
datetime object, we take this opportunity to determine whether the
|
||||
datetime is ambiguous and in a "fold" state (e.g. if it's the first
|
||||
occurance, chronologically, of the ambiguous datetime).
|
||||
|
||||
:param dt:
|
||||
A timezone-aware :class:`datetime.datetime` object.
|
||||
"""
|
||||
dt_wall = self._fromutc(dt)
|
||||
|
||||
# Calculate the fold status given the two datetimes.
|
||||
_fold = self._fold_status(dt, dt_wall)
|
||||
|
||||
# Set the default fold value for ambiguous dates
|
||||
return enfold(dt_wall, fold=_fold)
|
||||
|
||||
|
||||
class tzrangebase(_tzinfo):
|
||||
"""
|
||||
This is an abstract base class for time zones represented by an annual
|
||||
transition into and out of DST. Child classes should implement the following
|
||||
methods:
|
||||
|
||||
* ``__init__(self, *args, **kwargs)``
|
||||
* ``transitions(self, year)`` - this is expected to return a tuple of
|
||||
datetimes representing the DST on and off transitions in standard
|
||||
time.
|
||||
|
||||
A fully initialized ``tzrangebase`` subclass should also provide the
|
||||
following attributes:
|
||||
* ``hasdst``: Boolean whether or not the zone uses DST.
|
||||
* ``_dst_offset`` / ``_std_offset``: :class:`datetime.timedelta` objects
|
||||
representing the respective UTC offsets.
|
||||
* ``_dst_abbr`` / ``_std_abbr``: Strings representing the timezone short
|
||||
abbreviations in DST and STD, respectively.
|
||||
* ``_hasdst``: Whether or not the zone has DST.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
"""
|
||||
def __init__(self):
|
||||
raise NotImplementedError('tzrangebase is an abstract base class')
|
||||
|
||||
def utcoffset(self, dt):
|
||||
isdst = self._isdst(dt)
|
||||
|
||||
if isdst is None:
|
||||
return None
|
||||
elif isdst:
|
||||
return self._dst_offset
|
||||
else:
|
||||
return self._std_offset
|
||||
|
||||
def dst(self, dt):
|
||||
isdst = self._isdst(dt)
|
||||
|
||||
if isdst is None:
|
||||
return None
|
||||
elif isdst:
|
||||
return self._dst_base_offset
|
||||
else:
|
||||
return ZERO
|
||||
|
||||
@tzname_in_python2
|
||||
def tzname(self, dt):
|
||||
if self._isdst(dt):
|
||||
return self._dst_abbr
|
||||
else:
|
||||
return self._std_abbr
|
||||
|
||||
def fromutc(self, dt):
|
||||
""" Given a datetime in UTC, return local time """
|
||||
if not isinstance(dt, datetime):
|
||||
raise TypeError("fromutc() requires a datetime argument")
|
||||
|
||||
if dt.tzinfo is not self:
|
||||
raise ValueError("dt.tzinfo is not self")
|
||||
|
||||
# Get transitions - if there are none, fixed offset
|
||||
transitions = self.transitions(dt.year)
|
||||
if transitions is None:
|
||||
return dt + self.utcoffset(dt)
|
||||
|
||||
# Get the transition times in UTC
|
||||
dston, dstoff = transitions
|
||||
|
||||
dston -= self._std_offset
|
||||
dstoff -= self._std_offset
|
||||
|
||||
utc_transitions = (dston, dstoff)
|
||||
dt_utc = dt.replace(tzinfo=None)
|
||||
|
||||
isdst = self._naive_isdst(dt_utc, utc_transitions)
|
||||
|
||||
if isdst:
|
||||
dt_wall = dt + self._dst_offset
|
||||
else:
|
||||
dt_wall = dt + self._std_offset
|
||||
|
||||
_fold = int(not isdst and self.is_ambiguous(dt_wall))
|
||||
|
||||
return enfold(dt_wall, fold=_fold)
|
||||
|
||||
def is_ambiguous(self, dt):
|
||||
"""
|
||||
Whether or not the "wall time" of a given datetime is ambiguous in this
|
||||
zone.
|
||||
|
||||
:param dt:
|
||||
A :py:class:`datetime.datetime`, naive or time zone aware.
|
||||
|
||||
|
||||
:return:
|
||||
Returns ``True`` if ambiguous, ``False`` otherwise.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
"""
|
||||
if not self.hasdst:
|
||||
return False
|
||||
|
||||
start, end = self.transitions(dt.year)
|
||||
|
||||
dt = dt.replace(tzinfo=None)
|
||||
return (end <= dt < end + self._dst_base_offset)
|
||||
|
||||
def _isdst(self, dt):
|
||||
if not self.hasdst:
|
||||
return False
|
||||
elif dt is None:
|
||||
return None
|
||||
|
||||
transitions = self.transitions(dt.year)
|
||||
|
||||
if transitions is None:
|
||||
return False
|
||||
|
||||
dt = dt.replace(tzinfo=None)
|
||||
|
||||
isdst = self._naive_isdst(dt, transitions)
|
||||
|
||||
# Handle ambiguous dates
|
||||
if not isdst and self.is_ambiguous(dt):
|
||||
return not self._fold(dt)
|
||||
else:
|
||||
return isdst
|
||||
|
||||
def _naive_isdst(self, dt, transitions):
|
||||
dston, dstoff = transitions
|
||||
|
||||
dt = dt.replace(tzinfo=None)
|
||||
|
||||
if dston < dstoff:
|
||||
isdst = dston <= dt < dstoff
|
||||
else:
|
||||
isdst = not dstoff <= dt < dston
|
||||
|
||||
return isdst
|
||||
|
||||
@property
|
||||
def _dst_base_offset(self):
|
||||
return self._dst_offset - self._std_offset
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(...)" % self.__class__.__name__
|
||||
|
||||
__reduce__ = object.__reduce__
|
||||
|
||||
|
||||
def _total_seconds(td):
|
||||
# Python 2.6 doesn't have a total_seconds() method on timedelta objects
|
||||
return ((td.seconds + td.days * 86400) * 1000000 +
|
||||
td.microseconds) // 1000000
|
||||
|
||||
|
||||
_total_seconds = getattr(timedelta, 'total_seconds', _total_seconds)
|
1511
python/dateutil/tz/tz.py
Normal file
1511
python/dateutil/tz/tz.py
Normal file
File diff suppressed because it is too large
Load Diff
332
python/dateutil/tz/win.py
Normal file
332
python/dateutil/tz/win.py
Normal file
@ -0,0 +1,332 @@
|
||||
# This code was originally contributed by Jeffrey Harris.
|
||||
import datetime
|
||||
import struct
|
||||
|
||||
from six.moves import winreg
|
||||
from six import text_type
|
||||
|
||||
try:
|
||||
import ctypes
|
||||
from ctypes import wintypes
|
||||
except ValueError:
|
||||
# ValueError is raised on non-Windows systems for some horrible reason.
|
||||
raise ImportError("Running tzwin on non-Windows system")
|
||||
|
||||
from ._common import tzrangebase
|
||||
|
||||
__all__ = ["tzwin", "tzwinlocal", "tzres"]
|
||||
|
||||
ONEWEEK = datetime.timedelta(7)
|
||||
|
||||
TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones"
|
||||
TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones"
|
||||
TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation"
|
||||
|
||||
|
||||
def _settzkeyname():
|
||||
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
|
||||
try:
|
||||
winreg.OpenKey(handle, TZKEYNAMENT).Close()
|
||||
TZKEYNAME = TZKEYNAMENT
|
||||
except WindowsError:
|
||||
TZKEYNAME = TZKEYNAME9X
|
||||
handle.Close()
|
||||
return TZKEYNAME
|
||||
|
||||
|
||||
TZKEYNAME = _settzkeyname()
|
||||
|
||||
|
||||
class tzres(object):
|
||||
"""
|
||||
Class for accessing `tzres.dll`, which contains timezone name related
|
||||
resources.
|
||||
|
||||
.. versionadded:: 2.5.0
|
||||
"""
|
||||
p_wchar = ctypes.POINTER(wintypes.WCHAR) # Pointer to a wide char
|
||||
|
||||
def __init__(self, tzres_loc='tzres.dll'):
|
||||
# Load the user32 DLL so we can load strings from tzres
|
||||
user32 = ctypes.WinDLL('user32')
|
||||
|
||||
# Specify the LoadStringW function
|
||||
user32.LoadStringW.argtypes = (wintypes.HINSTANCE,
|
||||
wintypes.UINT,
|
||||
wintypes.LPWSTR,
|
||||
ctypes.c_int)
|
||||
|
||||
self.LoadStringW = user32.LoadStringW
|
||||
self._tzres = ctypes.WinDLL(tzres_loc)
|
||||
self.tzres_loc = tzres_loc
|
||||
|
||||
def load_name(self, offset):
|
||||
"""
|
||||
Load a timezone name from a DLL offset (integer).
|
||||
|
||||
>>> from dateutil.tzwin import tzres
|
||||
>>> tzr = tzres()
|
||||
>>> print(tzr.load_name(112))
|
||||
'Eastern Standard Time'
|
||||
|
||||
:param offset:
|
||||
A positive integer value referring to a string from the tzres dll.
|
||||
|
||||
..note:
|
||||
Offsets found in the registry are generally of the form
|
||||
`@tzres.dll,-114`. The offset in this case if 114, not -114.
|
||||
|
||||
"""
|
||||
resource = self.p_wchar()
|
||||
lpBuffer = ctypes.cast(ctypes.byref(resource), wintypes.LPWSTR)
|
||||
nchar = self.LoadStringW(self._tzres._handle, offset, lpBuffer, 0)
|
||||
return resource[:nchar]
|
||||
|
||||
def name_from_string(self, tzname_str):
|
||||
"""
|
||||
Parse strings as returned from the Windows registry into the time zone
|
||||
name as defined in the registry.
|
||||
|
||||
>>> from dateutil.tzwin import tzres
|
||||
>>> tzr = tzres()
|
||||
>>> print(tzr.name_from_string('@tzres.dll,-251'))
|
||||
'Dateline Daylight Time'
|
||||
>>> print(tzr.name_from_string('Eastern Standard Time'))
|
||||
'Eastern Standard Time'
|
||||
|
||||
:param tzname_str:
|
||||
A timezone name string as returned from a Windows registry key.
|
||||
|
||||
:return:
|
||||
Returns the localized timezone string from tzres.dll if the string
|
||||
is of the form `@tzres.dll,-offset`, else returns the input string.
|
||||
"""
|
||||
if not tzname_str.startswith('@'):
|
||||
return tzname_str
|
||||
|
||||
name_splt = tzname_str.split(',-')
|
||||
try:
|
||||
offset = int(name_splt[1])
|
||||
except:
|
||||
raise ValueError("Malformed timezone string.")
|
||||
|
||||
return self.load_name(offset)
|
||||
|
||||
|
||||
class tzwinbase(tzrangebase):
|
||||
"""tzinfo class based on win32's timezones available in the registry."""
|
||||
def __init__(self):
|
||||
raise NotImplementedError('tzwinbase is an abstract base class')
|
||||
|
||||
def __eq__(self, other):
|
||||
# Compare on all relevant dimensions, including name.
|
||||
if not isinstance(other, tzwinbase):
|
||||
return NotImplemented
|
||||
|
||||
return (self._std_offset == other._std_offset and
|
||||
self._dst_offset == other._dst_offset and
|
||||
self._stddayofweek == other._stddayofweek and
|
||||
self._dstdayofweek == other._dstdayofweek and
|
||||
self._stdweeknumber == other._stdweeknumber and
|
||||
self._dstweeknumber == other._dstweeknumber and
|
||||
self._stdhour == other._stdhour and
|
||||
self._dsthour == other._dsthour and
|
||||
self._stdminute == other._stdminute and
|
||||
self._dstminute == other._dstminute and
|
||||
self._std_abbr == other._std_abbr and
|
||||
self._dst_abbr == other._dst_abbr)
|
||||
|
||||
@staticmethod
|
||||
def list():
|
||||
"""Return a list of all time zones known to the system."""
|
||||
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
|
||||
with winreg.OpenKey(handle, TZKEYNAME) as tzkey:
|
||||
result = [winreg.EnumKey(tzkey, i)
|
||||
for i in range(winreg.QueryInfoKey(tzkey)[0])]
|
||||
return result
|
||||
|
||||
def display(self):
|
||||
return self._display
|
||||
|
||||
def transitions(self, year):
|
||||
"""
|
||||
For a given year, get the DST on and off transition times, expressed
|
||||
always on the standard time side. For zones with no transitions, this
|
||||
function returns ``None``.
|
||||
|
||||
:param year:
|
||||
The year whose transitions you would like to query.
|
||||
|
||||
:return:
|
||||
Returns a :class:`tuple` of :class:`datetime.datetime` objects,
|
||||
``(dston, dstoff)`` for zones with an annual DST transition, or
|
||||
``None`` for fixed offset zones.
|
||||
"""
|
||||
|
||||
if not self.hasdst:
|
||||
return None
|
||||
|
||||
dston = picknthweekday(year, self._dstmonth, self._dstdayofweek,
|
||||
self._dsthour, self._dstminute,
|
||||
self._dstweeknumber)
|
||||
|
||||
dstoff = picknthweekday(year, self._stdmonth, self._stddayofweek,
|
||||
self._stdhour, self._stdminute,
|
||||
self._stdweeknumber)
|
||||
|
||||
# Ambiguous dates default to the STD side
|
||||
dstoff -= self._dst_base_offset
|
||||
|
||||
return dston, dstoff
|
||||
|
||||
def _get_hasdst(self):
|
||||
return self._dstmonth != 0
|
||||
|
||||
@property
|
||||
def _dst_base_offset(self):
|
||||
return self._dst_base_offset_
|
||||
|
||||
|
||||
class tzwin(tzwinbase):
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
# multiple contexts only possible in 2.7 and 3.1, we still support 2.6
|
||||
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
|
||||
tzkeyname = text_type("{kn}\\{name}").format(kn=TZKEYNAME, name=name)
|
||||
with winreg.OpenKey(handle, tzkeyname) as tzkey:
|
||||
keydict = valuestodict(tzkey)
|
||||
|
||||
self._std_abbr = keydict["Std"]
|
||||
self._dst_abbr = keydict["Dlt"]
|
||||
|
||||
self._display = keydict["Display"]
|
||||
|
||||
# See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm
|
||||
tup = struct.unpack("=3l16h", keydict["TZI"])
|
||||
stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1
|
||||
dstoffset = stdoffset-tup[2] # + DaylightBias * -1
|
||||
self._std_offset = datetime.timedelta(minutes=stdoffset)
|
||||
self._dst_offset = datetime.timedelta(minutes=dstoffset)
|
||||
|
||||
# for the meaning see the win32 TIME_ZONE_INFORMATION structure docs
|
||||
# http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx
|
||||
(self._stdmonth,
|
||||
self._stddayofweek, # Sunday = 0
|
||||
self._stdweeknumber, # Last = 5
|
||||
self._stdhour,
|
||||
self._stdminute) = tup[4:9]
|
||||
|
||||
(self._dstmonth,
|
||||
self._dstdayofweek, # Sunday = 0
|
||||
self._dstweeknumber, # Last = 5
|
||||
self._dsthour,
|
||||
self._dstminute) = tup[12:17]
|
||||
|
||||
self._dst_base_offset_ = self._dst_offset - self._std_offset
|
||||
self.hasdst = self._get_hasdst()
|
||||
|
||||
def __repr__(self):
|
||||
return "tzwin(%s)" % repr(self._name)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self._name,))
|
||||
|
||||
|
||||
class tzwinlocal(tzwinbase):
|
||||
def __init__(self):
|
||||
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
|
||||
with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey:
|
||||
keydict = valuestodict(tzlocalkey)
|
||||
|
||||
self._std_abbr = keydict["StandardName"]
|
||||
self._dst_abbr = keydict["DaylightName"]
|
||||
|
||||
try:
|
||||
tzkeyname = text_type('{kn}\\{sn}').format(kn=TZKEYNAME,
|
||||
sn=self._std_abbr)
|
||||
with winreg.OpenKey(handle, tzkeyname) as tzkey:
|
||||
_keydict = valuestodict(tzkey)
|
||||
self._display = _keydict["Display"]
|
||||
except OSError:
|
||||
self._display = None
|
||||
|
||||
stdoffset = -keydict["Bias"]-keydict["StandardBias"]
|
||||
dstoffset = stdoffset-keydict["DaylightBias"]
|
||||
|
||||
self._std_offset = datetime.timedelta(minutes=stdoffset)
|
||||
self._dst_offset = datetime.timedelta(minutes=dstoffset)
|
||||
|
||||
# For reasons unclear, in this particular key, the day of week has been
|
||||
# moved to the END of the SYSTEMTIME structure.
|
||||
tup = struct.unpack("=8h", keydict["StandardStart"])
|
||||
|
||||
(self._stdmonth,
|
||||
self._stdweeknumber, # Last = 5
|
||||
self._stdhour,
|
||||
self._stdminute) = tup[1:5]
|
||||
|
||||
self._stddayofweek = tup[7]
|
||||
|
||||
tup = struct.unpack("=8h", keydict["DaylightStart"])
|
||||
|
||||
(self._dstmonth,
|
||||
self._dstweeknumber, # Last = 5
|
||||
self._dsthour,
|
||||
self._dstminute) = tup[1:5]
|
||||
|
||||
self._dstdayofweek = tup[7]
|
||||
|
||||
self._dst_base_offset_ = self._dst_offset - self._std_offset
|
||||
self.hasdst = self._get_hasdst()
|
||||
|
||||
def __repr__(self):
|
||||
return "tzwinlocal()"
|
||||
|
||||
def __str__(self):
|
||||
# str will return the standard name, not the daylight name.
|
||||
return "tzwinlocal(%s)" % repr(self._std_abbr)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, ())
|
||||
|
||||
|
||||
def picknthweekday(year, month, dayofweek, hour, minute, whichweek):
|
||||
""" dayofweek == 0 means Sunday, whichweek 5 means last instance """
|
||||
first = datetime.datetime(year, month, 1, hour, minute)
|
||||
|
||||
# This will work if dayofweek is ISO weekday (1-7) or Microsoft-style (0-6),
|
||||
# Because 7 % 7 = 0
|
||||
weekdayone = first.replace(day=((dayofweek - first.isoweekday()) % 7) + 1)
|
||||
wd = weekdayone + ((whichweek - 1) * ONEWEEK)
|
||||
if (wd.month != month):
|
||||
wd -= ONEWEEK
|
||||
|
||||
return wd
|
||||
|
||||
|
||||
def valuestodict(key):
|
||||
"""Convert a registry key's values to a dictionary."""
|
||||
dout = {}
|
||||
size = winreg.QueryInfoKey(key)[1]
|
||||
tz_res = None
|
||||
|
||||
for i in range(size):
|
||||
key_name, value, dtype = winreg.EnumValue(key, i)
|
||||
if dtype == winreg.REG_DWORD or dtype == winreg.REG_DWORD_LITTLE_ENDIAN:
|
||||
# If it's a DWORD (32-bit integer), it's stored as unsigned - convert
|
||||
# that to a proper signed integer
|
||||
if value & (1 << 31):
|
||||
value = value - (1 << 32)
|
||||
elif dtype == winreg.REG_SZ:
|
||||
# If it's a reference to the tzres DLL, load the actual string
|
||||
if value.startswith('@tzres'):
|
||||
tz_res = tz_res or tzres()
|
||||
value = tz_res.name_from_string(value)
|
||||
|
||||
value = value.rstrip('\x00') # Remove trailing nulls
|
||||
|
||||
dout[key_name] = value
|
||||
|
||||
return dout
|
2
python/dateutil/tzwin.py
Normal file
2
python/dateutil/tzwin.py
Normal file
@ -0,0 +1,2 @@
|
||||
# tzwin has moved to dateutil.tz.win
|
||||
from .tz.win import *
|
183
python/dateutil/zoneinfo/__init__.py
Normal file
183
python/dateutil/zoneinfo/__init__.py
Normal file
@ -0,0 +1,183 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import warnings
|
||||
import json
|
||||
|
||||
from tarfile import TarFile
|
||||
from pkgutil import get_data
|
||||
from io import BytesIO
|
||||
from contextlib import closing
|
||||
|
||||
from dateutil.tz import tzfile
|
||||
|
||||
__all__ = ["get_zonefile_instance", "gettz", "gettz_db_metadata", "rebuild"]
|
||||
|
||||
ZONEFILENAME = "dateutil-zoneinfo.tar.gz"
|
||||
METADATA_FN = 'METADATA'
|
||||
|
||||
# python2.6 compatability. Note that TarFile.__exit__ != TarFile.close, but
|
||||
# it's close enough for python2.6
|
||||
tar_open = TarFile.open
|
||||
if not hasattr(TarFile, '__exit__'):
|
||||
def tar_open(*args, **kwargs):
|
||||
return closing(TarFile.open(*args, **kwargs))
|
||||
|
||||
|
||||
class tzfile(tzfile):
|
||||
def __reduce__(self):
|
||||
return (gettz, (self._filename,))
|
||||
|
||||
|
||||
def getzoneinfofile_stream():
|
||||
try:
|
||||
return BytesIO(get_data(__name__, ZONEFILENAME))
|
||||
except IOError as e: # TODO switch to FileNotFoundError?
|
||||
warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror))
|
||||
return None
|
||||
|
||||
|
||||
class ZoneInfoFile(object):
|
||||
def __init__(self, zonefile_stream=None):
|
||||
if zonefile_stream is not None:
|
||||
with tar_open(fileobj=zonefile_stream, mode='r') as tf:
|
||||
# dict comprehension does not work on python2.6
|
||||
# TODO: get back to the nicer syntax when we ditch python2.6
|
||||
# self.zones = {zf.name: tzfile(tf.extractfile(zf),
|
||||
# filename = zf.name)
|
||||
# for zf in tf.getmembers() if zf.isfile()}
|
||||
self.zones = dict((zf.name, tzfile(tf.extractfile(zf),
|
||||
filename=zf.name))
|
||||
for zf in tf.getmembers()
|
||||
if zf.isfile() and zf.name != METADATA_FN)
|
||||
# deal with links: They'll point to their parent object. Less
|
||||
# waste of memory
|
||||
# links = {zl.name: self.zones[zl.linkname]
|
||||
# for zl in tf.getmembers() if zl.islnk() or zl.issym()}
|
||||
links = dict((zl.name, self.zones[zl.linkname])
|
||||
for zl in tf.getmembers() if
|
||||
zl.islnk() or zl.issym())
|
||||
self.zones.update(links)
|
||||
try:
|
||||
metadata_json = tf.extractfile(tf.getmember(METADATA_FN))
|
||||
metadata_str = metadata_json.read().decode('UTF-8')
|
||||
self.metadata = json.loads(metadata_str)
|
||||
except KeyError:
|
||||
# no metadata in tar file
|
||||
self.metadata = None
|
||||
else:
|
||||
self.zones = dict()
|
||||
self.metadata = None
|
||||
|
||||
def get(self, name, default=None):
|
||||
"""
|
||||
Wrapper for :func:`ZoneInfoFile.zones.get`. This is a convenience method
|
||||
for retrieving zones from the zone dictionary.
|
||||
|
||||
:param name:
|
||||
The name of the zone to retrieve. (Generally IANA zone names)
|
||||
|
||||
:param default:
|
||||
The value to return in the event of a missing key.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
|
||||
"""
|
||||
return self.zones.get(name, default)
|
||||
|
||||
|
||||
# The current API has gettz as a module function, although in fact it taps into
|
||||
# a stateful class. So as a workaround for now, without changing the API, we
|
||||
# will create a new "global" class instance the first time a user requests a
|
||||
# timezone. Ugly, but adheres to the api.
|
||||
#
|
||||
# TODO: Remove after deprecation period.
|
||||
_CLASS_ZONE_INSTANCE = list()
|
||||
|
||||
|
||||
def get_zonefile_instance(new_instance=False):
|
||||
"""
|
||||
This is a convenience function which provides a :class:`ZoneInfoFile`
|
||||
instance using the data provided by the ``dateutil`` package. By default, it
|
||||
caches a single instance of the ZoneInfoFile object and returns that.
|
||||
|
||||
:param new_instance:
|
||||
If ``True``, a new instance of :class:`ZoneInfoFile` is instantiated and
|
||||
used as the cached instance for the next call. Otherwise, new instances
|
||||
are created only as necessary.
|
||||
|
||||
:return:
|
||||
Returns a :class:`ZoneInfoFile` object.
|
||||
|
||||
.. versionadded:: 2.6
|
||||
"""
|
||||
if new_instance:
|
||||
zif = None
|
||||
else:
|
||||
zif = getattr(get_zonefile_instance, '_cached_instance', None)
|
||||
|
||||
if zif is None:
|
||||
zif = ZoneInfoFile(getzoneinfofile_stream())
|
||||
|
||||
get_zonefile_instance._cached_instance = zif
|
||||
|
||||
return zif
|
||||
|
||||
|
||||
def gettz(name):
|
||||
"""
|
||||
This retrieves a time zone from the local zoneinfo tarball that is packaged
|
||||
with dateutil.
|
||||
|
||||
:param name:
|
||||
An IANA-style time zone name, as found in the zoneinfo file.
|
||||
|
||||
:return:
|
||||
Returns a :class:`dateutil.tz.tzfile` time zone object.
|
||||
|
||||
.. warning::
|
||||
It is generally inadvisable to use this function, and it is only
|
||||
provided for API compatibility with earlier versions. This is *not*
|
||||
equivalent to ``dateutil.tz.gettz()``, which selects an appropriate
|
||||
time zone based on the inputs, favoring system zoneinfo. This is ONLY
|
||||
for accessing the dateutil-specific zoneinfo (which may be out of
|
||||
date compared to the system zoneinfo).
|
||||
|
||||
.. deprecated:: 2.6
|
||||
If you need to use a specific zoneinfofile over the system zoneinfo,
|
||||
instantiate a :class:`dateutil.zoneinfo.ZoneInfoFile` object and call
|
||||
:func:`dateutil.zoneinfo.ZoneInfoFile.get(name)` instead.
|
||||
|
||||
Use :func:`get_zonefile_instance` to retrieve an instance of the
|
||||
dateutil-provided zoneinfo.
|
||||
"""
|
||||
warnings.warn("zoneinfo.gettz() will be removed in future versions, "
|
||||
"to use the dateutil-provided zoneinfo files, instantiate a "
|
||||
"ZoneInfoFile object and use ZoneInfoFile.zones.get() "
|
||||
"instead. See the documentation for details.",
|
||||
DeprecationWarning)
|
||||
|
||||
if len(_CLASS_ZONE_INSTANCE) == 0:
|
||||
_CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream()))
|
||||
return _CLASS_ZONE_INSTANCE[0].zones.get(name)
|
||||
|
||||
|
||||
def gettz_db_metadata():
|
||||
""" Get the zonefile metadata
|
||||
|
||||
See `zonefile_metadata`_
|
||||
|
||||
:returns:
|
||||
A dictionary with the database metadata
|
||||
|
||||
.. deprecated:: 2.6
|
||||
See deprecation warning in :func:`zoneinfo.gettz`. To get metadata,
|
||||
query the attribute ``zoneinfo.ZoneInfoFile.metadata``.
|
||||
"""
|
||||
warnings.warn("zoneinfo.gettz_db_metadata() will be removed in future "
|
||||
"versions, to use the dateutil-provided zoneinfo files, "
|
||||
"ZoneInfoFile object and query the 'metadata' attribute "
|
||||
"instead. See the documentation for details.",
|
||||
DeprecationWarning)
|
||||
|
||||
if len(_CLASS_ZONE_INSTANCE) == 0:
|
||||
_CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream()))
|
||||
return _CLASS_ZONE_INSTANCE[0].metadata
|
BIN
python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz
Normal file
BIN
python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz
Normal file
Binary file not shown.
52
python/dateutil/zoneinfo/rebuild.py
Normal file
52
python/dateutil/zoneinfo/rebuild.py
Normal file
@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import json
|
||||
from subprocess import check_call
|
||||
|
||||
from dateutil.zoneinfo import tar_open, METADATA_FN, ZONEFILENAME
|
||||
|
||||
|
||||
def rebuild(filename, tag=None, format="gz", zonegroups=[], metadata=None):
|
||||
"""Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar*
|
||||
|
||||
filename is the timezone tarball from ftp.iana.org/tz.
|
||||
|
||||
"""
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
zonedir = os.path.join(tmpdir, "zoneinfo")
|
||||
moduledir = os.path.dirname(__file__)
|
||||
try:
|
||||
with tar_open(filename) as tf:
|
||||
for name in zonegroups:
|
||||
tf.extract(name, tmpdir)
|
||||
filepaths = [os.path.join(tmpdir, n) for n in zonegroups]
|
||||
try:
|
||||
check_call(["zic", "-d", zonedir] + filepaths)
|
||||
except OSError as e:
|
||||
_print_on_nosuchfile(e)
|
||||
raise
|
||||
# write metadata file
|
||||
with open(os.path.join(zonedir, METADATA_FN), 'w') as f:
|
||||
json.dump(metadata, f, indent=4, sort_keys=True)
|
||||
target = os.path.join(moduledir, ZONEFILENAME)
|
||||
with tar_open(target, "w:%s" % format) as tf:
|
||||
for entry in os.listdir(zonedir):
|
||||
entrypath = os.path.join(zonedir, entry)
|
||||
tf.add(entrypath, entry)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def _print_on_nosuchfile(e):
|
||||
"""Print helpful troubleshooting message
|
||||
|
||||
e is an exception raised by subprocess.check_call()
|
||||
|
||||
"""
|
||||
if e.errno == 2:
|
||||
logging.error(
|
||||
"Could not find zic. Perhaps you need to install "
|
||||
"libc-bin or some other package that provides it, "
|
||||
"or it's not in your PATH?")
|
112
python/defusedxml/ElementTree.py
Normal file
112
python/defusedxml/ElementTree.py
Normal file
@ -0,0 +1,112 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.etree.ElementTree facade
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
import sys
|
||||
from xml.etree.ElementTree import TreeBuilder as _TreeBuilder
|
||||
from xml.etree.ElementTree import parse as _parse
|
||||
from xml.etree.ElementTree import tostring
|
||||
|
||||
from .common import PY3
|
||||
|
||||
|
||||
if PY3:
|
||||
import importlib
|
||||
else:
|
||||
from xml.etree.ElementTree import XMLParser as _XMLParser
|
||||
from xml.etree.ElementTree import iterparse as _iterparse
|
||||
from xml.etree.ElementTree import ParseError
|
||||
|
||||
|
||||
from .common import (DTDForbidden, EntitiesForbidden,
|
||||
ExternalReferenceForbidden, _generate_etree_functions)
|
||||
|
||||
__origin__ = "xml.etree.ElementTree"
|
||||
|
||||
|
||||
def _get_py3_cls():
|
||||
"""Python 3.3 hides the pure Python code but defusedxml requires it.
|
||||
|
||||
The code is based on test.support.import_fresh_module().
|
||||
"""
|
||||
pymodname = "xml.etree.ElementTree"
|
||||
cmodname = "_elementtree"
|
||||
|
||||
pymod = sys.modules.pop(pymodname, None)
|
||||
cmod = sys.modules.pop(cmodname, None)
|
||||
|
||||
sys.modules[cmodname] = None
|
||||
pure_pymod = importlib.import_module(pymodname)
|
||||
if cmod is not None:
|
||||
sys.modules[cmodname] = cmod
|
||||
else:
|
||||
sys.modules.pop(cmodname)
|
||||
sys.modules[pymodname] = pymod
|
||||
|
||||
_XMLParser = pure_pymod.XMLParser
|
||||
_iterparse = pure_pymod.iterparse
|
||||
ParseError = pure_pymod.ParseError
|
||||
|
||||
return _XMLParser, _iterparse, ParseError
|
||||
|
||||
|
||||
if PY3:
|
||||
_XMLParser, _iterparse, ParseError = _get_py3_cls()
|
||||
|
||||
|
||||
class DefusedXMLParser(_XMLParser):
|
||||
|
||||
def __init__(self, html=0, target=None, encoding=None,
|
||||
forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
# Python 2.x old style class
|
||||
_XMLParser.__init__(self, html, target, encoding)
|
||||
self.forbid_dtd = forbid_dtd
|
||||
self.forbid_entities = forbid_entities
|
||||
self.forbid_external = forbid_external
|
||||
if PY3:
|
||||
parser = self.parser
|
||||
else:
|
||||
parser = self._parser
|
||||
if self.forbid_dtd:
|
||||
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
|
||||
if self.forbid_entities:
|
||||
parser.EntityDeclHandler = self.defused_entity_decl
|
||||
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
|
||||
if self.forbid_external:
|
||||
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
|
||||
|
||||
def defused_start_doctype_decl(self, name, sysid, pubid,
|
||||
has_internal_subset):
|
||||
raise DTDForbidden(name, sysid, pubid)
|
||||
|
||||
def defused_entity_decl(self, name, is_parameter_entity, value, base,
|
||||
sysid, pubid, notation_name):
|
||||
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
|
||||
notation_name):
|
||||
# expat 1.2
|
||||
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_external_entity_ref_handler(self, context, base, sysid,
|
||||
pubid):
|
||||
raise ExternalReferenceForbidden(context, base, sysid, pubid)
|
||||
|
||||
|
||||
# aliases
|
||||
XMLTreeBuilder = XMLParse = DefusedXMLParser
|
||||
|
||||
parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser,
|
||||
_TreeBuilder, _parse,
|
||||
_iterparse)
|
||||
XML = fromstring
|
||||
|
||||
|
||||
__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse',
|
||||
'parse', 'tostring']
|
45
python/defusedxml/__init__.py
Normal file
45
python/defusedxml/__init__.py
Normal file
@ -0,0 +1,45 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defuse XML bomb denial of service vulnerabilities
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from .common import (DefusedXmlException, DTDForbidden, EntitiesForbidden,
|
||||
ExternalReferenceForbidden, NotSupportedError,
|
||||
_apply_defusing)
|
||||
|
||||
|
||||
def defuse_stdlib():
|
||||
"""Monkey patch and defuse all stdlib packages
|
||||
|
||||
:warning: The monkey patch is an EXPERIMETNAL feature.
|
||||
"""
|
||||
defused = {}
|
||||
|
||||
from . import cElementTree
|
||||
from . import ElementTree
|
||||
from . import minidom
|
||||
from . import pulldom
|
||||
from . import sax
|
||||
from . import expatbuilder
|
||||
from . import expatreader
|
||||
from . import xmlrpc
|
||||
|
||||
xmlrpc.monkey_patch()
|
||||
defused[xmlrpc] = None
|
||||
|
||||
for defused_mod in [cElementTree, ElementTree, minidom, pulldom, sax,
|
||||
expatbuilder, expatreader]:
|
||||
stdlib_mod = _apply_defusing(defused_mod)
|
||||
defused[defused_mod] = stdlib_mod
|
||||
|
||||
return defused
|
||||
|
||||
|
||||
__version__ = "0.5.0"
|
||||
|
||||
__all__ = ['DefusedXmlException', 'DTDForbidden', 'EntitiesForbidden',
|
||||
'ExternalReferenceForbidden', 'NotSupportedError']
|
30
python/defusedxml/cElementTree.py
Normal file
30
python/defusedxml/cElementTree.py
Normal file
@ -0,0 +1,30 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.etree.cElementTree
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from xml.etree.cElementTree import TreeBuilder as _TreeBuilder
|
||||
from xml.etree.cElementTree import parse as _parse
|
||||
from xml.etree.cElementTree import tostring
|
||||
# iterparse from ElementTree!
|
||||
from xml.etree.ElementTree import iterparse as _iterparse
|
||||
|
||||
from .ElementTree import DefusedXMLParser
|
||||
from .common import _generate_etree_functions
|
||||
|
||||
__origin__ = "xml.etree.cElementTree"
|
||||
|
||||
|
||||
XMLTreeBuilder = XMLParse = DefusedXMLParser
|
||||
|
||||
parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser,
|
||||
_TreeBuilder, _parse,
|
||||
_iterparse)
|
||||
XML = fromstring
|
||||
|
||||
__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse',
|
||||
'parse', 'tostring']
|
120
python/defusedxml/common.py
Normal file
120
python/defusedxml/common.py
Normal file
@ -0,0 +1,120 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Common constants, exceptions and helpe functions
|
||||
"""
|
||||
import sys
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
|
||||
class DefusedXmlException(ValueError):
|
||||
"""Base exception
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class DTDForbidden(DefusedXmlException):
|
||||
"""Document type definition is forbidden
|
||||
"""
|
||||
|
||||
def __init__(self, name, sysid, pubid):
|
||||
super(DTDForbidden, self).__init__()
|
||||
self.name = name
|
||||
self.sysid = sysid
|
||||
self.pubid = pubid
|
||||
|
||||
def __str__(self):
|
||||
tpl = "DTDForbidden(name='{}', system_id={!r}, public_id={!r})"
|
||||
return tpl.format(self.name, self.sysid, self.pubid)
|
||||
|
||||
|
||||
class EntitiesForbidden(DefusedXmlException):
|
||||
"""Entity definition is forbidden
|
||||
"""
|
||||
|
||||
def __init__(self, name, value, base, sysid, pubid, notation_name):
|
||||
super(EntitiesForbidden, self).__init__()
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.base = base
|
||||
self.sysid = sysid
|
||||
self.pubid = pubid
|
||||
self.notation_name = notation_name
|
||||
|
||||
def __str__(self):
|
||||
tpl = "EntitiesForbidden(name='{}', system_id={!r}, public_id={!r})"
|
||||
return tpl.format(self.name, self.sysid, self.pubid)
|
||||
|
||||
|
||||
class ExternalReferenceForbidden(DefusedXmlException):
|
||||
"""Resolving an external reference is forbidden
|
||||
"""
|
||||
|
||||
def __init__(self, context, base, sysid, pubid):
|
||||
super(ExternalReferenceForbidden, self).__init__()
|
||||
self.context = context
|
||||
self.base = base
|
||||
self.sysid = sysid
|
||||
self.pubid = pubid
|
||||
|
||||
def __str__(self):
|
||||
tpl = "ExternalReferenceForbidden(system_id='{}', public_id={})"
|
||||
return tpl.format(self.sysid, self.pubid)
|
||||
|
||||
|
||||
class NotSupportedError(DefusedXmlException):
|
||||
"""The operation is not supported
|
||||
"""
|
||||
|
||||
|
||||
def _apply_defusing(defused_mod):
|
||||
assert defused_mod is sys.modules[defused_mod.__name__]
|
||||
stdlib_name = defused_mod.__origin__
|
||||
__import__(stdlib_name, {}, {}, ["*"])
|
||||
stdlib_mod = sys.modules[stdlib_name]
|
||||
stdlib_names = set(dir(stdlib_mod))
|
||||
for name, obj in vars(defused_mod).items():
|
||||
if name.startswith("_") or name not in stdlib_names:
|
||||
continue
|
||||
setattr(stdlib_mod, name, obj)
|
||||
return stdlib_mod
|
||||
|
||||
|
||||
def _generate_etree_functions(DefusedXMLParser, _TreeBuilder,
|
||||
_parse, _iterparse):
|
||||
"""Factory for functions needed by etree, dependent on whether
|
||||
cElementTree or ElementTree is used."""
|
||||
|
||||
def parse(source, parser=None, forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
if parser is None:
|
||||
parser = DefusedXMLParser(target=_TreeBuilder(),
|
||||
forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
return _parse(source, parser)
|
||||
|
||||
def iterparse(source, events=None, parser=None, forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
if parser is None:
|
||||
parser = DefusedXMLParser(target=_TreeBuilder(),
|
||||
forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
return _iterparse(source, events, parser)
|
||||
|
||||
def fromstring(text, forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
parser = DefusedXMLParser(target=_TreeBuilder(),
|
||||
forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
parser.feed(text)
|
||||
return parser.close()
|
||||
|
||||
return parse, iterparse, fromstring
|
110
python/defusedxml/expatbuilder.py
Normal file
110
python/defusedxml/expatbuilder.py
Normal file
@ -0,0 +1,110 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.dom.expatbuilder
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from xml.dom.expatbuilder import ExpatBuilder as _ExpatBuilder
|
||||
from xml.dom.expatbuilder import Namespaces as _Namespaces
|
||||
|
||||
from .common import (DTDForbidden, EntitiesForbidden,
|
||||
ExternalReferenceForbidden)
|
||||
|
||||
__origin__ = "xml.dom.expatbuilder"
|
||||
|
||||
|
||||
class DefusedExpatBuilder(_ExpatBuilder):
|
||||
"""Defused document builder"""
|
||||
|
||||
def __init__(self, options=None, forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
_ExpatBuilder.__init__(self, options)
|
||||
self.forbid_dtd = forbid_dtd
|
||||
self.forbid_entities = forbid_entities
|
||||
self.forbid_external = forbid_external
|
||||
|
||||
def defused_start_doctype_decl(self, name, sysid, pubid,
|
||||
has_internal_subset):
|
||||
raise DTDForbidden(name, sysid, pubid)
|
||||
|
||||
def defused_entity_decl(self, name, is_parameter_entity, value, base,
|
||||
sysid, pubid, notation_name):
|
||||
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
|
||||
notation_name):
|
||||
# expat 1.2
|
||||
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_external_entity_ref_handler(self, context, base, sysid,
|
||||
pubid):
|
||||
raise ExternalReferenceForbidden(context, base, sysid, pubid)
|
||||
|
||||
def install(self, parser):
|
||||
_ExpatBuilder.install(self, parser)
|
||||
|
||||
if self.forbid_dtd:
|
||||
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
|
||||
if self.forbid_entities:
|
||||
# if self._options.entities:
|
||||
parser.EntityDeclHandler = self.defused_entity_decl
|
||||
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
|
||||
if self.forbid_external:
|
||||
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
|
||||
|
||||
|
||||
class DefusedExpatBuilderNS(_Namespaces, DefusedExpatBuilder):
|
||||
"""Defused document builder that supports namespaces."""
|
||||
|
||||
def install(self, parser):
|
||||
DefusedExpatBuilder.install(self, parser)
|
||||
if self._options.namespace_declarations:
|
||||
parser.StartNamespaceDeclHandler = (
|
||||
self.start_namespace_decl_handler)
|
||||
|
||||
def reset(self):
|
||||
DefusedExpatBuilder.reset(self)
|
||||
self._initNamespaces()
|
||||
|
||||
|
||||
def parse(file, namespaces=True, forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
"""Parse a document, returning the resulting Document node.
|
||||
|
||||
'file' may be either a file name or an open file object.
|
||||
"""
|
||||
if namespaces:
|
||||
build_builder = DefusedExpatBuilderNS
|
||||
else:
|
||||
build_builder = DefusedExpatBuilder
|
||||
builder = build_builder(forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
|
||||
if isinstance(file, str):
|
||||
fp = open(file, 'rb')
|
||||
try:
|
||||
result = builder.parseFile(fp)
|
||||
finally:
|
||||
fp.close()
|
||||
else:
|
||||
result = builder.parseFile(file)
|
||||
return result
|
||||
|
||||
|
||||
def parseString(string, namespaces=True, forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
"""Parse a document from a string, returning the resulting
|
||||
Document node.
|
||||
"""
|
||||
if namespaces:
|
||||
build_builder = DefusedExpatBuilderNS
|
||||
else:
|
||||
build_builder = DefusedExpatBuilder
|
||||
builder = build_builder(forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
return builder.parseString(string)
|
59
python/defusedxml/expatreader.py
Normal file
59
python/defusedxml/expatreader.py
Normal file
@ -0,0 +1,59 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.sax.expatreader
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from xml.sax.expatreader import ExpatParser as _ExpatParser
|
||||
|
||||
from .common import (DTDForbidden, EntitiesForbidden,
|
||||
ExternalReferenceForbidden)
|
||||
|
||||
__origin__ = "xml.sax.expatreader"
|
||||
|
||||
|
||||
class DefusedExpatParser(_ExpatParser):
|
||||
"""Defused SAX driver for the pyexpat C module."""
|
||||
|
||||
def __init__(self, namespaceHandling=0, bufsize=2 ** 16 - 20,
|
||||
forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
_ExpatParser.__init__(self, namespaceHandling, bufsize)
|
||||
self.forbid_dtd = forbid_dtd
|
||||
self.forbid_entities = forbid_entities
|
||||
self.forbid_external = forbid_external
|
||||
|
||||
def defused_start_doctype_decl(self, name, sysid, pubid,
|
||||
has_internal_subset):
|
||||
raise DTDForbidden(name, sysid, pubid)
|
||||
|
||||
def defused_entity_decl(self, name, is_parameter_entity, value, base,
|
||||
sysid, pubid, notation_name):
|
||||
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
|
||||
notation_name):
|
||||
# expat 1.2
|
||||
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_external_entity_ref_handler(self, context, base, sysid,
|
||||
pubid):
|
||||
raise ExternalReferenceForbidden(context, base, sysid, pubid)
|
||||
|
||||
def reset(self):
|
||||
_ExpatParser.reset(self)
|
||||
parser = self._parser
|
||||
if self.forbid_dtd:
|
||||
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
|
||||
if self.forbid_entities:
|
||||
parser.EntityDeclHandler = self.defused_entity_decl
|
||||
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
|
||||
if self.forbid_external:
|
||||
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
|
||||
|
||||
|
||||
def create_parser(*args, **kwargs):
|
||||
return DefusedExpatParser(*args, **kwargs)
|
153
python/defusedxml/lxml.py
Normal file
153
python/defusedxml/lxml.py
Normal file
@ -0,0 +1,153 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Example code for lxml.etree protection
|
||||
|
||||
The code has NO protection against decompression bombs.
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
import threading
|
||||
from lxml import etree as _etree
|
||||
|
||||
from .common import DTDForbidden, EntitiesForbidden, NotSupportedError
|
||||
|
||||
LXML3 = _etree.LXML_VERSION[0] >= 3
|
||||
|
||||
__origin__ = "lxml.etree"
|
||||
|
||||
tostring = _etree.tostring
|
||||
|
||||
|
||||
class RestrictedElement(_etree.ElementBase):
|
||||
"""A restricted Element class that filters out instances of some classes
|
||||
"""
|
||||
__slots__ = ()
|
||||
# blacklist = (etree._Entity, etree._ProcessingInstruction, etree._Comment)
|
||||
blacklist = _etree._Entity
|
||||
|
||||
def _filter(self, iterator):
|
||||
blacklist = self.blacklist
|
||||
for child in iterator:
|
||||
if isinstance(child, blacklist):
|
||||
continue
|
||||
yield child
|
||||
|
||||
def __iter__(self):
|
||||
iterator = super(RestrictedElement, self).__iter__()
|
||||
return self._filter(iterator)
|
||||
|
||||
def iterchildren(self, tag=None, reversed=False):
|
||||
iterator = super(RestrictedElement, self).iterchildren(
|
||||
tag=tag, reversed=reversed)
|
||||
return self._filter(iterator)
|
||||
|
||||
def iter(self, tag=None, *tags):
|
||||
iterator = super(RestrictedElement, self).iter(tag=tag, *tags)
|
||||
return self._filter(iterator)
|
||||
|
||||
def iterdescendants(self, tag=None, *tags):
|
||||
iterator = super(RestrictedElement,
|
||||
self).iterdescendants(tag=tag, *tags)
|
||||
return self._filter(iterator)
|
||||
|
||||
def itersiblings(self, tag=None, preceding=False):
|
||||
iterator = super(RestrictedElement, self).itersiblings(
|
||||
tag=tag, preceding=preceding)
|
||||
return self._filter(iterator)
|
||||
|
||||
def getchildren(self):
|
||||
iterator = super(RestrictedElement, self).__iter__()
|
||||
return list(self._filter(iterator))
|
||||
|
||||
def getiterator(self, tag=None):
|
||||
iterator = super(RestrictedElement, self).getiterator(tag)
|
||||
return self._filter(iterator)
|
||||
|
||||
|
||||
class GlobalParserTLS(threading.local):
|
||||
"""Thread local context for custom parser instances
|
||||
"""
|
||||
parser_config = {
|
||||
'resolve_entities': False,
|
||||
# 'remove_comments': True,
|
||||
# 'remove_pis': True,
|
||||
}
|
||||
|
||||
element_class = RestrictedElement
|
||||
|
||||
def createDefaultParser(self):
|
||||
parser = _etree.XMLParser(**self.parser_config)
|
||||
element_class = self.element_class
|
||||
if self.element_class is not None:
|
||||
lookup = _etree.ElementDefaultClassLookup(element=element_class)
|
||||
parser.set_element_class_lookup(lookup)
|
||||
return parser
|
||||
|
||||
def setDefaultParser(self, parser):
|
||||
self._default_parser = parser
|
||||
|
||||
def getDefaultParser(self):
|
||||
parser = getattr(self, "_default_parser", None)
|
||||
if parser is None:
|
||||
parser = self.createDefaultParser()
|
||||
self.setDefaultParser(parser)
|
||||
return parser
|
||||
|
||||
|
||||
_parser_tls = GlobalParserTLS()
|
||||
getDefaultParser = _parser_tls.getDefaultParser
|
||||
|
||||
|
||||
def check_docinfo(elementtree, forbid_dtd=False, forbid_entities=True):
|
||||
"""Check docinfo of an element tree for DTD and entity declarations
|
||||
|
||||
The check for entity declarations needs lxml 3 or newer. lxml 2.x does
|
||||
not support dtd.iterentities().
|
||||
"""
|
||||
docinfo = elementtree.docinfo
|
||||
if docinfo.doctype:
|
||||
if forbid_dtd:
|
||||
raise DTDForbidden(docinfo.doctype,
|
||||
docinfo.system_url,
|
||||
docinfo.public_id)
|
||||
if forbid_entities and not LXML3:
|
||||
# lxml < 3 has no iterentities()
|
||||
raise NotSupportedError("Unable to check for entity declarations "
|
||||
"in lxml 2.x")
|
||||
|
||||
if forbid_entities:
|
||||
for dtd in docinfo.internalDTD, docinfo.externalDTD:
|
||||
if dtd is None:
|
||||
continue
|
||||
for entity in dtd.iterentities():
|
||||
raise EntitiesForbidden(entity.name, entity.content, None,
|
||||
None, None, None)
|
||||
|
||||
|
||||
def parse(source, parser=None, base_url=None, forbid_dtd=False,
|
||||
forbid_entities=True):
|
||||
if parser is None:
|
||||
parser = getDefaultParser()
|
||||
elementtree = _etree.parse(source, parser, base_url=base_url)
|
||||
check_docinfo(elementtree, forbid_dtd, forbid_entities)
|
||||
return elementtree
|
||||
|
||||
|
||||
def fromstring(text, parser=None, base_url=None, forbid_dtd=False,
|
||||
forbid_entities=True):
|
||||
if parser is None:
|
||||
parser = getDefaultParser()
|
||||
rootelement = _etree.fromstring(text, parser, base_url=base_url)
|
||||
elementtree = rootelement.getroottree()
|
||||
check_docinfo(elementtree, forbid_dtd, forbid_entities)
|
||||
return rootelement
|
||||
|
||||
|
||||
XML = fromstring
|
||||
|
||||
|
||||
def iterparse(*args, **kwargs):
|
||||
raise NotSupportedError("defused lxml.etree.iterparse not available")
|
42
python/defusedxml/minidom.py
Normal file
42
python/defusedxml/minidom.py
Normal file
@ -0,0 +1,42 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.dom.minidom
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from xml.dom.minidom import _do_pulldom_parse
|
||||
from . import expatbuilder as _expatbuilder
|
||||
from . import pulldom as _pulldom
|
||||
|
||||
__origin__ = "xml.dom.minidom"
|
||||
|
||||
|
||||
def parse(file, parser=None, bufsize=None, forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
"""Parse a file into a DOM by filename or file object."""
|
||||
if parser is None and not bufsize:
|
||||
return _expatbuilder.parse(file, forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
else:
|
||||
return _do_pulldom_parse(_pulldom.parse, (file,),
|
||||
{'parser': parser, 'bufsize': bufsize,
|
||||
'forbid_dtd': forbid_dtd, 'forbid_entities': forbid_entities,
|
||||
'forbid_external': forbid_external})
|
||||
|
||||
|
||||
def parseString(string, parser=None, forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
"""Parse a file into a DOM from a string."""
|
||||
if parser is None:
|
||||
return _expatbuilder.parseString(string, forbid_dtd=forbid_dtd,
|
||||
forbid_entities=forbid_entities,
|
||||
forbid_external=forbid_external)
|
||||
else:
|
||||
return _do_pulldom_parse(_pulldom.parseString, (string,),
|
||||
{'parser': parser, 'forbid_dtd': forbid_dtd,
|
||||
'forbid_entities': forbid_entities,
|
||||
'forbid_external': forbid_external})
|
34
python/defusedxml/pulldom.py
Normal file
34
python/defusedxml/pulldom.py
Normal file
@ -0,0 +1,34 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.dom.pulldom
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from xml.dom.pulldom import parse as _parse
|
||||
from xml.dom.pulldom import parseString as _parseString
|
||||
from .sax import make_parser
|
||||
|
||||
__origin__ = "xml.dom.pulldom"
|
||||
|
||||
|
||||
def parse(stream_or_string, parser=None, bufsize=None, forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
if parser is None:
|
||||
parser = make_parser()
|
||||
parser.forbid_dtd = forbid_dtd
|
||||
parser.forbid_entities = forbid_entities
|
||||
parser.forbid_external = forbid_external
|
||||
return _parse(stream_or_string, parser, bufsize)
|
||||
|
||||
|
||||
def parseString(string, parser=None, forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
if parser is None:
|
||||
parser = make_parser()
|
||||
parser.forbid_dtd = forbid_dtd
|
||||
parser.forbid_entities = forbid_entities
|
||||
parser.forbid_external = forbid_external
|
||||
return _parseString(string, parser)
|
49
python/defusedxml/sax.py
Normal file
49
python/defusedxml/sax.py
Normal file
@ -0,0 +1,49 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xml.sax
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from xml.sax import InputSource as _InputSource
|
||||
from xml.sax import ErrorHandler as _ErrorHandler
|
||||
|
||||
from . import expatreader
|
||||
|
||||
__origin__ = "xml.sax"
|
||||
|
||||
|
||||
def parse(source, handler, errorHandler=_ErrorHandler(), forbid_dtd=False,
|
||||
forbid_entities=True, forbid_external=True):
|
||||
parser = make_parser()
|
||||
parser.setContentHandler(handler)
|
||||
parser.setErrorHandler(errorHandler)
|
||||
parser.forbid_dtd = forbid_dtd
|
||||
parser.forbid_entities = forbid_entities
|
||||
parser.forbid_external = forbid_external
|
||||
parser.parse(source)
|
||||
|
||||
|
||||
def parseString(string, handler, errorHandler=_ErrorHandler(),
|
||||
forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
from io import BytesIO
|
||||
|
||||
if errorHandler is None:
|
||||
errorHandler = _ErrorHandler()
|
||||
parser = make_parser()
|
||||
parser.setContentHandler(handler)
|
||||
parser.setErrorHandler(errorHandler)
|
||||
parser.forbid_dtd = forbid_dtd
|
||||
parser.forbid_entities = forbid_entities
|
||||
parser.forbid_external = forbid_external
|
||||
|
||||
inpsrc = _InputSource()
|
||||
inpsrc.setByteStream(BytesIO(string))
|
||||
parser.parse(inpsrc)
|
||||
|
||||
|
||||
def make_parser(parser_list=[]):
|
||||
return expatreader.create_parser()
|
157
python/defusedxml/xmlrpc.py
Normal file
157
python/defusedxml/xmlrpc.py
Normal file
@ -0,0 +1,157 @@
|
||||
# defusedxml
|
||||
#
|
||||
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
|
||||
# Licensed to PSF under a Contributor Agreement.
|
||||
# See http://www.python.org/psf/license for licensing details.
|
||||
"""Defused xmlrpclib
|
||||
|
||||
Also defuses gzip bomb
|
||||
"""
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
import io
|
||||
|
||||
from .common import (
|
||||
DTDForbidden, EntitiesForbidden, ExternalReferenceForbidden, PY3)
|
||||
|
||||
if PY3:
|
||||
__origin__ = "xmlrpc.client"
|
||||
from xmlrpc.client import ExpatParser
|
||||
from xmlrpc import client as xmlrpc_client
|
||||
from xmlrpc import server as xmlrpc_server
|
||||
from xmlrpc.client import gzip_decode as _orig_gzip_decode
|
||||
from xmlrpc.client import GzipDecodedResponse as _OrigGzipDecodedResponse
|
||||
else:
|
||||
__origin__ = "xmlrpclib"
|
||||
from xmlrpclib import ExpatParser
|
||||
import xmlrpclib as xmlrpc_client
|
||||
xmlrpc_server = None
|
||||
from xmlrpclib import gzip_decode as _orig_gzip_decode
|
||||
from xmlrpclib import GzipDecodedResponse as _OrigGzipDecodedResponse
|
||||
|
||||
try:
|
||||
import gzip
|
||||
except ImportError:
|
||||
gzip = None
|
||||
|
||||
|
||||
# Limit maximum request size to prevent resource exhaustion DoS
|
||||
# Also used to limit maximum amount of gzip decoded data in order to prevent
|
||||
# decompression bombs
|
||||
# A value of -1 or smaller disables the limit
|
||||
MAX_DATA = 30 * 1024 * 1024 # 30 MB
|
||||
|
||||
|
||||
def defused_gzip_decode(data, limit=None):
|
||||
"""gzip encoded data -> unencoded data
|
||||
|
||||
Decode data using the gzip content encoding as described in RFC 1952
|
||||
"""
|
||||
if not gzip:
|
||||
raise NotImplementedError
|
||||
if limit is None:
|
||||
limit = MAX_DATA
|
||||
f = io.BytesIO(data)
|
||||
gzf = gzip.GzipFile(mode="rb", fileobj=f)
|
||||
try:
|
||||
if limit < 0: # no limit
|
||||
decoded = gzf.read()
|
||||
else:
|
||||
decoded = gzf.read(limit + 1)
|
||||
except IOError:
|
||||
raise ValueError("invalid data")
|
||||
f.close()
|
||||
gzf.close()
|
||||
if limit >= 0 and len(decoded) > limit:
|
||||
raise ValueError("max gzipped payload length exceeded")
|
||||
return decoded
|
||||
|
||||
|
||||
class DefusedGzipDecodedResponse(gzip.GzipFile if gzip else object):
|
||||
"""a file-like object to decode a response encoded with the gzip
|
||||
method, as described in RFC 1952.
|
||||
"""
|
||||
|
||||
def __init__(self, response, limit=None):
|
||||
# response doesn't support tell() and read(), required by
|
||||
# GzipFile
|
||||
if not gzip:
|
||||
raise NotImplementedError
|
||||
self.limit = limit = limit if limit is not None else MAX_DATA
|
||||
if limit < 0: # no limit
|
||||
data = response.read()
|
||||
self.readlength = None
|
||||
else:
|
||||
data = response.read(limit + 1)
|
||||
self.readlength = 0
|
||||
if limit >= 0 and len(data) > limit:
|
||||
raise ValueError("max payload length exceeded")
|
||||
self.stringio = io.BytesIO(data)
|
||||
gzip.GzipFile.__init__(self, mode="rb", fileobj=self.stringio)
|
||||
|
||||
def read(self, n):
|
||||
if self.limit >= 0:
|
||||
left = self.limit - self.readlength
|
||||
n = min(n, left + 1)
|
||||
data = gzip.GzipFile.read(self, n)
|
||||
self.readlength += len(data)
|
||||
if self.readlength > self.limit:
|
||||
raise ValueError("max payload length exceeded")
|
||||
return data
|
||||
else:
|
||||
return gzip.GzipFile.read(self, n)
|
||||
|
||||
def close(self):
|
||||
gzip.GzipFile.close(self)
|
||||
self.stringio.close()
|
||||
|
||||
|
||||
class DefusedExpatParser(ExpatParser):
|
||||
|
||||
def __init__(self, target, forbid_dtd=False, forbid_entities=True,
|
||||
forbid_external=True):
|
||||
ExpatParser.__init__(self, target)
|
||||
self.forbid_dtd = forbid_dtd
|
||||
self.forbid_entities = forbid_entities
|
||||
self.forbid_external = forbid_external
|
||||
parser = self._parser
|
||||
if self.forbid_dtd:
|
||||
parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl
|
||||
if self.forbid_entities:
|
||||
parser.EntityDeclHandler = self.defused_entity_decl
|
||||
parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl
|
||||
if self.forbid_external:
|
||||
parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler
|
||||
|
||||
def defused_start_doctype_decl(self, name, sysid, pubid,
|
||||
has_internal_subset):
|
||||
raise DTDForbidden(name, sysid, pubid)
|
||||
|
||||
def defused_entity_decl(self, name, is_parameter_entity, value, base,
|
||||
sysid, pubid, notation_name):
|
||||
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_unparsed_entity_decl(self, name, base, sysid, pubid,
|
||||
notation_name):
|
||||
# expat 1.2
|
||||
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
|
||||
|
||||
def defused_external_entity_ref_handler(self, context, base, sysid,
|
||||
pubid):
|
||||
raise ExternalReferenceForbidden(context, base, sysid, pubid)
|
||||
|
||||
|
||||
def monkey_patch():
|
||||
xmlrpc_client.FastParser = DefusedExpatParser
|
||||
xmlrpc_client.GzipDecodedResponse = DefusedGzipDecodedResponse
|
||||
xmlrpc_client.gzip_decode = defused_gzip_decode
|
||||
if xmlrpc_server:
|
||||
xmlrpc_server.gzip_decode = defused_gzip_decode
|
||||
|
||||
|
||||
def unmonkey_patch():
|
||||
xmlrpc_client.FastParser = None
|
||||
xmlrpc_client.GzipDecodedResponse = _OrigGzipDecodedResponse
|
||||
xmlrpc_client.gzip_decode = _orig_gzip_decode
|
||||
if xmlrpc_server:
|
||||
xmlrpc_server.gzip_decode = _orig_gzip_decode
|
891
python/six.py
Normal file
891
python/six.py
Normal file
@ -0,0 +1,891 @@
|
||||
# Copyright (c) 2010-2017 Benjamin Peterson
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""Utilities for writing code that runs on Python 2 and 3"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
import types
|
||||
|
||||
__author__ = "Benjamin Peterson <benjamin@python.org>"
|
||||
__version__ = "1.11.0"
|
||||
|
||||
|
||||
# Useful for very coarse version differentiation.
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PY3 = sys.version_info[0] == 3
|
||||
PY34 = sys.version_info[0:2] >= (3, 4)
|
||||
|
||||
if PY3:
|
||||
string_types = str,
|
||||
integer_types = int,
|
||||
class_types = type,
|
||||
text_type = str
|
||||
binary_type = bytes
|
||||
|
||||
MAXSIZE = sys.maxsize
|
||||
else:
|
||||
string_types = basestring,
|
||||
integer_types = (int, long)
|
||||
class_types = (type, types.ClassType)
|
||||
text_type = unicode
|
||||
binary_type = str
|
||||
|
||||
if sys.platform.startswith("java"):
|
||||
# Jython always uses 32 bits.
|
||||
MAXSIZE = int((1 << 31) - 1)
|
||||
else:
|
||||
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
|
||||
class X(object):
|
||||
|
||||
def __len__(self):
|
||||
return 1 << 31
|
||||
try:
|
||||
len(X())
|
||||
except OverflowError:
|
||||
# 32-bit
|
||||
MAXSIZE = int((1 << 31) - 1)
|
||||
else:
|
||||
# 64-bit
|
||||
MAXSIZE = int((1 << 63) - 1)
|
||||
del X
|
||||
|
||||
|
||||
def _add_doc(func, doc):
|
||||
"""Add documentation to a function."""
|
||||
func.__doc__ = doc
|
||||
|
||||
|
||||
def _import_module(name):
|
||||
"""Import module, returning the module after the last dot."""
|
||||
__import__(name)
|
||||
return sys.modules[name]
|
||||
|
||||
|
||||
class _LazyDescr(object):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __get__(self, obj, tp):
|
||||
result = self._resolve()
|
||||
setattr(obj, self.name, result) # Invokes __set__.
|
||||
try:
|
||||
# This is a bit ugly, but it avoids running this again by
|
||||
# removing this descriptor.
|
||||
delattr(obj.__class__, self.name)
|
||||
except AttributeError:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
class MovedModule(_LazyDescr):
|
||||
|
||||
def __init__(self, name, old, new=None):
|
||||
super(MovedModule, self).__init__(name)
|
||||
if PY3:
|
||||
if new is None:
|
||||
new = name
|
||||
self.mod = new
|
||||
else:
|
||||
self.mod = old
|
||||
|
||||
def _resolve(self):
|
||||
return _import_module(self.mod)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
_module = self._resolve()
|
||||
value = getattr(_module, attr)
|
||||
setattr(self, attr, value)
|
||||
return value
|
||||
|
||||
|
||||
class _LazyModule(types.ModuleType):
|
||||
|
||||
def __init__(self, name):
|
||||
super(_LazyModule, self).__init__(name)
|
||||
self.__doc__ = self.__class__.__doc__
|
||||
|
||||
def __dir__(self):
|
||||
attrs = ["__doc__", "__name__"]
|
||||
attrs += [attr.name for attr in self._moved_attributes]
|
||||
return attrs
|
||||
|
||||
# Subclasses should override this
|
||||
_moved_attributes = []
|
||||
|
||||
|
||||
class MovedAttribute(_LazyDescr):
|
||||
|
||||
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
|
||||
super(MovedAttribute, self).__init__(name)
|
||||
if PY3:
|
||||
if new_mod is None:
|
||||
new_mod = name
|
||||
self.mod = new_mod
|
||||
if new_attr is None:
|
||||
if old_attr is None:
|
||||
new_attr = name
|
||||
else:
|
||||
new_attr = old_attr
|
||||
self.attr = new_attr
|
||||
else:
|
||||
self.mod = old_mod
|
||||
if old_attr is None:
|
||||
old_attr = name
|
||||
self.attr = old_attr
|
||||
|
||||
def _resolve(self):
|
||||
module = _import_module(self.mod)
|
||||
return getattr(module, self.attr)
|
||||
|
||||
|
||||
class _SixMetaPathImporter(object):
|
||||
|
||||
"""
|
||||
A meta path importer to import six.moves and its submodules.
|
||||
|
||||
This class implements a PEP302 finder and loader. It should be compatible
|
||||
with Python 2.5 and all existing versions of Python3
|
||||
"""
|
||||
|
||||
def __init__(self, six_module_name):
|
||||
self.name = six_module_name
|
||||
self.known_modules = {}
|
||||
|
||||
def _add_module(self, mod, *fullnames):
|
||||
for fullname in fullnames:
|
||||
self.known_modules[self.name + "." + fullname] = mod
|
||||
|
||||
def _get_module(self, fullname):
|
||||
return self.known_modules[self.name + "." + fullname]
|
||||
|
||||
def find_module(self, fullname, path=None):
|
||||
if fullname in self.known_modules:
|
||||
return self
|
||||
return None
|
||||
|
||||
def __get_module(self, fullname):
|
||||
try:
|
||||
return self.known_modules[fullname]
|
||||
except KeyError:
|
||||
raise ImportError("This loader does not know module " + fullname)
|
||||
|
||||
def load_module(self, fullname):
|
||||
try:
|
||||
# in case of a reload
|
||||
return sys.modules[fullname]
|
||||
except KeyError:
|
||||
pass
|
||||
mod = self.__get_module(fullname)
|
||||
if isinstance(mod, MovedModule):
|
||||
mod = mod._resolve()
|
||||
else:
|
||||
mod.__loader__ = self
|
||||
sys.modules[fullname] = mod
|
||||
return mod
|
||||
|
||||
def is_package(self, fullname):
|
||||
"""
|
||||
Return true, if the named module is a package.
|
||||
|
||||
We need this method to get correct spec objects with
|
||||
Python 3.4 (see PEP451)
|
||||
"""
|
||||
return hasattr(self.__get_module(fullname), "__path__")
|
||||
|
||||
def get_code(self, fullname):
|
||||
"""Return None
|
||||
|
||||
Required, if is_package is implemented"""
|
||||
self.__get_module(fullname) # eventually raises ImportError
|
||||
return None
|
||||
get_source = get_code # same as get_code
|
||||
|
||||
_importer = _SixMetaPathImporter(__name__)
|
||||
|
||||
|
||||
class _MovedItems(_LazyModule):
|
||||
|
||||
"""Lazy loading of moved objects"""
|
||||
__path__ = [] # mark as package
|
||||
|
||||
|
||||
_moved_attributes = [
|
||||
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
|
||||
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
|
||||
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
|
||||
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
|
||||
MovedAttribute("intern", "__builtin__", "sys"),
|
||||
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
|
||||
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
|
||||
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
|
||||
MovedAttribute("getoutput", "commands", "subprocess"),
|
||||
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
|
||||
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
|
||||
MovedAttribute("reduce", "__builtin__", "functools"),
|
||||
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
|
||||
MovedAttribute("StringIO", "StringIO", "io"),
|
||||
MovedAttribute("UserDict", "UserDict", "collections"),
|
||||
MovedAttribute("UserList", "UserList", "collections"),
|
||||
MovedAttribute("UserString", "UserString", "collections"),
|
||||
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
|
||||
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
|
||||
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
|
||||
MovedModule("builtins", "__builtin__"),
|
||||
MovedModule("configparser", "ConfigParser"),
|
||||
MovedModule("copyreg", "copy_reg"),
|
||||
MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
|
||||
MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"),
|
||||
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
|
||||
MovedModule("http_cookies", "Cookie", "http.cookies"),
|
||||
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
|
||||
MovedModule("html_parser", "HTMLParser", "html.parser"),
|
||||
MovedModule("http_client", "httplib", "http.client"),
|
||||
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
|
||||
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
|
||||
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
|
||||
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
|
||||
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
|
||||
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
|
||||
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
|
||||
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
|
||||
MovedModule("cPickle", "cPickle", "pickle"),
|
||||
MovedModule("queue", "Queue"),
|
||||
MovedModule("reprlib", "repr"),
|
||||
MovedModule("socketserver", "SocketServer"),
|
||||
MovedModule("_thread", "thread", "_thread"),
|
||||
MovedModule("tkinter", "Tkinter"),
|
||||
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
|
||||
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
|
||||
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
|
||||
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
|
||||
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
|
||||
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
|
||||
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
|
||||
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
|
||||
MovedModule("tkinter_colorchooser", "tkColorChooser",
|
||||
"tkinter.colorchooser"),
|
||||
MovedModule("tkinter_commondialog", "tkCommonDialog",
|
||||
"tkinter.commondialog"),
|
||||
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
|
||||
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
|
||||
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
|
||||
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
|
||||
"tkinter.simpledialog"),
|
||||
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
|
||||
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
|
||||
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
|
||||
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
|
||||
MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
|
||||
MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
|
||||
]
|
||||
# Add windows specific modules.
|
||||
if sys.platform == "win32":
|
||||
_moved_attributes += [
|
||||
MovedModule("winreg", "_winreg"),
|
||||
]
|
||||
|
||||
for attr in _moved_attributes:
|
||||
setattr(_MovedItems, attr.name, attr)
|
||||
if isinstance(attr, MovedModule):
|
||||
_importer._add_module(attr, "moves." + attr.name)
|
||||
del attr
|
||||
|
||||
_MovedItems._moved_attributes = _moved_attributes
|
||||
|
||||
moves = _MovedItems(__name__ + ".moves")
|
||||
_importer._add_module(moves, "moves")
|
||||
|
||||
|
||||
class Module_six_moves_urllib_parse(_LazyModule):
|
||||
|
||||
"""Lazy loading of moved objects in six.moves.urllib_parse"""
|
||||
|
||||
|
||||
_urllib_parse_moved_attributes = [
|
||||
MovedAttribute("ParseResult", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("SplitResult", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("parse_qs", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("parse_qsl", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("urldefrag", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("urljoin", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("urlparse", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("urlsplit", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("urlunparse", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("urlunsplit", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("quote", "urllib", "urllib.parse"),
|
||||
MovedAttribute("quote_plus", "urllib", "urllib.parse"),
|
||||
MovedAttribute("unquote", "urllib", "urllib.parse"),
|
||||
MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
|
||||
MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"),
|
||||
MovedAttribute("urlencode", "urllib", "urllib.parse"),
|
||||
MovedAttribute("splitquery", "urllib", "urllib.parse"),
|
||||
MovedAttribute("splittag", "urllib", "urllib.parse"),
|
||||
MovedAttribute("splituser", "urllib", "urllib.parse"),
|
||||
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
|
||||
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("uses_params", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("uses_query", "urlparse", "urllib.parse"),
|
||||
MovedAttribute("uses_relative", "urlparse", "urllib.parse"),
|
||||
]
|
||||
for attr in _urllib_parse_moved_attributes:
|
||||
setattr(Module_six_moves_urllib_parse, attr.name, attr)
|
||||
del attr
|
||||
|
||||
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
|
||||
|
||||
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
|
||||
"moves.urllib_parse", "moves.urllib.parse")
|
||||
|
||||
|
||||
class Module_six_moves_urllib_error(_LazyModule):
|
||||
|
||||
"""Lazy loading of moved objects in six.moves.urllib_error"""
|
||||
|
||||
|
||||
_urllib_error_moved_attributes = [
|
||||
MovedAttribute("URLError", "urllib2", "urllib.error"),
|
||||
MovedAttribute("HTTPError", "urllib2", "urllib.error"),
|
||||
MovedAttribute("ContentTooShortError", "urllib", "urllib.error"),
|
||||
]
|
||||
for attr in _urllib_error_moved_attributes:
|
||||
setattr(Module_six_moves_urllib_error, attr.name, attr)
|
||||
del attr
|
||||
|
||||
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
|
||||
|
||||
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
|
||||
"moves.urllib_error", "moves.urllib.error")
|
||||
|
||||
|
||||
class Module_six_moves_urllib_request(_LazyModule):
|
||||
|
||||
"""Lazy loading of moved objects in six.moves.urllib_request"""
|
||||
|
||||
|
||||
_urllib_request_moved_attributes = [
|
||||
MovedAttribute("urlopen", "urllib2", "urllib.request"),
|
||||
MovedAttribute("install_opener", "urllib2", "urllib.request"),
|
||||
MovedAttribute("build_opener", "urllib2", "urllib.request"),
|
||||
MovedAttribute("pathname2url", "urllib", "urllib.request"),
|
||||
MovedAttribute("url2pathname", "urllib", "urllib.request"),
|
||||
MovedAttribute("getproxies", "urllib", "urllib.request"),
|
||||
MovedAttribute("Request", "urllib2", "urllib.request"),
|
||||
MovedAttribute("OpenerDirector", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"),
|
||||
MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
|
||||
MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("FileHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("FTPHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("UnknownHandler", "urllib2", "urllib.request"),
|
||||
MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"),
|
||||
MovedAttribute("urlretrieve", "urllib", "urllib.request"),
|
||||
MovedAttribute("urlcleanup", "urllib", "urllib.request"),
|
||||
MovedAttribute("URLopener", "urllib", "urllib.request"),
|
||||
MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
|
||||
MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
|
||||
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
|
||||
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
|
||||
]
|
||||
for attr in _urllib_request_moved_attributes:
|
||||
setattr(Module_six_moves_urllib_request, attr.name, attr)
|
||||
del attr
|
||||
|
||||
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
|
||||
|
||||
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
|
||||
"moves.urllib_request", "moves.urllib.request")
|
||||
|
||||
|
||||
class Module_six_moves_urllib_response(_LazyModule):
|
||||
|
||||
"""Lazy loading of moved objects in six.moves.urllib_response"""
|
||||
|
||||
|
||||
_urllib_response_moved_attributes = [
|
||||
MovedAttribute("addbase", "urllib", "urllib.response"),
|
||||
MovedAttribute("addclosehook", "urllib", "urllib.response"),
|
||||
MovedAttribute("addinfo", "urllib", "urllib.response"),
|
||||
MovedAttribute("addinfourl", "urllib", "urllib.response"),
|
||||
]
|
||||
for attr in _urllib_response_moved_attributes:
|
||||
setattr(Module_six_moves_urllib_response, attr.name, attr)
|
||||
del attr
|
||||
|
||||
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
|
||||
|
||||
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
|
||||
"moves.urllib_response", "moves.urllib.response")
|
||||
|
||||
|
||||
class Module_six_moves_urllib_robotparser(_LazyModule):
|
||||
|
||||
"""Lazy loading of moved objects in six.moves.urllib_robotparser"""
|
||||
|
||||
|
||||
_urllib_robotparser_moved_attributes = [
|
||||
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
|
||||
]
|
||||
for attr in _urllib_robotparser_moved_attributes:
|
||||
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
|
||||
del attr
|
||||
|
||||
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
|
||||
|
||||
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
|
||||
"moves.urllib_robotparser", "moves.urllib.robotparser")
|
||||
|
||||
|
||||
class Module_six_moves_urllib(types.ModuleType):
|
||||
|
||||
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
|
||||
__path__ = [] # mark as package
|
||||
parse = _importer._get_module("moves.urllib_parse")
|
||||
error = _importer._get_module("moves.urllib_error")
|
||||
request = _importer._get_module("moves.urllib_request")
|
||||
response = _importer._get_module("moves.urllib_response")
|
||||
robotparser = _importer._get_module("moves.urllib_robotparser")
|
||||
|
||||
def __dir__(self):
|
||||
return ['parse', 'error', 'request', 'response', 'robotparser']
|
||||
|
||||
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
|
||||
"moves.urllib")
|
||||
|
||||
|
||||
def add_move(move):
|
||||
"""Add an item to six.moves."""
|
||||
setattr(_MovedItems, move.name, move)
|
||||
|
||||
|
||||
def remove_move(name):
|
||||
"""Remove item from six.moves."""
|
||||
try:
|
||||
delattr(_MovedItems, name)
|
||||
except AttributeError:
|
||||
try:
|
||||
del moves.__dict__[name]
|
||||
except KeyError:
|
||||
raise AttributeError("no such move, %r" % (name,))
|
||||
|
||||
|
||||
if PY3:
|
||||
_meth_func = "__func__"
|
||||
_meth_self = "__self__"
|
||||
|
||||
_func_closure = "__closure__"
|
||||
_func_code = "__code__"
|
||||
_func_defaults = "__defaults__"
|
||||
_func_globals = "__globals__"
|
||||
else:
|
||||
_meth_func = "im_func"
|
||||
_meth_self = "im_self"
|
||||
|
||||
_func_closure = "func_closure"
|
||||
_func_code = "func_code"
|
||||
_func_defaults = "func_defaults"
|
||||
_func_globals = "func_globals"
|
||||
|
||||
|
||||
try:
|
||||
advance_iterator = next
|
||||
except NameError:
|
||||
def advance_iterator(it):
|
||||
return it.next()
|
||||
next = advance_iterator
|
||||
|
||||
|
||||
try:
|
||||
callable = callable
|
||||
except NameError:
|
||||
def callable(obj):
|
||||
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
|
||||
|
||||
|
||||
if PY3:
|
||||
def get_unbound_function(unbound):
|
||||
return unbound
|
||||
|
||||
create_bound_method = types.MethodType
|
||||
|
||||
def create_unbound_method(func, cls):
|
||||
return func
|
||||
|
||||
Iterator = object
|
||||
else:
|
||||
def get_unbound_function(unbound):
|
||||
return unbound.im_func
|
||||
|
||||
def create_bound_method(func, obj):
|
||||
return types.MethodType(func, obj, obj.__class__)
|
||||
|
||||
def create_unbound_method(func, cls):
|
||||
return types.MethodType(func, None, cls)
|
||||
|
||||
class Iterator(object):
|
||||
|
||||
def next(self):
|
||||
return type(self).__next__(self)
|
||||
|
||||
callable = callable
|
||||
_add_doc(get_unbound_function,
|
||||
"""Get the function out of a possibly unbound function""")
|
||||
|
||||
|
||||
get_method_function = operator.attrgetter(_meth_func)
|
||||
get_method_self = operator.attrgetter(_meth_self)
|
||||
get_function_closure = operator.attrgetter(_func_closure)
|
||||
get_function_code = operator.attrgetter(_func_code)
|
||||
get_function_defaults = operator.attrgetter(_func_defaults)
|
||||
get_function_globals = operator.attrgetter(_func_globals)
|
||||
|
||||
|
||||
if PY3:
|
||||
def iterkeys(d, **kw):
|
||||
return iter(d.keys(**kw))
|
||||
|
||||
def itervalues(d, **kw):
|
||||
return iter(d.values(**kw))
|
||||
|
||||
def iteritems(d, **kw):
|
||||
return iter(d.items(**kw))
|
||||
|
||||
def iterlists(d, **kw):
|
||||
return iter(d.lists(**kw))
|
||||
|
||||
viewkeys = operator.methodcaller("keys")
|
||||
|
||||
viewvalues = operator.methodcaller("values")
|
||||
|
||||
viewitems = operator.methodcaller("items")
|
||||
else:
|
||||
def iterkeys(d, **kw):
|
||||
return d.iterkeys(**kw)
|
||||
|
||||
def itervalues(d, **kw):
|
||||
return d.itervalues(**kw)
|
||||
|
||||
def iteritems(d, **kw):
|
||||
return d.iteritems(**kw)
|
||||
|
||||
def iterlists(d, **kw):
|
||||
return d.iterlists(**kw)
|
||||
|
||||
viewkeys = operator.methodcaller("viewkeys")
|
||||
|
||||
viewvalues = operator.methodcaller("viewvalues")
|
||||
|
||||
viewitems = operator.methodcaller("viewitems")
|
||||
|
||||
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
|
||||
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
|
||||
_add_doc(iteritems,
|
||||
"Return an iterator over the (key, value) pairs of a dictionary.")
|
||||
_add_doc(iterlists,
|
||||
"Return an iterator over the (key, [values]) pairs of a dictionary.")
|
||||
|
||||
|
||||
if PY3:
|
||||
def b(s):
|
||||
return s.encode("latin-1")
|
||||
|
||||
def u(s):
|
||||
return s
|
||||
unichr = chr
|
||||
import struct
|
||||
int2byte = struct.Struct(">B").pack
|
||||
del struct
|
||||
byte2int = operator.itemgetter(0)
|
||||
indexbytes = operator.getitem
|
||||
iterbytes = iter
|
||||
import io
|
||||
StringIO = io.StringIO
|
||||
BytesIO = io.BytesIO
|
||||
_assertCountEqual = "assertCountEqual"
|
||||
if sys.version_info[1] <= 1:
|
||||
_assertRaisesRegex = "assertRaisesRegexp"
|
||||
_assertRegex = "assertRegexpMatches"
|
||||
else:
|
||||
_assertRaisesRegex = "assertRaisesRegex"
|
||||
_assertRegex = "assertRegex"
|
||||
else:
|
||||
def b(s):
|
||||
return s
|
||||
# Workaround for standalone backslash
|
||||
|
||||
def u(s):
|
||||
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
|
||||
unichr = unichr
|
||||
int2byte = chr
|
||||
|
||||
def byte2int(bs):
|
||||
return ord(bs[0])
|
||||
|
||||
def indexbytes(buf, i):
|
||||
return ord(buf[i])
|
||||
iterbytes = functools.partial(itertools.imap, ord)
|
||||
import StringIO
|
||||
StringIO = BytesIO = StringIO.StringIO
|
||||
_assertCountEqual = "assertItemsEqual"
|
||||
_assertRaisesRegex = "assertRaisesRegexp"
|
||||
_assertRegex = "assertRegexpMatches"
|
||||
_add_doc(b, """Byte literal""")
|
||||
_add_doc(u, """Text literal""")
|
||||
|
||||
|
||||
def assertCountEqual(self, *args, **kwargs):
|
||||
return getattr(self, _assertCountEqual)(*args, **kwargs)
|
||||
|
||||
|
||||
def assertRaisesRegex(self, *args, **kwargs):
|
||||
return getattr(self, _assertRaisesRegex)(*args, **kwargs)
|
||||
|
||||
|
||||
def assertRegex(self, *args, **kwargs):
|
||||
return getattr(self, _assertRegex)(*args, **kwargs)
|
||||
|
||||
|
||||
if PY3:
|
||||
exec_ = getattr(moves.builtins, "exec")
|
||||
|
||||
def reraise(tp, value, tb=None):
|
||||
try:
|
||||
if value is None:
|
||||
value = tp()
|
||||
if value.__traceback__ is not tb:
|
||||
raise value.with_traceback(tb)
|
||||
raise value
|
||||
finally:
|
||||
value = None
|
||||
tb = None
|
||||
|
||||
else:
|
||||
def exec_(_code_, _globs_=None, _locs_=None):
|
||||
"""Execute code in a namespace."""
|
||||
if _globs_ is None:
|
||||
frame = sys._getframe(1)
|
||||
_globs_ = frame.f_globals
|
||||
if _locs_ is None:
|
||||
_locs_ = frame.f_locals
|
||||
del frame
|
||||
elif _locs_ is None:
|
||||
_locs_ = _globs_
|
||||
exec("""exec _code_ in _globs_, _locs_""")
|
||||
|
||||
exec_("""def reraise(tp, value, tb=None):
|
||||
try:
|
||||
raise tp, value, tb
|
||||
finally:
|
||||
tb = None
|
||||
""")
|
||||
|
||||
|
||||
if sys.version_info[:2] == (3, 2):
|
||||
exec_("""def raise_from(value, from_value):
|
||||
try:
|
||||
if from_value is None:
|
||||
raise value
|
||||
raise value from from_value
|
||||
finally:
|
||||
value = None
|
||||
""")
|
||||
elif sys.version_info[:2] > (3, 2):
|
||||
exec_("""def raise_from(value, from_value):
|
||||
try:
|
||||
raise value from from_value
|
||||
finally:
|
||||
value = None
|
||||
""")
|
||||
else:
|
||||
def raise_from(value, from_value):
|
||||
raise value
|
||||
|
||||
|
||||
print_ = getattr(moves.builtins, "print", None)
|
||||
if print_ is None:
|
||||
def print_(*args, **kwargs):
|
||||
"""The new-style print function for Python 2.4 and 2.5."""
|
||||
fp = kwargs.pop("file", sys.stdout)
|
||||
if fp is None:
|
||||
return
|
||||
|
||||
def write(data):
|
||||
if not isinstance(data, basestring):
|
||||
data = str(data)
|
||||
# If the file has an encoding, encode unicode with it.
|
||||
if (isinstance(fp, file) and
|
||||
isinstance(data, unicode) and
|
||||
fp.encoding is not None):
|
||||
errors = getattr(fp, "errors", None)
|
||||
if errors is None:
|
||||
errors = "strict"
|
||||
data = data.encode(fp.encoding, errors)
|
||||
fp.write(data)
|
||||
want_unicode = False
|
||||
sep = kwargs.pop("sep", None)
|
||||
if sep is not None:
|
||||
if isinstance(sep, unicode):
|
||||
want_unicode = True
|
||||
elif not isinstance(sep, str):
|
||||
raise TypeError("sep must be None or a string")
|
||||
end = kwargs.pop("end", None)
|
||||
if end is not None:
|
||||
if isinstance(end, unicode):
|
||||
want_unicode = True
|
||||
elif not isinstance(end, str):
|
||||
raise TypeError("end must be None or a string")
|
||||
if kwargs:
|
||||
raise TypeError("invalid keyword arguments to print()")
|
||||
if not want_unicode:
|
||||
for arg in args:
|
||||
if isinstance(arg, unicode):
|
||||
want_unicode = True
|
||||
break
|
||||
if want_unicode:
|
||||
newline = unicode("\n")
|
||||
space = unicode(" ")
|
||||
else:
|
||||
newline = "\n"
|
||||
space = " "
|
||||
if sep is None:
|
||||
sep = space
|
||||
if end is None:
|
||||
end = newline
|
||||
for i, arg in enumerate(args):
|
||||
if i:
|
||||
write(sep)
|
||||
write(arg)
|
||||
write(end)
|
||||
if sys.version_info[:2] < (3, 3):
|
||||
_print = print_
|
||||
|
||||
def print_(*args, **kwargs):
|
||||
fp = kwargs.get("file", sys.stdout)
|
||||
flush = kwargs.pop("flush", False)
|
||||
_print(*args, **kwargs)
|
||||
if flush and fp is not None:
|
||||
fp.flush()
|
||||
|
||||
_add_doc(reraise, """Reraise an exception.""")
|
||||
|
||||
if sys.version_info[0:2] < (3, 4):
|
||||
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
|
||||
updated=functools.WRAPPER_UPDATES):
|
||||
def wrapper(f):
|
||||
f = functools.wraps(wrapped, assigned, updated)(f)
|
||||
f.__wrapped__ = wrapped
|
||||
return f
|
||||
return wrapper
|
||||
else:
|
||||
wraps = functools.wraps
|
||||
|
||||
|
||||
def with_metaclass(meta, *bases):
|
||||
"""Create a base class with a metaclass."""
|
||||
# This requires a bit of explanation: the basic idea is to make a dummy
|
||||
# metaclass for one level of class instantiation that replaces itself with
|
||||
# the actual metaclass.
|
||||
class metaclass(type):
|
||||
|
||||
def __new__(cls, name, this_bases, d):
|
||||
return meta(name, bases, d)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, name, this_bases):
|
||||
return meta.__prepare__(name, bases)
|
||||
return type.__new__(metaclass, 'temporary_class', (), {})
|
||||
|
||||
|
||||
def add_metaclass(metaclass):
|
||||
"""Class decorator for creating a class with a metaclass."""
|
||||
def wrapper(cls):
|
||||
orig_vars = cls.__dict__.copy()
|
||||
slots = orig_vars.get('__slots__')
|
||||
if slots is not None:
|
||||
if isinstance(slots, str):
|
||||
slots = [slots]
|
||||
for slots_var in slots:
|
||||
orig_vars.pop(slots_var)
|
||||
orig_vars.pop('__dict__', None)
|
||||
orig_vars.pop('__weakref__', None)
|
||||
return metaclass(cls.__name__, cls.__bases__, orig_vars)
|
||||
return wrapper
|
||||
|
||||
|
||||
def python_2_unicode_compatible(klass):
|
||||
"""
|
||||
A decorator that defines __unicode__ and __str__ methods under Python 2.
|
||||
Under Python 3 it does nothing.
|
||||
|
||||
To support Python 2 and 3 with a single code base, define a __str__ method
|
||||
returning text and apply this decorator to the class.
|
||||
"""
|
||||
if PY2:
|
||||
if '__str__' not in klass.__dict__:
|
||||
raise ValueError("@python_2_unicode_compatible cannot be applied "
|
||||
"to %s because it doesn't define __str__()." %
|
||||
klass.__name__)
|
||||
klass.__unicode__ = klass.__str__
|
||||
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
|
||||
return klass
|
||||
|
||||
|
||||
# Complete the moves implementation.
|
||||
# This code is at the end of this module to speed up module loading.
|
||||
# Turn this module into a package.
|
||||
__path__ = [] # required for PEP 302 and PEP 451
|
||||
__package__ = __name__ # see PEP 366 @ReservedAssignment
|
||||
if globals().get("__spec__") is not None:
|
||||
__spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable
|
||||
# Remove other six meta path importers, since they cause problems. This can
|
||||
# happen if six is removed from sys.modules and then reloaded. (Setuptools does
|
||||
# this for some reason.)
|
||||
if sys.meta_path:
|
||||
for i, importer in enumerate(sys.meta_path):
|
||||
# Here's some real nastiness: Another "instance" of the six module might
|
||||
# be floating around. Therefore, we can't use isinstance() to check for
|
||||
# the six meta path importer, since the other six instance will have
|
||||
# inserted an importer with different class.
|
||||
if (type(importer).__name__ == "_SixMetaPathImporter" and
|
||||
importer.name == __name__):
|
||||
del sys.meta_path[i]
|
||||
break
|
||||
del i, importer
|
||||
# Finally, add the importer to the meta path import hook.
|
||||
sys.meta_path.append(_importer)
|
@ -6,7 +6,7 @@ from youtube import yt_app
|
||||
from youtube import util
|
||||
|
||||
# these are just so the files get run - they import yt_app and add routes to it
|
||||
from youtube import watch, search, playlist, channel, local_playlist, comments, post_comment
|
||||
from youtube import watch, search, playlist, channel, local_playlist, comments, post_comment, subscriptions
|
||||
|
||||
import settings
|
||||
|
||||
|
@ -119,6 +119,12 @@ For security reasons, enabling this is not recommended.''',
|
||||
],
|
||||
}),
|
||||
|
||||
('autocheck_subscriptions', {
|
||||
'type': bool,
|
||||
'default': 0,
|
||||
'comment': '',
|
||||
}),
|
||||
|
||||
('gather_googlevideo_domains', {
|
||||
'type': bool,
|
||||
'default': False,
|
||||
|
@ -1,5 +1,5 @@
|
||||
import base64
|
||||
from youtube import util, yt_data_extract, local_playlist
|
||||
from youtube import util, yt_data_extract, local_playlist, subscriptions
|
||||
from youtube import yt_app
|
||||
|
||||
import urllib
|
||||
@ -83,13 +83,15 @@ def channel_ctoken(channel_id, page, sort, tab, view=1):
|
||||
|
||||
return base64.urlsafe_b64encode(pointless_nest).decode('ascii')
|
||||
|
||||
def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1):
|
||||
def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1, print_status=True):
|
||||
ctoken = channel_ctoken(channel_id, page, sort, tab, view).replace('=', '%3D')
|
||||
url = "https://www.youtube.com/browse_ajax?ctoken=" + ctoken
|
||||
|
||||
print("Sending channel tab ajax request")
|
||||
if print_status:
|
||||
print("Sending channel tab ajax request")
|
||||
content = util.fetch_url(url, util.desktop_ua + headers_1, debug_name='channel_tab')
|
||||
print("Finished recieving channel tab response")
|
||||
if print_status:
|
||||
print("Finished recieving channel tab response")
|
||||
|
||||
return content
|
||||
|
||||
@ -312,7 +314,7 @@ def get_channel_page(channel_id, tab='videos'):
|
||||
info['current_sort'] = sort
|
||||
elif tab == 'search':
|
||||
info['search_box_value'] = query
|
||||
|
||||
info['subscribed'] = subscriptions.is_subscribed(info['channel_id'])
|
||||
|
||||
return flask.render_template('channel.html',
|
||||
parameters_dictionary = request.args,
|
||||
@ -352,7 +354,7 @@ def get_channel_page_general_url(base_url, tab, request):
|
||||
info['current_sort'] = sort
|
||||
elif tab == 'search':
|
||||
info['search_box_value'] = query
|
||||
|
||||
info['subscribed'] = subscriptions.is_subscribed(info['channel_id'])
|
||||
|
||||
return flask.render_template('channel.html',
|
||||
parameters_dictionary = request.args,
|
||||
|
@ -34,33 +34,7 @@ def add_to_playlist(name, video_info_list):
|
||||
if id not in ids:
|
||||
file.write(info + "\n")
|
||||
missing_thumbnails.append(id)
|
||||
gevent.spawn(download_thumbnails, name, missing_thumbnails)
|
||||
|
||||
def download_thumbnail(playlist_name, video_id):
|
||||
url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"
|
||||
save_location = os.path.join(thumbnails_directory, playlist_name, video_id + ".jpg")
|
||||
try:
|
||||
thumbnail = util.fetch_url(url, report_text="Saved local playlist thumbnail: " + video_id)
|
||||
except urllib.error.HTTPError as e:
|
||||
print("Failed to download thumbnail for " + video_id + ": " + str(e))
|
||||
return
|
||||
try:
|
||||
f = open(save_location, 'wb')
|
||||
except FileNotFoundError:
|
||||
os.makedirs(os.path.join(thumbnails_directory, playlist_name))
|
||||
f = open(save_location, 'wb')
|
||||
f.write(thumbnail)
|
||||
f.close()
|
||||
|
||||
def download_thumbnails(playlist_name, ids):
|
||||
# only do 5 at a time
|
||||
# do the n where n is divisible by 5
|
||||
i = -1
|
||||
for i in range(0, int(len(ids)/5) - 1 ):
|
||||
gevent.joinall([gevent.spawn(download_thumbnail, playlist_name, ids[j]) for j in range(i*5, i*5 + 5)])
|
||||
# do the remainders (< 5)
|
||||
gevent.joinall([gevent.spawn(download_thumbnail, playlist_name, ids[j]) for j in range(i*5 + 5, len(ids))])
|
||||
|
||||
gevent.spawn(util.download_thumbnails, os.path.join(thumbnails_directory, name), missing_thumbnails)
|
||||
|
||||
|
||||
def get_local_playlist_videos(name, offset=0, amount=50):
|
||||
@ -88,7 +62,7 @@ def get_local_playlist_videos(name, offset=0, amount=50):
|
||||
except json.decoder.JSONDecodeError:
|
||||
if not video_json.strip() == '':
|
||||
print('Corrupt playlist video entry: ' + video_json)
|
||||
gevent.spawn(download_thumbnails, name, missing_thumbnails)
|
||||
gevent.spawn(util.download_thumbnails, os.path.join(thumbnails_directory, name), missing_thumbnails)
|
||||
return videos[offset:offset+amount], len(videos)
|
||||
|
||||
def get_playlist_names():
|
||||
|
@ -9,7 +9,7 @@ a:link {
|
||||
}
|
||||
|
||||
a:visited {
|
||||
color: ##7755ff;
|
||||
color: #7755ff;
|
||||
}
|
||||
|
||||
a:not([href]){
|
||||
@ -23,3 +23,16 @@ a:not([href]){
|
||||
.setting-item{
|
||||
background-color: #444444;
|
||||
}
|
||||
|
||||
|
||||
.muted{
|
||||
background-color: #111111;
|
||||
color: gray;
|
||||
}
|
||||
|
||||
.muted a:link {
|
||||
color: #10547f;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -11,3 +11,7 @@ body{
|
||||
.setting-item{
|
||||
background-color: #eeeeee;
|
||||
}
|
||||
|
||||
.muted{
|
||||
background-color: #888888;
|
||||
}
|
||||
|
@ -8,7 +8,11 @@ body{
|
||||
color: #000000;
|
||||
}
|
||||
|
||||
|
||||
.setting-item{
|
||||
background-color: #f8f8f8;
|
||||
}
|
||||
|
||||
.muted{
|
||||
background-color: #888888;
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
h1, h2, h3, h4, h5, h6, div, button{
|
||||
margin:0;
|
||||
padding:0;
|
||||
|
||||
}
|
||||
|
||||
address{
|
||||
|
@ -1,18 +1,828 @@
|
||||
from youtube import util, yt_data_extract, channel
|
||||
from youtube import yt_app
|
||||
import settings
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
import time
|
||||
import gevent
|
||||
import json
|
||||
import traceback
|
||||
import contextlib
|
||||
import defusedxml.ElementTree
|
||||
import urllib
|
||||
import math
|
||||
import secrets
|
||||
import collections
|
||||
import calendar # bullshit! https://bugs.python.org/issue6280
|
||||
|
||||
with open("subscriptions.txt", 'r', encoding='utf-8') as file:
|
||||
subscriptions = file.read()
|
||||
|
||||
# Line format: "channel_id channel_name"
|
||||
# Example:
|
||||
# UCYO_jab_esuFRV4b17AJtAw 3Blue1Brown
|
||||
import flask
|
||||
from flask import request
|
||||
|
||||
subscriptions = ((line[0:24], line[25: ]) for line in subscriptions.splitlines())
|
||||
|
||||
def get_new_videos():
|
||||
for channel_id, channel_name in subscriptions:
|
||||
|
||||
thumbnails_directory = os.path.join(settings.data_dir, "subscription_thumbnails")
|
||||
|
||||
# https://stackabuse.com/a-sqlite-tutorial-with-python/
|
||||
|
||||
database_path = os.path.join(settings.data_dir, "subscriptions.sqlite")
|
||||
|
||||
def open_database():
|
||||
if not os.path.exists(settings.data_dir):
|
||||
os.makedirs(settings.data_dir)
|
||||
connection = sqlite3.connect(database_path, check_same_thread=False)
|
||||
|
||||
try:
|
||||
cursor = connection.cursor()
|
||||
cursor.execute('''PRAGMA foreign_keys = 1''')
|
||||
# Create tables if they don't exist
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS subscribed_channels (
|
||||
id integer PRIMARY KEY,
|
||||
yt_channel_id text UNIQUE NOT NULL,
|
||||
channel_name text NOT NULL,
|
||||
time_last_checked integer DEFAULT 0,
|
||||
next_check_time integer DEFAULT 0,
|
||||
muted integer DEFAULT 0
|
||||
)''')
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS videos (
|
||||
id integer PRIMARY KEY,
|
||||
sql_channel_id integer NOT NULL REFERENCES subscribed_channels(id) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
video_id text UNIQUE NOT NULL,
|
||||
title text NOT NULL,
|
||||
duration text,
|
||||
time_published integer NOT NULL,
|
||||
is_time_published_exact integer DEFAULT 0,
|
||||
time_noticed integer NOT NULL,
|
||||
description text,
|
||||
watched integer default 0
|
||||
)''')
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS tag_associations (
|
||||
id integer PRIMARY KEY,
|
||||
tag text NOT NULL,
|
||||
sql_channel_id integer NOT NULL REFERENCES subscribed_channels(id) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
UNIQUE(tag, sql_channel_id)
|
||||
)''')
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS db_info (
|
||||
version integer DEFAULT 1
|
||||
)''')
|
||||
|
||||
connection.commit()
|
||||
except:
|
||||
connection.rollback()
|
||||
connection.close()
|
||||
raise
|
||||
|
||||
# https://stackoverflow.com/questions/19522505/using-sqlite3-in-python-with-with-keyword
|
||||
return contextlib.closing(connection)
|
||||
|
||||
def with_open_db(function, *args, **kwargs):
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
return function(cursor, *args, **kwargs)
|
||||
|
||||
def _is_subscribed(cursor, channel_id):
|
||||
result = cursor.execute('''SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM subscribed_channels
|
||||
WHERE yt_channel_id=?
|
||||
LIMIT 1
|
||||
)''', [channel_id]).fetchone()
|
||||
return bool(result[0])
|
||||
|
||||
def is_subscribed(channel_id):
|
||||
if not os.path.exists(database_path):
|
||||
return False
|
||||
|
||||
return with_open_db(_is_subscribed, channel_id)
|
||||
|
||||
def _subscribe(channels):
|
||||
''' channels is a list of (channel_id, channel_name) '''
|
||||
channels = list(channels)
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
channel_ids_to_check = [channel[0] for channel in channels if not _is_subscribed(cursor, channel[0])]
|
||||
|
||||
rows = ( (channel_id, channel_name, 0, 0) for channel_id, channel_name in channels)
|
||||
cursor.executemany('''INSERT OR IGNORE INTO subscribed_channels (yt_channel_id, channel_name, time_last_checked, next_check_time)
|
||||
VALUES (?, ?, ?, ?)''', rows)
|
||||
|
||||
if settings.autocheck_subscriptions:
|
||||
# important that this is after the changes have been committed to database
|
||||
# otherwise the autochecker (other thread) tries checking the channel before it's in the database
|
||||
channel_names.update(channels)
|
||||
check_channels_if_necessary(channel_ids_to_check)
|
||||
|
||||
def delete_thumbnails(to_delete):
|
||||
for thumbnail in to_delete:
|
||||
try:
|
||||
video_id = thumbnail[0:-4]
|
||||
if video_id in existing_thumbnails:
|
||||
os.remove(os.path.join(thumbnails_directory, thumbnail))
|
||||
existing_thumbnails.remove(video_id)
|
||||
except Exception:
|
||||
print('Failed to delete thumbnail: ' + thumbnail)
|
||||
traceback.print_exc()
|
||||
|
||||
def _unsubscribe(cursor, channel_ids):
|
||||
''' channel_ids is a list of channel_ids '''
|
||||
to_delete = []
|
||||
for channel_id in channel_ids:
|
||||
rows = cursor.execute('''SELECT video_id
|
||||
FROM videos
|
||||
WHERE sql_channel_id = (
|
||||
SELECT id
|
||||
FROM subscribed_channels
|
||||
WHERE yt_channel_id=?
|
||||
)''', (channel_id,)).fetchall()
|
||||
to_delete += [row[0] + '.jpg' for row in rows]
|
||||
|
||||
gevent.spawn(delete_thumbnails, to_delete)
|
||||
cursor.executemany("DELETE FROM subscribed_channels WHERE yt_channel_id=?", ((channel_id, ) for channel_id in channel_ids))
|
||||
|
||||
def _get_videos(cursor, number_per_page, offset, tag = None):
|
||||
'''Returns a full page of videos with an offset, and a value good enough to be used as the total number of videos'''
|
||||
# We ask for the next 9 pages from the database
|
||||
# Then the actual length of the results tell us if there are more than 9 pages left, and if not, how many there actually are
|
||||
# This is done since there are only 9 page buttons on display at a time
|
||||
# If there are more than 9 pages left, we give a fake value in place of the real number of results if the entire database was queried without limit
|
||||
# This fake value is sufficient to get the page button generation macro to display 9 page buttons
|
||||
# If we wish to display more buttons this logic must change
|
||||
# We cannot use tricks with the sql id for the video since we frequently have filters and other restrictions in place on the results anyway
|
||||
# TODO: This is probably not the ideal solution
|
||||
if tag is not None:
|
||||
db_videos = cursor.execute('''SELECT video_id, title, duration, time_published, is_time_published_exact, channel_name
|
||||
FROM videos
|
||||
INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id
|
||||
INNER JOIN tag_associations on videos.sql_channel_id = tag_associations.sql_channel_id
|
||||
WHERE tag = ? AND muted = 0
|
||||
ORDER BY time_noticed DESC, time_published DESC
|
||||
LIMIT ? OFFSET ?''', (tag, number_per_page*9, offset)).fetchall()
|
||||
else:
|
||||
db_videos = cursor.execute('''SELECT video_id, title, duration, time_published, is_time_published_exact, channel_name
|
||||
FROM videos
|
||||
INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id
|
||||
WHERE muted = 0
|
||||
ORDER BY time_noticed DESC, time_published DESC
|
||||
LIMIT ? OFFSET ?''', (number_per_page*9, offset)).fetchall()
|
||||
|
||||
pseudo_number_of_videos = offset + len(db_videos)
|
||||
|
||||
videos = []
|
||||
for db_video in db_videos[0:number_per_page]:
|
||||
videos.append({
|
||||
'id': db_video[0],
|
||||
'title': db_video[1],
|
||||
'duration': db_video[2],
|
||||
'published': exact_timestamp(db_video[3]) if db_video[4] else posix_to_dumbed_down(db_video[3]),
|
||||
'author': db_video[5],
|
||||
})
|
||||
|
||||
return videos, pseudo_number_of_videos
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_subscribed_channels(cursor):
|
||||
for item in cursor.execute('''SELECT channel_name, yt_channel_id, muted
|
||||
FROM subscribed_channels
|
||||
ORDER BY channel_name COLLATE NOCASE'''):
|
||||
yield item
|
||||
|
||||
|
||||
def _add_tags(cursor, channel_ids, tags):
|
||||
pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids]
|
||||
cursor.executemany('''INSERT OR IGNORE INTO tag_associations (tag, sql_channel_id)
|
||||
SELECT ?, id FROM subscribed_channels WHERE yt_channel_id = ? ''', pairs)
|
||||
|
||||
|
||||
def _remove_tags(cursor, channel_ids, tags):
|
||||
pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids]
|
||||
cursor.executemany('''DELETE FROM tag_associations
|
||||
WHERE tag = ? AND sql_channel_id = (
|
||||
SELECT id FROM subscribed_channels WHERE yt_channel_id = ?
|
||||
)''', pairs)
|
||||
|
||||
|
||||
|
||||
def _get_tags(cursor, channel_id):
|
||||
return [row[0] for row in cursor.execute('''SELECT tag
|
||||
FROM tag_associations
|
||||
WHERE sql_channel_id = (
|
||||
SELECT id FROM subscribed_channels WHERE yt_channel_id = ?
|
||||
)''', (channel_id,))]
|
||||
|
||||
def _get_all_tags(cursor):
|
||||
return [row[0] for row in cursor.execute('''SELECT DISTINCT tag FROM tag_associations''')]
|
||||
|
||||
def _get_channel_names(cursor, channel_ids):
|
||||
''' returns list of (channel_id, channel_name) '''
|
||||
result = []
|
||||
for channel_id in channel_ids:
|
||||
row = cursor.execute('''SELECT channel_name
|
||||
FROM subscribed_channels
|
||||
WHERE yt_channel_id = ?''', (channel_id,)).fetchone()
|
||||
result.append( (channel_id, row[0]) )
|
||||
return result
|
||||
|
||||
|
||||
def _channels_with_tag(cursor, tag, order=False, exclude_muted=False, include_muted_status=False):
|
||||
''' returns list of (channel_id, channel_name) '''
|
||||
|
||||
statement = '''SELECT yt_channel_id, channel_name'''
|
||||
|
||||
if include_muted_status:
|
||||
statement += ''', muted'''
|
||||
|
||||
statement += '''
|
||||
FROM subscribed_channels
|
||||
WHERE subscribed_channels.id IN (
|
||||
SELECT tag_associations.sql_channel_id FROM tag_associations WHERE tag=?
|
||||
)
|
||||
'''
|
||||
if exclude_muted:
|
||||
statement += '''AND muted != 1\n'''
|
||||
if order:
|
||||
statement += '''ORDER BY channel_name COLLATE NOCASE'''
|
||||
|
||||
return cursor.execute(statement, [tag]).fetchall()
|
||||
|
||||
def _schedule_checking(cursor, channel_id, next_check_time):
|
||||
cursor.execute('''UPDATE subscribed_channels SET next_check_time = ? WHERE yt_channel_id = ?''', [int(next_check_time), channel_id])
|
||||
|
||||
def _is_muted(cursor, channel_id):
|
||||
return bool(cursor.execute('''SELECT muted FROM subscribed_channels WHERE yt_channel_id=?''', [channel_id]).fetchone()[0])
|
||||
|
||||
units = collections.OrderedDict([
|
||||
('year', 31536000), # 365*24*3600
|
||||
('month', 2592000), # 30*24*3600
|
||||
('week', 604800), # 7*24*3600
|
||||
('day', 86400), # 24*3600
|
||||
('hour', 3600),
|
||||
('minute', 60),
|
||||
('second', 1),
|
||||
])
|
||||
def youtube_timestamp_to_posix(dumb_timestamp):
|
||||
''' Given a dumbed down timestamp such as 1 year ago, 3 hours ago,
|
||||
approximates the unix time (seconds since 1/1/1970) '''
|
||||
dumb_timestamp = dumb_timestamp.lower()
|
||||
now = time.time()
|
||||
if dumb_timestamp == "just now":
|
||||
return now
|
||||
split = dumb_timestamp.split(' ')
|
||||
quantifier, unit = int(split[0]), split[1]
|
||||
if quantifier > 1:
|
||||
unit = unit[:-1] # remove s from end
|
||||
return now - quantifier*units[unit]
|
||||
|
||||
def posix_to_dumbed_down(posix_time):
|
||||
'''Inverse of youtube_timestamp_to_posix.'''
|
||||
delta = int(time.time() - posix_time)
|
||||
assert delta >= 0
|
||||
|
||||
if delta == 0:
|
||||
return '0 seconds ago'
|
||||
|
||||
for unit_name, unit_time in units.items():
|
||||
if delta >= unit_time:
|
||||
quantifier = round(delta/unit_time)
|
||||
if quantifier == 1:
|
||||
return '1 ' + unit_name + ' ago'
|
||||
else:
|
||||
return str(quantifier) + ' ' + unit_name + 's ago'
|
||||
else:
|
||||
raise Exception()
|
||||
|
||||
def exact_timestamp(posix_time):
|
||||
result = time.strftime('%I:%M %p %m/%d/%y', time.localtime(posix_time))
|
||||
if result[0] == '0': # remove 0 infront of hour (like 01:00 PM)
|
||||
return result[1:]
|
||||
return result
|
||||
|
||||
try:
|
||||
existing_thumbnails = set(os.path.splitext(name)[0] for name in os.listdir(thumbnails_directory))
|
||||
except FileNotFoundError:
|
||||
existing_thumbnails = set()
|
||||
|
||||
|
||||
# --- Manual checking system. Rate limited in order to support very large numbers of channels to be checked ---
|
||||
# Auto checking system plugs into this for convenience, though it doesn't really need the rate limiting
|
||||
|
||||
check_channels_queue = util.RateLimitedQueue()
|
||||
checking_channels = set()
|
||||
|
||||
# Just to use for printing channel checking status to console without opening database
|
||||
channel_names = dict()
|
||||
|
||||
def check_channel_worker():
|
||||
while True:
|
||||
channel_id = check_channels_queue.get()
|
||||
try:
|
||||
_get_upstream_videos(channel_id)
|
||||
finally:
|
||||
checking_channels.remove(channel_id)
|
||||
|
||||
for i in range(0,5):
|
||||
gevent.spawn(check_channel_worker)
|
||||
# ----------------------------
|
||||
|
||||
|
||||
|
||||
# --- Auto checking system - Spaghetti code ---
|
||||
|
||||
if settings.autocheck_subscriptions:
|
||||
# job application format: dict with keys (channel_id, channel_name, next_check_time)
|
||||
autocheck_job_application = gevent.queue.Queue() # only really meant to hold 1 item, just reusing gevent's wait and timeout machinery
|
||||
|
||||
autocheck_jobs = [] # list of dicts with the keys (channel_id, channel_name, next_check_time). Stores all the channels that need to be autochecked and when to check them
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
now = time.time()
|
||||
for row in cursor.execute('''SELECT yt_channel_id, channel_name, next_check_time FROM subscribed_channels WHERE muted != 1''').fetchall():
|
||||
|
||||
if row[2] is None:
|
||||
next_check_time = 0
|
||||
else:
|
||||
next_check_time = row[2]
|
||||
|
||||
# expired, check randomly within the next hour
|
||||
# note: even if it isn't scheduled in the past right now, it might end up being if it's due soon and we dont start dispatching by then, see below where time_until_earliest_job is negative
|
||||
if next_check_time < now:
|
||||
next_check_time = now + 3600*secrets.randbelow(60)/60
|
||||
row = (row[0], row[1], next_check_time)
|
||||
_schedule_checking(cursor, row[0], next_check_time)
|
||||
autocheck_jobs.append({'channel_id': row[0], 'channel_name': row[1], 'next_check_time': next_check_time})
|
||||
|
||||
|
||||
|
||||
def autocheck_dispatcher():
|
||||
'''Scans the auto_check_list. Sleeps until the earliest job is due, then adds that channel to the checking queue above. Can be sent a new job through autocheck_job_application'''
|
||||
while True:
|
||||
if len(autocheck_jobs) == 0:
|
||||
new_job = autocheck_job_application.get()
|
||||
autocheck_jobs.append(new_job)
|
||||
else:
|
||||
earliest_job_index = min(range(0, len(autocheck_jobs)), key=lambda index: autocheck_jobs[index]['next_check_time']) # https://stackoverflow.com/a/11825864
|
||||
earliest_job = autocheck_jobs[earliest_job_index]
|
||||
time_until_earliest_job = earliest_job['next_check_time'] - time.time()
|
||||
|
||||
if time_until_earliest_job <= -5: # should not happen unless we're running extremely slow
|
||||
print('ERROR: autocheck_dispatcher got job scheduled in the past, skipping and rescheduling: ' + earliest_job['channel_id'] + ', ' + earliest_job['channel_name'] + ', ' + str(earliest_job['next_check_time']))
|
||||
next_check_time = time.time() + 3600*secrets.randbelow(60)/60
|
||||
with_open_db(_schedule_checking, earliest_job['channel_id'], next_check_time)
|
||||
autocheck_jobs[earliest_job_index]['next_check_time'] = next_check_time
|
||||
continue
|
||||
|
||||
# make sure it's not muted
|
||||
if with_open_db(_is_muted, earliest_job['channel_id']):
|
||||
del autocheck_jobs[earliest_job_index]
|
||||
continue
|
||||
|
||||
if time_until_earliest_job > 0: # it can become less than zero (in the past) when it's set to go off while the dispatcher is doing something else at that moment
|
||||
try:
|
||||
new_job = autocheck_job_application.get(timeout = time_until_earliest_job) # sleep for time_until_earliest_job time, but allow to be interrupted by new jobs
|
||||
except gevent.queue.Empty: # no new jobs
|
||||
pass
|
||||
else: # new job, add it to the list
|
||||
autocheck_jobs.append(new_job)
|
||||
continue
|
||||
|
||||
# no new jobs, time to execute the earliest job
|
||||
channel_names[earliest_job['channel_id']] = earliest_job['channel_name']
|
||||
checking_channels.add(earliest_job['channel_id'])
|
||||
check_channels_queue.put(earliest_job['channel_id'])
|
||||
del autocheck_jobs[earliest_job_index]
|
||||
|
||||
|
||||
gevent.spawn(autocheck_dispatcher)
|
||||
# ----------------------------
|
||||
|
||||
|
||||
|
||||
def check_channels_if_necessary(channel_ids):
|
||||
for channel_id in channel_ids:
|
||||
if channel_id not in checking_channels:
|
||||
checking_channels.add(channel_id)
|
||||
check_channels_queue.put(channel_id)
|
||||
|
||||
|
||||
|
||||
def _get_upstream_videos(channel_id):
|
||||
try:
|
||||
channel_status_name = channel_names[channel_id]
|
||||
except KeyError:
|
||||
channel_status_name = channel_id
|
||||
|
||||
print("Checking channel: " + channel_status_name)
|
||||
|
||||
tasks = (
|
||||
gevent.spawn(channel.get_channel_tab, channel_id, print_status=False), # channel page, need for video duration
|
||||
gevent.spawn(util.fetch_url, 'https://www.youtube.com/feeds/videos.xml?channel_id=' + channel_id) # atoma feed, need for exact published time
|
||||
)
|
||||
gevent.joinall(tasks)
|
||||
|
||||
channel_tab, feed = tasks[0].value, tasks[1].value
|
||||
|
||||
# extract published times from atoma feed
|
||||
times_published = {}
|
||||
try:
|
||||
def remove_bullshit(tag):
|
||||
'''Remove XML namespace bullshit from tagname. https://bugs.python.org/issue18304'''
|
||||
if '}' in tag:
|
||||
return tag[tag.rfind('}')+1:]
|
||||
return tag
|
||||
|
||||
def find_element(base, tag_name):
|
||||
for element in base:
|
||||
if remove_bullshit(element.tag) == tag_name:
|
||||
return element
|
||||
return None
|
||||
|
||||
root = defusedxml.ElementTree.fromstring(feed.decode('utf-8'))
|
||||
assert remove_bullshit(root.tag) == 'feed'
|
||||
for entry in root:
|
||||
if (remove_bullshit(entry.tag) != 'entry'):
|
||||
continue
|
||||
|
||||
# it's yt:videoId in the xml but the yt: is turned into a namespace which is removed by remove_bullshit
|
||||
video_id_element = find_element(entry, 'videoId')
|
||||
time_published_element = find_element(entry, 'published')
|
||||
assert video_id_element is not None
|
||||
assert time_published_element is not None
|
||||
|
||||
time_published = int(calendar.timegm(time.strptime(time_published_element.text, '%Y-%m-%dT%H:%M:%S+00:00')))
|
||||
times_published[video_id_element.text] = time_published
|
||||
|
||||
except (AssertionError, defusedxml.ElementTree.ParseError) as e:
|
||||
print('Failed to read atoma feed for ' + channel_status_name)
|
||||
traceback.print_exc()
|
||||
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
is_first_check = cursor.execute('''SELECT time_last_checked FROM subscribed_channels WHERE yt_channel_id=?''', [channel_id]).fetchone()[0] in (None, 0)
|
||||
video_add_time = int(time.time())
|
||||
|
||||
videos = []
|
||||
channel_videos = channel.extract_info(json.loads(channel_tab), 'videos')['items']
|
||||
for i, video_item in enumerate(channel_videos):
|
||||
if 'description' not in video_item:
|
||||
video_item['description'] = ''
|
||||
|
||||
if video_item['id'] in times_published:
|
||||
time_published = times_published[video_item['id']]
|
||||
is_time_published_exact = True
|
||||
else:
|
||||
is_time_published_exact = False
|
||||
try:
|
||||
time_published = youtube_timestamp_to_posix(video_item['published']) - i # subtract a few seconds off the videos so they will be in the right order
|
||||
except KeyError:
|
||||
print(video_item)
|
||||
if is_first_check:
|
||||
time_noticed = time_published # don't want a crazy ordering on first check, since we're ordering by time_noticed
|
||||
else:
|
||||
time_noticed = video_add_time
|
||||
videos.append((channel_id, video_item['id'], video_item['title'], video_item['duration'], time_published, is_time_published_exact, time_noticed, video_item['description']))
|
||||
|
||||
|
||||
if len(videos) == 0:
|
||||
average_upload_period = 4*7*24*3600 # assume 1 month for channel with no videos
|
||||
elif len(videos) < 5:
|
||||
average_upload_period = int((time.time() - videos[len(videos)-1][4])/len(videos))
|
||||
else:
|
||||
average_upload_period = int((time.time() - videos[4][4])/5) # equivalent to averaging the time between videos for the last 5 videos
|
||||
|
||||
# calculate when to check next for auto checking
|
||||
# add some quantization and randomness to make pattern analysis by Youtube slightly harder
|
||||
quantized_upload_period = average_upload_period - (average_upload_period % (4*3600)) + 4*3600 # round up to nearest 4 hours
|
||||
randomized_upload_period = quantized_upload_period*(1 + secrets.randbelow(50)/50*0.5) # randomly between 1x and 1.5x
|
||||
next_check_delay = randomized_upload_period/10 # check at 10x the channel posting rate. might want to fine tune this number
|
||||
next_check_time = int(time.time() + next_check_delay)
|
||||
|
||||
# calculate how many new videos there are
|
||||
row = cursor.execute('''SELECT video_id
|
||||
FROM videos
|
||||
INNER JOIN subscribed_channels ON videos.sql_channel_id = subscribed_channels.id
|
||||
WHERE yt_channel_id=?
|
||||
ORDER BY time_published DESC
|
||||
LIMIT 1''', [channel_id]).fetchone()
|
||||
if row is None:
|
||||
number_of_new_videos = len(videos)
|
||||
else:
|
||||
latest_video_id = row[0]
|
||||
index = 0
|
||||
for video in videos:
|
||||
if video[1] == latest_video_id:
|
||||
break
|
||||
index += 1
|
||||
number_of_new_videos = index
|
||||
|
||||
cursor.executemany('''INSERT OR IGNORE INTO videos (sql_channel_id, video_id, title, duration, time_published, is_time_published_exact, time_noticed, description)
|
||||
VALUES ((SELECT id FROM subscribed_channels WHERE yt_channel_id=?), ?, ?, ?, ?, ?, ?, ?)''', videos)
|
||||
cursor.execute('''UPDATE subscribed_channels
|
||||
SET time_last_checked = ?, next_check_time = ?
|
||||
WHERE yt_channel_id=?''', [int(time.time()), next_check_time, channel_id])
|
||||
|
||||
if settings.autocheck_subscriptions:
|
||||
if not _is_muted(cursor, channel_id):
|
||||
autocheck_job_application.put({'channel_id': channel_id, 'channel_name': channel_names[channel_id], 'next_check_time': next_check_time})
|
||||
|
||||
if number_of_new_videos == 0:
|
||||
print('No new videos from ' + channel_status_name)
|
||||
elif number_of_new_videos == 1:
|
||||
print('1 new video from ' + channel_status_name)
|
||||
else:
|
||||
print(str(number_of_new_videos) + ' new videos from ' + channel_status_name)
|
||||
|
||||
|
||||
|
||||
def check_all_channels():
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
channel_id_name_list = cursor.execute('''SELECT yt_channel_id, channel_name
|
||||
FROM subscribed_channels
|
||||
WHERE muted != 1''').fetchall()
|
||||
|
||||
channel_names.update(channel_id_name_list)
|
||||
check_channels_if_necessary([item[0] for item in channel_id_name_list])
|
||||
|
||||
|
||||
def check_tags(tags):
|
||||
channel_id_name_list = []
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
for tag in tags:
|
||||
channel_id_name_list += _channels_with_tag(cursor, tag, exclude_muted=True)
|
||||
|
||||
channel_names.update(channel_id_name_list)
|
||||
check_channels_if_necessary([item[0] for item in channel_id_name_list])
|
||||
|
||||
|
||||
def check_specific_channels(channel_ids):
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
channel_id_name_list = []
|
||||
for channel_id in channel_ids:
|
||||
channel_id_name_list += cursor.execute('''SELECT yt_channel_id, channel_name
|
||||
FROM subscribed_channels
|
||||
WHERE yt_channel_id=?''', [channel_id]).fetchall()
|
||||
channel_names.update(channel_id_name_list)
|
||||
check_channels_if_necessary(channel_ids)
|
||||
|
||||
|
||||
|
||||
@yt_app.route('/import_subscriptions', methods=['POST'])
|
||||
def import_subscriptions():
|
||||
|
||||
# check if the post request has the file part
|
||||
if 'subscriptions_file' not in request.files:
|
||||
#flash('No file part')
|
||||
return flask.redirect(util.URL_ORIGIN + request.full_path)
|
||||
file = request.files['subscriptions_file']
|
||||
# if user does not select file, browser also
|
||||
# submit an empty part without filename
|
||||
if file.filename == '':
|
||||
#flash('No selected file')
|
||||
return flask.redirect(util.URL_ORIGIN + request.full_path)
|
||||
|
||||
|
||||
mime_type = file.mimetype
|
||||
|
||||
if mime_type == 'application/json':
|
||||
file = file.read().decode('utf-8')
|
||||
try:
|
||||
file = json.loads(file)
|
||||
except json.decoder.JSONDecodeError:
|
||||
traceback.print_exc()
|
||||
return '400 Bad Request: Invalid json file', 400
|
||||
|
||||
try:
|
||||
channels = ( (item['snippet']['resourceId']['channelId'], item['snippet']['title']) for item in file)
|
||||
except (KeyError, IndexError):
|
||||
traceback.print_exc()
|
||||
return '400 Bad Request: Unknown json structure', 400
|
||||
elif mime_type in ('application/xml', 'text/xml', 'text/x-opml'):
|
||||
file = file.read().decode('utf-8')
|
||||
try:
|
||||
root = defusedxml.ElementTree.fromstring(file)
|
||||
assert root.tag == 'opml'
|
||||
channels = []
|
||||
for outline_element in root[0][0]:
|
||||
if (outline_element.tag != 'outline') or ('xmlUrl' not in outline_element.attrib):
|
||||
continue
|
||||
|
||||
|
||||
channel_name = outline_element.attrib['text']
|
||||
channel_rss_url = outline_element.attrib['xmlUrl']
|
||||
channel_id = channel_rss_url[channel_rss_url.find('channel_id=')+11:].strip()
|
||||
channels.append( (channel_id, channel_name) )
|
||||
|
||||
except (AssertionError, IndexError, defusedxml.ElementTree.ParseError) as e:
|
||||
return '400 Bad Request: Unable to read opml xml file, or the file is not the expected format', 400
|
||||
else:
|
||||
return '400 Bad Request: Unsupported file format: ' + mime_type + '. Only subscription.json files (from Google Takeouts) and XML OPML files exported from Youtube\'s subscription manager page are supported', 400
|
||||
|
||||
_subscribe(channels)
|
||||
|
||||
return flask.redirect(util.URL_ORIGIN + '/subscription_manager', 303)
|
||||
|
||||
|
||||
|
||||
@yt_app.route('/subscription_manager', methods=['GET'])
|
||||
def get_subscription_manager_page():
|
||||
group_by_tags = request.args.get('group_by_tags', '0') == '1'
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
if group_by_tags:
|
||||
tag_groups = []
|
||||
|
||||
for tag in _get_all_tags(cursor):
|
||||
sub_list = []
|
||||
for channel_id, channel_name, muted in _channels_with_tag(cursor, tag, order=True, include_muted_status=True):
|
||||
sub_list.append({
|
||||
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
|
||||
'channel_name': channel_name,
|
||||
'channel_id': channel_id,
|
||||
'muted': muted,
|
||||
'tags': [t for t in _get_tags(cursor, channel_id) if t != tag],
|
||||
})
|
||||
|
||||
tag_groups.append( (tag, sub_list) )
|
||||
|
||||
# Channels with no tags
|
||||
channel_list = cursor.execute('''SELECT yt_channel_id, channel_name, muted
|
||||
FROM subscribed_channels
|
||||
WHERE id NOT IN (
|
||||
SELECT sql_channel_id FROM tag_associations
|
||||
)
|
||||
ORDER BY channel_name COLLATE NOCASE''').fetchall()
|
||||
if channel_list:
|
||||
sub_list = []
|
||||
for channel_id, channel_name, muted in channel_list:
|
||||
sub_list.append({
|
||||
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
|
||||
'channel_name': channel_name,
|
||||
'channel_id': channel_id,
|
||||
'muted': muted,
|
||||
'tags': [],
|
||||
})
|
||||
|
||||
tag_groups.append( ('No tags', sub_list) )
|
||||
else:
|
||||
sub_list = []
|
||||
for channel_name, channel_id, muted in _get_subscribed_channels(cursor):
|
||||
sub_list.append({
|
||||
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
|
||||
'channel_name': channel_name,
|
||||
'channel_id': channel_id,
|
||||
'muted': muted,
|
||||
'tags': _get_tags(cursor, channel_id),
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
if group_by_tags:
|
||||
return flask.render_template('subscription_manager.html',
|
||||
group_by_tags = True,
|
||||
tag_groups = tag_groups,
|
||||
)
|
||||
else:
|
||||
return flask.render_template('subscription_manager.html',
|
||||
group_by_tags = False,
|
||||
sub_list = sub_list,
|
||||
)
|
||||
|
||||
def list_from_comma_separated_tags(string):
|
||||
return [tag.strip() for tag in string.split(',') if tag.strip()]
|
||||
|
||||
|
||||
@yt_app.route('/subscription_manager', methods=['POST'])
|
||||
def post_subscription_manager_page():
|
||||
action = request.values['action']
|
||||
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
if action == 'add_tags':
|
||||
_add_tags(cursor, request.values.getlist('channel_ids'), [tag.lower() for tag in list_from_comma_separated_tags(request.values['tags'])])
|
||||
elif action == 'remove_tags':
|
||||
_remove_tags(cursor, request.values.getlist('channel_ids'), [tag.lower() for tag in list_from_comma_separated_tags(request.values['tags'])])
|
||||
elif action == 'unsubscribe':
|
||||
_unsubscribe(cursor, request.values.getlist('channel_ids'))
|
||||
elif action == 'unsubscribe_verify':
|
||||
unsubscribe_list = _get_channel_names(cursor, request.values.getlist('channel_ids'))
|
||||
return flask.render_template('unsubscribe_verify.html', unsubscribe_list = unsubscribe_list)
|
||||
|
||||
elif action == 'mute':
|
||||
cursor.executemany('''UPDATE subscribed_channels
|
||||
SET muted = 1
|
||||
WHERE yt_channel_id = ?''', [(ci,) for ci in request.values.getlist('channel_ids')])
|
||||
elif action == 'unmute':
|
||||
cursor.executemany('''UPDATE subscribed_channels
|
||||
SET muted = 0
|
||||
WHERE yt_channel_id = ?''', [(ci,) for ci in request.values.getlist('channel_ids')])
|
||||
else:
|
||||
flask.abort(400)
|
||||
|
||||
return flask.redirect(util.URL_ORIGIN + request.full_path, 303)
|
||||
|
||||
@yt_app.route('/subscriptions', methods=['GET'])
|
||||
@yt_app.route('/feed/subscriptions', methods=['GET'])
|
||||
def get_subscriptions_page():
|
||||
page = int(request.args.get('page', 1))
|
||||
with open_database() as connection:
|
||||
with connection as cursor:
|
||||
tag = request.args.get('tag', None)
|
||||
videos, number_of_videos_in_db = _get_videos(cursor, 60, (page - 1)*60, tag)
|
||||
for video in videos:
|
||||
video['thumbnail'] = util.URL_ORIGIN + '/data/subscription_thumbnails/' + video['id'] + '.jpg'
|
||||
video['type'] = 'video'
|
||||
video['item_size'] = 'small'
|
||||
yt_data_extract.add_extra_html_info(video)
|
||||
|
||||
tags = _get_all_tags(cursor)
|
||||
|
||||
|
||||
subscription_list = []
|
||||
for channel_name, channel_id, muted in _get_subscribed_channels(cursor):
|
||||
subscription_list.append({
|
||||
'channel_url': util.URL_ORIGIN + '/channel/' + channel_id,
|
||||
'channel_name': channel_name,
|
||||
'channel_id': channel_id,
|
||||
'muted': muted,
|
||||
})
|
||||
|
||||
return flask.render_template('subscriptions.html',
|
||||
videos = videos,
|
||||
num_pages = math.ceil(number_of_videos_in_db/60),
|
||||
parameters_dictionary = request.args,
|
||||
tags = tags,
|
||||
current_tag = tag,
|
||||
subscription_list = subscription_list,
|
||||
)
|
||||
|
||||
@yt_app.route('/subscriptions', methods=['POST'])
|
||||
@yt_app.route('/feed/subscriptions', methods=['POST'])
|
||||
def post_subscriptions_page():
|
||||
action = request.values['action']
|
||||
if action == 'subscribe':
|
||||
if len(request.values.getlist('channel_id')) != len(request.values.getlist('channel_name')):
|
||||
return '400 Bad Request, length of channel_id != length of channel_name', 400
|
||||
_subscribe(zip(request.values.getlist('channel_id'), request.values.getlist('channel_name')))
|
||||
|
||||
elif action == 'unsubscribe':
|
||||
with_open_db(_unsubscribe, request.values.getlist('channel_id'))
|
||||
|
||||
elif action == 'refresh':
|
||||
type = request.values['type']
|
||||
if type == 'all':
|
||||
check_all_channels()
|
||||
elif type == 'tag':
|
||||
check_tags(request.values.getlist('tag_name'))
|
||||
elif type == 'channel':
|
||||
check_specific_channels(request.values.getlist('channel_id'))
|
||||
else:
|
||||
flask.abort(400)
|
||||
else:
|
||||
flask.abort(400)
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@yt_app.route('/data/subscription_thumbnails/<thumbnail>')
|
||||
def serve_subscription_thumbnail(thumbnail):
|
||||
'''Serves thumbnail from disk if it's been saved already. If not, downloads the thumbnail, saves to disk, and serves it.'''
|
||||
assert thumbnail[-4:] == '.jpg'
|
||||
video_id = thumbnail[0:-4]
|
||||
thumbnail_path = os.path.join(thumbnails_directory, thumbnail)
|
||||
|
||||
if video_id in existing_thumbnails:
|
||||
try:
|
||||
f = open(thumbnail_path, 'rb')
|
||||
except FileNotFoundError:
|
||||
existing_thumbnails.remove(video_id)
|
||||
else:
|
||||
image = f.read()
|
||||
f.close()
|
||||
return flask.Response(image, mimetype='image/jpeg')
|
||||
|
||||
url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"
|
||||
try:
|
||||
image = util.fetch_url(url, report_text="Saved thumbnail: " + video_id)
|
||||
except urllib.error.HTTPError as e:
|
||||
print("Failed to download thumbnail for " + video_id + ": " + str(e))
|
||||
abort(e.code)
|
||||
try:
|
||||
f = open(thumbnail_path, 'wb')
|
||||
except FileNotFoundError:
|
||||
os.makedirs(thumbnails_directory, exist_ok = True)
|
||||
f = open(thumbnail_path, 'wb')
|
||||
f.write(image)
|
||||
f.close()
|
||||
existing_thumbnails.add(video_id)
|
||||
|
||||
return flask.Response(image, mimetype='image/jpeg')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -23,6 +23,9 @@
|
||||
grid-column:2;
|
||||
margin-left: 5px;
|
||||
}
|
||||
.summary subscribe-unsubscribe, .summary short-description{
|
||||
margin-top: 10px;
|
||||
}
|
||||
main .channel-tabs{
|
||||
grid-row:2;
|
||||
grid-column: 1 / span 2;
|
||||
@ -89,6 +92,12 @@
|
||||
<div class="summary">
|
||||
<h2 class="title">{{ channel_name }}</h2>
|
||||
<p class="short-description">{{ short_description }}</p>
|
||||
<form method="POST" action="/youtube.com/subscriptions" class="subscribe-unsubscribe">
|
||||
<input type="submit" value="{{ 'Unsubscribe' if subscribed else 'Subscribe' }}">
|
||||
<input type="hidden" name="channel_id" value="{{ channel_id }}">
|
||||
<input type="hidden" name="channel_name" value="{{ channel_name }}">
|
||||
<input type="hidden" name="action" value="{{ 'unsubscribe' if subscribed else 'subscribe' }}">
|
||||
</form>
|
||||
</div>
|
||||
<nav class="channel-tabs">
|
||||
{% for tab_name in ('Videos', 'Playlists', 'About') %}
|
||||
|
139
youtube/templates/subscription_manager.html
Normal file
139
youtube/templates/subscription_manager.html
Normal file
@ -0,0 +1,139 @@
|
||||
{% set page_title = 'Subscription Manager' %}
|
||||
{% extends "base.html" %}
|
||||
{% block style %}
|
||||
.import-export{
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
}
|
||||
.subscriptions-import-form{
|
||||
background-color: var(--interface-color);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
max-width: 300px;
|
||||
padding:10px;
|
||||
}
|
||||
.subscriptions-import-form h2{
|
||||
font-size: 20px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.import-submit-button{
|
||||
margin-top:15px;
|
||||
align-self: flex-end;
|
||||
}
|
||||
|
||||
|
||||
.subscriptions-export-links{
|
||||
margin: 0px 0px 0px 20px;
|
||||
background-color: var(--interface-color);
|
||||
list-style: none;
|
||||
max-width: 300px;
|
||||
padding:10px;
|
||||
}
|
||||
|
||||
.sub-list-controls{
|
||||
background-color: var(--interface-color);
|
||||
padding:10px;
|
||||
}
|
||||
|
||||
|
||||
.tag-group-list{
|
||||
list-style: none;
|
||||
margin-left: 10px;
|
||||
margin-right: 10px;
|
||||
padding: 0px;
|
||||
}
|
||||
.tag-group{
|
||||
border-style: solid;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.sub-list{
|
||||
list-style: none;
|
||||
padding:10px;
|
||||
column-width: 300px;
|
||||
column-gap: 40px;
|
||||
}
|
||||
.sub-list-item{
|
||||
display:flex;
|
||||
margin-bottom: 10px;
|
||||
break-inside:avoid;
|
||||
background-color: var(--interface-color);
|
||||
}
|
||||
.tag-list{
|
||||
margin-left:15px;
|
||||
font-weight:bold;
|
||||
}
|
||||
.sub-list-item-name{
|
||||
margin-left:15px;
|
||||
}
|
||||
.sub-list-checkbox{
|
||||
height: 1.5em;
|
||||
min-width: 1.5em; // need min-width otherwise browser doesn't respect the width and squishes the checkbox down when there's too many tags
|
||||
}
|
||||
{% endblock style %}
|
||||
|
||||
|
||||
{% macro subscription_list(sub_list) %}
|
||||
{% for subscription in sub_list %}
|
||||
<li class="sub-list-item {{ 'muted' if subscription['muted'] else '' }}">
|
||||
<input class="sub-list-checkbox" name="channel_ids" value="{{ subscription['channel_id'] }}" form="subscription-manager-form" type="checkbox">
|
||||
<a href="{{ subscription['channel_url'] }}" class="sub-list-item-name" title="{{ subscription['channel_name'] }}">{{ subscription['channel_name'] }}</a>
|
||||
<span class="tag-list">{{ ', '.join(subscription['tags']) }}</span>
|
||||
</li>
|
||||
{% endfor %}
|
||||
{% endmacro %}
|
||||
|
||||
|
||||
|
||||
{% block main %}
|
||||
<div class="import-export">
|
||||
<form class="subscriptions-import-form" enctype="multipart/form-data" action="/youtube.com/import_subscriptions" method="POST">
|
||||
<h2>Import subscriptions</h2>
|
||||
<input type="file" id="subscriptions-import" accept="application/json, application/xml, text/x-opml" name="subscriptions_file">
|
||||
<input type="submit" value="Import" class="import-submit-button">
|
||||
</form>
|
||||
|
||||
<ul class="subscriptions-export-links">
|
||||
<li><a href="/youtube.com/subscriptions.opml">Export subscriptions (OPML)</a></li>
|
||||
<li><a href="/youtube.com/subscriptions.xml">Export subscriptions (RSS)</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
<form id="subscription-manager-form" class="sub-list-controls" method="POST">
|
||||
{% if group_by_tags %}
|
||||
<a class="sort-button" href="/https://www.youtube.com/subscription_manager?group_by_tags=0">Don't group</a>
|
||||
{% else %}
|
||||
<a class="sort-button" href="/https://www.youtube.com/subscription_manager?group_by_tags=1">Group by tags</a>
|
||||
{% endif %}
|
||||
<input type="text" name="tags">
|
||||
<button type="submit" name="action" value="add_tags">Add tags</button>
|
||||
<button type="submit" name="action" value="remove_tags">Remove tags</button>
|
||||
<button type="submit" name="action" value="unsubscribe_verify">Unsubscribe</button>
|
||||
<button type="submit" name="action" value="mute">Mute</button>
|
||||
<button type="submit" name="action" value="unmute">Unmute</button>
|
||||
<input type="reset" value="Clear Selection">
|
||||
</form>
|
||||
|
||||
|
||||
{% if group_by_tags %}
|
||||
<ul class="tag-group-list">
|
||||
{% for tag_name, sub_list in tag_groups %}
|
||||
<li class="tag-group">
|
||||
<h2 class="tag-group-name">{{ tag_name }}</h2>
|
||||
<ol class="sub-list">
|
||||
{{ subscription_list(sub_list) }}
|
||||
</ol>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% else %}
|
||||
<ol class="sub-list">
|
||||
{{ subscription_list(sub_list) }}
|
||||
</ol>
|
||||
{% endif %}
|
||||
|
||||
{% endblock main %}
|
113
youtube/templates/subscriptions.html
Normal file
113
youtube/templates/subscriptions.html
Normal file
@ -0,0 +1,113 @@
|
||||
{% set page_title = 'Subscriptions' %}
|
||||
{% extends "base.html" %}
|
||||
{% import "common_elements.html" as common_elements %}
|
||||
|
||||
{% block style %}
|
||||
main{
|
||||
display:flex;
|
||||
flex-direction: row;
|
||||
}
|
||||
.video-section{
|
||||
flex-grow: 1;
|
||||
}
|
||||
.video-section .page-button-row{
|
||||
justify-content: center;
|
||||
}
|
||||
.subscriptions-sidebar{
|
||||
flex-basis: 300px;
|
||||
background-color: var(--interface-color);
|
||||
border-left: 2px;
|
||||
}
|
||||
.sidebar-links{
|
||||
display:flex;
|
||||
justify-content: space-between;
|
||||
padding-left:10px;
|
||||
padding-right: 10px;
|
||||
}
|
||||
|
||||
.sidebar-list{
|
||||
list-style: none;
|
||||
padding-left:10px;
|
||||
padding-right: 10px;
|
||||
}
|
||||
.sidebar-list-item{
|
||||
display:flex;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
.sub-refresh-list .sidebar-item-name{
|
||||
text-overflow: clip;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
max-width: 200px;
|
||||
}
|
||||
{% endblock style %}
|
||||
|
||||
{% block main %}
|
||||
<div class="video-section">
|
||||
<nav class="item-grid">
|
||||
{% for video_info in videos %}
|
||||
{{ common_elements.item(video_info, include_author=false) }}
|
||||
{% endfor %}
|
||||
</nav>
|
||||
|
||||
<nav class="page-button-row">
|
||||
{{ common_elements.page_buttons(num_pages, '/youtube.com/subscriptions', parameters_dictionary) }}
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
<div class="subscriptions-sidebar">
|
||||
<div class="sidebar-links">
|
||||
<a href="/youtube.com/subscription_manager" class="sub-manager-link">Subscription Manager</a>
|
||||
<form method="POST" class="refresh-all">
|
||||
<input type="submit" value="Check All">
|
||||
<input type="hidden" name="action" value="refresh">
|
||||
<input type="hidden" name="type" value="all">
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
<ol class="sidebar-list tags">
|
||||
{% if current_tag %}
|
||||
<li class="sidebar-list-item">
|
||||
<a href="/youtube.com/subscriptions" class="sidebar-item-name">Any tag</a>
|
||||
</li>
|
||||
{% endif %}
|
||||
|
||||
{% for tag in tags %}
|
||||
<li class="sidebar-list-item">
|
||||
{% if tag == current_tag %}
|
||||
<span class="sidebar-item-name">{{ tag }}</span>
|
||||
{% else %}
|
||||
<a href="?tag={{ tag|urlencode }}" class="sidebar-item-name">{{ tag }}</a>
|
||||
{% endif %}
|
||||
<form method="POST" class="sidebar-item-refresh">
|
||||
<input type="submit" value="Check">
|
||||
<input type="hidden" name="action" value="refresh">
|
||||
<input type="hidden" name="type" value="tag">
|
||||
<input type="hidden" name="tag_name" value="{{ tag }}">
|
||||
</form>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ol>
|
||||
|
||||
<hr>
|
||||
|
||||
<ol class="sidebar-list sub-refresh-list">
|
||||
{% for subscription in subscription_list %}
|
||||
<li class="sidebar-list-item {{ 'muted' if subscription['muted'] else '' }}">
|
||||
<a href="{{ subscription['channel_url'] }}" class="sidebar-item-name" title="{{ subscription['channel_name'] }}">{{ subscription['channel_name'] }}</a>
|
||||
<form method="POST" class="sidebar-item-refresh">
|
||||
<input type="submit" value="Check">
|
||||
<input type="hidden" name="action" value="refresh">
|
||||
<input type="hidden" name="type" value="channel">
|
||||
<input type="hidden" name="channel_id" value="{{ subscription['channel_id'] }}">
|
||||
</form>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ol>
|
||||
|
||||
</div>
|
||||
|
||||
{% endblock main %}
|
19
youtube/templates/unsubscribe_verify.html
Normal file
19
youtube/templates/unsubscribe_verify.html
Normal file
@ -0,0 +1,19 @@
|
||||
{% set page_title = 'Unsubscribe?' %}
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block main %}
|
||||
<span>Are you sure you want to unsubscribe from these channels?</span>
|
||||
<form class="subscriptions-import-form" action="/youtube.com/subscription_manager" method="POST">
|
||||
{% for channel_id, channel_name in unsubscribe_list %}
|
||||
<input type="hidden" name="channel_ids" value="{{ channel_id }}">
|
||||
{% endfor %}
|
||||
|
||||
<input type="hidden" name="action" value="unsubscribe">
|
||||
<input type="submit" value="Yes, unsubscribe">
|
||||
</form>
|
||||
<ul>
|
||||
{% for channel_id, channel_name in unsubscribe_list %}
|
||||
<li><a href="{{ '/https://www.youtube.com/channel/' + channel_id }}" title="{{ channel_name }}">{{ channel_name }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% endblock main %}
|
@ -6,6 +6,9 @@ import urllib.parse
|
||||
import re
|
||||
import time
|
||||
import os
|
||||
import gevent
|
||||
import gevent.queue
|
||||
import gevent.lock
|
||||
|
||||
# The trouble with the requests library: It ships its own certificate bundle via certifi
|
||||
# instead of using the system certificate store, meaning self-signed certificates
|
||||
@ -183,6 +186,84 @@ desktop_ua = (('User-Agent', desktop_user_agent),)
|
||||
|
||||
|
||||
|
||||
class RateLimitedQueue(gevent.queue.Queue):
|
||||
''' Does initial_burst (def. 30) at first, then alternates between waiting waiting_period (def. 5) seconds and doing subsequent_bursts (def. 10) queries. After 5 seconds with nothing left in the queue, resets rate limiting. '''
|
||||
|
||||
def __init__(self, initial_burst=30, waiting_period=5, subsequent_bursts=10):
|
||||
self.initial_burst = initial_burst
|
||||
self.waiting_period = waiting_period
|
||||
self.subsequent_bursts = subsequent_bursts
|
||||
|
||||
self.count_since_last_wait = 0
|
||||
self.surpassed_initial = False
|
||||
|
||||
self.lock = gevent.lock.BoundedSemaphore(1)
|
||||
self.currently_empty = False
|
||||
self.empty_start = 0
|
||||
gevent.queue.Queue.__init__(self)
|
||||
|
||||
|
||||
def get(self):
|
||||
self.lock.acquire() # blocks if another greenlet currently has the lock
|
||||
if self.count_since_last_wait >= self.subsequent_bursts and self.surpassed_initial:
|
||||
gevent.sleep(self.waiting_period)
|
||||
self.count_since_last_wait = 0
|
||||
|
||||
elif self.count_since_last_wait >= self.initial_burst and not self.surpassed_initial:
|
||||
self.surpassed_initial = True
|
||||
gevent.sleep(self.waiting_period)
|
||||
self.count_since_last_wait = 0
|
||||
|
||||
self.count_since_last_wait += 1
|
||||
|
||||
if not self.currently_empty and self.empty():
|
||||
self.currently_empty = True
|
||||
self.empty_start = time.monotonic()
|
||||
|
||||
item = gevent.queue.Queue.get(self) # blocks when nothing left
|
||||
|
||||
if self.currently_empty:
|
||||
if time.monotonic() - self.empty_start >= self.waiting_period:
|
||||
self.count_since_last_wait = 0
|
||||
self.surpassed_initial = False
|
||||
|
||||
self.currently_empty = False
|
||||
|
||||
self.lock.release()
|
||||
|
||||
return item
|
||||
|
||||
|
||||
|
||||
def download_thumbnail(save_directory, video_id):
|
||||
url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"
|
||||
save_location = os.path.join(save_directory, video_id + ".jpg")
|
||||
try:
|
||||
thumbnail = fetch_url(url, report_text="Saved thumbnail: " + video_id)
|
||||
except urllib.error.HTTPError as e:
|
||||
print("Failed to download thumbnail for " + video_id + ": " + str(e))
|
||||
return False
|
||||
try:
|
||||
f = open(save_location, 'wb')
|
||||
except FileNotFoundError:
|
||||
os.makedirs(save_directory, exist_ok = True)
|
||||
f = open(save_location, 'wb')
|
||||
f.write(thumbnail)
|
||||
f.close()
|
||||
return True
|
||||
|
||||
def download_thumbnails(save_directory, ids):
|
||||
if not isinstance(ids, (list, tuple)):
|
||||
ids = list(ids)
|
||||
# only do 5 at a time
|
||||
# do the n where n is divisible by 5
|
||||
i = -1
|
||||
for i in range(0, int(len(ids)/5) - 1 ):
|
||||
gevent.joinall([gevent.spawn(download_thumbnail, save_directory, ids[j]) for j in range(i*5, i*5 + 5)])
|
||||
# do the remainders (< 5)
|
||||
gevent.joinall([gevent.spawn(download_thumbnail, save_directory, ids[j]) for j in range(i*5 + 5, len(ids))])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user