from flask import Flask, request, jsonify, render_template, url_for
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications import mobilenet
import numpy as np
from io import BytesIO
from PIL import Image
import random
import torch
import torch.nn as nn
import io
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import pydicom
import cv2

app = Flask(__name__)


# =============== CẤU HÌNH NHANH ===============
ARTIFACT_DIR = r"./aimis_artifacts"   # thư mục chứa mô hình và nhãn
DICOM_PATH   = r"./sample_dicom/CT_000001"  # file DICOM (có thể không có phần mở rộng)
IMG_SIZE     = 224                      # kích thước đầu vào cho CNN
WINDOW_LEVEL = 40                       # WL cho não
WINDOW_WIDTH = 80                       # WW cho não
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
# ==============================================
def save_dicom_as_image(img2d, output_path, format='PNG'):
    """
    Lưu ảnh DICOM (numpy float32) thành PNG hoặc JPG.
    
    Args:
        img2d: numpy array (H, W) - float32
        output_path: đường dẫn file đầu ra (vd: 'output.png')
        format: 'PNG' hoặc 'JPG'
    """
    # 1. Chuẩn hóa về 0-255
    img_norm = img2d.copy()
    img_min, img_max = img_norm.min(), img_norm.max()
    
    if img_max > img_min:
        img_norm = (img_norm - img_min) / (img_max - img_min)  # 0.0 ~ 1.0
    else:
        img_norm = np.zeros_like(img_norm)  # nếu ảnh toàn cùng giá trị
    
    img_uint8 = (img_norm * 255).astype(np.uint8)

    # 2. Tạo ảnh PIL
    img_pil = Image.fromarray(img_uint8, mode='L')  # 'L' = grayscale

    # 3. Lưu
    img_pil.save(output_path, format=format)
    print(f"Đã lưu: {output_path}")

def load_label_map(artifact_dir):
    """Ưu tiên label_map.json; nếu không có thì tạo nhãn giả định 0..K-1 khi biết K từ mô hình."""
    lm_path = os.path.join(artifact_dir, "label_map.json")
    if os.path.isfile(lm_path):
        with open(lm_path, "r", encoding="utf-8") as f:
            lm = json.load(f)
        # Hỗ trợ hai dạng: {"0":"No_ICH","1":"ICH"} hoặc {"classes":["No_ICH","ICH"]}
        if isinstance(lm, dict) and "classes" in lm:
            idx2label = {i: name for i, name in enumerate(lm["classes"])}
        else:
            # convert keys to int safely
            idx2label = {int(k): v for k, v in lm.items()}
        return idx2label
    return None

def find_model_file(artifact_dir):
    """Tìm file mô hình khả dĩ trong thư mục artifacts."""
    # ưu tiên các tên phổ biến
    common = []
    for name in ["aimis_model.pt", "aimis.pt", "model.pt", "aimis_model.pth", "aimis.pth", "model.pth"]:
        p = os.path.join(artifact_dir, name)
        if os.path.isfile(p):
            common.append(p)
    if common:
        return common[0]
    # nếu không có, quét tất cả .pt/.pth
    candidates = glob.glob(os.path.join(artifact_dir, "*.pt")) + glob.glob(os.path.join(artifact_dir, "*.pth"))
    if not candidates:
        raise FileNotFoundError("Không tìm thấy file mô hình (.pt/.pth) trong aimis_artifacts.")
    # chọn file mới nhất
    candidates.sort(key=lambda p: os.path.getmtime(p), reverse=True)
    return candidates[0]

def window_ct(img, wl=40, ww=80):
    """Áp cửa sổ CT (WL/WW) rồi chuẩn hóa về [0,1]."""
    low = wl - ww // 2
    high = wl + ww // 2
    img = np.clip(img, low, high)
    img = (img - low) / max(1, (high - low))
    return img

def read_dicom(dcm_path):
    """Đọc DICOM, trả về numpy float32 đã scale RescaleSlope/Intercept nếu có."""
    ds = pydicom.dcmread(dcm_path)
    img = ds.pixel_array.astype(np.float32)
    slope = float(getattr(ds, "RescaleSlope", 1.0))
    intercept = float(getattr(ds, "RescaleIntercept", 0.0))
    img = img * slope + intercept
    return img

