from langchain_core.prompts import ChatPromptTemplate
import sys
from langchain_ollama import ChatOllama
from langchain_elasticsearch import ElasticsearchStore
from langchain_core.documents import Document
from typing_extensions import List, TypedDict
from langgraph.graph import START, StateGraph
from get_embedding_function import get_embedding_function
import json

base_url = "http://localhost:11434"
model = "qwen2.5:32b"

# Create embeddings and vector store
embeddings = get_embedding_function()

elastic_vector_search = ElasticsearchStore(
    es_url="http://localhost:9200",
    index_name="rag_index",
    embedding=embeddings,
    es_user="elastic",
    es_password="HMkJyNKW",
    distance_strategy="COSINE",
    strategy=ElasticsearchStore.ApproxRetrievalStrategy(),
    vector_query_field="dense_vector"
)

# Initialize LLM
llm = ChatOllama(model=model, temperature=0.7, base_url=base_url)

# State definition
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

# Steps
def retrieve(state: State):
    retrieved_docs = elastic_vector_search.similarity_search(state["question"])
    return {"context": retrieved_docs}

# Use ChatPromptTemplate inside this step
def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    
    prompt_template = ChatPromptTemplate.from_messages([
        ("system", "Answer the question based on the context."),
        ("human", "Question: {question}\n\nContext:\n{context}")
    ])
    
    prompt_value = prompt_template.invoke({
        "question": state["question"],
        "context": docs_content
    })

    response = llm.invoke(prompt_value)
    return {"answer": response.content}


prompt = sys.argv[1]
# Build graph
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()

# Invoke graph with dictionary input
response = graph.invoke({"question": prompt})

print(response)

