Commit
·
b07203d
1
Parent(s):
815053b
minor fixes
Browse files
app.py
CHANGED
|
@@ -106,8 +106,8 @@ def process_uploaded_video_file(
|
|
| 106 |
|
| 107 |
logging.info(f"Processing uploaded file: {in_filename}")
|
| 108 |
|
| 109 |
-
ans
|
| 110 |
-
return (in_filename, ans[0]), ans[0], ans[1], ans[2],
|
| 111 |
|
| 112 |
|
| 113 |
def process_uploaded_audio_file(
|
|
@@ -137,6 +137,10 @@ def process(language: str, repo_id: str, add_punctuation: str, in_filename: str)
|
|
| 137 |
logging.info(f"add_punctuation: {add_punctuation}")
|
| 138 |
recognizer = get_pretrained_model(repo_id)
|
| 139 |
vad = get_vad()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
if add_punctuation == "Yes":
|
| 141 |
punct = get_punct_model()
|
| 142 |
else:
|
|
@@ -144,7 +148,6 @@ def process(language: str, repo_id: str, add_punctuation: str, in_filename: str)
|
|
| 144 |
|
| 145 |
result, all_text = decode(recognizer, vad, punct, in_filename)
|
| 146 |
logging.info(result)
|
| 147 |
-
logging.info(all_text)
|
| 148 |
|
| 149 |
srt_filename = Path(in_filename).with_suffix(".srt")
|
| 150 |
with open(srt_filename, "w", encoding="utf-8") as f:
|
|
|
|
| 106 |
|
| 107 |
logging.info(f"Processing uploaded file: {in_filename}")
|
| 108 |
|
| 109 |
+
ans = process(language, repo_id, add_punctuation, in_filename)
|
| 110 |
+
return (in_filename, ans[0]), ans[0], ans[1], ans[2], ans[3]
|
| 111 |
|
| 112 |
|
| 113 |
def process_uploaded_audio_file(
|
|
|
|
| 137 |
logging.info(f"add_punctuation: {add_punctuation}")
|
| 138 |
recognizer = get_pretrained_model(repo_id)
|
| 139 |
vad = get_vad()
|
| 140 |
+
|
| 141 |
+
if "whisper" in repo_id:
|
| 142 |
+
add_punctuation = "No"
|
| 143 |
+
|
| 144 |
if add_punctuation == "Yes":
|
| 145 |
punct = get_punct_model()
|
| 146 |
else:
|
|
|
|
| 148 |
|
| 149 |
result, all_text = decode(recognizer, vad, punct, in_filename)
|
| 150 |
logging.info(result)
|
|
|
|
| 151 |
|
| 152 |
srt_filename = Path(in_filename).with_suffix(".srt")
|
| 153 |
with open(srt_filename, "w", encoding="utf-8") as f:
|
decode.py
CHANGED
|
@@ -129,9 +129,7 @@ def decode(
|
|
| 129 |
if punct is not None:
|
| 130 |
seg.text = punct.add_punctuation(seg.text)
|
| 131 |
segment_list.append(seg)
|
| 132 |
-
logging.info(f"all text: {all_text}")
|
| 133 |
all_text = "".join(all_text)
|
| 134 |
-
logging.info(f"all text: {all_text}")
|
| 135 |
if punct is not None:
|
| 136 |
all_text = punct.add_punctuation(all_text)
|
| 137 |
|
|
|
|
| 129 |
if punct is not None:
|
| 130 |
seg.text = punct.add_punctuation(seg.text)
|
| 131 |
segment_list.append(seg)
|
|
|
|
| 132 |
all_text = "".join(all_text)
|
|
|
|
| 133 |
if punct is not None:
|
| 134 |
all_text = punct.add_punctuation(all_text)
|
| 135 |
|