from langchain_core.prompts import ChatPromptTemplate
import os
from langchain_ollama import ChatOllama
from langchain_elasticsearch import ElasticsearchStore
from langchain_core.documents import Document
from typing_extensions import List, TypedDict
from langgraph.graph import START, StateGraph
from get_embedding_function import get_embedding_function
from dotenv import load_dotenv

load_dotenv()

ollama_url = os.getenv('OLLAMA_URL')
model = os.getenv('CHAT_MODEL')

elasticsearch_url = os.getenv('ELASTICSEARCH_URL')
index_name = os.getenv('ELASTICSEARCH_INDEX_NAME')
elasticsearch_user = os.getenv('ELASTICSEARCH_USER')
elasticsearch_password = os.getenv('ELASTICSEARCH_PASSWORD')
elasticsearch_strategy = os.getenv('ELASTICSEARCH_DISTANCE_STRATEGY')
elasticsearch_querry_field = os.getenv('ELASTICSEARCH_QUERRY_FIELD')

# Create embeddings and vector store
embeddings = get_embedding_function()

elastic_vector_search = ElasticsearchStore(
    es_url=elasticsearch_url,
    index_name=index_name,
    embedding=embeddings,
    es_user=elasticsearch_user,
    es_password=elasticsearch_password,
    distance_strategy=elasticsearch_strategy,
    strategy=ElasticsearchStore.ApproxRetrievalStrategy(),
    vector_query_field=elasticsearch_querry_field
)

# Initialize LLM
llm = ChatOllama(model=model, temperature=0.7, base_url=ollama_url)

# State definition
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

# Steps
def retrieve(state: State):
    retrieved_docs = elastic_vector_search.similarity_search(state["question"])
    return {"context": retrieved_docs}

# Use ChatPromptTemplate inside this step
def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    
    prompt_template = ChatPromptTemplate.from_messages([
        ("system", "Answer the question based on the context."),
        ("human", "Question: {question}\n\nContext:\n{context}")
    ])
    
    prompt_value = prompt_template.invoke({
        "question": state["question"],
        "context": docs_content
    })

    response = llm.invoke(prompt_value)
    return {"answer": response.content}

def chat(prompt: str) -> dict:
    # Build graph
    graph_builder = StateGraph(State).add_sequence([retrieve, generate])
    graph_builder.add_edge(START, "retrieve")
    graph = graph_builder.compile()

    # Invoke graph with dictionary input
    response = graph.invoke({"question": prompt})

    return response

