Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
cfbd4d5
1
Parent(s):
de9c8c4
Update preprocessing script to use logging module
Browse files- src/preprocess.py +28 -27
src/preprocess.py
CHANGED
|
@@ -20,6 +20,9 @@ import time
|
|
| 20 |
import requests
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
| 24 |
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
| 25 |
|
|
@@ -204,7 +207,7 @@ def get_words(video_id, process=True, transcript_type='auto', fallback='manual',
|
|
| 204 |
pass # Mark as empty transcript
|
| 205 |
|
| 206 |
except json.decoder.JSONDecodeError:
|
| 207 |
-
|
| 208 |
if os.path.exists(transcript_path):
|
| 209 |
os.remove(transcript_path) # Remove file and try again
|
| 210 |
return get_words(video_id, process, transcript_type, fallback, granularity)
|
|
@@ -543,12 +546,12 @@ def main():
|
|
| 543 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
| 544 |
|
| 545 |
if preprocess_args.update_database:
|
| 546 |
-
|
| 547 |
for mirror in MIRRORS:
|
| 548 |
-
|
| 549 |
if download_file(mirror, raw_dataset_path):
|
| 550 |
break
|
| 551 |
-
|
| 552 |
|
| 553 |
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
| 554 |
processed_db_path = os.path.join(
|
|
@@ -558,11 +561,10 @@ def main():
|
|
| 558 |
@lru_cache(maxsize=1)
|
| 559 |
def read_db():
|
| 560 |
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
| 561 |
-
|
| 562 |
-
'Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
| 563 |
with open(processed_db_path) as fp:
|
| 564 |
return json.load(fp)
|
| 565 |
-
|
| 566 |
db = {}
|
| 567 |
|
| 568 |
allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
|
|
@@ -618,7 +620,7 @@ def main():
|
|
| 618 |
|
| 619 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
| 620 |
if not preprocess_args.keep_duplicate_segments:
|
| 621 |
-
|
| 622 |
for key in db:
|
| 623 |
db[key] = remove_duplicate_segments(db[key])
|
| 624 |
|
|
@@ -646,7 +648,7 @@ def main():
|
|
| 646 |
|
| 647 |
# TODO remove videos that contain a full-video label?
|
| 648 |
|
| 649 |
-
|
| 650 |
|
| 651 |
with open(processed_db_path, 'w') as fp:
|
| 652 |
json.dump(db, fp)
|
|
@@ -660,7 +662,7 @@ def main():
|
|
| 660 |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
| 661 |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
| 662 |
if preprocess_args.do_transcribe:
|
| 663 |
-
|
| 664 |
parsed_database = read_db()
|
| 665 |
|
| 666 |
# Remove transcripts already processed
|
|
@@ -678,7 +680,7 @@ def main():
|
|
| 678 |
get_words(video_id)
|
| 679 |
return video_id
|
| 680 |
|
| 681 |
-
|
| 682 |
with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
|
| 683 |
tqdm(total=len(video_ids)) as progress:
|
| 684 |
|
|
@@ -698,21 +700,21 @@ def main():
|
|
| 698 |
progress.update()
|
| 699 |
|
| 700 |
except KeyboardInterrupt:
|
| 701 |
-
|
| 702 |
|
| 703 |
# only futures that are not done will prevent exiting
|
| 704 |
for future in to_process:
|
| 705 |
future.cancel()
|
| 706 |
|
| 707 |
-
|
| 708 |
concurrent.futures.wait(to_process, timeout=None)
|
| 709 |
-
|
| 710 |
|
| 711 |
final_path = os.path.join(
|
| 712 |
dataset_args.data_dir, dataset_args.processed_file)
|
| 713 |
|
| 714 |
if preprocess_args.do_create:
|
| 715 |
-
|
| 716 |
|
| 717 |
final_data = {}
|
| 718 |
|
|
@@ -786,7 +788,7 @@ def main():
|
|
| 786 |
dataset_args.data_dir, dataset_args.negative_file)
|
| 787 |
|
| 788 |
if preprocess_args.do_generate:
|
| 789 |
-
|
| 790 |
# max_videos=preprocess_args.max_videos,
|
| 791 |
# max_segments=preprocess_args.max_segments,
|
| 792 |
# , max_videos, max_segments
|
|
@@ -868,8 +870,8 @@ def main():
|
|
| 868 |
print(json.dumps(d), file=negative)
|
| 869 |
|
| 870 |
if preprocess_args.do_split:
|
| 871 |
-
|
| 872 |
-
|
| 873 |
|
| 874 |
with open(positive_file, encoding='utf-8') as positive:
|
| 875 |
sponsors = positive.readlines()
|
|
@@ -877,11 +879,11 @@ def main():
|
|
| 877 |
with open(negative_file, encoding='utf-8') as negative:
|
| 878 |
non_sponsors = negative.readlines()
|
| 879 |
|
| 880 |
-
|
| 881 |
random.shuffle(sponsors)
|
| 882 |
random.shuffle(non_sponsors)
|
| 883 |
|
| 884 |
-
|
| 885 |
# Ensure correct ratio of positive to negative segments
|
| 886 |
percentage_negative = 1 - preprocess_args.percentage_positive
|
| 887 |
|
|
@@ -901,12 +903,12 @@ def main():
|
|
| 901 |
excess = non_sponsors[z:]
|
| 902 |
non_sponsors = non_sponsors[:z]
|
| 903 |
|
| 904 |
-
|
| 905 |
all_labelled_segments = sponsors + non_sponsors
|
| 906 |
|
| 907 |
random.shuffle(all_labelled_segments)
|
| 908 |
|
| 909 |
-
|
| 910 |
ratios = [preprocess_args.train_split,
|
| 911 |
preprocess_args.test_split,
|
| 912 |
preprocess_args.valid_split]
|
|
@@ -927,9 +929,9 @@ def main():
|
|
| 927 |
with open(outfile, 'w', encoding='utf-8') as fp:
|
| 928 |
fp.writelines(items)
|
| 929 |
else:
|
| 930 |
-
|
| 931 |
|
| 932 |
-
|
| 933 |
# Save excess items
|
| 934 |
excess_path = os.path.join(
|
| 935 |
dataset_args.data_dir, dataset_args.excess_file)
|
|
@@ -937,10 +939,9 @@ def main():
|
|
| 937 |
with open(excess_path, 'w', encoding='utf-8') as fp:
|
| 938 |
fp.writelines(excess)
|
| 939 |
else:
|
| 940 |
-
|
| 941 |
|
| 942 |
-
|
| 943 |
-
'sponsors,', len(non_sponsors), 'non sponsors')
|
| 944 |
|
| 945 |
|
| 946 |
def split(arr, ratios):
|
|
|
|
| 20 |
import requests
|
| 21 |
|
| 22 |
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
| 27 |
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
| 28 |
|
|
|
|
| 207 |
pass # Mark as empty transcript
|
| 208 |
|
| 209 |
except json.decoder.JSONDecodeError:
|
| 210 |
+
logger.warning(f'JSONDecodeError for {video_id}')
|
| 211 |
if os.path.exists(transcript_path):
|
| 212 |
os.remove(transcript_path) # Remove file and try again
|
| 213 |
return get_words(video_id, process, transcript_type, fallback, granularity)
|
|
|
|
| 546 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
| 547 |
|
| 548 |
if preprocess_args.update_database:
|
| 549 |
+
logger.info('Updating database')
|
| 550 |
for mirror in MIRRORS:
|
| 551 |
+
logger.info(f'Downloading from {mirror}')
|
| 552 |
if download_file(mirror, raw_dataset_path):
|
| 553 |
break
|
| 554 |
+
logger.warning('Failed, trying next')
|
| 555 |
|
| 556 |
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
| 557 |
processed_db_path = os.path.join(
|
|
|
|
| 561 |
@lru_cache(maxsize=1)
|
| 562 |
def read_db():
|
| 563 |
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
| 564 |
+
logger.info('Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
|
|
|
| 565 |
with open(processed_db_path) as fp:
|
| 566 |
return json.load(fp)
|
| 567 |
+
logger.info('Processing raw database')
|
| 568 |
db = {}
|
| 569 |
|
| 570 |
allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
|
|
|
|
| 620 |
|
| 621 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
| 622 |
if not preprocess_args.keep_duplicate_segments:
|
| 623 |
+
logger.info('Remove duplicate segments')
|
| 624 |
for key in db:
|
| 625 |
db[key] = remove_duplicate_segments(db[key])
|
| 626 |
|
|
|
|
| 648 |
|
| 649 |
# TODO remove videos that contain a full-video label?
|
| 650 |
|
| 651 |
+
logger.info(f'Saved {len(db)} videos')
|
| 652 |
|
| 653 |
with open(processed_db_path, 'w') as fp:
|
| 654 |
json.dump(db, fp)
|
|
|
|
| 662 |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
| 663 |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
| 664 |
if preprocess_args.do_transcribe:
|
| 665 |
+
logger.info('Collecting videos')
|
| 666 |
parsed_database = read_db()
|
| 667 |
|
| 668 |
# Remove transcripts already processed
|
|
|
|
| 680 |
get_words(video_id)
|
| 681 |
return video_id
|
| 682 |
|
| 683 |
+
logger.info('Setting up ThreadPoolExecutor')
|
| 684 |
with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
|
| 685 |
tqdm(total=len(video_ids)) as progress:
|
| 686 |
|
|
|
|
| 700 |
progress.update()
|
| 701 |
|
| 702 |
except KeyboardInterrupt:
|
| 703 |
+
logger.info('Gracefully shutting down: Cancelling unscheduled tasks')
|
| 704 |
|
| 705 |
# only futures that are not done will prevent exiting
|
| 706 |
for future in to_process:
|
| 707 |
future.cancel()
|
| 708 |
|
| 709 |
+
logger.info('Waiting for in-progress tasks to complete')
|
| 710 |
concurrent.futures.wait(to_process, timeout=None)
|
| 711 |
+
logger.info('Cancellation successful')
|
| 712 |
|
| 713 |
final_path = os.path.join(
|
| 714 |
dataset_args.data_dir, dataset_args.processed_file)
|
| 715 |
|
| 716 |
if preprocess_args.do_create:
|
| 717 |
+
logger.info('Create final data')
|
| 718 |
|
| 719 |
final_data = {}
|
| 720 |
|
|
|
|
| 788 |
dataset_args.data_dir, dataset_args.negative_file)
|
| 789 |
|
| 790 |
if preprocess_args.do_generate:
|
| 791 |
+
logger.info('Generating')
|
| 792 |
# max_videos=preprocess_args.max_videos,
|
| 793 |
# max_segments=preprocess_args.max_segments,
|
| 794 |
# , max_videos, max_segments
|
|
|
|
| 870 |
print(json.dumps(d), file=negative)
|
| 871 |
|
| 872 |
if preprocess_args.do_split:
|
| 873 |
+
logger.info('Splitting')
|
| 874 |
+
logger.info('Read files')
|
| 875 |
|
| 876 |
with open(positive_file, encoding='utf-8') as positive:
|
| 877 |
sponsors = positive.readlines()
|
|
|
|
| 879 |
with open(negative_file, encoding='utf-8') as negative:
|
| 880 |
non_sponsors = negative.readlines()
|
| 881 |
|
| 882 |
+
logger.info('Shuffle')
|
| 883 |
random.shuffle(sponsors)
|
| 884 |
random.shuffle(non_sponsors)
|
| 885 |
|
| 886 |
+
logger.info('Calculate ratios')
|
| 887 |
# Ensure correct ratio of positive to negative segments
|
| 888 |
percentage_negative = 1 - preprocess_args.percentage_positive
|
| 889 |
|
|
|
|
| 903 |
excess = non_sponsors[z:]
|
| 904 |
non_sponsors = non_sponsors[:z]
|
| 905 |
|
| 906 |
+
logger.info('Join')
|
| 907 |
all_labelled_segments = sponsors + non_sponsors
|
| 908 |
|
| 909 |
random.shuffle(all_labelled_segments)
|
| 910 |
|
| 911 |
+
logger.info('Split')
|
| 912 |
ratios = [preprocess_args.train_split,
|
| 913 |
preprocess_args.test_split,
|
| 914 |
preprocess_args.valid_split]
|
|
|
|
| 929 |
with open(outfile, 'w', encoding='utf-8') as fp:
|
| 930 |
fp.writelines(items)
|
| 931 |
else:
|
| 932 |
+
logger.info(f'Skipping {name}')
|
| 933 |
|
| 934 |
+
logger.info('Write')
|
| 935 |
# Save excess items
|
| 936 |
excess_path = os.path.join(
|
| 937 |
dataset_args.data_dir, dataset_args.excess_file)
|
|
|
|
| 939 |
with open(excess_path, 'w', encoding='utf-8') as fp:
|
| 940 |
fp.writelines(excess)
|
| 941 |
else:
|
| 942 |
+
logger.info(f'Skipping {dataset_args.excess_file}')
|
| 943 |
|
| 944 |
+
logger.info(f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
|
|
|
|
| 945 |
|
| 946 |
|
| 947 |
def split(arr, ratios):
|