NewsBlur/utils/S3.py
2024-04-24 09:50:42 -04:00

625 lines
21 KiB
Python

#!/usr/bin/env python
# This software code is made available "AS IS" without warranties of any
# kind. You may copy, display, modify and redistribute the software
# code either by itself or as incorporated into your code; provided that
# you do not remove any proprietary notices. Your use of this software
# code is at your own risk and you waive any claim against Amazon
# Digital Services, Inc. or its affiliates with respect to your use of
# this software code. (c) 2006-2007 Amazon Digital Services, Inc. or its
# affiliates.
import base64
import hmac
import http.client
import re
import sys
import time
import urllib.error
import urllib.parse
import urllib.request
import xml.sax
import sha
DEFAULT_HOST = "s3.amazonaws.com"
PORTS_BY_SECURITY = {True: 443, False: 80}
METADATA_PREFIX = "x-amz-meta-"
AMAZON_HEADER_PREFIX = "x-amz-"
# generates the aws canonical string for the given parameters
def canonical_string(method, bucket="", key="", query_args={}, headers={}, expires=None):
interesting_headers = {}
for header_key in headers:
lk = header_key.lower()
if lk in ["content-md5", "content-type", "date"] or lk.startswith(AMAZON_HEADER_PREFIX):
interesting_headers[lk] = headers[header_key].strip()
# these keys get empty strings if they don't exist
if "content-type" not in interesting_headers:
interesting_headers["content-type"] = ""
if "content-md5" not in interesting_headers:
interesting_headers["content-md5"] = ""
# just in case someone used this. it's not necessary in this lib.
if "x-amz-date" in interesting_headers:
interesting_headers["date"] = ""
# if you're using expires for query string auth, then it trumps date
# (and x-amz-date)
if expires:
interesting_headers["date"] = str(expires)
sorted_header_keys = list(interesting_headers.keys())
sorted_header_keys.sort()
buf = "%s\n" % method
for header_key in sorted_header_keys:
if header_key.startswith(AMAZON_HEADER_PREFIX):
buf += "%s:%s\n" % (header_key, interesting_headers[header_key])
else:
buf += "%s\n" % interesting_headers[header_key]
# append the bucket if it exists
if bucket != "":
buf += "/%s" % bucket
# add the key. even if it doesn't exist, add the slash
buf += "/%s" % urllib.parse.quote_plus(key)
# handle special query string arguments
if "acl" in query_args:
buf += "?acl"
elif "torrent" in query_args:
buf += "?torrent"
elif "logging" in query_args:
buf += "?logging"
elif "location" in query_args:
buf += "?location"
return buf
# computes the base64'ed hmac-sha hash of the canonical string and the secret
# access key, optionally urlencoding the result
def encode(aws_secret_access_key, str, urlencode=False):
b64_hmac = base64.encodestring(hmac.new(aws_secret_access_key, str, sha).digest()).strip()
if urlencode:
return urllib.parse.quote_plus(b64_hmac)
else:
return b64_hmac
def merge_meta(headers, metadata):
final_headers = headers.copy()
for k in list(metadata.keys()):
final_headers[METADATA_PREFIX + k] = metadata[k]
return final_headers
# builds the query arg string
def query_args_hash_to_string(query_args):
query_string = ""
pairs = []
for k, v in list(query_args.items()):
piece = k
if v != None:
piece += "=%s" % urllib.parse.quote_plus(str(v))
pairs.append(piece)
return "&".join(pairs)
class CallingFormat:
PATH = 1
SUBDOMAIN = 2
VANITY = 3
def build_url_base(protocol, server, port, bucket, calling_format):
url_base = "%s://" % protocol
if bucket == "":
url_base += server
elif calling_format == CallingFormat.SUBDOMAIN:
url_base += "%s.%s" % (bucket, server)
elif calling_format == CallingFormat.VANITY:
url_base += bucket
else:
url_base += server
url_base += ":%s" % port
if (bucket != "") and (calling_format == CallingFormat.PATH):
url_base += "/%s" % bucket
return url_base
build_url_base = staticmethod(build_url_base)
class Location:
DEFAULT = None
EU = "EU"
class AWSAuthConnection:
def __init__(
self,
aws_access_key_id,
aws_secret_access_key,
is_secure=True,
server=DEFAULT_HOST,
port=None,
calling_format=CallingFormat.SUBDOMAIN,
):
if not port:
port = PORTS_BY_SECURITY[is_secure]
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.is_secure = is_secure
self.server = server
self.port = port
self.calling_format = calling_format
def create_bucket(self, bucket, headers={}):
return Response(self._make_request("PUT", bucket, "", {}, headers))
def create_located_bucket(self, bucket, location=Location.DEFAULT, headers={}):
if location == Location.DEFAULT:
body = ""
else:
body = (
"<CreateBucketConstraint><LocationConstraint>"
+ location
+ "</LocationConstraint></CreateBucketConstraint>"
)
return Response(self._make_request("PUT", bucket, "", {}, headers, body))
def check_bucket_exists(self, bucket):
return self._make_request("HEAD", bucket, "", {}, {})
def list_bucket(self, bucket, options={}, headers={}):
return ListBucketResponse(self._make_request("GET", bucket, "", options, headers))
def delete_bucket(self, bucket, headers={}):
return Response(self._make_request("DELETE", bucket, "", {}, headers))
def put(self, bucket, key, object, headers={}):
if not isinstance(object, S3Object):
object = S3Object(object)
return Response(self._make_request("PUT", bucket, key, {}, headers, object.data, object.metadata))
def get(self, bucket, key, headers={}):
return GetResponse(self._make_request("GET", bucket, key, {}, headers))
def delete(self, bucket, key, headers={}):
return Response(self._make_request("DELETE", bucket, key, {}, headers))
def get_bucket_logging(self, bucket, headers={}):
return GetResponse(self._make_request("GET", bucket, "", {"logging": None}, headers))
def put_bucket_logging(self, bucket, logging_xml_doc, headers={}):
return Response(self._make_request("PUT", bucket, "", {"logging": None}, headers, logging_xml_doc))
def get_bucket_acl(self, bucket, headers={}):
return self.get_acl(bucket, "", headers)
def get_acl(self, bucket, key, headers={}):
return GetResponse(self._make_request("GET", bucket, key, {"acl": None}, headers))
def put_bucket_acl(self, bucket, acl_xml_document, headers={}):
return self.put_acl(bucket, "", acl_xml_document, headers)
def put_acl(self, bucket, key, acl_xml_document, headers={}):
return Response(self._make_request("PUT", bucket, key, {"acl": None}, headers, acl_xml_document))
def list_all_my_buckets(self, headers={}):
return ListAllMyBucketsResponse(self._make_request("GET", "", "", {}, headers))
def get_bucket_location(self, bucket):
return LocationResponse(self._make_request("GET", bucket, "", {"location": None}))
# end public methods
def _make_request(self, method, bucket="", key="", query_args={}, headers={}, data="", metadata={}):
server = ""
if bucket == "":
server = self.server
elif self.calling_format == CallingFormat.SUBDOMAIN:
server = "%s.%s" % (bucket, self.server)
elif self.calling_format == CallingFormat.VANITY:
server = bucket
else:
server = self.server
path = ""
if (bucket != "") and (self.calling_format == CallingFormat.PATH):
path += "/%s" % bucket
# add the slash after the bucket regardless
# the key will be appended if it is non-empty
path += "/%s" % urllib.parse.quote_plus(key)
# build the path_argument string
# add the ? in all cases since
# signature and credentials follow path args
if len(query_args):
path += "?" + query_args_hash_to_string(query_args)
is_secure = self.is_secure
host = "%s:%d" % (server, self.port)
while True:
if is_secure:
connection = http.client.HTTPSConnection(host)
else:
connection = http.client.HTTPConnection(host)
final_headers = merge_meta(headers, metadata)
# add auth header
self._add_aws_auth_header(final_headers, method, bucket, key, query_args)
connection.request(method, path, data, final_headers)
resp = connection.getresponse()
if resp.status < 300 or resp.status >= 400:
return resp
# handle redirect
location = resp.getheader("location")
if not location:
return resp
# (close connection)
resp.read()
scheme, host, path, params, query, fragment = urllib.parse.urlparse(location)
if scheme == "http":
is_secure = True
elif scheme == "https":
is_secure = False
else:
raise invalidURL("Not http/https: " + location)
if query:
path += "?" + query
# retry with redirect
def _add_aws_auth_header(self, headers, method, bucket, key, query_args):
if "Date" not in headers:
headers["Date"] = time.strftime("%a, %d %b %Y %X GMT", time.gmtime())
c_string = canonical_string(method, bucket, key, query_args, headers)
headers["Authorization"] = "AWS %s:%s" % (
self.aws_access_key_id,
encode(self.aws_secret_access_key, c_string),
)
class QueryStringAuthGenerator:
# by default, expire in 1 minute
DEFAULT_EXPIRES_IN = 60
def __init__(
self,
aws_access_key_id,
aws_secret_access_key,
is_secure=True,
server=DEFAULT_HOST,
port=None,
calling_format=CallingFormat.SUBDOMAIN,
):
if not port:
port = PORTS_BY_SECURITY[is_secure]
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
if is_secure:
self.protocol = "https"
else:
self.protocol = "http"
self.is_secure = is_secure
self.server = server
self.port = port
self.calling_format = calling_format
self.__expires_in = QueryStringAuthGenerator.DEFAULT_EXPIRES_IN
self.__expires = None
# for backwards compatibility with older versions
self.server_name = "%s:%s" % (self.server, self.port)
def set_expires_in(self, expires_in):
self.__expires_in = expires_in
self.__expires = None
def set_expires(self, expires):
self.__expires = expires
self.__expires_in = None
def create_bucket(self, bucket, headers={}):
return self.generate_url("PUT", bucket, "", {}, headers)
def list_bucket(self, bucket, options={}, headers={}):
return self.generate_url("GET", bucket, "", options, headers)
def delete_bucket(self, bucket, headers={}):
return self.generate_url("DELETE", bucket, "", {}, headers)
def put(self, bucket, key, object, headers={}):
if not isinstance(object, S3Object):
object = S3Object(object)
return self.generate_url("PUT", bucket, key, {}, merge_meta(headers, object.metadata))
def get(self, bucket, key, headers={}):
return self.generate_url("GET", bucket, key, {}, headers)
def delete(self, bucket, key, headers={}):
return self.generate_url("DELETE", bucket, key, {}, headers)
def get_bucket_logging(self, bucket, headers={}):
return self.generate_url("GET", bucket, "", {"logging": None}, headers)
def put_bucket_logging(self, bucket, logging_xml_doc, headers={}):
return self.generate_url("PUT", bucket, "", {"logging": None}, headers)
def get_bucket_acl(self, bucket, headers={}):
return self.get_acl(bucket, "", headers)
def get_acl(self, bucket, key="", headers={}):
return self.generate_url("GET", bucket, key, {"acl": None}, headers)
def put_bucket_acl(self, bucket, acl_xml_document, headers={}):
return self.put_acl(bucket, "", acl_xml_document, headers)
# don't really care what the doc is here.
def put_acl(self, bucket, key, acl_xml_document, headers={}):
return self.generate_url("PUT", bucket, key, {"acl": None}, headers)
def list_all_my_buckets(self, headers={}):
return self.generate_url("GET", "", "", {}, headers)
def make_bare_url(self, bucket, key=""):
full_url = self.generate_url(self, bucket, key)
return full_url[: full_url.index("?")]
def generate_url(self, method, bucket="", key="", query_args={}, headers={}):
expires = 0
if self.__expires_in != None:
expires = int(time.time() + self.__expires_in)
elif self.__expires != None:
expires = int(self.__expires)
else:
raise "Invalid expires state"
canonical_str = canonical_string(method, bucket, key, query_args, headers, expires)
encoded_canonical = encode(self.aws_secret_access_key, canonical_str)
url = CallingFormat.build_url_base(self.protocol, self.server, self.port, bucket, self.calling_format)
url += "/%s" % urllib.parse.quote_plus(key)
query_args["Signature"] = encoded_canonical
query_args["Expires"] = expires
query_args["AWSAccessKeyId"] = self.aws_access_key_id
url += "?%s" % query_args_hash_to_string(query_args)
return url
class S3Object:
def __init__(self, data, metadata={}):
self.data = data
self.metadata = metadata
class Owner:
def __init__(self, id="", display_name=""):
self.id = id
self.display_name = display_name
class ListEntry:
def __init__(self, key="", last_modified=None, etag="", size=0, storage_class="", owner=None):
self.key = key
self.last_modified = last_modified
self.etag = etag
self.size = size
self.storage_class = storage_class
self.owner = owner
class CommonPrefixEntry:
def __init(self, prefix=""):
self.prefix = prefix
class Bucket:
def __init__(self, name="", creation_date=""):
self.name = name
self.creation_date = creation_date
class Response:
def __init__(self, http_response):
self.http_response = http_response
# you have to do this read, even if you don't expect a body.
# otherwise, the next request fails.
self.body = http_response.read()
if http_response.status >= 300 and self.body:
self.message = self.body
else:
self.message = "%03d %s" % (http_response.status, http_response.reason)
class ListBucketResponse(Response):
def __init__(self, http_response):
Response.__init__(self, http_response)
if http_response.status < 300:
handler = ListBucketHandler()
xml.sax.parseString(self.body, handler)
self.entries = handler.entries
self.common_prefixes = handler.common_prefixes
self.name = handler.name
self.marker = handler.marker
self.prefix = handler.prefix
self.is_truncated = handler.is_truncated
self.delimiter = handler.delimiter
self.max_keys = handler.max_keys
self.next_marker = handler.next_marker
else:
self.entries = []
class ListAllMyBucketsResponse(Response):
def __init__(self, http_response):
Response.__init__(self, http_response)
if http_response.status < 300:
handler = ListAllMyBucketsHandler()
xml.sax.parseString(self.body, handler)
self.entries = handler.entries
else:
self.entries = []
class GetResponse(Response):
def __init__(self, http_response):
Response.__init__(self, http_response)
response_headers = http_response.msg # older pythons don't have getheaders
metadata = self.get_aws_metadata(response_headers)
self.object = S3Object(self.body, metadata)
def get_aws_metadata(self, headers):
metadata = {}
for hkey in list(headers.keys()):
if hkey.lower().startswith(METADATA_PREFIX):
metadata[hkey[len(METADATA_PREFIX) :]] = headers[hkey]
del headers[hkey]
return metadata
class LocationResponse(Response):
def __init__(self, http_response):
Response.__init__(self, http_response)
if http_response.status < 300:
handler = LocationHandler()
xml.sax.parseString(self.body, handler)
self.location = handler.location
class ListBucketHandler(xml.sax.ContentHandler):
def __init__(self):
self.entries = []
self.curr_entry = None
self.curr_text = ""
self.common_prefixes = []
self.curr_common_prefix = None
self.name = ""
self.marker = ""
self.prefix = ""
self.is_truncated = False
self.delimiter = ""
self.max_keys = 0
self.next_marker = ""
self.is_echoed_prefix_set = False
def startElement(self, name, attrs):
if name == "Contents":
self.curr_entry = ListEntry()
elif name == "Owner":
self.curr_entry.owner = Owner()
elif name == "CommonPrefixes":
self.curr_common_prefix = CommonPrefixEntry()
def endElement(self, name):
if name == "Contents":
self.entries.append(self.curr_entry)
elif name == "CommonPrefixes":
self.common_prefixes.append(self.curr_common_prefix)
elif name == "Key":
self.curr_entry.key = self.curr_text
elif name == "LastModified":
self.curr_entry.last_modified = self.curr_text
elif name == "ETag":
self.curr_entry.etag = self.curr_text
elif name == "Size":
self.curr_entry.size = int(self.curr_text)
elif name == "ID":
self.curr_entry.owner.id = self.curr_text
elif name == "DisplayName":
self.curr_entry.owner.display_name = self.curr_text
elif name == "StorageClass":
self.curr_entry.storage_class = self.curr_text
elif name == "Name":
self.name = self.curr_text
elif name == "Prefix" and self.is_echoed_prefix_set:
self.curr_common_prefix.prefix = self.curr_text
elif name == "Prefix":
self.prefix = self.curr_text
self.is_echoed_prefix_set = True
elif name == "Marker":
self.marker = self.curr_text
elif name == "IsTruncated":
self.is_truncated = self.curr_text == "true"
elif name == "Delimiter":
self.delimiter = self.curr_text
elif name == "MaxKeys":
self.max_keys = int(self.curr_text)
elif name == "NextMarker":
self.next_marker = self.curr_text
self.curr_text = ""
def characters(self, content):
self.curr_text += content
class ListAllMyBucketsHandler(xml.sax.ContentHandler):
def __init__(self):
self.entries = []
self.curr_entry = None
self.curr_text = ""
def startElement(self, name, attrs):
if name == "Bucket":
self.curr_entry = Bucket()
def endElement(self, name):
if name == "Name":
self.curr_entry.name = self.curr_text
elif name == "CreationDate":
self.curr_entry.creation_date = self.curr_text
elif name == "Bucket":
self.entries.append(self.curr_entry)
def characters(self, content):
self.curr_text = content
class LocationHandler(xml.sax.ContentHandler):
def __init__(self):
self.location = None
self.state = "init"
def startElement(self, name, attrs):
if self.state == "init":
if name == "LocationConstraint":
self.state = "tag_location"
self.location = ""
else:
self.state = "bad"
else:
self.state = "bad"
def endElement(self, name):
if self.state == "tag_location" and name == "LocationConstraint":
self.state = "done"
else:
self.state = "bad"
def characters(self, content):
if self.state == "tag_location":
self.location += content