#!/usr/bin/env python3
"""
RAG CLI – pure Python, keine pip-Abhängigkeiten.

Verwendung:
  python rag.py <verzeichnis> <dateifilter>

Beispiele:
  python rag.py ~/projekte "*.py"
  python rag.py ~/dokumente "*.txt"
  python rag.py ~/daten "*.csv"
"""

# ──────────────────────────────────────────────────────────────────
# KONFIGURATION
# ──────────────────────────────────────────────────────────────────

# llama.cpp Binary
LLAMA_SERVER_BIN  = "~/llama.cpp/build/bin/llama-server"

# Modelle
CHAT_MODEL        = "~/models/qwen2.5-coder-14b-instruct-q4_k_m.gguf"
EMBED_MODEL       = "~/models/nomic-embed-text-v2-moe.Q8_0.gguf"

# Server
CHAT_PORT         = 8080
EMBED_PORT        = 8081
GPU_LAYERS        = 99        # Layer auf GPU laden; 0 = nur CPU
SERVER_TIMEOUT    = 60        # Sekunden warten bis Server bereit

# Server-Autostart: True = Skript startet llama-server selbst
#                   False = Server muss bereits laufen
AUTOSTART_SERVER  = True

# Index
INDEX_ROOT        = "~/.ragindex"   # Wo der Index gespeichert wird

# Chunking
CHUNK_SIZE        = 500      # Zeichen pro Chunk
CHUNK_OVERLAP     = 50       # Überlappung zwischen Chunks

# Retrieval
TOP_K             = 4        # Wie viele Chunks als Kontext ans Modell

# Chat-Prompt – {context} und {question} werden ersetzt
SYSTEM_PROMPT = "Du bist ein hilfreicher Assistent. Beantworte die Frage ausschließlich auf Basis des folgenden Kontexts."
# ──────────────────────────────────────────────────────────────────

import sys
import os
import json
import math
import time
import signal
import atexit
import hashlib
import sqlite3
import fnmatch
import subprocess
import urllib.request
import urllib.error
from pathlib import Path


# ── Pfade expandieren ─────────────────────────────────────────────

BIN   = Path(LLAMA_SERVER_BIN).expanduser()
CHAT  = Path(CHAT_MODEL).expanduser()
EMBD  = Path(EMBED_MODEL).expanduser()
IROOT = Path(INDEX_ROOT).expanduser()

CHAT_URL  = f"http://127.0.0.1:{CHAT_PORT}"
EMBED_URL = f"http://127.0.0.1:{EMBED_PORT}"

# ── Server-Verwaltung ─────────────────────────────────────────────

_procs: list = []


def _start_server(model: Path, port: int, embeddings: bool):
    cmd = [
        str(BIN),
        "-m", str(model),
        "--port", str(port),
        "-ngl", str(GPU_LAYERS),
        "--host", "127.0.0.1",
    ]
    if embeddings:
        cmd.append("--embeddings")
    label = "embed" if embeddings else "chat"
    print(f"[*] Starte {label}-server (Port {port}) ...")
    proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    _procs.append(proc)


def _wait_for_server(base_url: str) -> bool:
    deadline = time.time() + SERVER_TIMEOUT
    while time.time() < deadline:
        try:
            urllib.request.urlopen(f"{base_url}/health", timeout=2)
            return True
        except Exception:
            time.sleep(1)
    return False


def _stop_servers():
    if not _procs:
        return
    print("[*] Stoppe Server ...")
    for p in _procs:
        try:
            p.terminate()
            p.wait(timeout=5)
        except Exception:
            p.kill()


# ── HTTP-Helpers ──────────────────────────────────────────────────

def _post(url: str, payload: dict) -> dict:
    data = json.dumps(payload).encode()
    req  = urllib.request.Request(
        url, data=data, headers={"Content-Type": "application/json"}
    )
    with urllib.request.urlopen(req, timeout=120) as resp:
        return json.loads(resp.read())


def get_embedding(text: str) -> list:
    result = _post(f"{EMBED_URL}/v1/embeddings", {
        "model": "local",
        "input": text,
    })
    return result["data"][0]["embedding"]


def chat_completion(context: str, question: str) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": f"Kontext:\n{context}\n\nFrage: {question}"},
    ]
    result = _post(f"{CHAT_URL}/v1/chat/completions", {
        "model": "local",
        "messages": messages,
        "max_tokens": 1024,
        "temperature": 0.2,
    })
    return result["choices"][0]["message"]["content"].strip()


# ── Datei-Lesen ───────────────────────────────────────────────────

def read_file(path: Path):
    suffix = path.suffix.lower()
    try:
        if suffix in (".txt", ".md", ".py", ".go", ".js", ".ts", ".sh",
                      ".yaml", ".yml", ".toml", ".ini", ".cfg", ".json"):
            return path.read_text(errors="replace")

        if suffix == ".csv":
            import csv
            rows = []
            with path.open(newline="", errors="replace") as f:
                reader = csv.reader(f)
                for i, row in enumerate(reader):
                    rows.append(", ".join(row))
                    if i > 5000:
                        rows.append("... (gekürzt)")
                        break
            return "\n".join(rows)

    except Exception as e:
        print(f"  [!] Konnte {path} nicht lesen: {e}")
    return None


# ── Chunking ──────────────────────────────────────────────────────

def chunk_text(text: str, source: str) -> list:
    chunks = []
    start  = 0
    while start < len(text):
        end   = min(start + CHUNK_SIZE, len(text))
        chunk = text[start:end].strip()
        if chunk:
            chunks.append({"text": chunk, "source": source})
        start += CHUNK_SIZE - CHUNK_OVERLAP
    return chunks


