from flask import Flask, request, jsonify
from flask_jwt_extended import JWTManager, create_access_token, create_refresh_token, get_jwt_identity, jwt_required, get_jwt
from flask_sqlalchemy import SQLAlchemy
import bcrypt
from datetime import timedelta
from dotenv import load_dotenv
import os
from crida_api_rag import chat
from introduccio_docs import afegirDoc

app = Flask(__name__)

load_dotenv()

app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')
app.config["JWT_SECRET_KEY"] = os.getenv('JWT_SECRET_KEY')
app.config['JWT_TOKEN_LOCATION'] = ['headers']
app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv('SQLALCHEMY_DATABASE_URI')
app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(minutes=int(os.getenv('JWT_ACCESS_TOKEN_EXPIRES_MINUTES')))
app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(days=int(os.getenv('JWT_REFRESH_TOKEN_EXPIRES_DAYS')))

jwt = JWTManager(app)
db = SQLAlchemy(app) 

class User(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(20), unique=True, nullable=False)
    password = db.Column(db.String(80), nullable=False)
    is_active = db.Column(db.Boolean(), default=True)

    def __repr__(self):
        return f'<User {self.username}>'

class RevokedToken(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    jti = db.Column(db.String(120), unique=True, nullable=False)

    def __repr__(self):
        return f'<RevokedToken {self.jti}>'

@app.route('/register', methods=['POST'])
def register():
    username = request.json.get('username')
    password = request.json.get('password')

    if not username or not password:
        return jsonify({"msg": "Missing username or password"}), 400

    if User.query.filter_by(username=username).first():
        return jsonify({"msg": "User already exists"}), 400

    hashed_pw = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
    user = User(username=username, password=hashed_pw.decode('utf-8'))
    db.session.add(user)
    db.session.commit()
    return jsonify({"msg": "User registered successfully"}), 201

@app.route('/login', methods=['POST'])
def login():
    data = request.get_json()
    username = data['username']
    password = data['password']

    user = User.query.filter_by(username=username).first()
    if user and bcrypt.checkpw(password.encode('utf-8'), user.password.encode('utf-8')):
        access_token = create_access_token(identity=user.id)
        refresh_token = create_refresh_token(identity=user.id)
        return jsonify({'message': 'Login Success', 'access_token': access_token, 'refresh_token': refresh_token}), 200
    else:
        return jsonify({'message': 'Login Failed'}), 401

@app.route('/refresh', methods=['POST'])
@jwt_required(refresh=True)
def refresh():
    current_user_id = get_jwt_identity()
    new_access_token = create_access_token(identity=current_user_id)
    return jsonify({'access_token': new_access_token}), 200

@app.before_request
@jwt.token_in_blocklist_loader
def check_if_token_revoked(jwt_header, jwt_payload):
    jti = jwt_payload["jti"]
    if RevokedToken.query.filter_by(jti=jti).first():
        return jsonify({"msg": "Token has been revoked"}), 401

@app.route('/logout', methods=['DELETE'])
@jwt_required()
def logout():
    jti = get_jwt()['jti']
    revoked_token = RevokedToken(jti=jti)
    db.session.add(revoked_token)
    db.session.commit()
    return jsonify({"msg": "Successfully logged out"}), 200


@app.route('/chat', methods=['POST'])
@jwt_required()
def chat_endpoint():
    current_user = get_jwt_identity()
    if not User.query.get(current_user):
        return jsonify({"msg": "User not found"}), 404

    prompt = request.json.get('prompt')
    if not prompt:
        return jsonify({"msg": "Missing prompt"}), 400

    response = chat(prompt)
    return jsonify(response), 200

@app.route('/add_documents', methods=['POST'])
@jwt_required()
def add_documents():
    current_user = get_jwt_identity()
    if not User.query.get(current_user):
        return jsonify({"msg": "User not found"}), 404

    data_path = request.json.get('data_path')
    if not data_path:
        return jsonify({"msg": "Missing data_path"}), 400

    result = afegirDoc(data_path)
    if result == "success":
        return jsonify({"msg": "Documents added successfully"}), 200
    elif result == "nothing_new_added":
        return jsonify({"msg": "No new documents to add"}), 200
    else:
        return jsonify({"msg": result}), 400

if __name__ == '__main__':
    with app.app_context():
        db.create_all()
    app.run(debug=True)