diff --git a/roop/core.py b/roop/core.py index 608a075c4..32665f71a 100755 --- a/roop/core.py +++ b/roop/core.py @@ -20,7 +20,7 @@ import roop.globals import roop.metadata import roop.ui as ui -from roop.predicter import predict_image, predict_video +from roop.predictor import predict_image, predict_video from roop.processors.frame.core import get_frame_processors_modules from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path diff --git a/roop/predicter.py b/roop/predictor.py similarity index 80% rename from roop/predicter.py rename to roop/predictor.py index 7ebc2b62e..4b1ace20d 100644 --- a/roop/predicter.py +++ b/roop/predictor.py @@ -1,18 +1,24 @@ +import threading import numpy import opennsfw2 from PIL import Image +from keras import Model from roop.typing import Frame MAX_PROBABILITY = 0.85 +def clear_predictor() -> None: + global PREDICTOR + + PREDICTOR = None + def predict_frame(target_frame: Frame) -> bool: image = Image.fromarray(target_frame) image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) - model = opennsfw2.make_open_nsfw_model() views = numpy.expand_dims(image, axis=0) - _, probability = model.predict(views)[0] + _, probability = get_predictor().predict(views)[0] return probability > MAX_PROBABILITY diff --git a/roop/ui.py b/roop/ui.py index f54426870..8ad11382a 100644 --- a/roop/ui.py +++ b/roop/ui.py @@ -11,7 +11,7 @@ from roop.face_analyser import get_one_face from roop.capturer import get_video_frame, get_video_frame_total from roop.face_reference import get_face_reference, set_face_reference, clear_face_reference -from roop.predicter import predict_frame +from roop.predictor import predict_frame, clear_predictor from roop.processors.frame.core import get_frame_processors_modules from roop.utilities import is_image, is_video, resolve_relative_path @@ -212,6 +212,7 @@ def render_video_preview(video_path: str, size: Tuple[int, int], frame_number: i def toggle_preview() -> None: if PREVIEW.state() == 'normal': PREVIEW.withdraw() + clear_predictor() elif roop.globals.source_path and roop.globals.target_path: init_preview() update_preview(roop.globals.reference_frame_number) @@ -252,12 +253,10 @@ def update_preview(frame_number: int = 0) -> None: preview_label.configure(image=image) -def update_face_reference(delta: int) -> None: - global preview_slider - +def update_face_reference(step: int) -> None: clear_face_reference() reference_frame_number = preview_slider.get() - roop.globals.reference_face_position += delta # type: ignore + roop.globals.reference_face_position += step # type: ignore roop.globals.reference_frame_number = reference_frame_number update_preview(reference_frame_number)