fixts first attempt

mike/backfill-extras
Mike Lang 2 years ago committed by Mike Lang
parent 30d5ccc483
commit eaf3ed2e54

@ -0,0 +1,245 @@
import struct
class FixTS():
"""Does stream processing on an MPEG-TS stream, adjusting all timestamps in it.
The stream will be adjusted such that the first packet starts at the given start_time,
with all other packets adjusted to be the same time relative to that packet.
In other words, a video that goes from 01:23:45 to 01:24:45 will be retimed to instead
go from (for example) 00:10:00 to 00:11:00.
The object maintains an internal buffer of data.
Use feed() to add more data. Data will be removed from the buffer when a whole packet
can be parsed, and any completed data will be returned from feed().
Finally, when there is no more data, call end() to assert there is no left-over data
and provide the final video end time.
All timestamps are in seconds as a float.
Example use:
fixer = FixTimestamps(0)
for chunk in input:
fixed_data = fixer.feed(chunk)
output(fixed_data)
end_time = fixer.end()
"""
PACKET_SIZE = 188
def __init__(self, start_time):
self.start_time = start_time
self.end_time = start_time
self.offset = None
self.data = b""
def feed(self, data):
"""Takes more data as a bytestring to add to buffer.
Fixes any whole packets in the buffer and returns them as a single bytestring."""
self.data += data
output = []
while len(self.data) >= self.PACKET_SIZE:
packet = self.data[:self.PACKET_SIZE]
self.data = self.data[self.PACKET_SIZE:]
output.append(self._fix_packet(packet))
return b''.join(output)
def end(self):
"""Should be called when no more data will be added.
Checks no data was left over, and returns the final end time (ie. start time + video duration).
"""
if len(self.data) > 0:
raise ValueError("Stream has a partial packet remaining: {!r}", self.data)
return self.end_time
# TODO we should really be only using PCR to calibrate the offset (ie. we want the first PCR
# to be = start_time, not the first PTS we see which might be the audio stream).
# Also we need to pad the end_time to the time of the NEXT expected frame, or else
# we'll overlap the last frame here with the first frame of the next segment.
# How to determine expected frame? Easiest way is probably average or median difference
# between PCRs, with a reasonable fallback.
def _convert_time(self, old_time):
# If this is the first timestamp we've seen, use it to calibrate offset.
if self.offset is None:
self.offset = self.start_time - old_time
new_time = old_time + self.offset
# It's rare but possible that when resetting times to start at 0, the second packet
# might start slightly earlier than the first and thus have a negative time.
# This isn't encodable in the data format, so just clamp to 0.
new_time = max(0, new_time)
# keep track of latest new_time as the end time
self.end_time = max(self.end_time, new_time)
return new_time
def _fix_packet(self, packet):
"""
- If an adaptation field is present and contains a PCR, fix the PCR
- If packet is the start of a unit, and the unit begins with 0x0001
(ie. it's an elementary stream and not a table):
- If the packet header contains a PTS, fix the PTS
- If the packet header cannot be decoded far enough (not enough data in first packet),
bail - we don't care about this edge case.
"""
assert len(packet) == self.PACKET_SIZE
def check(expr, reason):
if not expr:
raise ValueError("Packet cannot be parsed: {}\n{!r}".format(reason, packet))
# Note this is a very simple, hacky parser that only parses as much as we need.
# Useful links: https://en.wikipedia.org/wiki/MPEG_transport_stream
# 4 byte header: "G" | TEI(1) PUSI(1) PRI(1) PID(5) | PID(8) | TSC(2) AFC(2) CONT(4)
# Of interest to us:
# TEI: If set, data is known to be corrupt
# PUSI: If set, this packet contains a new payload unit
# This matters because payload unit headers contain a timestamp we need to edit
# TSC: If non-zero, indicates data is scrambled (we don't implement handling that)
# AFC: First bit indicates an adaptation field header is present, second bit indicates a payload
check(packet[0:1] == b"G", "Sync byte is incorrect")
check(packet[1] & 0x80 == 0, "Transport error indicator is set")
pusi = bool(packet[1] & 0x40)
check(packet[3] & 0xc0 == 0, "TSC indicates data is scrambled")
has_adaptation_field = bool(packet[3] & 0x20)
has_payload = bool(packet[3] & 0x10)
has_pcr = False
if has_adaptation_field:
field_length = packet[4]
payload_index = 5 + field_length
# According to the spec, the adaptation field header is at least 1 byte.
# But in the wild we see a header as "present" except 0 bytes long.
# We should just treat this as "not present"
if field_length > 0:
# The adaptation field is a bit field of 8 flags indicating whether optional
# sections are present. Thankfully, the only one we're interested in (the PCR)
# is always the first field if present, so we don't even need to check the others.
has_pcr = bool(packet[5] & 0x10)
if has_pcr:
check(field_length >= 7, "Adaptation field indicates PCR but is too small")
old_time = decode_pcr(packet[6:12])
new_time = self._convert_time(old_time)
encoded = encode_pcr(new_time)
packet = packet[:6] + encoded + packet[12:]
assert len(packet) == 188
else:
# No adapatation field, payload starts immediately after the packet header
payload_index = 4
if pusi:
# Payload Unit Start Indicator indicates there is a new payload unit in this packet.
# When set, there is an extra byte before the payload indicating where within the
# payload the new payload unit starts.
# A payload unit is a thing like a video frame, audio packet, etc. The payload unit header
# contains a timestamp we need to edit.
check(has_payload, "PUSI set but no payload is present")
payload_pointer = packet[payload_index]
# move index past payload pointer, then seek into payload to find start of payload unit.
unit_index = payload_index + 1 + payload_pointer
# The header we're after is only present in elementary streams, not in program tables.
# We can tell the difference because streams start with a 0x0001 prefix,
# whereas program tables start with a header where at least bits 0x0030 must be set.
# Note wikipedia in https://en.wikipedia.org/wiki/Packetized_elementary_stream
# claims the prefix is 0x000001, but that is including the payload pointer, which seems
# to always be set to 0 for an elementary stream
# (compare https://en.wikipedia.org/wiki/Program-specific_information which also includes
# the payload pointer but says it can be any 8-bit value).
if packet[unit_index : unit_index + 2] == b"\x00\x01":
# unit header looks like: 00, 01, stream id, length(2 bytes), then PES header
# The only thing we care about in the PES header is the top two bits of the second byte,
# which indicates if timestamps are present.
# It's possible that we didn't get enough of the payload in this one packet
# to read the whole header, but exceedingly unlikely.
check(unit_index + 6 < self.PACKET_SIZE, "Payload too small to read unit header")
flags = packet[unit_index + 6]
has_pts = bool(flags & 0x80)
has_dts = bool(flags & 0x40)
check(not has_dts, "DTS timestamp is present, we cannot handle fixing it")
# Once again, PTS is the first optional field, so we don't need to worry
# about other fields being present.
if has_pts:
pts_index = unit_index + 8
check(pts_index + 5 <= self.PACKET_SIZE, "Payload too small to read PTS")
raw = packet[pts_index : pts_index + 5]
pts = decode_ts(raw, 2)
pts = self._convert_time(pts)
encoded = encode_ts(pts, 2)
packet = packet[:pts_index] + encoded + packet[pts_index + 5:]
assert len(packet) == 188
return packet
def bits(value, start, end):
"""Extract bits [START, END) from value, where 0 is LSB"""
size = end - start
return (value >> start) & ((1 << size) - 1)
def decode_padded(value, spec):
size = struct.calcsize(spec)
pad = size - len(value)
assert pad >= 0
value = b"\0" * pad + value
return struct.unpack(spec, value)[0]
def encode_pcr(seconds):
assert seconds >= 0
raw = int(seconds * 27000000)
base, ext = divmod(raw, 300)
assert base < 2**33
value = (base << 15) + ext
value = struct.pack('!Q', value)
return value[2:]
def decode_pcr(value):
value = decode_padded(value, '!Q')
base = bits(value, 15, 48)
extension = bits(value, 0, 9)
raw = 300 * base + extension
seconds = float(raw) / 27000000
return seconds
def encode_ts(seconds, tag):
# bits: TTTTxxx1 xxxxxxxx xxxxxxx1 xxxxxxxx xxxxxxx1
# T is tag, x is bits of actual number
assert seconds >= 0
raw = int(seconds * 90000)
a = bits(raw, 30, 33)
b = bits(raw, 15, 30)
c = bits(raw, 0, 15)
value = 1 + (1 << 16) + (1 << 32) + (tag << 36) + (a << 33) + (b << 17) + (c << 1)
value = struct.pack('!Q', value)
return value[3:]
def decode_ts(value, tag):
# bits: TTTTxxx1 xxxxxxxx xxxxxxx1 xxxxxxxx xxxxxxx1
# T is tag, x is bits of actual number
value = decode_padded(value, '!Q')
assert bits(value, 36, 40) == tag
assert all(value & (1 << bit) for bit in [0, 16, 32])
a = bits(value, 33, 36)
b = bits(value, 17, 32)
c = bits(value, 1, 16)
value = (a << 30) + (b << 15) + c
seconds = float(value) / 90000
return seconds
if __name__ == '__main__':
# simple test: read file from stdin, set start to first arg, output to stdout.
import sys
start_time = float(sys.argv[1])
fixer = FixTS(start_time)
chunk = None
while chunk != b"":
chunk = sys.stdin.buffer.read(8192)
if chunk:
output = fixer.feed(chunk)
while output:
written = sys.stdout.buffer.write(output)
output = output[written:]
end_time = fixer.end()
sys.stderr.write(str(end_time) + '\n')
Loading…
Cancel
Save