# ── Index (SQLite) ────────────────────────────────────────────────

def cosine_similarity(a: list, b: list) -> float:
    dot = sum(x * y for x, y in zip(a, b))
    na  = math.sqrt(sum(x * x for x in a))
    nb  = math.sqrt(sum(x * x for x in b))
    return dot / (na * nb) if na and nb else 0.0


def get_db(index_path: Path) -> sqlite3.Connection:
    db = sqlite3.connect(index_path / "index.db")
    db.execute("""
        CREATE TABLE IF NOT EXISTS chunks (
            id      INTEGER PRIMARY KEY,
            source  TEXT,
            text    TEXT,
            vector  TEXT
        )
    """)
    db.execute("CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT)")
    db.commit()
    return db


def _get_mtimes(directory: Path, file_filter: str) -> dict:
    result = {}
    for path in sorted(directory.rglob("*")):
        if path.is_file() and fnmatch.fnmatch(path.name, file_filter):
            result[str(path)] = path.stat().st_mtime
    return result


def index_is_current(db: sqlite3.Connection, directory: Path, file_filter: str) -> bool:
    row = db.execute("SELECT value FROM meta WHERE key='file_mtimes'").fetchone()
    if not row:
        return False
    return json.loads(row[0]) == _get_mtimes(directory, file_filter)


def build_index(db: sqlite3.Connection, directory: Path, file_filter: str):
    files = [p for p in sorted(directory.rglob("*"))
             if p.is_file() and fnmatch.fnmatch(p.name, file_filter)]

    if not files:
        print(f"[!] Keine Dateien für Filter '{file_filter}' in {directory} gefunden.")
        sys.exit(1)

    print(f"[*] {len(files)} Datei(en) gefunden, indexiere ...")
    db.execute("DELETE FROM chunks")
    db.commit()

    total = 0
    for path in files:
        text = read_file(path)
        if not text:
            continue
        chunks = chunk_text(text, str(path))
        print(f"  {path.name}: {len(chunks)} Chunk(s)", end="", flush=True)
        for chunk in chunks:
            vec = get_embedding(chunk["text"])
            db.execute(
                "INSERT INTO chunks (source, text, vector) VALUES (?, ?, ?)",
                (chunk["source"], chunk["text"], json.dumps(vec))
            )
            print(".", end="", flush=True)
        db.commit()
        total += len(chunks)
        print()

    mtimes = _get_mtimes(directory, file_filter)
    db.execute("INSERT OR REPLACE INTO meta VALUES ('file_mtimes', ?)", (json.dumps(mtimes),))
    db.commit()
    print(f"[*] Index fertig: {total} Chunks aus {len(files)} Datei(en).")


def retrieve(db: sqlite3.Connection, query: str) -> list:
    q_vec  = get_embedding(query)
    rows   = db.execute("SELECT text, vector FROM chunks").fetchall()
    scored = sorted(
        ((cosine_similarity(q_vec, json.loads(v)), t) for t, v in rows),
        reverse=True
    )
    return [text for _, text in scored[:TOP_K]]


# ── Chat-Loop ─────────────────────────────────────────────────────

def chat_loop(db: sqlite3.Connection):
    print("\nBereit. Fragen stellen (exit / Ctrl+C zum Beenden):\n")
    while True:
        try:
            question = input("Du: ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\nTschüss.")
            break
        if not question:
            continue
        if question.lower() in ("exit", "quit", "q"):
            print("Tschüss.")
            break

        chunks   = retrieve(db, question)
        context  = "\n\n---\n\n".join(chunks)
        response = chat_completion(context, question)
        print(f"\nAssistent: {response}\n")


# ── Main ──────────────────────────────────────────────────────────

def main():
    if len(sys.argv) < 3:
        print("Verwendung: python rag.py <verzeichnis> <dateifilter>")
        print('Beispiel:   python rag.py ~/projekte "*.py"')
        sys.exit(1)

    directory   = Path(sys.argv[1]).expanduser().resolve()
    file_filter = sys.argv[2]
    force_reindex = "--reindex" in sys.argv

    if not directory.is_dir():
        print(f"Fehler: '{directory}' ist kein Verzeichnis.")
        sys.exit(1)

    if AUTOSTART_SERVER:
        for path, label in [(BIN, "llama-server"), (CHAT, "Chat-Modell"), (EMBD, "Embedding-Modell")]:
            if not path.exists():
                print(f"Fehler: {label} nicht gefunden: {path}")
                sys.exit(1)

        atexit.register(_stop_servers)
        signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))

        _start_server(EMBD, EMBED_PORT, embeddings=True)
        _start_server(CHAT, CHAT_PORT,  embeddings=False)

        print("[*] Warte auf Server ...")
        for url, label in [(EMBED_URL, "Embed"), (CHAT_URL, "Chat")]:
            if not _wait_for_server(url):
                print(f"Fehler: {label}-Server nicht erreichbar ({url})")
                sys.exit(1)
        print("[*] Beide Server bereit.")
    else:
        print("[*] Verwende laufende Server ...")

    h = hashlib.md5(f"{directory}{file_filter}".encode()).hexdigest()[:10]
    index_path = IROOT / f"{directory.name}_{h}"
    index_path.mkdir(parents=True, exist_ok=True)

    db = get_db(index_path)

    if force_reindex or not index_is_current(db, directory, file_filter):
        if not force_reindex:
            print("[*] Neue oder geänderte Dateien erkannt, indexiere neu ...")
        build_index(db, directory, file_filter)
    else:
        count = db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
        print(f"[*] Index aktuell ({count} Chunks). Mit --reindex neu aufbauen.")

    chat_loop(db)


if __name__ == "__main__":
    main()
