Function bodies 86 total
ResBlock.__init__ method · python · L20-L34 (15 LOC)models/archs/vqgan_arch.py
def __init__(self, in_channels, out_channels=None):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
if in_channels != out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, 1)
else:
self.conv_out = NoneResBlock.forward method · python · L36-L45 (10 LOC)models/archs/vqgan_arch.py
def forward(self, x_in):
x = self.norm1(x_in)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.conv_out is not None:
x_in = self.conv_out(x_in)
return x + x_inAttnBlock.__init__ method · python · L49-L56 (8 LOC)models/archs/vqgan_arch.py
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = nn.Conv2d(in_channels, in_channels, 1)
self.k = nn.Conv2d(in_channels, in_channels, 1)
self.v = nn.Conv2d(in_channels, in_channels, 1)
self.proj_out = nn.Conv2d(in_channels, in_channels, 1)AttnBlock.forward method · python · L58-L76 (19 LOC)models/archs/vqgan_arch.py
def forward(self, x):
h_ = self.norm(x)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = q.reshape(b, c, h * w).permute(0, 2, 1) # b, hw, c
k = k.reshape(b, c, h * w) # b, c, hw
w_ = torch.bmm(q, k) * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
v = v.reshape(b, c, h * w) # b, c, hw
w_ = w_.permute(0, 2, 1) # b, hw, hw
h_ = torch.bmm(v, w_) # b, c, hw
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_VectorQuantizer.__init__ method · python · L103-L112 (10 LOC)models/archs/vqgan_arch.py
def __init__(self, codebook_size, emb_dim, beta=0.25):
super().__init__()
self.codebook_size = codebook_size
self.emb_dim = emb_dim
self.beta = beta
self.embedding = nn.Embedding(codebook_size, emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / codebook_size,
1.0 / codebook_size,
)VectorQuantizer.forward method · python · L114-L131 (18 LOC)models/archs/vqgan_arch.py
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
z_q = z + (z_q - z).detach()
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, min_encoding_indicesEncoder.__init__ method · python · L138-L182 (45 LOC)models/archs/vqgan_arch.py
def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial conv
blocks.append(nn.Conv2d(in_channels, nf, 3, 1, 1))
# body
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
Powered by Repobility — scan your code at https://repobility.com
Generator.__init__ method · python · L191-L230 (40 LOC)models/archs/vqgan_arch.py
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, 3, 1, 1))
# mid
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# body (reverse of encoder)
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_resFuse_sft_block.__init__ method · python · L239-L251 (13 LOC)models/archs/vqgan_arch.py
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, 3, 1, 1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, 3, 1, 1),
)Fuse_sft_block.forward method · python · L253-L259 (7 LOC)models/archs/vqgan_arch.py
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return outVQAutoEncoder.__init__ method · python · L263-L302 (40 LOC)models/archs/vqgan_arch.py
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=None,
codebook_size=1024,
emb_dim=256,
beta=0.25,
):
super().__init__()
if attn_resolutions is None:
attn_resolutions = [16]
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.emb_dim = emb_dim
self.encoder = Encoder(
self.in_channels,
nf,
emb_dim,
ch_mult,
self.n_blocks,
img_size,
attn_resolutions,
)
self.quantize = VectorQuantizer(codebook_size, emb_dim, beta=beta)
self.generator = Generator(
nf,
emb_dim,
ch_mult,
self.n_blocks,
img_size,
attn_resolutions,
)DDColorWrapper.__init__ method · python · L22-L44 (23 LOC)models/wrappers.py
def __init__(self, model_path: str, device: str, variant: str = "paper_tiny") -> None:
"""Load DDColor model and build the colorization pipeline.
Args:
model_path: Path to the DDColor checkpoint file.
device: Compute device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
variant: Model variant name used to select model size.
"""
from ddcolor import ColorizationPipeline, DDColor, build_ddcolor_model
model_size = self.MODEL_SIZE_MAP.get(variant, "tiny")
torch_device = torch.device(device)
self.model = build_ddcolor_model(
DDColor,
model_path=model_path,
input_size=512,
model_size=model_size,
device=torch_device,
)
self.pipeline = ColorizationPipeline(self.model, input_size=512, device=torch_device)
self.device = device
logger.info("DDColorWrapper loaded — %s (size=%s) on %s", model_path, model_size, deDDColorWrapper.predict method · python · L46-L56 (11 LOC)models/wrappers.py
def predict(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Colorize a grayscale/B&W image.
Args:
image: BGR uint8 numpy array.
**kwargs: Optional ``render_factor`` (unused by DDColor pipeline).
Returns:
Colorized BGR uint8 numpy array.
"""
return self.pipeline.process(image)RealESRGANWrapper.__init__ method · python · L62-L112 (51 LOC)models/wrappers.py
def __init__(self, model_path: str, device: str, variant: str = "x4plus") -> None:
"""Load a Real-ESRGAN RRDBNet checkpoint, inferring architecture from keys.
Args:
model_path: Path to the Real-ESRGAN checkpoint file.
device: Compute device string.
variant: Model variant name (informational, architecture is inferred).
"""
import re
from models.archs.rrdbnet_arch import RRDBNet
self.device = device
state_dict = torch.load(model_path, map_location=device, weights_only=False)
if "params_ema" in state_dict:
state_dict = state_dict["params_ema"]
elif "params" in state_dict:
state_dict = state_dict["params"]
# Infer architecture from checkpoint so any RRDBNet weights work.
num_feat = state_dict["conv_first.weight"].shape[0]
num_in_ch = state_dict["conv_first.weight"].shape[1]
num_out_ch = state_dict["conv_last.weight"].shRealESRGANWrapper.predict method · python · L114-L137 (24 LOC)models/wrappers.py
def predict(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Upscale an image using the loaded Real-ESRGAN model.
Args:
image: BGR uint8 numpy array.
**kwargs: Optional ``tile_size`` (int) for tiled processing,
``scale`` (int, unused — scale is determined by the model).
Returns:
Upscaled BGR uint8 numpy array.
"""
tile_size = kwargs.get("tile_size", 0)
# BGR uint8 -> RGB float32 [0, 1]
img = image[:, :, ::-1].astype(np.float32) / 255.0
img_t = torch.from_numpy(img.copy()).permute(2, 0, 1).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self._tile_process(img_t, tile_size) if tile_size > 0 else self.model(img_t)
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
output = np.clip(output * 255.0, 0, 255).astype(np.uint8)
# RGB -> BGR
return output[:, :, ::-1].copy()Open data scored by Repobility · https://repobility.com
RealESRGANWrapper._tile_process method · python · L139-L200 (62 LOC)models/wrappers.py
def _tile_process(self, img, tile_size=256, tile_pad=10):
"""Process an image tensor in tiles to limit VRAM usage.
Splits the input into a grid of overlapping tiles, runs each through the
model, strips the overlap padding, and assembles the output.
Args:
img: Input tensor of shape ``(B, C, H, W)``.
tile_size: Tile dimension in pixels (default 256).
tile_pad: Overlap padding in pixels (default 10).
Returns:
Output tensor of shape ``(B, C, H*scale, W*scale)``.
"""
batch, channel, height, width = img.shape
output_height = height * self.scale
output_width = width * self.scale
output = img.new_zeros(batch, channel, output_height, output_width)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)
for y in range(tiles_y):
for x in range(tiles_x):
ofs_x = x * tile_size
NAFNetWrapper.__init__ method · python · L206-L256 (51 LOC)models/wrappers.py
def __init__(self, model_path: str, device: str, variant: str = "denoise") -> None:
"""Load a NAFNet checkpoint, inferring architecture from keys.
Args:
model_path: Path to the NAFNet checkpoint file.
device: Compute device string.
variant: Model variant name (informational, architecture is inferred).
"""
import re
from models.archs.nafnet_arch import NAFNet
self.device = device
state_dict = torch.load(model_path, map_location=device, weights_only=False)
if "params" in state_dict:
state_dict = state_dict["params"]
# Infer architecture from checkpoint keys so we don't need
# hard-coded block counts per variant.
enc_blk_nums = self._count_blocks(state_dict, r"encoders\.(\d+)\.(\d+)\.")
dec_blk_nums = self._count_blocks(state_dict, r"decoders\.(\d+)\.(\d+)\.")
middle_blk_num = 1 + max(
(
int(m.group(1))NAFNetWrapper._count_blocks method · python · L259-L277 (19 LOC)models/wrappers.py
def _count_blocks(state_dict, pattern):
"""Count blocks per stage from checkpoint key names.
Args:
state_dict: Model state dict mapping key names to tensors.
pattern: Regex with two capture groups: ``(stage_idx, block_idx)``.
Returns:
List of block counts, one per stage, in order.
"""
import re
stages: dict[int, int] = {}
for k in state_dict:
m = re.match(pattern, k)
if m:
stage, idx = int(m.group(1)), int(m.group(2))
stages[stage] = max(stages.get(stage, 0), idx + 1)
return [stages[i] for i in range(len(stages))]NAFNetWrapper.predict method · python · L279-L299 (21 LOC)models/wrappers.py
def predict(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Remove noise or blur from an image using the loaded NAFNet model.
Args:
image: BGR uint8 numpy array.
**kwargs: Optional ``tile_size`` (int, unused by NAFNet wrapper).
Returns:
Restored BGR uint8 numpy array.
"""
# BGR uint8 -> RGB float32 [0, 1]
img = image[:, :, ::-1].astype(np.float32) / 255.0
img_t = torch.from_numpy(img.copy()).permute(2, 0, 1).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(img_t)
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
output = np.clip(output * 255.0, 0, 255).astype(np.uint8)
# RGB -> BGR
return output[:, :, ::-1].copy()CodeFormerWrapper.__init__ method · python · L305-L351 (47 LOC)models/wrappers.py
def __init__(self, model_path: str, device: str, variant: str = "v0.1") -> None:
"""Load CodeFormer checkpoint and initialize the face detection helper.
Args:
model_path: Path to the CodeFormer checkpoint file.
device: Compute device string.
variant: Model variant name (informational).
"""
from models.archs.codeformer_arch import CodeFormer
self.device = device
model = CodeFormer(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
)
state_dict = torch.load(model_path, map_location=device, weights_only=False)
if "params_ema" in state_dict:
state_dict = state_dict["params_ema"]
elif "params" in state_dict:
state_dict = state_dict["params"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.to(deviceCodeFormerWrapper.predict method · python · L353-L419 (67 LOC)models/wrappers.py
def predict(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Detect and restore faces in an image.
Detects all faces, restores each via the CodeFormer transformer, and
pastes them back into the (optionally upscaled) original. If no faces
are detected, returns a bicubic upscale.
Args:
image: BGR uint8 numpy array.
**kwargs: Optional ``fidelity`` (float, 0-1) and ``upscale`` (int, 1-4).
Returns:
Face-restored BGR uint8 numpy array.
"""
fidelity = kwargs.get("fidelity", 0.5)
upscale = kwargs.get("upscale", 2)
self.face_helper.clean_all()
self.face_helper.upscale_factor = upscale
# facexlib expects BGR uint8 numpy array
self.face_helper.read_image(image)
self.face_helper.get_face_landmarks_5(
only_center_face=False,
resize=640,
eye_dist_threshold=5,
)
self.face_helper.align_warp_fLaMaWrapper.__init__ method · python · L428-L439 (12 LOC)models/wrappers.py
def __init__(self, model_path: str, device: str, variant: str = "big") -> None:
"""Load a LaMa TorchScript checkpoint.
Args:
model_path: Path to the ``.pt`` TorchScript model file.
device: Compute device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
variant: Model variant name (informational).
"""
self.device = device
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
logger.info("LaMaWrapper loaded — %s on %s", model_path, device)LaMaWrapper.predict method · python · L441-L503 (63 LOC)models/wrappers.py
def predict(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Inpaint masked regions of an image.
Args:
image: BGR uint8 numpy array.
**kwargs: Required ``mask`` (grayscale uint8 numpy array, 255 = inpaint).
Returns:
Inpainted BGR uint8 numpy array.
Raises:
ValueError: If no mask is provided or mask dimensions don't match.
"""
mask = kwargs.get("mask")
if mask is None:
raise ValueError("Inpainting requires a mask image")
h, w = image.shape[:2]
mh, mw = mask.shape[:2]
if (mh, mw) != (h, w):
raise ValueError(f"Mask dimensions ({mw}x{mh}) do not match image dimensions ({w}x{h})")
# Ensure mask is single-channel
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# BGR -> RGB, float32 [0,1]
img_rgb = image[:, :, ::-1].astype(np.float32) / 255.0
mask_f = maskRepobility · code-quality intelligence platform · https://repobility.com
OldPhotoRestoreWrapper.__init__ method · python · L524-L610 (87 LOC)models/wrappers.py
def __init__(self, model_path: str, device: str, variant: str = "v1") -> None:
"""Load all sub-networks for old photo restoration.
Args:
model_path: Path to the directory containing all weight files.
device: Compute device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
variant: Model variant name (informational).
"""
import dlib
from models.archs.old_photo_detect_arch import UNet
from models.archs.old_photo_face_arch import SPADEGenerator
from models.archs.old_photo_global_arch import (
GlobalGenerator_DCDCv2,
Mapping_Model_with_mask_2,
)
self.device = device
# Scratch detection UNet
self.scratch_net = UNet(in_channels=1, out_channels=1, depth=4, conv_num=2, wf=6)
scratch_sd = torch.load(
os.path.join(model_path, "scratch_detection.pt"),
map_location=device,
weights_only=False,
)
OldPhotoRestoreWrapper.predict method · python · L612-L644 (33 LOC)models/wrappers.py
def predict(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Restore an old/damaged photo.
Args:
image: BGR uint8 numpy array.
**kwargs: Optional parameters:
- ``with_scratch`` (bool, default True): detect and repair scratches.
- ``with_face`` (bool, default True): enhance detected faces.
- ``scratch_threshold`` (float, default 0.4): scratch detection threshold.
Returns:
Restored BGR uint8 numpy array.
"""
with_scratch = kwargs.get("with_scratch", True)
with_face = kwargs.get("with_face", True)
scratch_threshold = kwargs.get("scratch_threshold", 0.4)
h, w = image.shape[:2]
# Step 1: Detect scratches
if with_scratch:
scratch_mask = self._detect_scratches(image, threshold=scratch_threshold)
else:
scratch_mask = np.zeros((h, w), dtype=np.uint8)
# Step 2: Global restoration
OldPhotoRestoreWrapper._detect_scratches method · python · L646-L680 (35 LOC)models/wrappers.py
def _detect_scratches(self, image: np.ndarray, threshold: float = 0.4) -> np.ndarray:
"""Detect scratches in the image using UNet.
Args:
image: BGR uint8 numpy array.
threshold: Sigmoid probability below which pixels are zeroed out.
Returns:
Soft scratch mask (uint8, 0-255) at the original image resolution.
Values represent scratch confidence: 0 = clean, 255 = definite scratch.
"""
h, w = image.shape[:2]
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Resize to fixed size for the UNet
resized = cv2.resize(gray, (256, 256), interpolation=cv2.INTER_AREA)
inp = resized.astype(np.float32) / 255.0
inp_t = torch.from_numpy(inp).unsqueeze(0).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.scratch_net(inp_t)
out = torch.sigmoid(out)
prob = out.squeeze().cpu().numpy()
# Zero out below threshold, keOldPhotoRestoreWrapper._global_restore method · python · L682-L753 (72 LOC)models/wrappers.py
def _global_restore(self, image: np.ndarray, scratch_mask: np.ndarray) -> np.ndarray:
"""Globally restore the image using VAE encoder → mapping → decoder.
Args:
image: BGR uint8 numpy array.
scratch_mask: Soft mask (uint8, 0-255) at image resolution.
Returns:
Restored BGR uint8 numpy array at original resolution.
"""
h, w = image.shape[:2]
# Resize to 256x256 for the VAE — binarize for the mapping network
# (trained with binary masks) while keeping the soft version for blending.
img_256 = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
mask_256 = cv2.resize(scratch_mask, (256, 256), interpolation=cv2.INTER_LINEAR)
mask_256_bin = ((mask_256 > 127).astype(np.uint8)) * 255
# BGR → RGB, normalize to [-1, 1]
img_rgb = img_256[:, :, ::-1].astype(np.float32) / 127.5 - 1.0
img_t = torch.from_numpy(img_rgb.copy()).permute(2, 0, 1).unsqueOldPhotoRestoreWrapper._enhance_faces method · python · L755-L842 (88 LOC)models/wrappers.py
def _enhance_faces(
self, image: np.ndarray, detect_image: np.ndarray | None = None
) -> np.ndarray:
"""Detect and enhance faces using dlib + SPADE generator.
Args:
image: BGR uint8 numpy array (restored image for face cropping).
detect_image: BGR uint8 numpy array to run face detection on.
If ``None``, detection runs on *image*.
Returns:
Image with enhanced faces blended back, BGR uint8.
"""
det_img = detect_image if detect_image is not None else image
gray = cv2.cvtColor(det_img, cv2.COLOR_BGR2GRAY)
faces = self.face_detector(gray, 1)
if len(faces) == 0:
return image
result = image.copy()
for face in faces:
landmarks = self.landmark_predictor(gray, face)
# Extract 5-point landmarks (eye centers, nose, mouth corners)
left_eye = np.mean(
[[landmarks.part(i).x, landmarks.pa_resolve_google_drive_url function · python · L53-L74 (22 LOC)utils/downloader.py
def _resolve_google_drive_url(url: str) -> str:
"""Convert a Google Drive share/uc URL to a direct-download URL.
Extracts the file ID from the URL and constructs a
``drive.usercontent.google.com`` URL that bypasses the virus-scan
interstitial page.
Args:
url: Original Google Drive URL containing an ``id`` query parameter.
Returns:
Direct-download URL, or the original URL if no file ID is found.
"""
import re
match = re.search(r"[?&]id=([a-zA-Z0-9_-]+)", url)
if match:
file_id = match.group(1)
return (
f"https://drive.usercontent.google.com/download?id={file_id}&export=download&confirm=t"
)
return url_download_file function · python · L77-L122 (46 LOC)utils/downloader.py
def _download_file(url: str, dest: Path, part: Path, category: str, variant: str) -> None:
"""Download a model weight file to ``part``, validating it is not HTML.
Google Drive URLs are resolved to direct-download URLs before fetching.
After download, the first bytes are checked to reject HTML error pages
that may be served instead of the real model file.
Args:
url: Source download URL.
dest: Final destination path (not written by this function).
part: Temporary ``.part`` path to stream into.
category: Model category name (for logging).
variant: Model variant name (for logging).
Raises:
requests.HTTPError: If the HTTP response status is not OK.
RuntimeError: If the downloaded file appears to be HTML.
"""
session = requests.Session()
if "drive.google.com" in url:
url = _resolve_google_drive_url(url)
resp = session.get(url, stream=True, timeout=600)
resp.raise_for_status()
ensure_model_exists function · python · L125-L190 (66 LOC)utils/downloader.py
def ensure_model_exists(
category: str,
variant: str,
weights_dir: str = os.environ.get("WEIGHTS_DIR", "/app/weights"),
) -> str:
"""Ensure the weight file for the given model category/variant exists on disk.
If the file is already present and valid, returns immediately. Otherwise
downloads it from the URL registered in ``MODEL_URLS``.
Args:
category: Model category (e.g. ``"colorize"``, ``"restore"``).
variant: Variant name within the category.
weights_dir: Root directory for weight storage.
Returns:
Absolute path to the weight file.
Raises:
ValueError: If the category or variant is unknown.
RuntimeError: If the download produces a corrupt (HTML) file.
"""
if category not in MODEL_URLS:
raise ValueError(f"Unknown model category: {category}")
variants = MODEL_URLS[category]
if variant not in variants:
raise ValueError(
f"Unknown variant '{variant}' for catRepobility · open methodology · https://repobility.com/research/
ensure_model_files_exist function · python · L193-L254 (62 LOC)utils/downloader.py
def ensure_model_files_exist(
category: str,
variant: str,
weights_dir: str = os.environ.get("WEIGHTS_DIR", "/app/weights"),
) -> str:
"""Ensure all weight files for a multi-file model exist on disk.
Downloads any missing files from the URLs registered in ``MODEL_URLS_MULTI``.
Args:
category: Model category (e.g. ``"old_photo_restore"``).
variant: Variant name within the category.
weights_dir: Root directory for weight storage.
Returns:
Absolute path to the directory containing the weight files.
Raises:
ValueError: If the category or variant is unknown.
RuntimeError: If a download produces a corrupt (HTML) file.
"""
if category not in MODEL_URLS_MULTI:
raise ValueError(f"Unknown multi-file model category: {category}")
variants = MODEL_URLS_MULTI[category]
if variant not in variants:
raise ValueError(
f"Unknown variant '{variant}' for category '{category}'. "
read_image function · python · L9-L25 (17 LOC)utils/image_ops.py
def read_image(file_bytes: bytes) -> np.ndarray:
"""Decode raw file bytes into a BGR uint8 numpy image.
Args:
file_bytes: Raw image file content (JPEG, PNG, WebP, etc.).
Returns:
BGR uint8 numpy array of the decoded image.
Raises:
ValueError: If the bytes cannot be decoded as a valid image.
"""
buf = np.frombuffer(file_bytes, dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Could not decode image — invalid or unsupported format")
return imagevalidate_and_resize function · python · L28-L49 (22 LOC)utils/image_ops.py
def validate_and_resize(image: np.ndarray, max_dim: int = 2048) -> np.ndarray:
"""Ensure the image does not exceed ``max_dim`` on either side.
If either dimension exceeds ``max_dim``, the image is scaled down
proportionally using area interpolation.
Args:
image: BGR uint8 numpy array.
max_dim: Maximum allowed dimension in pixels (default 2048).
Returns:
The original image if within bounds, or a resized copy.
"""
h, w = image.shape[:2]
if h <= max_dim and w <= max_dim:
return image
scale = max_dim / max(h, w)
new_w = int(w * scale)
new_h = int(h * scale)
logger.warning("Image too large (%dx%d), resizing to %dx%d", w, h, new_w, new_h)
return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)encode_image function · python · L52-L84 (33 LOC)utils/image_ops.py
def encode_image(image: np.ndarray, output_format: str = "png") -> bytes:
"""Encode a BGR uint8 numpy image to the specified format.
Args:
image: BGR uint8 numpy array.
output_format: Target format (``"png"``, ``"jpg"``, ``"jpeg"``, or ``"webp"``).
Returns:
Encoded image bytes.
Raises:
ValueError: If the format is unsupported or encoding fails.
"""
fmt = output_format.lower().strip(".")
ext_map = {
"png": ".png",
"jpg": ".jpg",
"jpeg": ".jpg",
"webp": ".webp",
}
if fmt not in ext_map:
raise ValueError(f"Unsupported output format: '{output_format}'. Supported: png, jpg, webp")
params = []
if fmt in ("jpg", "jpeg"):
params = [cv2.IMWRITE_JPEG_QUALITY, 95]
elif fmt == "webp":
params = [cv2.IMWRITE_WEBP_QUALITY, 95]
success, buf = cv2.imencode(ext_map[fmt], image, params)
if not success:
raise ValueError(f"Failed to encode image as {fmsetup_logging function · python · L9-L29 (21 LOC)utils/logging.py
def setup_logging(level: int = logging.INFO) -> None:
"""Configure root logger with JSON structured output.
Sets up the root logger with a JsonFormatter that outputs one JSON object per line
to stdout. Fields include: timestamp, level, logger name, message, plus any extras
passed via the `extra` kwarg on log calls.
Args:
level: Logging level (default: logging.INFO).
"""
handler = logging.StreamHandler(sys.stdout)
formatter = JsonFormatter(
fmt="%(asctime)s %(levelname)s %(name)s %(message)s",
rename_fields={"asctime": "timestamp", "levelname": "level", "name": "logger"},
)
handler.setFormatter(formatter)
root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(level)‹ prevpage 2 / 2