import numpy as np
import tensorrt_rtx as trt
import pycuda.driver as cuda
import pycuda.autoinit
from PIL import Image, ImageFile
import cv2
import os
import time
from pathlib import Path
import rawpy
import concurrent.futures
import argparse
from tqdm.auto import tqdm
import sys
import traceback

ImageFile.LOAD_TRUNCATED_IMAGES = True

def load_image(image_path: Path):
"""Učitava sliku različitih formata uključujući RAW formate."""
ext = image_path.suffix.lower()
try:
if ext in ['.arw', '.cr2', '.nef', '.dng', '.raf', '.raw']:
with rawpy.imread(str(image_path)) as raw:
return raw.postprocess(use_camera_wb=True, output_bps=8)
else:
img = Image.open(image_path)
if img.mode != 'RGB':
img = img.convert('RGB')
return np.array(img)
except FileNotFoundError:
print(f"Greška: Fajl nije pronađen - {image_path}")
return None
except (rawpy.LibRawError, IOError, SyntaxError) as e:
print(f"Greška pri učitavanju {image_path.name} ({type(e).__name__}): {e}")
return None
except Exception as e:
print(f"Neočekivana greška pri učitavanju {image_path.name}: {e}")
traceback.print_exc()
return None

def process_image(
image_path: Path,
context: trt.IExecutionContext,
bindings: list,
d_input: cuda.DeviceAllocation,
d_output: cuda.DeviceAllocation,
stream: cuda.Stream,
input_shape: tuple,
output_shape: tuple,
output_np_dtype: np.dtype,
output_dir: Path,
scale_factor: int,
tile_size: int,
overlap: int,
output_format: str,
jpg_quality: int,
numpy_dtype: np.dtype
):
"""Obrađuje jednu sliku kroz TensorRT engine koristeći tiling pristup."""
start_time = time.time()
img = load_image(image_path)
if img is None:
return None
if not isinstance(img, np.ndarray) or img.ndim != 3 or img.shape[2] != 3:
print(f"Greška: Slika {image_path.name} nije u očekivanom RGB formatu.")
return None

h, w = img.shape[:2]
output_h, output_w = h * scale_factor, w * scale_factor
output_img = np.zeros((output_h, output_w, 3), dtype=np.float32)
weight_map = np.zeros((output_h, output_w, 3), dtype=np.float32)

# Kreiranje Gaussian weight template za blending
y_grid, x_grid = np.mgrid[0:tile_size, 0:tile_size].astype(np.float64)
center = (tile_size - 1) / 2.0
sigma = tile_size / 4.0
weight_template = np.exp(-((x_grid - center)**2 + (y_grid - center)**2) / (2 * sigma**2))
weight_template = np.repeat(weight_template[:, :, np.newaxis], 3, axis=2).astype(np.float32)
scaled_tile_size = tile_size * scale_factor
scaled_weight_template = cv2.resize(weight_template, (scaled_tile_size, scaled_tile_size), interpolation=cv2.INTER_LINEAR)

# Priprema za tiling
stride = tile_size - overlap
x_tiles = (w + stride - 1) // stride
y_tiles = (h + stride - 1) // stride
total_tiles = max(1, x_tiles * y_tiles)
tile_times = []

# Alokacija HOST memorije
h_input = cuda.pagelocked_empty(trt.volume(input_shape), dtype=numpy_dtype)
h_output = cuda.pagelocked_empty(trt.volume(output_shape), dtype=output_np_dtype)

with tqdm(total=total_tiles, desc=f"Pločice ({image_path.name})", unit="pločica",
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
leave=False) as pbar:
for y in range(0, h, stride):
for x in range(0, w, stride):
tile_start_time = time.time()
y_start, y_end = y, min(y + tile_size, h)
x_start, x_end = x, min(x + tile_size, w)
tile_orig = img[y_start:y_end, x_start:x_end]
tile_h, tile_w = tile_orig.shape[:2]

# Padding ako je potrebno
if tile_h < tile_size or tile_w < tile_size:
pad_h, pad_w = tile_size - tile_h, tile_size - tile_w
pad_top, pad_bottom = pad_h // 2, pad_h - (pad_h // 2)
pad_left, pad_right = pad_w // 2, pad_w - (pad_w // 2)
tile = cv2.copyMakeBorder(tile_orig, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_REFLECT_101)
else:
tile = tile_orig

try:
# Priprema ulaza za TensorRT
tile_input = tile.transpose(2, 0, 1).astype(numpy_dtype) / 255.0
tile_input = np.expand_dims(tile_input, axis=0)

if tile_input.shape != tuple(input_shape):
print(f"\nGreška: Oblik pločice {tile_input.shape} != očekivani {input_shape} za {image_path.name}")
pbar.update(1)
continue

