From 103ba21ce191f621bab18b43543606976ac8e207 Mon Sep 17 00:00:00 2001 From: Mike Lang Date: Mon, 22 May 2023 15:30:20 +1000 Subject: [PATCH] fixts first attempt --- common/common/fixts.py | 245 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 common/common/fixts.py diff --git a/common/common/fixts.py b/common/common/fixts.py new file mode 100644 index 0000000..c26bd01 --- /dev/null +++ b/common/common/fixts.py @@ -0,0 +1,245 @@ + + +import struct + +class FixTS(): + """Does stream processing on an MPEG-TS stream, adjusting all timestamps in it. + The stream will be adjusted such that the first packet starts at the given start_time, + with all other packets adjusted to be the same time relative to that packet. + In other words, a video that goes from 01:23:45 to 01:24:45 will be retimed to instead + go from (for example) 00:10:00 to 00:11:00. + + The object maintains an internal buffer of data. + Use feed() to add more data. Data will be removed from the buffer when a whole packet + can be parsed, and any completed data will be returned from feed(). + Finally, when there is no more data, call end() to assert there is no left-over data + and provide the final video end time. + + All timestamps are in seconds as a float. + + Example use: + fixer = FixTimestamps(0) + for chunk in input: + fixed_data = fixer.feed(chunk) + output(fixed_data) + end_time = fixer.end() + """ + PACKET_SIZE = 188 + + def __init__(self, start_time): + self.start_time = start_time + self.end_time = start_time + self.offset = None + self.data = b"" + + def feed(self, data): + """Takes more data as a bytestring to add to buffer. + Fixes any whole packets in the buffer and returns them as a single bytestring.""" + self.data += data + output = [] + while len(self.data) >= self.PACKET_SIZE: + packet = self.data[:self.PACKET_SIZE] + self.data = self.data[self.PACKET_SIZE:] + output.append(self._fix_packet(packet)) + return b''.join(output) + + def end(self): + """Should be called when no more data will be added. + Checks no data was left over, and returns the final end time (ie. start time + video duration). + """ + if len(self.data) > 0: + raise ValueError("Stream has a partial packet remaining: {!r}", self.data) + return self.end_time + + # TODO we should really be only using PCR to calibrate the offset (ie. we want the first PCR + # to be = start_time, not the first PTS we see which might be the audio stream). + # Also we need to pad the end_time to the time of the NEXT expected frame, or else + # we'll overlap the last frame here with the first frame of the next segment. + # How to determine expected frame? Easiest way is probably average or median difference + # between PCRs, with a reasonable fallback. + def _convert_time(self, old_time): + # If this is the first timestamp we've seen, use it to calibrate offset. + if self.offset is None: + self.offset = self.start_time - old_time + new_time = old_time + self.offset + # It's rare but possible that when resetting times to start at 0, the second packet + # might start slightly earlier than the first and thus have a negative time. + # This isn't encodable in the data format, so just clamp to 0. + new_time = max(0, new_time) + # keep track of latest new_time as the end time + self.end_time = max(self.end_time, new_time) + return new_time + + def _fix_packet(self, packet): + """ + - If an adaptation field is present and contains a PCR, fix the PCR + - If packet is the start of a unit, and the unit begins with 0x0001 + (ie. it's an elementary stream and not a table): + - If the packet header contains a PTS, fix the PTS + - If the packet header cannot be decoded far enough (not enough data in first packet), + bail - we don't care about this edge case. + """ + assert len(packet) == self.PACKET_SIZE + + def check(expr, reason): + if not expr: + raise ValueError("Packet cannot be parsed: {}\n{!r}".format(reason, packet)) + + # Note this is a very simple, hacky parser that only parses as much as we need. + # Useful links: https://en.wikipedia.org/wiki/MPEG_transport_stream + + # 4 byte header: "G" | TEI(1) PUSI(1) PRI(1) PID(5) | PID(8) | TSC(2) AFC(2) CONT(4) + # Of interest to us: + # TEI: If set, data is known to be corrupt + # PUSI: If set, this packet contains a new payload unit + # This matters because payload unit headers contain a timestamp we need to edit + # TSC: If non-zero, indicates data is scrambled (we don't implement handling that) + # AFC: First bit indicates an adaptation field header is present, second bit indicates a payload + check(packet[0:1] == b"G", "Sync byte is incorrect") + check(packet[1] & 0x80 == 0, "Transport error indicator is set") + pusi = bool(packet[1] & 0x40) + check(packet[3] & 0xc0 == 0, "TSC indicates data is scrambled") + has_adaptation_field = bool(packet[3] & 0x20) + has_payload = bool(packet[3] & 0x10) + + has_pcr = False + if has_adaptation_field: + field_length = packet[4] + payload_index = 5 + field_length + # According to the spec, the adaptation field header is at least 1 byte. + # But in the wild we see a header as "present" except 0 bytes long. + # We should just treat this as "not present" + if field_length > 0: + # The adaptation field is a bit field of 8 flags indicating whether optional + # sections are present. Thankfully, the only one we're interested in (the PCR) + # is always the first field if present, so we don't even need to check the others. + has_pcr = bool(packet[5] & 0x10) + if has_pcr: + check(field_length >= 7, "Adaptation field indicates PCR but is too small") + old_time = decode_pcr(packet[6:12]) + new_time = self._convert_time(old_time) + encoded = encode_pcr(new_time) + packet = packet[:6] + encoded + packet[12:] + assert len(packet) == 188 + else: + # No adapatation field, payload starts immediately after the packet header + payload_index = 4 + + if pusi: + # Payload Unit Start Indicator indicates there is a new payload unit in this packet. + # When set, there is an extra byte before the payload indicating where within the + # payload the new payload unit starts. + # A payload unit is a thing like a video frame, audio packet, etc. The payload unit header + # contains a timestamp we need to edit. + check(has_payload, "PUSI set but no payload is present") + payload_pointer = packet[payload_index] + # move index past payload pointer, then seek into payload to find start of payload unit. + unit_index = payload_index + 1 + payload_pointer + # The header we're after is only present in elementary streams, not in program tables. + # We can tell the difference because streams start with a 0x0001 prefix, + # whereas program tables start with a header where at least bits 0x0030 must be set. + # Note wikipedia in https://en.wikipedia.org/wiki/Packetized_elementary_stream + # claims the prefix is 0x000001, but that is including the payload pointer, which seems + # to always be set to 0 for an elementary stream + # (compare https://en.wikipedia.org/wiki/Program-specific_information which also includes + # the payload pointer but says it can be any 8-bit value). + if packet[unit_index : unit_index + 2] == b"\x00\x01": + # unit header looks like: 00, 01, stream id, length(2 bytes), then PES header + # The only thing we care about in the PES header is the top two bits of the second byte, + # which indicates if timestamps are present. + # It's possible that we didn't get enough of the payload in this one packet + # to read the whole header, but exceedingly unlikely. + check(unit_index + 6 < self.PACKET_SIZE, "Payload too small to read unit header") + flags = packet[unit_index + 6] + has_pts = bool(flags & 0x80) + has_dts = bool(flags & 0x40) + check(not has_dts, "DTS timestamp is present, we cannot handle fixing it") + # Once again, PTS is the first optional field, so we don't need to worry + # about other fields being present. + if has_pts: + pts_index = unit_index + 8 + check(pts_index + 5 <= self.PACKET_SIZE, "Payload too small to read PTS") + raw = packet[pts_index : pts_index + 5] + pts = decode_ts(raw, 2) + pts = self._convert_time(pts) + encoded = encode_ts(pts, 2) + packet = packet[:pts_index] + encoded + packet[pts_index + 5:] + assert len(packet) == 188 + return packet + + +def bits(value, start, end): + """Extract bits [START, END) from value, where 0 is LSB""" + size = end - start + return (value >> start) & ((1 << size) - 1) + + +def decode_padded(value, spec): + size = struct.calcsize(spec) + pad = size - len(value) + assert pad >= 0 + value = b"\0" * pad + value + return struct.unpack(spec, value)[0] + + +def encode_pcr(seconds): + assert seconds >= 0 + raw = int(seconds * 27000000) + base, ext = divmod(raw, 300) + assert base < 2**33 + value = (base << 15) + ext + value = struct.pack('!Q', value) + return value[2:] + + +def decode_pcr(value): + value = decode_padded(value, '!Q') + base = bits(value, 15, 48) + extension = bits(value, 0, 9) + raw = 300 * base + extension + seconds = float(raw) / 27000000 + return seconds + + +def encode_ts(seconds, tag): + # bits: TTTTxxx1 xxxxxxxx xxxxxxx1 xxxxxxxx xxxxxxx1 + # T is tag, x is bits of actual number + assert seconds >= 0 + raw = int(seconds * 90000) + a = bits(raw, 30, 33) + b = bits(raw, 15, 30) + c = bits(raw, 0, 15) + value = 1 + (1 << 16) + (1 << 32) + (tag << 36) + (a << 33) + (b << 17) + (c << 1) + value = struct.pack('!Q', value) + return value[3:] + + +def decode_ts(value, tag): + # bits: TTTTxxx1 xxxxxxxx xxxxxxx1 xxxxxxxx xxxxxxx1 + # T is tag, x is bits of actual number + value = decode_padded(value, '!Q') + assert bits(value, 36, 40) == tag + assert all(value & (1 << bit) for bit in [0, 16, 32]) + a = bits(value, 33, 36) + b = bits(value, 17, 32) + c = bits(value, 1, 16) + value = (a << 30) + (b << 15) + c + seconds = float(value) / 90000 + return seconds + + +if __name__ == '__main__': + # simple test: read file from stdin, set start to first arg, output to stdout. + import sys + start_time = float(sys.argv[1]) + fixer = FixTS(start_time) + chunk = None + while chunk != b"": + chunk = sys.stdin.buffer.read(8192) + if chunk: + output = fixer.feed(chunk) + while output: + written = sys.stdout.buffer.write(output) + output = output[written:] + end_time = fixer.end() + sys.stderr.write(str(end_time) + '\n')