Skip to content

Commit

Permalink
Implement Gallery for face selector
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Aug 10, 2023
1 parent 31a50e1 commit 27d9e96
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 30 deletions.
71 changes: 45 additions & 26 deletions roop/uis/__components__/face_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,58 @@

import roop.globals
from roop.capturer import get_video_frame
from roop.face_analyser import get_faces_total
from roop.face_analyser import get_many_faces
from roop.face_reference import clear_face_reference
from roop.uis import core as ui
from roop.uis.typing import ComponentName, Update
from roop.utilities import is_image, is_video

REFERENCE_FACE_POSITION_SLIDER: Optional[gradio.Slider] = None
REFERENCE_FACE_POSITION_GALLERY: Optional[gradio.Gallery] = None
SIMILAR_FACE_DISTANCE_SLIDER: Optional[gradio.Slider] = None


def render() -> None:
global REFERENCE_FACE_POSITION_SLIDER
global REFERENCE_FACE_POSITION_GALLERY
global SIMILAR_FACE_DISTANCE_SLIDER

with gradio.Box():
reference_face_position_slider_args = {
'label': 'REFERENCE FACE POSITION',
'value': roop.globals.reference_face_position,
'step': 1,
'maximum': 0
reference_face_gallery_args = {
'label': 'REFERENCE FACE',
'height': 120,
'object_fit': 'cover',
'columns': 10,
'allow_preview': False,
'visible': True
}
faces = []
if is_image(roop.globals.target_path):
target_frame = cv2.imread(roop.globals.target_path)
reference_face_position_slider_args['maximum'] = get_faces_total(target_frame)
reference_frame = cv2.imread(roop.globals.target_path)
faces = get_many_faces(reference_frame)
if is_video(roop.globals.target_path):
temp_frame = get_video_frame(roop.globals.target_path, roop.globals.reference_frame_number)
reference_face_position_slider_args['maximum'] = get_faces_total(temp_frame)
REFERENCE_FACE_POSITION_SLIDER = gradio.Slider(**reference_face_position_slider_args)
reference_frame = get_video_frame(roop.globals.target_path, roop.globals.reference_frame_number)
faces = get_many_faces(reference_frame)
if faces:
value = []
for face in faces:
start_x, start_y, end_x, end_y = map(int, face['bbox'])
crop_frame = reference_frame[start_y:end_y, start_x:end_x]
value.append(ui.normalize_frame(crop_frame))
reference_face_gallery_args['value'] = value
else:
reference_face_gallery_args['value'] = []
REFERENCE_FACE_POSITION_GALLERY = gradio.Gallery(**reference_face_gallery_args)
SIMILAR_FACE_DISTANCE_SLIDER = gradio.Slider(
label='SIMILAR FACE DISTANCE',
value=roop.globals.similar_face_distance,
maximum=2,
step=0.05
)
ui.register_component('reference_face_position_slider', REFERENCE_FACE_POSITION_SLIDER)
ui.register_component('reference_face_position_gallery', REFERENCE_FACE_POSITION_GALLERY)
ui.register_component('similar_face_distance_slider', SIMILAR_FACE_DISTANCE_SLIDER)


def listen() -> None:
REFERENCE_FACE_POSITION_GALLERY.select(clear_and_update_face_reference_position, outputs=REFERENCE_FACE_POSITION_GALLERY)
SIMILAR_FACE_DISTANCE_SLIDER.change(update_similar_face_distance, inputs=SIMILAR_FACE_DISTANCE_SLIDER)
component_names: List[ComponentName] = [
'target_file',
Expand All @@ -53,26 +66,32 @@ def listen() -> None:
for component_name in component_names:
component = ui.get_component(component_name)
if component:
component.change(update_face_reference_position, inputs=REFERENCE_FACE_POSITION_SLIDER, outputs=REFERENCE_FACE_POSITION_SLIDER)
REFERENCE_FACE_POSITION_SLIDER.change(clear_and_update_face_reference_position, inputs=REFERENCE_FACE_POSITION_SLIDER)
component.change(update_face_reference_position, outputs=REFERENCE_FACE_POSITION_GALLERY)


def clear_and_update_face_reference_position(reference_face_position: int) -> Update:
def clear_and_update_face_reference_position(event: gradio.SelectData) -> Update:
clear_face_reference()
return update_face_reference_position(reference_face_position)
return update_face_reference_position(event.index)


def update_face_reference_position(reference_face_position: int) -> Update:
sleep(0.5)
maximum = 0
def update_face_reference_position(reference_face_position: int = 0) -> Update:
sleep(0.2)
roop.globals.reference_face_position = reference_face_position
faces = []
if is_image(roop.globals.target_path):
target_frame = cv2.imread(roop.globals.target_path)
maximum = max(get_faces_total(target_frame) - 1, 0)
reference_frame = cv2.imread(roop.globals.target_path)
faces = get_many_faces(reference_frame)
if is_video(roop.globals.target_path):
temp_frame = get_video_frame(roop.globals.target_path, roop.globals.reference_frame_number)
maximum = max(get_faces_total(temp_frame) - 1, 0)
return gradio.update(value=reference_face_position, maximum=maximum)
reference_frame = get_video_frame(roop.globals.target_path, roop.globals.reference_frame_number)
faces = get_many_faces(reference_frame)
if faces:
value = []
for face in faces:
start_x, start_y, end_x, end_y = map(int, face['bbox'])
crop_frame = reference_frame[start_y:end_y, start_x:end_x]
value.append(ui.normalize_frame(crop_frame))
return gradio.update(value=value)
return gradio.update(value=None)


def update_similar_face_distance(similar_face_distance: float) -> Update:
Expand Down
6 changes: 4 additions & 2 deletions roop/uis/__components__/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def listen() -> None:
component_names: List[ComponentName] = [
'source_file',
'target_file',
'reference_face_position_slider',
'similar_face_distance_slider',
'frame_processors_checkbox_group',
'many_faces_checkbox'
Expand All @@ -63,10 +62,13 @@ def listen() -> None:
component = ui.get_component(component_name)
if component:
component.change(update, inputs=PREVIEW_FRAME_SLIDER, outputs=[PREVIEW_IMAGE, PREVIEW_FRAME_SLIDER])
reference_face_position_gallery = ui.get_component('reference_face_position_gallery')
if reference_face_position_gallery:
reference_face_position_gallery.select(update, inputs=PREVIEW_FRAME_SLIDER, outputs=[PREVIEW_IMAGE, PREVIEW_FRAME_SLIDER])


def update(frame_number: int = 0) -> Tuple[Update, Update]:
sleep(0.5)
sleep(0.1)
if is_image(roop.globals.target_path):
target_frame = cv2.imread(roop.globals.target_path)
preview_frame = get_preview_frame(target_frame)
Expand Down
2 changes: 1 addition & 1 deletion roop/uis/__components__/trim_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def listen() -> None:


def remote_update() -> Tuple[Update, Update]:
sleep(0.5)
sleep(0.1)
if is_video(roop.globals.target_path):
video_frame_total = get_video_frame_total(roop.globals.target_path)
roop.globals.trim_frame_start = 0
Expand Down
2 changes: 1 addition & 1 deletion roop/uis/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'source_file',
'target_file',
'preview_frame_slider',
'reference_face_position_slider',
'reference_face_position_gallery',
'similar_face_distance_slider',
'frame_processors_checkbox_group',
'many_faces_checkbox'
Expand Down

0 comments on commit 27d9e96

Please sign in to comment.