# Inferencija
np.copyto(h_input, tile_input.ravel())
cuda.memcpy_htod_async(d_input, h_input, stream)
context.execute_async_v3(stream_handle=stream.handle)
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()

# Obrada izlaza
tile_output_full = h_output.reshape(output_shape)
if tile_output_full.shape[0] == 1:
tile_output_full = tile_output_full[0]

# Konverzija u FP32 za blending ako je izlaz bio FP16
if tile_output_full.dtype == np.float16:
tile_output_full = tile_output_full.astype(np.float32)

tile_output_full = tile_output_full.transpose(1, 2, 0) # CHW -> HWC

# Sečenje paddinga i blending
out_y_start, out_x_start = y_start * scale_factor, x_start * scale_factor
out_h_orig, out_w_orig = tile_h * scale_factor, tile_w * scale_factor

if tile_h < tile_size or tile_w < tile_size:
pad_top_scaled, pad_left_scaled = pad_top * scale_factor, pad_left * scale_factor
end_row, end_col = pad_top_scaled + out_h_orig, pad_left_scaled + out_w_orig
if end_row > tile_output_full.shape[0] or end_col > tile_output_full.shape[1]:
print(f"\nUpozorenje: Neispravne granice sečenja za padding {image_path.name} ({y},{x}).")
tile_output = tile_output_full
else:
tile_output = tile_output_full[pad_top_scaled:end_row, pad_left_scaled:end_col, :]
else:
tile_output = tile_output_full

out_y_end = out_y_start + tile_output.shape[0]
out_x_end = out_x_start + tile_output.shape[1]

# Provera granica
if out_y_end > output_img.shape[0] or out_x_end > output_img.shape[1]:
out_y_end = min(out_y_end, output_img.shape[0])
out_x_end = min(out_x_end, output_img.shape[1])
tile_output = tile_output[:out_y_end-out_y_start, :out_x_end-out_x_start, :]

# Primena težina za blending
current_weight = scaled_weight_template[:tile_output.shape[0], :tile_output.shape[1], :]

if tile_output.shape == current_weight.shape:
output_img[out_y_start:out_y_end, out_x_start:out_x_end] += tile_output * current_weight
weight_map[out_y_start:out_y_end, out_x_start:out_x_end] += current_weight
else:
print(f"\nUpozorenje: Neslaganje oblika tile/weight {image_path.name} ({y},{x}).")

except Exception as e:
print(f"\nGreška pri TRT obradi pločice {image_path.name} ({y},{x}): {e}")
traceback.print_exc()

# Merenje vremena i update progress bara
tile_time = time.time() - tile_start_time
tile_times.append(tile_time)
if tile_times:
avg_tile_time = sum(tile_times) / len(tile_times)
remaining_tiles = total_tiles - pbar.n - 1
if remaining_tiles > 0:
estimated_remaining_time = remaining_tiles * avg_tile_time
pbar.set_postfix_str(f"Avg: {avg_tile_time:.3f}s/pločica, Preostalo: {estimated_remaining_time:.1f}s")
else:
pbar.set_postfix_str(f"Avg: {avg_tile_time:.3f}s/pločica")
pbar.update(1)

# Normalizacija i čuvanje
mask = weight_map > 1e-6
output_img[mask] /= weight_map[mask]
output_img = np.clip(output_img * 255.0, 0, 255).astype(np.uint8)

output_filename = f"{image_path.stem}_upscaled_.{output_format.lower()}"
output_path = output_dir / output_filename
try:
output_image = Image.fromarray(output_img)
save_params = {}
fmt = output_format.lower()
if fmt in ['jpg', 'jpeg']:
save_params['quality'] = jpg_quality
save_params['subsampling'] = 0
elif fmt == 'png':
save_params['compress_level'] = 4
output_image.save(output_path, **save_params)
elapsed_time = time.time() - start_time
avg_tile_t = sum(tile_times) / len(tile_times) if tile_times else 0
tqdm.write(f"Završeno: {image_path.name} -> {output_path.name} (vreme: {elapsed_time:.2f}s, avg pločica: {avg_tile_t:.3f}s)")
return output_path
except Exception as e:
print(f"\nGreška pri čuvanju {output_path.name}: {e}")
traceback.print_exc()
return None


def main():
parser = argparse.ArgumentParser(description='TensorRT Batch Upscaling.')
parser.add_argument('--input_dir', type=str, default='input')
parser.add_argument('--output_dir', type=str, default='output')
parser.add_argument('--engine_path', type=str, default='scunet_color_real_gan_fp16_4600.engine')
parser.add_argument('--scale_factor', type=int, default=1)
parser.add_argument('--workers', type=int, default=1)
parser.add_argument('--force', action='store_true')
parser.add_argument('--tile_size', type=int, default=4600)
parser.add_argument('--overlap', type=int, default=2000)
parser.add_argument('--output_format', type=str, default='jpg', choices=['png', 'jpg', 'jpeg', 'bmp'])
parser.add_argument('--jpg_quality', type=int, default=100)
parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'float32'])
args = parser.parse_args()

