diff --git a/README.md b/README.md index 591e8b5..59f6bb0 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,7 @@ The `formatters` submodule provides a few basic formatters to wrap around you tr We provided a few subclasses of formatters to use: - JSONFormatter +- PrettyPrintFormatter - TextFormatter - WebVTTFormatter (a basic implementation) diff --git a/youtube_transcript_api/_cli.py b/youtube_transcript_api/_cli.py index bf83331..870be3a 100644 --- a/youtube_transcript_api/_cli.py +++ b/youtube_transcript_api/_cli.py @@ -1,11 +1,9 @@ -import json - -import pprint - import argparse from ._api import YouTubeTranscriptApi +from .formatters import FormatterLoader + class YouTubeTranscriptCli(object): def __init__(self, args): @@ -34,7 +32,7 @@ class YouTubeTranscriptCli(object): return '\n\n'.join( [str(exception) for exception in exceptions] - + ([json.dumps(transcripts) if parsed_args.json else pprint.pformat(transcripts)] if transcripts else []) + + ([FormatterLoader().load(parsed_args.format).format_transcripts(transcripts)] if transcripts else []) ) def _fetch_transcript(self, parsed_args, proxies, cookies, video_id): @@ -98,11 +96,10 @@ class YouTubeTranscriptCli(object): help='If this flag is set transcripts which have been manually created will not be retrieved.', ) parser.add_argument( - '--json', - action='store_const', - const=True, - default=False, - help='If this flag is set the output will be JSON formatted.', + '--format', + type=str, + default='pretty', + choices=tuple(FormatterLoader.TYPES.keys()), ) parser.add_argument( '--translate', diff --git a/youtube_transcript_api/formatters.py b/youtube_transcript_api/formatters.py index 1cc6e9d..d957a41 100644 --- a/youtube_transcript_api/formatters.py +++ b/youtube_transcript_api/formatters.py @@ -9,49 +9,75 @@ class Formatter(object): Formatter classes should inherit from this class and implement their own .format() method which should return a string. A transcript is represented by a List of Dictionary items. - - :param transcript: list representing 1 or more transcripts - :type transcript: list """ - def __init__(self, transcript): - if not isinstance(transcript, list): - raise TypeError("'transcript' must be of type: List") - self._transcript = transcript - - def format(self, **kwargs): + def format_transcript(self, transcript, **kwargs): raise NotImplementedError('A subclass of Formatter must implement ' \ - 'their own .format() method.') + 'their own .format_transcript() method.') + + def format_transcripts(self, transcripts, **kwargs): + raise NotImplementedError('A subclass of Formatter must implement ' \ + 'their own .format_transcripts() method.') class PrettyPrintFormatter(Formatter): - def format(self, **kwargs): + def format_transcript(self, transcript, **kwargs): """Pretty prints a transcript. - :return: A pretty printed string representation of the transcript dict.' + :param transcript: + :return: A pretty printed string representation of the transcript.' :rtype str """ - return pprint.pformat(self._transcript, **kwargs) + return pprint.pformat(transcript, **kwargs) + + def format_transcripts(self, transcripts, **kwargs): + """Pretty prints a list of transcripts. + + :param transcripts: + :return: A pretty printed string representation of the transcripts.' + :rtype str + """ + return self.format_transcript(transcripts, **kwargs) class JSONFormatter(Formatter): - def format(self, **kwargs): + def format_transcript(self, transcript, **kwargs): """Converts a transcript into a JSON string. + :param transcript: :return: A JSON string representation of the transcript.' :rtype str """ - return json.dumps(self._transcript, **kwargs) + return json.dumps(transcript, **kwargs) + + def format_transcripts(self, transcripts, **kwargs): + """Converts a list of transcripts into a JSON string. + + :param transcripts: + :return: A JSON string representation of the transcript.' + :rtype str + """ + return self.format_transcript(transcripts, **kwargs) class TextFormatter(Formatter): - def format(self, **kwargs): + def format_transcript(self, transcript, **kwargs): """Converts a transcript into plain text with no timestamps. + :param transcript: :return: all transcript text lines separated by newline breaks.' :rtype str """ - return "\n".join(line['text'] for line in self._transcript) + return '\n'.join(line['text'] for line in transcript) + + def format_transcripts(self, transcripts, **kwargs): + """Converts a list of transcripts into plain text with no timestamps. + + :param transcripts: + :return: all transcript text lines separated by newline breaks.' + :rtype str + """ + return '\n\n\n'.join([self.format_transcript(transcript, **kwargs) for transcript in transcripts]) class WebVTTFormatter(Formatter): @@ -77,19 +103,20 @@ class WebVTTFormatter(Formatter): ms = int(round((time - int(time))*1000, 2)) return "{:02d}:{:02d}:{:02d}.{:03d}".format(hours, mins, secs, ms) - def format(self, **kwargs): + def format_transcript(self, transcript, **kwargs): """A basic implementation of WEBVTT formatting. + :param transcript: :reference: https://www.w3.org/TR/webvtt1/#introduction-caption """ lines = [] - for i, line in enumerate(self._transcript): - if i < len(self._transcript) - 1: + for i, line in enumerate(transcript): + if i < len(transcript) - 1: # Looks ahead, use next start time since duration value # would create an overlap between start times. time_text = "{} --> {}".format( self._seconds_to_timestamp(line['start']), - self._seconds_to_timestamp(self._transcript[i + 1]['start']) + self._seconds_to_timestamp(transcript[i + 1]['start']) ) else: # Reached the end, cannot look ahead, use duration now. @@ -102,6 +129,14 @@ class WebVTTFormatter(Formatter): return "WEBVTT\n\n" + "\n\n".join(lines) + "\n" + def format_transcripts(self, transcripts, **kwargs): + """A basic implementation of WEBVTT formatting for a list of transcripts. + + :param transcripts: + :reference: https://www.w3.org/TR/webvtt1/#introduction-caption + """ + return '\n\n\n'.join([self.format_transcript(transcript, **kwargs) for transcript in transcripts]) + class FormatterLoader(object): TYPES = { @@ -118,10 +153,13 @@ class FormatterLoader(object): f'Choose one of the following formats: {", ".join(FormatterLoader.TYPES.keys())}' ) - def __init__(self, formatter_type='pretty'): + def load(self, formatter_type='pretty'): + """ + Loads the Formatter for the given formatter type. + + :param formatter_type: + :return: Formatter object + """ if formatter_type not in FormatterLoader.TYPES.keys(): raise FormatterLoader.UnknownFormatterType(formatter_type) - self._formatter = FormatterLoader.TYPES[formatter_type] - - def load(self, transcript): - return self._formatter(transcript) + return FormatterLoader.TYPES[formatter_type]() diff --git a/youtube_transcript_api/test/test_cli.py b/youtube_transcript_api/test/test_cli.py index 158cd35..1cb3eff 100644 --- a/youtube_transcript_api/test/test_cli.py +++ b/youtube_transcript_api/test/test_cli.py @@ -25,50 +25,52 @@ class TestYouTubeTranscriptCli(TestCase): YouTubeTranscriptApi.list_transcripts = MagicMock(return_value=self.transcript_list_mock) def test_argument_parsing(self): - parsed_args = YouTubeTranscriptCli('v1 v2 --json --languages de en'.split())._parse_args() + parsed_args = YouTubeTranscriptCli('v1 v2 --format json --languages de en'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.http_proxy, '') self.assertEqual(parsed_args.https_proxy, '') - parsed_args = YouTubeTranscriptCli('v1 v2 --languages de en --json'.split())._parse_args() + parsed_args = YouTubeTranscriptCli('v1 v2 --languages de en --format json'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.http_proxy, '') self.assertEqual(parsed_args.https_proxy, '') - parsed_args = YouTubeTranscriptCli(' --json v1 v2 --languages de en'.split())._parse_args() + parsed_args = YouTubeTranscriptCli(' --format json v1 v2 --languages de en'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.http_proxy, '') self.assertEqual(parsed_args.https_proxy, '') parsed_args = YouTubeTranscriptCli( - 'v1 v2 --languages de en --json --http-proxy http://user:pass@domain:port --https-proxy https://user:pass@domain:port'.split() + 'v1 v2 --languages de en --format json ' + '--http-proxy http://user:pass@domain:port ' + '--https-proxy https://user:pass@domain:port'.split() )._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.http_proxy, 'http://user:pass@domain:port') self.assertEqual(parsed_args.https_proxy, 'https://user:pass@domain:port') parsed_args = YouTubeTranscriptCli( - 'v1 v2 --languages de en --json --http-proxy http://user:pass@domain:port'.split() + 'v1 v2 --languages de en --format json --http-proxy http://user:pass@domain:port'.split() )._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.http_proxy, 'http://user:pass@domain:port') self.assertEqual(parsed_args.https_proxy, '') parsed_args = YouTubeTranscriptCli( - 'v1 v2 --languages de en --json --https-proxy https://user:pass@domain:port'.split() + 'v1 v2 --languages de en --format json --https-proxy https://user:pass@domain:port'.split() )._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.https_proxy, 'https://user:pass@domain:port') self.assertEqual(parsed_args.http_proxy, '') @@ -76,28 +78,28 @@ class TestYouTubeTranscriptCli(TestCase): def test_argument_parsing__only_video_ids(self): parsed_args = YouTubeTranscriptCli('v1 v2'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, False) + self.assertEqual(parsed_args.format, 'pretty') self.assertEqual(parsed_args.languages, ['en']) def test_argument_parsing__fail_without_video_ids(self): with self.assertRaises(SystemExit): - YouTubeTranscriptCli('--json'.split())._parse_args() + YouTubeTranscriptCli('--format json'.split())._parse_args() def test_argument_parsing__json(self): - parsed_args = YouTubeTranscriptCli('v1 v2 --json'.split())._parse_args() + parsed_args = YouTubeTranscriptCli('v1 v2 --format json'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['en']) - parsed_args = YouTubeTranscriptCli('--json v1 v2'.split())._parse_args() + parsed_args = YouTubeTranscriptCli('--format json v1 v2'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, True) + self.assertEqual(parsed_args.format, 'json') self.assertEqual(parsed_args.languages, ['en']) def test_argument_parsing__languages(self): parsed_args = YouTubeTranscriptCli('v1 v2 --languages de en'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, False) + self.assertEqual(parsed_args.format, 'pretty') self.assertEqual(parsed_args.languages, ['de', 'en']) def test_argument_parsing__proxies(self): @@ -135,13 +137,13 @@ class TestYouTubeTranscriptCli(TestCase): def test_argument_parsing__translate(self): parsed_args = YouTubeTranscriptCli('v1 v2 --languages de en --translate cz'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, False) + self.assertEqual(parsed_args.format, 'pretty') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.translate, 'cz') parsed_args = YouTubeTranscriptCli('v1 v2 --translate cz --languages de en'.split())._parse_args() self.assertEqual(parsed_args.video_ids, ['v1', 'v2']) - self.assertEqual(parsed_args.json, False) + self.assertEqual(parsed_args.format, 'pretty') self.assertEqual(parsed_args.languages, ['de', 'en']) self.assertEqual(parsed_args.translate, 'cz') @@ -188,7 +190,9 @@ class TestYouTubeTranscriptCli(TestCase): def test_run__exclude_manually_created_and_generated(self): self.assertEqual( - YouTubeTranscriptCli('v1 v2 --languages de en --exclude-manually-created --exclude-generated'.split()).run(), + YouTubeTranscriptCli( + 'v1 v2 --languages de en --exclude-manually-created --exclude-generated'.split() + ).run(), '' ) @@ -204,7 +208,7 @@ class TestYouTubeTranscriptCli(TestCase): YouTubeTranscriptApi.list_transcripts.assert_any_call('v2', proxies=None, cookies=None) def test_run__json_output(self): - output = YouTubeTranscriptCli('v1 v2 --languages de en --json'.split()).run() + output = YouTubeTranscriptCli('v1 v2 --languages de en --format json'.split()).run() # will fail if output is not valid json json.loads(output) diff --git a/youtube_transcript_api/test/test_formatters.py b/youtube_transcript_api/test/test_formatters.py index bb0b274..748ed02 100644 --- a/youtube_transcript_api/test/test_formatters.py +++ b/youtube_transcript_api/test/test_formatters.py @@ -20,23 +20,16 @@ class TestFormatters(TestCase): {'text': 'line between', 'start': 1.5, 'duration': 2.0}, {'text': 'testing the end line', 'start': 2.5, 'duration': 3.25} ] - - def test_base_formatter_valid_type(self): - with self.assertRaises(TypeError) as err: - Formatter({"test": []}) - expected_err = "'transcript' must be of type: List" - self.assertEqual(expected_err, str(err.exception)) - + self.transcripts = [self.transcript, self.transcript] + def test_base_formatter_format_call(self): - with self.assertRaises(NotImplementedError) as err: - Formatter(self.transcript).format() - - expected_err = "A subclass of Formatter must implement their own " \ - ".format() method." - self.assertEqual(expected_err, str(err.exception)) + with self.assertRaises(NotImplementedError): + Formatter().format_transcript(self.transcript) + with self.assertRaises(NotImplementedError): + Formatter().format_transcripts([self.transcript]) def test_webvtt_formatter_starting(self): - content = WebVTTFormatter(self.transcript).format() + content = WebVTTFormatter().format_transcript(self.transcript) lines = content.split('\n') # test starting lines @@ -44,42 +37,66 @@ class TestFormatters(TestCase): self.assertEqual(lines[1], "") def test_webvtt_formatter_ending(self): - content = WebVTTFormatter(self.transcript).format() + content = WebVTTFormatter().format_transcript(self.transcript) lines = content.split('\n') # test ending lines self.assertEqual(lines[-2], self.transcript[-1]['text']) self.assertEqual(lines[-1], "") + def test_webvtt_formatter_many(self): + formatter = WebVTTFormatter() + content = formatter.format_transcripts(self.transcripts) + formatted_single_transcript = formatter.format_transcript(self.transcript) + + self.assertEqual(content, formatted_single_transcript + '\n\n\n' + formatted_single_transcript) + def test_pretty_print_formatter(self): - content = PrettyPrintFormatter(self.transcript).format() + content = PrettyPrintFormatter().format_transcript(self.transcript) self.assertEqual(content, pprint.pformat(self.transcript)) + def test_pretty_print_formatter_many(self): + content = PrettyPrintFormatter().format_transcripts(self.transcripts) + + self.assertEqual(content, pprint.pformat(self.transcripts)) + def test_json_formatter(self): - content = JSONFormatter(self.transcript).format() + content = JSONFormatter().format_transcript(self.transcript) self.assertEqual(json.loads(content), self.transcript) + def test_json_formatter_many(self): + content = JSONFormatter().format_transcripts(self.transcripts) + + self.assertEqual(json.loads(content), self.transcripts) + def test_text_formatter(self): - content = TextFormatter(self.transcript).format() + content = TextFormatter().format_transcript(self.transcript) lines = content.split('\n') self.assertEqual(lines[0], self.transcript[0]["text"]) self.assertEqual(lines[-1], self.transcript[-1]["text"]) + def test_text_formatter_many(self): + formatter = TextFormatter() + content = formatter.format_transcripts(self.transcripts) + formatted_single_transcript = formatter.format_transcript(self.transcript) + + self.assertEqual(content, formatted_single_transcript + '\n\n\n' + formatted_single_transcript) + def test_formatter_loader(self): - loader = FormatterLoader('json') - formatter = loader.load(self.transcript) + loader = FormatterLoader() + formatter = loader.load('json') self.assertTrue(isinstance(formatter, JSONFormatter)) def test_formatter_loader__default_formatter(self): loader = FormatterLoader() - formatter = loader.load(self.transcript) + formatter = loader.load() self.assertTrue(isinstance(formatter, PrettyPrintFormatter)) def test_formatter_loader__unknown_format(self): with self.assertRaises(FormatterLoader.UnknownFormatterType): - FormatterLoader('png') + FormatterLoader().load('png')