Pytorch implementation of video upscaling using Real-ESRGAN.

This commit is contained in:
2026-02-28 12:25:57 +01:00
parent 72fe4e04d0
commit 0420add695
6 changed files with 342 additions and 94 deletions

View File

@@ -5,6 +5,7 @@ import subprocess
import sys
import tempfile
import time
import urllib.request
from pathlib import Path
try:
@@ -14,6 +15,46 @@ except ImportError:
HAS_TQDM = False
MODEL_SPECS = {
"realesrgan-x4plus": {
"arch": "rrdb",
"netscale": 4,
"filename": "RealESRGAN_x4plus.pth",
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
},
"realesrnet-x4plus": {
"arch": "rrdb",
"netscale": 4,
"filename": "RealESRNet_x4plus.pth",
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRNet_x4plus.pth",
},
"realesr-general-x4v3": {
"arch": "rrdb",
"netscale": 4,
"filename": "RealESRGAN_x4plus.pth",
"url": "https://huggingface.co/qualcomm/Real-ESRGAN-General-x4v3/resolve/main/RealESRGAN_x4plus.pth",
},
"flashvsr-x4": {
"arch": "rrdb",
"netscale": 4,
"filename": "flashvsr_x4.pth",
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.3.0/flashvsr_x4.pth",
},
"real-cugan-x4": {
"arch": "cugan",
"netscale": 4,
"filename": "real_cugan_x4.pth",
"url": "https://huggingface.co/Hacksider/Real-CUGAN/resolve/main/models/real_cugan_x4.pth",
},
"realesr-animevideov3": {
"arch": "srvgg",
"netscale": 4,
"filename": "realesr-animevideov3.pth",
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
},
}
def run(cmd: list[str]) -> None:
print("\n$", " ".join(cmd))
completed = subprocess.run(cmd)
@@ -25,16 +66,17 @@ def command_exists(name: str) -> bool:
return shutil.which(name) is not None
def assert_prerequisites(realesrgan_bin: str) -> None:
def assert_prerequisites(backend: str, realesrgan_bin: str) -> None:
missing = []
if not command_exists("ffmpeg"):
missing.append("ffmpeg")
if not command_exists("ffprobe"):
missing.append("ffprobe")
real_esrgan_ok = Path(realesrgan_bin).exists() or command_exists(realesrgan_bin)
if not real_esrgan_ok:
missing.append(realesrgan_bin)
if backend == "ncnn":
real_esrgan_ok = Path(realesrgan_bin).exists() or command_exists(realesrgan_bin)
if not real_esrgan_ok:
missing.append(realesrgan_bin)
if missing:
items = ", ".join(missing)
@@ -44,6 +86,31 @@ def assert_prerequisites(realesrgan_bin: str) -> None:
)
def ensure_pytorch_deps() -> tuple:
try:
import cv2
import importlib
import sys as _sys
import torch
try:
importlib.import_module("torchvision.transforms.functional_tensor")
except ModuleNotFoundError:
compat_mod = importlib.import_module("torchvision.transforms._functional_tensor")
_sys.modules["torchvision.transforms.functional_tensor"] = compat_mod
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
except ImportError as exc:
raise RuntimeError(
"PyTorch backend dependencies are missing. Install with:\n"
"python -m pip install -r requirements.txt\n"
"python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128"
) from exc
return cv2, torch, RRDBNet, RealESRGANer, SRVGGNetCompact
def has_audio_stream(input_video: Path) -> bool:
cmd = [
@@ -66,18 +133,47 @@ def count_png_frames(folder: Path) -> int:
return sum(1 for _ in folder.glob("*.png"))
def run_upscale_with_progress(cmd: list[str], input_frames: Path, output_frames: Path) -> None:
def resolve_model_weights(model: str, model_path: str | None, weights_dir: Path) -> tuple[Path, dict]:
if model not in MODEL_SPECS:
supported = ", ".join(MODEL_SPECS.keys())
raise RuntimeError(f"Unsupported model '{model}'. Supported: {supported}")
spec = MODEL_SPECS[model]
if model_path:
candidate = Path(model_path)
if candidate.is_file():
return candidate, spec
if candidate.is_dir():
resolved = candidate / spec["filename"]
if resolved.exists():
return resolved, spec
raise RuntimeError(f"Model file not found: {resolved}")
weights_dir.mkdir(parents=True, exist_ok=True)
resolved = weights_dir / spec["filename"]
if not resolved.exists():
print(f"Downloading model weights to: {resolved}")
urllib.request.urlretrieve(spec["url"], resolved)
return resolved, spec
def run_ncnn_upscale_with_progress(cmd: list[str], input_frames: Path, output_frames: Path) -> None:
total_frames = count_png_frames(input_frames)
if total_frames == 0:
raise RuntimeError("No extracted frames found before upscaling.")
started = time.time()
# Suppress Real-ESRGAN's verbose output by redirecting stdout/stderr
process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if HAS_TQDM:
pbar = tqdm(total=total_frames, unit="frames", desc="Upscaling",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]")
pbar = tqdm(
total=total_frames,
unit="frames",
desc="Upscaling",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
)
last_count = 0
else:
print(f"Upscaling: 0/{total_frames} frames (0.0%) | ETA --:--:--")
@@ -95,9 +191,9 @@ def run_upscale_with_progress(cmd: list[str], input_frames: Path, output_frames:
last_count = done_frames
else:
if now - last_print >= 2.0:
progress = min(100.0, (done_frames / total_frames) * 100)
elapsed = max(now - started, 1e-6)
fps = done_frames / elapsed
progress = min(100.0, (done_frames / total_frames) * 100)
if done_frames > 0 and fps > 0:
remaining_frames = max(total_frames - done_frames, 0)
eta_seconds = int(remaining_frames / fps)
@@ -106,10 +202,7 @@ def run_upscale_with_progress(cmd: list[str], input_frames: Path, output_frames:
eta_str = f"{eta_h:02d}:{eta_m:02d}:{eta_s:02d}"
else:
eta_str = "--:--:--"
print(
f"Upscaling: {done_frames}/{total_frames} "
f"({progress:.1f}%) | {fps:.2f} fps | ETA {eta_str}"
)
print(f"Upscaling: {done_frames}/{total_frames} ({progress:.1f}%) | {fps:.2f} fps | ETA {eta_str}")
last_print = now
if return_code is not None:
@@ -121,10 +214,7 @@ def run_upscale_with_progress(cmd: list[str], input_frames: Path, output_frames:
pbar.close()
elapsed = max(time.time() - started, 1e-6)
fps = done_frames / elapsed
print(
f"Upscaling complete: {done_frames}/{total_frames} frames | "
f"avg {fps:.2f} fps | total time {elapsed:.1f}s"
)
print(f"Upscaling complete: {done_frames}/{total_frames} frames | avg {fps:.2f} fps | total time {elapsed:.1f}s")
if return_code != 0:
raise RuntimeError(f"Command failed ({return_code}): {' '.join(cmd)}")
break
@@ -132,13 +222,89 @@ def run_upscale_with_progress(cmd: list[str], input_frames: Path, output_frames:
time.sleep(0.2)
def run_pytorch_upscale_with_progress(
input_frames: Path,
output_frames: Path,
model_name: str,
model_path: str | None,
weights_dir: Path,
scale: int,
tile_size: int,
gpu_id: str,
fp32: bool,
) -> None:
cv2, torch, RRDBNet, RealESRGANer, SRVGGNetCompact = ensure_pytorch_deps()
weights_file, spec = resolve_model_weights(model_name, model_path, weights_dir)
if spec["arch"] == "rrdb":
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
elif spec["arch"] == "cugan":
# Real-CUGAN uses a different model format - load as state dict
import torch
model = None # Will be loaded directly from pth file
else:
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type="prelu")
if gpu_id == "auto":
selected_gpu = 0 if torch.cuda.is_available() else None
else:
selected_gpu = int(gpu_id)
upsampler = RealESRGANer(
scale=spec["netscale"],
model_path=str(weights_file),
model=model,
tile=tile_size,
tile_pad=10,
pre_pad=0,
half=(not fp32 and torch.cuda.is_available()),
gpu_id=selected_gpu,
)
frame_files = sorted(input_frames.glob("*.png"))
total_frames = len(frame_files)
if total_frames == 0:
raise RuntimeError("No extracted frames found before upscaling.")
started = time.time()
if HAS_TQDM:
progress_iter = tqdm(frame_files, total=total_frames, unit="frames", desc="Upscaling")
else:
progress_iter = frame_files
done = 0
for frame_file in progress_iter:
image = cv2.imread(str(frame_file), cv2.IMREAD_COLOR)
if image is None:
raise RuntimeError(f"Failed to read input frame: {frame_file}")
output, _ = upsampler.enhance(image, outscale=scale)
target = output_frames / frame_file.name
ok = cv2.imwrite(str(target), output)
if not ok:
raise RuntimeError(f"Failed to write output frame: {target}")
done += 1
if not HAS_TQDM and done % 20 == 0:
elapsed = max(time.time() - started, 1e-6)
fps = done / elapsed
progress = min(100.0, (done / total_frames) * 100)
print(f"Upscaling: {done}/{total_frames} ({progress:.1f}%) | {fps:.2f} fps")
elapsed = max(time.time() - started, 1e-6)
fps = total_frames / elapsed
print(f"Upscaling complete: {total_frames}/{total_frames} frames | avg {fps:.2f} fps | total time {elapsed:.1f}s")
def upscale_video(
input_video: Path,
output_video: Path,
backend: str,
realesrgan_bin: str,
model: str,
model_path: str | None,
weights_dir: Path,
scale: int,
tile_size: int,
jobs: str,
@@ -150,7 +316,11 @@ def upscale_video(
temp_root: Path,
gpu_id: str,
test_seconds: float | None,
start_time: float | None,
skip_sar_correction: bool,
pre_vf: str | None,
final_res: str | None,
fp32: bool,
) -> None:
if not input_video.exists():
raise FileNotFoundError(f"Input video does not exist: {input_video}")
@@ -167,29 +337,38 @@ def upscale_video(
frames_out.mkdir(parents=True, exist_ok=True)
print(f"Working directory: {tmp_dir}")
if test_seconds is not None:
print(f"Test mode: processing first {test_seconds:.2f} seconds")
if start_time is not None or test_seconds is not None:
msg = "Test mode:"
if start_time is not None:
msg += f" starting at {start_time:.2f}s"
if test_seconds is not None:
msg += f" processing {test_seconds:.2f}s"
print(msg)
seek_args = ["-ss", str(start_time)] if start_time is not None else []
test_duration_args = ["-t", str(test_seconds)] if test_seconds is not None else []
filter_chain = []
if pre_vf:
filter_chain.append(pre_vf)
filter_chain.extend([
"scale=ceil(iw*sar/2)*2:ih",
"setsar=1",
])
if not skip_sar_correction:
filter_chain.extend([
"scale=ceil(iw*sar/2)*2:ih",
"setsar=1",
])
# Build extract command with optional video filter
extract_cmd = [
"ffmpeg",
"-y",
"-i",
str(input_video),
*seek_args,
*test_duration_args,
"-vf",
",".join(filter_chain),
str(frames_in / "%08d.png"),
]
if filter_chain:
extract_cmd.extend(["-vf", ",".join(filter_chain)])
extract_cmd.append(str(frames_in / "%08d.png"))
run(extract_cmd)
audio_present = has_audio_stream(input_video)
@@ -199,6 +378,7 @@ def upscale_video(
"-y",
"-i",
str(input_video),
*seek_args,
*test_duration_args,
"-vn",
"-c:a",
@@ -209,31 +389,43 @@ def upscale_video(
]
run(extract_audio_cmd)
upscale_cmd = [
realesrgan_bin,
"-i",
str(frames_in),
"-o",
str(frames_out),
"-n",
model,
"-s",
str(scale),
"-f",
"png",
"-t",
str(tile_size),
"-j",
jobs,
"-g",
gpu_id,
]
if model_path:
upscale_cmd.extend(["-m", model_path])
run_upscale_with_progress(upscale_cmd, frames_in, frames_out)
if backend == "pytorch":
run_pytorch_upscale_with_progress(
input_frames=frames_in,
output_frames=frames_out,
model_name=model,
model_path=model_path,
weights_dir=weights_dir,
scale=scale,
tile_size=tile_size,
gpu_id=gpu_id,
fp32=fp32,
)
else:
upscale_cmd = [
realesrgan_bin,
"-i",
str(frames_in),
"-o",
str(frames_out),
"-n",
model,
"-s",
str(scale),
"-f",
"png",
"-t",
str(tile_size),
"-j",
jobs,
"-g",
gpu_id,
]
if model_path:
upscale_cmd.extend(["-m", model_path])
run_ncnn_upscale_with_progress(upscale_cmd, frames_in, frames_out)
fps_args = ["-r", fps] if fps else []
encode_cmd = [
"ffmpeg",
"-y",
@@ -245,20 +437,21 @@ def upscale_video(
if audio_present and audio_file.exists():
encode_cmd.extend(["-i", str(audio_file), "-c:a", "copy"])
encode_cmd.extend(
[
"-c:v",
codec,
"-crf",
str(crf),
"-preset",
preset,
"-pix_fmt",
"yuv420p",
str(output_video),
]
)
# Add scaling filter if final resolution specified
if final_res:
encode_cmd.extend(["-vf", f"scale={final_res}:flags=lanczos"])
encode_cmd.extend([
"-c:v",
codec,
"-crf",
str(crf),
"-preset",
preset,
"-pix_fmt",
"yuv420p",
str(output_video),
])
run(encode_cmd)
if keep_temp:
@@ -269,39 +462,45 @@ def upscale_video(
print(f"Temporary files copied to: {kept}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Upscale a video locally with Real-ESRGAN (RTX GPU via Vulkan)."
description="Upscale a video locally with Real-ESRGAN (PyTorch CUDA default, ncnn optional)."
)
parser.add_argument("-i", "--input", required=True, help="Input video path")
parser.add_argument("-o", "--output", required=True, help="Output video path")
parser.add_argument("--backend", choices=["pytorch", "ncnn"], default="pytorch")
parser.add_argument(
"--realesrgan-bin",
default="realesrgan-ncnn-vulkan",
help="Path or command name of realesrgan-ncnn-vulkan",
help="Path or command name of realesrgan-ncnn-vulkan (ncnn backend only)",
)
parser.add_argument(
"--model",
default="realesr-animevideov3",
help="Model name (e.g. realesr-animevideov3, realesrgan-x4plus)",
default="realesrgan-x4plus",
choices=["realesrgan-x4plus", "realesrnet-x4plus", "realesr-general-x4v3", "flashvsr-x4", "real-cugan-x4", "realesr-animevideov3"],
help="Model name",
)
parser.add_argument(
"--model-path",
default=None,
help="Path to models directory (required if models not in default location)",
help="Model file or model directory. For pytorch: .pth file or folder containing model .pth",
)
parser.add_argument(
"--weights-dir",
default=str(Path.home() / ".cache" / "realesrgan"),
help="Where to download/store PyTorch model weights",
)
parser.add_argument("--scale", type=int, default=2, choices=[2, 3, 4])
parser.add_argument(
"--tile-size",
type=int,
default=0,
help="Tile size for VRAM-limited cases (0 = auto)",
help="Tile size (0 = auto/no tile for backend defaults)",
)
parser.add_argument(
"--jobs",
default="2:2:2",
help="NCNN worker threads as load:proc:save",
help="NCNN worker threads as load:proc:save (ncnn backend only)",
)
parser.add_argument(
"--fps",
@@ -311,6 +510,12 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--codec", default="libx264", help="Output video codec")
parser.add_argument("--crf", type=int, default=16, help="Quality (lower = better)")
parser.add_argument("--preset", default="medium", help="Encoder preset")
parser.add_argument("--fp32", action="store_true", help="Use FP32 inference for PyTorch backend")
parser.add_argument(
"--skip-sar-correction",
action="store_true",
help="Skip SAR (aspect ratio) correction before upscaling (for testing native resolution)",
)
parser.add_argument(
"--keep-temp",
action="store_true",
@@ -324,34 +529,46 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--gpu-id",
default="auto",
help="Vulkan GPU id for Real-ESRGAN (e.g. 0, 1, 0,1). Use 'auto' by default",
help="GPU id (e.g. 0,1) or 'auto'. For ncnn this maps to Vulkan id",
)
parser.add_argument(
"--test-seconds",
type=float,
default=None,
help="Only process first N seconds (for quick test runs)",
help="Only process N seconds (for quick test runs)",
)
parser.add_argument(
"--start-time",
type=float,
default=None,
help="Start at specific time in video (seconds, for testing specific frames)",
)
parser.add_argument(
"--pre-vf",
default=None,
help="Optional ffmpeg video filter(s) applied before upscaling (e.g. hqdn3d=1.5:1.5:6:6)",
)
parser.add_argument(
"--final-res",
default=None,
help="Final output resolution (e.g. 1920x1080 for Full HD) - scales downsampled frames before encoding",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
try:
assert_prerequisites(args.realesrgan_bin)
assert_prerequisites(args.backend, args.realesrgan_bin)
upscale_video(
input_video=Path(args.input),
output_video=Path(args.output),
backend=args.backend,
realesrgan_bin=args.realesrgan_bin,
model=args.model,
model_path=args.model_path,
weights_dir=Path(args.weights_dir),
scale=args.scale,
tile_size=args.tile_size,
jobs=args.jobs,
@@ -363,7 +580,11 @@ def main() -> int:
temp_root=Path(args.temp_root),
gpu_id=args.gpu_id,
test_seconds=args.test_seconds,
start_time=args.start_time,
skip_sar_correction=args.skip_sar_correction,
pre_vf=args.pre_vf,
final_res=args.final_res,
fp32=args.fp32,
)
except Exception as exc:
print(f"Error: {exc}", file=sys.stderr)