# Provera i podešavanje broja radnika
if args.workers == 0:
args.workers = os.cpu_count() or 1
print(f"Koristi se {args.workers} radnika.")
if args.workers < 0:
print("Greška: Negativan broj radnika.")
sys.exit(1)

# Priprema putanja
input_dir, output_dir = Path(args.input_dir), Path(args.output_dir)
engine_path = Path(args.engine_path)
if not input_dir.is_dir():
print(f"Greška: Ulazni dir '{input_dir}' nije pronađen.")
sys.exit(1)
if not engine_path.is_file():
print(f"Greška: Engine fajl '{engine_path}' nije pronađen.")
sys.exit(1)
output_dir.mkdir(parents=True, exist_ok=True)

# Podešavanje tipa podataka
numpy_dtype = np.float16 if args.dtype == 'float16' else np.float32
print(f"Koristiće se {args.dtype} ({numpy_dtype.__name__}) za pripremu ulaza.")

# Učitavanje TensorRT engine-a
print(f"--- DEBUG: TensorRT Python verzija: {trt.__version__} ---")
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
print(f"Učitavam TensorRT engine: {engine_path}")
engine, context = None, None
try:
with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
print(f"TensorRT engine uspešno učitan.")
context = engine.create_execution_context()
if not context:
raise RuntimeError("Ne mogu kreirati execution context.")
print("TensorRT execution context kreiran.")
except Exception as e:
print(f"Greška pri učitavanju engine-a/kreiranju konteksta: {e}")
traceback.print_exc()
sys.exit(1)

# Dobijanje informacija o I/O tenzorima
input_binding_idx, output_binding_idx = -1, -1
input_shape, output_shape = None, None
input_name, output_name = "", ""
input_dtype_trt, output_dtype_trt = None, None
output_np_dtype = np.float32 # Default NumPy tip za izlaz

print("--- Informacije o Engine Tenzorima (TRT 10.x API) ---")
num_io_tensors = engine.num_io_tensors
tensor_indices = list(range(num_io_tensors))

for i in tensor_indices:
name = engine.get_tensor_name(i)
shape = engine.get_tensor_shape(name)
dtype = engine.get_tensor_dtype(name)
mode = engine.get_tensor_mode(name)

print(f"Tensor {i}: Name='{name}', Shape={shape}, Dtype={dtype}, Mode={mode}")

if mode == trt.TensorIOMode.INPUT:
if input_binding_idx != -1:
print("Upozorenje: Više ulaznih tenzora.")
else:
input_binding_idx = i
input_shape = tuple(shape)
input_name = name
input_dtype_trt = dtype
if len(input_shape) == 4 and (input_shape[2] != args.tile_size or input_shape[3] != args.tile_size):
print(f"UPOZORENJE: Veličina pločice engine-a ({input_shape[2]}x{input_shape[3]}) != --tile_size ({args.tile_size}).")
elif len(input_shape) != 4:
print(f"UPOZORENJE: Ulazni oblik {input_shape} nije 4D.")
elif mode == trt.TensorIOMode.OUTPUT:
if output_binding_idx != -1:
print("Upozorenje: Više izlaznih tenzora.")
else:
output_binding_idx = i
output_shape = tuple(shape)
output_name = name
output_dtype_trt = dtype
# Određujemo NumPy tip na osnovu TRT tipa
if output_dtype_trt == trt.float16:
output_np_dtype = np.float16
elif output_dtype_trt == trt.int32:
output_np_dtype = np.int32

if not input_name or not output_name:
print("Greška: Nije identifikovan ulazni/izlazni tenzor po imenu.")
sys.exit(1)

print(f"Identifikovan ulaz: Index={input_binding_idx}, Ime='{input_name}', Oblik={input_shape}, Tip={input_dtype_trt}")
print(f"Identifikovan izlaz: Index={output_binding_idx}, Ime='{output_name}', Oblik={output_shape}, Tip={output_dtype_trt} (NumPy: {output_np_dtype.__name__})")
print("------------------------------------")

# Korekcija veličine pločice prema engine-u ako je potrebno
if len(input_shape) == 4 and (args.tile_size != input_shape[2] or args.tile_size != input_shape[3]):
print(f"Korigujem tile_size na {input_shape[2]}x{input_shape[3]} prema engine-u.")
args.tile_size = input_shape[2]
elif len(input_shape) != 4:
print(f"Ne mogu korigovati tile_size, ulaz nije 4D.")

