proto_debug.py: Use new implementations from proto

And change base64u to base64p to match (u too easily confused
with "unpadded")

Signed-off-by: Jesús <heckyel@hyperbola.info>
This commit is contained in:
James Taylor 2021-02-25 20:00:37 -08:00 committed by Jesús
parent 889dabb112
commit 3a73953e6c
No known key found for this signature in database
GPG Key ID: F6EE7BC59A315766

View File

@ -79,11 +79,11 @@ The function pp will pretty print the recursive structure:
make_proto will take a recursive_pb structure and make a ctoken out of it: make_proto will take a recursive_pb structure and make a ctoken out of it:
- base64 means a base64 encode with equals sign paddings - base64 means a base64 encode with equals sign paddings
- base64s means a base64 encode without padding - base64s means a base64 encode without padding
- base64u means a url base64 encode with equals signs replaced with %3D - base64p means a url base64 encode with equals signs replaced with %3D
recursive_pb cannot detect between base64 or base64u or base64s so recursive_pb cannot detect between base64 or base64p or base64s so
those must be manually specified if recreating the token. Will not have those must be manually specified if recreating the token. Will not have
make_proto(recursive_pb(x)) == x if x is using base64u or base64s make_proto(recursive_pb(x)) == x if x is using base64p or base64s
There are some other functions I wrote while reverse engineering stuff There are some other functions I wrote while reverse engineering stuff
that may or may not be useful. that may or may not be useful.
@ -123,7 +123,8 @@ def varint_encode(offset):
for i in range(0, needed_bytes - 1): for i in range(0, needed_bytes - 1):
encoded_bytes[i] = (offset & 127) | 128 # 7 least significant bits encoded_bytes[i] = (offset & 127) | 128 # 7 least significant bits
offset = offset >> 7 offset = offset >> 7
encoded_bytes[-1] = offset & 127 # leave first bit as zero for last byte # leave first bit as zero for last byte
encoded_bytes[-1] = offset & 127
return bytes(encoded_bytes) return bytes(encoded_bytes)
@ -198,11 +199,88 @@ def read_group(data, end_sequence):
return data.original[start:index] return data.original[start:index]
def parse(data): def parse(data, include_wire_type=False):
return { '''Returns a dict mapping field numbers to values
field_number: value for _,
field_number, value in read_protobuf(data) data is the protobuf structure, which must not be b64-encoded'''
} if include_wire_type:
return {field_number: [wire_type, value]
for wire_type, field_number, value in read_protobuf(data)}
return {field_number: value
for _, field_number, value in read_protobuf(data)}
base64_enc_funcs = {
'base64': base64.urlsafe_b64encode,
'base64s': unpadded_b64encode,
'base64p': percent_b64encode,
}
def _make_protobuf(data):
# must be dict mapping field_number to [wire_type, value]
if isinstance(data, dict):
new_data = []
for field_num, (wire_type, value) in sorted(data.items()):
new_data.append((wire_type, field_num, value))
data = new_data
if isinstance(data, str):
return data.encode('utf-8')
elif len(data) == 2 and data[0] in list(base64_enc_funcs.keys()):
return base64_enc_funcs[data[0]](_make_protobuf(data[1]))
elif isinstance(data, list):
result = b''
for field in data:
if field[0] == 0:
result += uint(field[1], field[2])
elif field[0] == 2:
result += string(field[1], _make_protobuf(field[2]))
else:
raise NotImplementedError('Wire type ' + str(field[0])
+ ' not implemented')
return result
return data
def make_protobuf(data):
return _make_protobuf(data).decode('ascii')
make_proto = make_protobuf
def _set_protobuf_value(data, *path, value):
if not path:
return value
op = path[0]
if op in base64_enc_funcs:
inner_data = b64_to_bytes(data)
return base64_enc_funcs[op](
_set_protobuf_value(inner_data, *path[1:], value=value)
)
pb_dict = parse(data, include_wire_type=True)
pb_dict[op][1] = _set_protobuf_value(
pb_dict[op][1], *path[1:], value=value
)
return _make_protobuf(pb_dict)
def set_protobuf_value(data, *path, value):
'''Set a field's value in a raw protobuf structure
path is a list of field numbers and/or base64 encoding directives
The directives are
base64: normal base64 encoding with equal signs padding
base64s ("stripped"): no padding
base64p: %3D instead of = for padding
return new_protobuf, err'''
try:
new_protobuf = _set_protobuf_value(data, *path, value=value)
return new_protobuf.decode('ascii'), None
except Exception:
return None, traceback.format_exc()
def b64_to_bytes(data): def b64_to_bytes(data):
@ -287,35 +365,13 @@ def parse_protobuf(data, mutable=False, spec=()):
yield (wire_type, field_number, value) yield (wire_type, field_number, value)
read_protobuf = parse_protobuf
def pb(data, mutable=False): def pb(data, mutable=False):
return list(parse_protobuf(data, mutable=mutable)) return list(parse_protobuf(data, mutable=mutable))
def make_proto(fields):
if len(fields) == 2 and fields[0] == 'base64':
return enc(make_proto(fields[1]))
result = b''
for field in fields:
if field[0] == 0:
result += _proto_field(0, field[1], varint_encode(field[2]))
elif field[0] == 2:
data = field[2]
if isinstance(data, str):
data = data.encode('utf-8')
elif len(data) == 2 and data[0] == 'base64':
data = base64.urlsafe_b64encode(make_proto(data[1]))
elif len(data) == 2 and data[0] == 'base64s':
data = base64.urlsafe_b64encode(make_proto(data[1])).rstrip(b'=')
elif len(data) == 2 and data[0] == 'base64u':
data = base64.urlsafe_b64encode(make_proto(data[1])).replace(b'=', b'%3D')
elif isinstance(data, list):
data = make_proto(data)
result += _proto_field(2, field[1], varint_encode(len(data)) + data)
else:
raise NotImplementedError('Wire type ' + str(field[0]) + ' not implemented')
return result
def bytes_to_base4(data): def bytes_to_base4(data):
result = '' result = ''
for b in data: for b in data:
@ -424,7 +480,7 @@ def b32decode(s, casefold=False, map01=None):
def dec32(data): def dec32(data):
if isinstance(data, bytes): if isinstance(data, bytes):
data = data.decode('ascii') data = data.decode('ascii')
return b32decode(data + "="*((8 - len(data) % 8) % 8)) return b32decode(data + "="*((8 - len(data)%8)%8))
def recursive_pb(data, filt=True): def recursive_pb(data, filt=True):