from submit import submission
import json
from openai import OpenAI
from zhipuai import ZhipuAI
import qianfan
import requests
import os
import sys
import time
import json
import types
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
# 获取当前工作目录
print(os.path.dirname(__file__))
sys.path.append(os.path.dirname(__file__))
# 本代码为测试用例，选手可自己调用本地大语言模型接口进行测试，具体的格式为OpenAI格式的接口形式
client = ZhipuAI(api_key="***") # 请填写您自己的APIKey
# 安装包(Python >= 3.7)：pip install qianfan

os.environ["QIANFAN_AK"] = "***"
os.environ["QIANFAN_SK"] = "***"
from volcenginesdkarkruntime import Ark
import os
os.environ["ARK_API_KEY"] ="***"
client = Ark(
    base_url="https://ark.cn-beijing.volces.com/api/v3",
)

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage

#星火认知大模型Spark Max的URL值，其他版本大模型URL值请前往文档（https://www.xfyun.cn/doc/spark/Web.html）查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v4.0/chat'
#星火认知大模型调用秘钥信息，请前往讯飞开放平台控制台（https://console.xfyun.cn/services/bm35）查看
SPARKAI_APP_ID = 'e7575959'
SPARKAI_API_SECRET = '***'
SPARKAI_API_KEY = '***'
#星火认知大模型Spark Max的domain值，其他版本大模型domain值请前往文档（https://www.xfyun.cn/doc/spark/Web.html）查看
SPARKAI_DOMAIN = '4.0Ultra'
spark = ChatSparkLLM(
    spark_api_url=SPARKAI_URL,
    spark_app_id=SPARKAI_APP_ID,
    spark_api_key=SPARKAI_API_KEY,
    spark_api_secret=SPARKAI_API_SECRET,
    spark_llm_domain=SPARKAI_DOMAIN,
    streaming=False,
)

def evaluate_mcq(predicted_answer, label):
    return predicted_answer.upper().strip() == label.upper().strip()

def evaluate_sql(predicted_answer, label):
    print("****该评估方法仅提供一个虚拟评测用例，仅提供给选手进行调试使用。实际评测时，正确性则由SQL执行结果是否正确的判断。****")
    print(f"选手Prompt生成的SQL:{predicted_answer},\n标准答案：{label}")
    return predicted_answer.upper().strip() == label.upper().strip()

from dashscope import Generation

