# ============================================================
# STRATEGY 2 - SHORT STRADDLE (DB INTEGRATED SIMULATOR)
# FULL SAFE VERSION
# ============================================================

import time
import sqlite3
import logging
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime, timedelta, time as dtime
import pytz
import pandas as pd
from functools import lru_cache
from fyers_apiv3 import fyersModel

# ============================================================
# SETTINGS
# ============================================================

@dataclass
class Settings:
    db_path: Path = Path("stock2.db")
    check_interval: int = 60
    nifty_symbol: str = "NSE:NIFTY50-INDEX"
    strike_count: int = 9
    lots: int = 1
    trade_start: dtime = dtime(9, 20)
    trade_end: dtime = dtime(15, 20)
    target: float = 5000
    stop: float = -5000

settings = Settings()

# ============================================================
# LOGGING
# ============================================================

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
)

logger = logging.getLogger("Strategy2-DB")

IST = pytz.timezone("Asia/Kolkata")

def now_ist():
    return datetime.now(IST)

def within_trading_window():
    now = now_ist().time()
    return settings.trade_start <= now <= settings.trade_end

# ============================================================
# DATABASE LAYER
# ============================================================

class Database:

    def __init__(self):
        self.conn = sqlite3.connect(settings.db_path, check_same_thread=False)
        self.conn.row_factory = sqlite3.Row

    def get_active_user(self):
        return self.conn.execute(
            "SELECT * FROM users WHERE status = 1 LIMIT 1"
        ).fetchone()

    def get_strategy(self, user_id):
        return self.conn.execute(
            "SELECT * FROM strategies WHERE user_id = ? AND status = 1 LIMIT 1",
            (user_id,)
        ).fetchone()

    def get_fyers_credentials(self, user_id):
        return self.conn.execute(
            "SELECT client_id, access_token FROM fyers WHERE user_id = ? LIMIT 1",
            (user_id,)
        ).fetchone()

    def update_wallet(self, user_id, new_balance):
        self.conn.execute(
            "UPDATE users SET wallet_amount = ?, updated_at = ? WHERE id = ?",
            (new_balance, now_ist(), user_id)
        )
        self.conn.commit()

    def insert_order(self, user_id, strategy_id, symbol, qty, price, balance, signal):
        self.conn.execute(
            """
            INSERT INTO orders
            (user_id, strategy_id, signal_type, balance,
             trade_qty, trade_price, trade_value,
             symbol, created_at, status)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (
                user_id,
                strategy_id,
                signal,
                balance,
                qty,
                price,
                qty * price,
                symbol,
                now_ist(),
                1
            )
        )
        self.conn.commit()

    def update_daily_pnl(self, pnl):

        today = now_ist().date()

        row = self.conn.execute(
            "SELECT * FROM daily_pnl WHERE date = ?",
            (today,)
        ).fetchone()

        if row:
            self.conn.execute(
                """
                UPDATE daily_pnl
                SET total_pnl = total_pnl + ?,
                    total_trades = total_trades + 1,
                    winning_trades = winning_trades + ?,
                    losing_trades = losing_trades + ?
                WHERE date = ?
                """,
                (
                    pnl,
                    1 if pnl > 0 else 0,
                    1 if pnl <= 0 else 0,
                    today
                )
            )
        else:
            self.conn.execute(
                """
                INSERT INTO daily_pnl
                (date, total_pnl, total_trades, winning_trades, losing_trades)
                VALUES (?, ?, ?, ?, ?)
                """,
                (
                    today,
                    pnl,
                    1,
                    1 if pnl > 0 else 0,
                    1 if pnl <= 0 else 0
                )
            )

        self.conn.commit()

# ============================================================
# FYERS CLIENT
# ============================================================

class FyersClient:

    def __init__(self, db, user_id):
        creds = db.get_fyers_credentials(user_id)
        if not creds:
            raise Exception("Fyers credentials not found")

        self.client = fyersModel.FyersModel(
            token=creds["access_token"],
            client_id=creds["client_id"],
            is_async=False
        )

    def get(self):
        return self.client

# ============================================================
# LOT SIZE
# ============================================================

@lru_cache(maxsize=1)
def load_lot_map():
    url = "https://public.fyers.in/sym_details/NSE_FO.csv"
    df = pd.read_csv(url, header=None)
    return dict(zip(df[9], df[3]))

def get_lot_size(symbol):
    return int(load_lot_map().get(symbol, 1))

# ============================================================
# TRADER
# ============================================================

class Trader:

    def __init__(self):

        self.db = Database()
        self.user = self.db.get_active_user()

        if not self.user:
            raise Exception("No active user found")

        self.strategy = self.db.get_strategy(self.user["id"])
        self.fyers = FyersClient(self.db, self.user["id"]).get()

        self.balance = float(self.user["wallet_amount"] or 0)
        self.entry_df = None

    # --------------------------------------------------------

    def safe_option_chain(self):

        try:
            response = self.fyers.optionchain({
                "symbol": settings.nifty_symbol,
                "strikecount": settings.strike_count
            })

            if response.get("s") != "ok":
                logger.error(f"OptionChain Error: {response}")
                return None

            if "data" not in response or "optionsChain" not in response["data"]:
                logger.error("Invalid OptionChain Structure")
                return None

            return pd.DataFrame(response["data"]["optionsChain"])

        except Exception:
            logger.exception("OptionChain Exception")
            return None

    # --------------------------------------------------------

    def enter_trade(self):

        logger.info("Entering trade (SIM)")

        df = self.safe_option_chain()
        if df is None or df.empty:
            return

        underlying_rows = df[df["strike_price"] == -1]

        if underlying_rows.empty:
            logger.error("Underlying not found")
            return

        underlying = underlying_rows["ltp"].iloc[0]

        strikes = df["strike_price"].unique()
        strikes = strikes[strikes != -1]

        atm = min(strikes, key=lambda x: abs(x - underlying))

        call_atm = df[(df["strike_price"] == atm) & (df["option_type"] == "CE")]
        put_atm = df[(df["strike_price"] == atm) & (df["option_type"] == "PE")]

        if call_atm.empty or put_atm.empty:
            logger.error("ATM options not found")
            return

        basket = pd.concat([call_atm.iloc[[0]], put_atm.iloc[[0]]]).copy()

        basket["lot_size"] = basket["symbol"].apply(get_lot_size)
        basket["qty"] = basket["lot_size"] * settings.lots
        basket["entry_price"] = basket["ltp"]

        self.entry_df = basket
        self.entry_time = now_ist()

        logger.info("Short Straddle Entered")

    # --------------------------------------------------------

    def monitor_trade(self):

        symbols = ",".join(self.entry_df["symbol"].tolist())

        try:
            response = self.fyers.quotes({"symbols": symbols})
        except Exception:
            logger.exception("Quotes API failed")
            return

        if "d" not in response:
            logger.error(f"Invalid quotes response: {response}")
            return

        prices = {item["n"]: item["v"]["lp"] for item in response["d"]}

        pnl = 0

        for _, row in self.entry_df.iterrows():
            if row["symbol"] not in prices:
                continue
            pnl += (row["entry_price"] - prices[row["symbol"]]) * row["qty"]

        logger.info(f"P/L: {pnl}")

        if pnl >= settings.target or pnl <= settings.stop:
            self.exit_trade(pnl)

    # --------------------------------------------------------

    def exit_trade(self, pnl):

        self.balance += pnl

        for _, row in self.entry_df.iterrows():
            self.db.insert_order(
                user_id=self.user["id"],
                strategy_id=self.strategy["id"],
                symbol=row["symbol"],
                qty=row["qty"],
                price=row["entry_price"],
                balance=self.balance,
                signal="EXIT"
            )

        self.db.update_wallet(self.user["id"], self.balance)
        self.db.update_daily_pnl(pnl)

        logger.info(f"Trade exited. New balance: {self.balance}")

        self.entry_df = None

    # --------------------------------------------------------

    def run(self):

        logger.info("DB Integrated Trader Started")

        while True:

            try:

                if not within_trading_window():
                    time.sleep(60)
                    continue

                if self.entry_df is None:
                    self.enter_trade()
                else:
                    self.monitor_trade()

                time.sleep(settings.check_interval)

            except Exception:
                logger.exception("Main Loop Error")
                time.sleep(30)

# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    trader = Trader()
    trader.run()