# Alokacija Device Memorije
d_input, d_output, stream, bindings = None, None, None, None
try:
if numpy_dtype == np.float16 and input_dtype_trt != trt.float16:
print(f"UPOZORENJE: Ulazni dtype {numpy_dtype} != očekivani TRT dtype {input_dtype_trt}")
elif numpy_dtype == np.float32 and input_dtype_trt == trt.float16:
print(f"UPOZORENJE: Ulazni dtype {numpy_dtype} != očekivani TRT dtype {input_dtype_trt}")

d_input_size = trt.volume(input_shape) * np.dtype(numpy_dtype).itemsize
d_input = cuda.mem_alloc(d_input_size)
d_output_size = trt.volume(output_shape) * np.dtype(output_np_dtype).itemsize
d_output = cuda.mem_alloc(d_output_size)
stream = cuda.Stream()

# Postavljanje adresa u kontekst koristeći imena
context.set_tensor_address(input_name, int(d_input))
context.set_tensor_address(output_name, int(d_output))
bindings = None # Označavamo da ne koristimo staru listu

print("Device (GPU) memorija alocirana.")
print(f" Ulazni bafer: {d_input_size / (1024**2):.2f} MiB")
print(f" Izlazni bafer: {d_output_size / (1024**2):.2f} MiB (tip: {output_np_dtype.__name__})")

except cuda.MemoryError:
print(f"Greška: Nedovoljno GPU memorije!")
sys.exit(1)
except Exception as e:
print(f"Greška pri alokaciji GPU memorije: {e}")
traceback.print_exc()
sys.exit(1)

# Pronalaženje slika za obradu
supported_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp', '.arw', '.cr2', '.nef', '.dng', '.raf', '.raw']
all_image_paths = []
print(f"Pretražujem: {input_dir}")
for ext in supported_extensions:
all_image_paths.extend(input_dir.glob(f'*{ext}'))
all_image_paths.extend(input_dir.glob(f'*{ext.upper()}'))

image_paths_to_process = []
processed_stems = set()
output_suffix = f"_ESC-XL-REAL-GAN.{args.output_format.lower()}"
existing_outputs_stems = set()

if not args.force:
for f in output_dir.glob(f'*{output_suffix}'):
existing_outputs_stems.add(f.stem.replace('_ESC-XL-REAL-GAN', ''))
if existing_outputs_stems:
print(f"Pronađeno {len(existing_outputs_stems)} postojećih izlaza.")

for img_path in all_image_paths:
stem = img_path.stem
if stem in processed_stems:
continue
if not args.force and stem in existing_outputs_stems:
continue
processed_stems.add(stem)
image_paths_to_process.append(img_path)

if not image_paths_to_process:
print(f"Nema novih slika za obradu.")
return

print(f"Pronađeno {len(image_paths_to_process)} slika za obradu.")

# Obrada slika
results = []
start_overall_time = time.time()

if args.workers > 1:
print(f"Pokrećem obradu sa {args.workers} radnika...")
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {
executor.submit(
process_image, img_path, context, bindings, d_input, d_output, stream,
input_shape, output_shape, output_np_dtype, output_dir, args.scale_factor,
args.tile_size, args.overlap, args.output_format, args.jpg_quality, numpy_dtype
): img_path for img_path in image_paths_to_process
}

for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures),
desc="Ukupni napredak", unit="slika"):
img_path_completed = futures[future]
try:
result = future.result()
if result:
results.append(result)
except Exception as e:
tqdm.write(f"\nGREŠKA u threadu za {img_path_completed.name}: {e}")
else:
print("Pokrećem sekvencijalnu obradu...")
for image_path in tqdm(image_paths_to_process, desc="Ukupni napredak", unit="slika"):
try:
result = process_image(
image_path, context, bindings, d_input, d_output, stream,
input_shape, output_shape, output_np_dtype, output_dir, args.scale_factor,
args.tile_size, args.overlap, args.output_format, args.jpg_quality, numpy_dtype
)
if result:
results.append(result)
except Exception as e:
tqdm.write(f"\nGREŠKA pri obradi {image_path.name}: {e}")
traceback.print_exc()

# Finalni izveštaj
end_overall_time = time.time()
total_time = end_overall_time - start_overall_time
num_processed = len(results)
avg_time_per_image = total_time / num_processed if num_processed > 0 else 0

print("-" * 30)
print(f"Obrada završena za {total_time:.2f} s.")
print(f"Uspešno sačuvano {num_processed} slika.")
if num_processed > 0:
print(f"Prosečno vreme po slici: {avg_time_per_image:.2f} s.")

failed_count = len(image_paths_to_process) - num_processed
if failed_count > 0:
print(f"Neuspele/preskočene slike: {failed_count}")

print(f"Izlazni fajlovi u: {output_dir}")

if __name__ == "__main__":
main()