def preprocess_ct_for_cnn(img2d, img_size=224, wl=40, ww=80):
    """Tiền xử lý ảnh CT não cho mô hình CNN: WL/WW, resize, 3 kênh, chuẩn hóa."""
    img = window_ct(img2d, wl, ww).astype(np.float32)  # [0,1]
    # Nếu ảnh quá nhỏ/lệch tỉ lệ, dùng resize với giữ tỉ lệ + pad
    h, w = img.shape
    scale = img_size / max(h, w)
    nh, nw = int(h * scale), int(w * scale)
    img_resized = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
    canvas = np.zeros((img_size, img_size), dtype=np.float32)
    top = (img_size - nh) // 2
    left = (img_size - nw) // 2
    canvas[top:top+nh, left:left+nw] = img_resized
    # lặp thành 3 kênh cho mô hình pretrained nếu cần
    img3 = np.stack([canvas, canvas, canvas], axis=0)   # (3, H, W)
    # normalize kiểu ImageNet để tương thích nhiều backbone
    mean = np.array([0.485, 0.456, 0.406]).reshape(3,1,1)
    std  = np.array([0.229, 0.224, 0.225]).reshape(3,1,1)
    img3 = (img3 - mean) / std
    tensor = torch.from_numpy(img3).float().unsqueeze(0)  # (1,3,H,W)
    return tensor

def safe_softmax(logits):
    try:
        return F.softmax(logits, dim=1)
    except Exception:
        return logits

def try_load_state_dict(model_path):
    """Cố gắng load state_dict nếu không phải TorchScript."""
    ckpt = torch.load(model_path, map_location="cpu")
    # Hỗ trợ kiểu {"state_dict":..., "num_classes":..., "arch":...}
    state_dict = None
    meta = {}
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
        meta = {k:v for k,v in ckpt.items() if k != "state_dict"}
    elif isinstance(ckpt, dict) and all(isinstance(k, str) for k in ckpt.keys()):
        # có thể chính là state_dict
        state_dict = ckpt
    else:
        raise ValueError("Không nhận diện được định dạng checkpoint để load state_dict.")
    return state_dict, meta

def build_fallback_model(num_classes=2):
    """Mô hình tối giản để load state_dict (kiến trúc phải khớp mới dùng được).
       Nếu state_dict không khớp, người dùng nên thay bằng class model đúng kiến trúc."""
    import torchvision.models as models
    m = models.resnet18(weights=None)
    m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
    return m

def load_model(artifact_dir):
    model_path = find_model_file(artifact_dir)
    # Thử TorchScript trước
    try:
        model = torch.jit.load(model_path, map_location=DEVICE)
        model.eval()
        model.to(DEVICE)
        meta = {"type": "torchscript", "path": model_path}
        return model, meta
    except Exception:
        pass
    # Thử state_dict
    state_dict, meta = try_load_state_dict(model_path)
    # Suy đoán số lớp
    num_classes = meta.get("num_classes", None)
    if num_classes is None:
        # đoán bằng cách tìm tham số lớp cuối
        last_keys = [k for k in state_dict.keys() if "fc" in k or "classifier" in k or "head" in k]
        guessed = 2
        for k in last_keys:
            w = state_dict[k]
            if w.ndim == 2:  # weight linear [out, in]
                guessed = w.shape[0]
                break
        num_classes = guessed
    model = build_fallback_model(num_classes=num_classes)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    model.to(DEVICE)
    meta["type"] = "state_dict"
    meta["path"] = model_path
    meta["num_classes"] = num_classes
    return model, meta

def predict_one(dicom_path, artifact_dir=ARTIFACT_DIR):
    # 1) Nạp mô hình + nhãn
    model, meta = load_model(artifact_dir)
    idx2label = load_label_map(artifact_dir)
    # 2) Đọc DICOM
    img2d = read_dicom(dicom_path)
    save_dicom_as_image(img2d, dicom_path + ".png", format='PNG')
    # 3) Tiền xử lý
    x = preprocess_ct_for_cnn(img2d, IMG_SIZE, WINDOW_LEVEL, WINDOW_WIDTH).to(DEVICE)
    # 4) Suy luận
    with torch.no_grad():
        logits = model(x)
        if isinstance(logits, (list, tuple)):
            logits = logits[0]
        probs = safe_softmax(logits)
        conf, pred_idx = torch.max(probs, dim=1)
    pred_idx = int(pred_idx.item())
    conf = float(conf.item())
    # 5) Ánh xạ nhãn
    if idx2label is not None and pred_idx in idx2label:
        label_text = idx2label[pred_idx]
    else:
        label_text = f"class_{pred_idx}"
    return label_text, conf, meta


# trang chủ
@app.route('/')
def index():
    return render_template('index.html')

#xử lý dự đoán
@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400
    
    file = request.files['image']
    if file.filename == '':
        return jsonify({'error': 'No image selected'}), 400
    
    try:
        temp_path = os.path.join("static/uploads", file.filename)
        file.save(temp_path)


        label_text, conf, meta = predict_one(temp_path, ARTIFACT_DIR)
        

        # Kết quả
        result = {
            'label': label_text,  
            'accur': random.randint(0, 100),
            'image_url': file.filename + '.png',
        }
        return jsonify(result)
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=5002)