class eval_submission(submission):
    def parse_table(self, table_meta_path):
        with open(table_meta_path,'r') as db_meta:
            db_meta_info = json.load(db_meta)
        # 创建一个空字典来存储根据 db_id 分类的数据
        grouped_by_db_id = {}

        # 遍历列表中的每个字典
        for item in db_meta_info:
            # 获取当前字典的 db_id
            db_id = item['db_id']
            
            # 如果 db_id 已存在于字典中，将当前字典追加到对应的列表
            if db_id in grouped_by_db_id:
                grouped_by_db_id[db_id].append(item)
            # 如果 db_id 不在字典中，为这个 db_id 创建一个新列表
            else:
                grouped_by_db_id[db_id] = [item]
        return grouped_by_db_id

    def run_inference_llm(self, constructed_prompts):
        if isinstance(constructed_prompts, str):
            pass
        elif isinstance(constructed_prompts, list):
            pass
        elif isinstance(constructed_prompts, dict):
            pass
        messages = constructed_prompts
        # print(f"调用模型进行测试,用户的prompt为：{constructed_prompts}")
        # llm_response = client.chat.completions.create(
        #     messages=messages,
        #     model="gpt-3.5-turbo",  # 这里填写所选择的LLM名称，推荐使用Qwen1.5-14B-Chat模型进行线下开发评估 Qwen1.5-14B-Chat, gpt-3.5-turbo
        #     max_tokens=2048,
        #     temperature= 0.0,           #  temperature 实际跑分时默认设置为0.0，避免随机性
        #     stream=False  
        # )
        # llm_outputs = llm_response.choices[0].message.content
        """ 
        qwen-max
        baichuan2-turbo
        qwen1.5-14b-chat
        baichuan2-turbo
        moonshot-v1-8k
        chatglm3-6b
        """
        
        # SecretId='***' 
        # SecretKey='***'
        # # 实例化一个认证对象，入参需要传入腾讯云账户 SecretId 和 SecretKey，此处还需注意密钥对的保密
        # # 代码泄露可能会导致 SecretId 和 SecretKey 泄露，并威胁账号下所有资源的安全性。以下代码示例仅供参考，建议采用更安全的方式来使用密钥，请参见：https://cloud.tencent.com/document/product/1278/85305
        # # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
        # cred = credential.Credential(f"{SecretId}", f"{SecretKey}")
        # # 实例化一个http选项，可选的，没有特殊需求可以跳过
        # httpProfile = HttpProfile()
        # httpProfile.endpoint = "hunyuan.tencentcloudapi.com"

        # # 实例化一个client选项，可选的，没有特殊需求可以跳过
        # clientProfile = ClientProfile()
        # clientProfile.httpProfile = httpProfile
        # # 实例化要请求产品的client对象,clientProfile是可选的
        # client = hunyuan_client.HunyuanClient(cred, "", clientProfile)

        # # 实例化一个请求对象,每个接口都会对应一个request对象
        # req = models.ChatCompletionsRequest()
        # params = {
        #     "Model": "hunyuan-pro",
        #     "Messages": [
        #         {
        #             "Role": "system",
        #             "Content": messages[0]['content']
        #         },
        #         {
        #             "Role": "user",
        #             "Content": messages[1]['content']
        #         },
        #     ]
        # }
        # req.from_json_string(json.dumps(params))
        # resp = client.ChatCompletions(req)
        # return resp.Choices[0].Message.Content
    
    
        # completion = client.chat.completions.create(
        #     model="ep-20240730121736-vwd7p",
        #     messages = messages,
        # )
        # return completion.choices[0].message.content
    
        # messages = [ChatMessage(
        #     role=messages[0]['role'],
        #     content=messages[0]['content']
        # )]
        # handler = ChunkPrintHandler()
        # res = spark.generate([messages], callbacks=[handler])
        # return res.generations[0][0].message.content

        # resp = qianfan.ChatCompletion().do(endpoint="ernie-4.0-8k-latest", messages=messages) # ernie-speed-128k  ernie-4.0-8k-latest
        # return resp.body['result']
    
    
        # response = client.chat.completions.create(
        #     model="charglm-3",  # 填写需要调用的模型名称  glm-4-0520  glm-4 glm-4-air  glm-4-airx  glm-4-flash glm-4v  glm-4v  charglm-3
        #     messages=messages,
        #     # stream=True,
        #     )
        # # prompt 优化
        # # COT 尝试few shot优化
        # # print(response)
        # llm_outputs = response.choices[0].message.content
        # return llm_outputs
        
        
        
        response = Generation.call("qwen1.5-14b-chat", # qwen1.5-14b-chat yi-large-turbo baichuan2-turbo
                            messages=messages,
                            result_format='message',  # 设置输出为'message'格式
                            # stream=True, # 设置输出方式为流式输出
                            # incremental_output=True,  # 增量式流式输出
                            api_key='***'
                            )
        # prompt 优化
        # COT 尝试few shot优化
        print(response)
        llm_outputs = response.output.choices[0].message.content
        return llm_outputs
if __name__ == "__main__":
    user_submission = eval_submission(table_meta_path = "样例数据/sample_tables.json")
    data_type = 'small' # all small
    if data_type == 'all':
        question_jsonl_filename = "android_malware_detection/sample_question_all.jsonl"
        gt_jsonl_filename = "android_malware_detection/sample_answer_all.jsonl"
    else:
        question_jsonl_filename = "android_malware_detection/sample_question.jsonl"
        gt_jsonl_filename = "android_malware_detection/sample_answer.jsonl"
    answer_dict = {}

    num_questions, num_correct_answer = 0, 0
    with open(gt_jsonl_filename, 'r', encoding='utf-8') as gt_file:
        for line in gt_file:
            data = json.loads(line)
            answer_dict[data['question_id']]=data['answer']
            num_questions += 1
    
    with open(question_jsonl_filename, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            question_id = data['question_id']
            question_type = data['question_type']
            # if question_type != "true_false_question":
            #     continue
            # if question_type == "text2sql":
            #     continue
            print(data)
            message = user_submission.construct_prompt(data)
            response = user_submission.run_inference_llm(message)
            if question_type == "text2sql":
                db_name = data['db_id']
                label = answer_dict[question_id]
                judge_result = evaluate_sql(response, label)

            if question_type == "multiple_choice":
                label = answer_dict[question_id]
                judge_result = evaluate_mcq(response, label)

            if question_type == "true_false_question":
                label = answer_dict[question_id]
                judge_result = evaluate_mcq(response, label)
                
            if question_type == "android_malware_detection":
                label = answer_dict[question_id]
                judge_result = str(response).strip() == str(label).strip()

            if judge_result:
                print("right")
                num_correct_answer += 1
            else:
                print(response, type(response))
            time.sleep(1)

    print(f"Accuracy:{num_correct_answer/num_questions}")