59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
|
|
|
|
from ...utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class IFSafetyChecker(PreTrainedModel):
|
|
config_class = CLIPConfig
|
|
|
|
_no_split_modules = ["CLIPEncoderLayer"]
|
|
|
|
def __init__(self, config: CLIPConfig):
|
|
super().__init__(config)
|
|
|
|
self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
|
|
|
|
self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
|
|
self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
|
|
|
|
@torch.no_grad()
|
|
def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
|
|
image_embeds = self.vision_model(clip_input)[0]
|
|
|
|
nsfw_detected = self.p_head(image_embeds)
|
|
nsfw_detected = nsfw_detected.flatten()
|
|
nsfw_detected = nsfw_detected > p_threshold
|
|
nsfw_detected = nsfw_detected.tolist()
|
|
|
|
if any(nsfw_detected):
|
|
logger.warning(
|
|
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
|
|
" Try again with a different prompt and/or seed."
|
|
)
|
|
|
|
for idx, nsfw_detected_ in enumerate(nsfw_detected):
|
|
if nsfw_detected_:
|
|
images[idx] = np.zeros(images[idx].shape)
|
|
|
|
watermark_detected = self.w_head(image_embeds)
|
|
watermark_detected = watermark_detected.flatten()
|
|
watermark_detected = watermark_detected > w_threshold
|
|
watermark_detected = watermark_detected.tolist()
|
|
|
|
if any(watermark_detected):
|
|
logger.warning(
|
|
"Potential watermarked content was detected in one or more images. A black image will be returned instead."
|
|
" Try again with a different prompt and/or seed."
|
|
)
|
|
|
|
for idx, watermark_detected_ in enumerate(watermark_detected):
|
|
if watermark_detected_:
|
|
images[idx] = np.zeros(images[idx].shape)
|
|
|
|
return images, nsfw_detected, watermark_detected
|