early-access version 3828

This commit is contained in:
pineappleEA 2023-08-22 19:22:27 +02:00
parent af2061a2c2
commit 3e1870d567
18 changed files with 506 additions and 213 deletions

View file

@ -1,7 +1,7 @@
yuzu emulator early access yuzu emulator early access
============= =============
This is the source code for early-access 3827. This is the source code for early-access 3828.
## Legal Notice ## Legal Notice

View file

@ -160,6 +160,11 @@ android {
} }
} }
tasks.create<Delete>("ktlintReset") {
delete(File(buildDir.path + File.separator + "intermediates/ktLint"))
}
tasks.getByPath("loadKtlintReporters").dependsOn("ktlintReset")
tasks.getByPath("preBuild").dependsOn("ktlintCheck") tasks.getByPath("preBuild").dependsOn("ktlintCheck")
ktlint { ktlint {

View file

@ -3,19 +3,25 @@
package org.yuzu.yuzu_emu.adapters package org.yuzu.yuzu_emu.adapters
import android.text.TextUtils
import android.view.LayoutInflater import android.view.LayoutInflater
import android.view.View import android.view.View
import android.view.ViewGroup import android.view.ViewGroup
import androidx.appcompat.app.AppCompatActivity import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import androidx.core.content.res.ResourcesCompat import androidx.core.content.res.ResourcesCompat
import androidx.lifecycle.LifecycleOwner
import androidx.recyclerview.widget.RecyclerView import androidx.recyclerview.widget.RecyclerView
import org.yuzu.yuzu_emu.R import org.yuzu.yuzu_emu.R
import org.yuzu.yuzu_emu.databinding.CardHomeOptionBinding import org.yuzu.yuzu_emu.databinding.CardHomeOptionBinding
import org.yuzu.yuzu_emu.fragments.MessageDialogFragment import org.yuzu.yuzu_emu.fragments.MessageDialogFragment
import org.yuzu.yuzu_emu.model.HomeSetting import org.yuzu.yuzu_emu.model.HomeSetting
class HomeSettingAdapter(private val activity: AppCompatActivity, var options: List<HomeSetting>) : class HomeSettingAdapter(
private val activity: AppCompatActivity,
private val viewLifecycle: LifecycleOwner,
var options: List<HomeSetting>
) :
RecyclerView.Adapter<HomeSettingAdapter.HomeOptionViewHolder>(), RecyclerView.Adapter<HomeSettingAdapter.HomeOptionViewHolder>(),
View.OnClickListener { View.OnClickListener {
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): HomeOptionViewHolder { override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): HomeOptionViewHolder {
@ -79,6 +85,22 @@ class HomeSettingAdapter(private val activity: AppCompatActivity, var options: L
binding.optionDescription.alpha = 0.5f binding.optionDescription.alpha = 0.5f
binding.optionIcon.alpha = 0.5f binding.optionIcon.alpha = 0.5f
} }
option.details.observe(viewLifecycle) { updateOptionDetails(it) }
binding.optionDetail.postDelayed(
{
binding.optionDetail.ellipsize = TextUtils.TruncateAt.MARQUEE
binding.optionDetail.isSelected = true
},
3000
)
}
private fun updateOptionDetails(detailString: String) {
if (detailString.isNotEmpty()) {
binding.optionDetail.text = detailString
binding.optionDetail.visibility = View.VISIBLE
}
} }
} }
} }

View file

@ -129,7 +129,11 @@ class HomeSettingsFragment : Fragment() {
mainActivity.getGamesDirectory.launch( mainActivity.getGamesDirectory.launch(
Intent(Intent.ACTION_OPEN_DOCUMENT_TREE).data Intent(Intent.ACTION_OPEN_DOCUMENT_TREE).data
) )
} },
{ true },
0,
0,
homeViewModel.gamesDir
) )
) )
add( add(
@ -201,7 +205,11 @@ class HomeSettingsFragment : Fragment() {
binding.homeSettingsList.apply { binding.homeSettingsList.apply {
layoutManager = LinearLayoutManager(requireContext()) layoutManager = LinearLayoutManager(requireContext())
adapter = HomeSettingAdapter(requireActivity() as AppCompatActivity, optionsList) adapter = HomeSettingAdapter(
requireActivity() as AppCompatActivity,
viewLifecycleOwner,
optionsList
)
} }
setInsets() setInsets()

View file

@ -3,6 +3,9 @@
package org.yuzu.yuzu_emu.model package org.yuzu.yuzu_emu.model
import androidx.lifecycle.LiveData
import androidx.lifecycle.MutableLiveData
data class HomeSetting( data class HomeSetting(
val titleId: Int, val titleId: Int,
val descriptionId: Int, val descriptionId: Int,
@ -10,5 +13,6 @@ data class HomeSetting(
val onClick: () -> Unit, val onClick: () -> Unit,
val isEnabled: () -> Boolean = { true }, val isEnabled: () -> Boolean = { true },
val disabledTitleId: Int = 0, val disabledTitleId: Int = 0,
val disabledMessageId: Int = 0 val disabledMessageId: Int = 0,
val details: LiveData<String> = MutableLiveData("")
) )

View file

@ -3,9 +3,15 @@
package org.yuzu.yuzu_emu.model package org.yuzu.yuzu_emu.model
import android.net.Uri
import androidx.fragment.app.FragmentActivity
import androidx.lifecycle.LiveData import androidx.lifecycle.LiveData
import androidx.lifecycle.MutableLiveData import androidx.lifecycle.MutableLiveData
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import androidx.preference.PreferenceManager
import org.yuzu.yuzu_emu.YuzuApplication
import org.yuzu.yuzu_emu.utils.GameHelper
class HomeViewModel : ViewModel() { class HomeViewModel : ViewModel() {
private val _navigationVisible = MutableLiveData<Pair<Boolean, Boolean>>() private val _navigationVisible = MutableLiveData<Pair<Boolean, Boolean>>()
@ -17,6 +23,14 @@ class HomeViewModel : ViewModel() {
private val _shouldPageForward = MutableLiveData(false) private val _shouldPageForward = MutableLiveData(false)
val shouldPageForward: LiveData<Boolean> get() = _shouldPageForward val shouldPageForward: LiveData<Boolean> get() = _shouldPageForward
private val _gamesDir = MutableLiveData(
Uri.parse(
PreferenceManager.getDefaultSharedPreferences(YuzuApplication.appContext)
.getString(GameHelper.KEY_GAME_PATH, "")
).path ?: ""
)
val gamesDir: LiveData<String> get() = _gamesDir
var navigatedToSetup = false var navigatedToSetup = false
init { init {
@ -40,4 +54,9 @@ class HomeViewModel : ViewModel() {
fun setShouldPageForward(pageForward: Boolean) { fun setShouldPageForward(pageForward: Boolean) {
_shouldPageForward.value = pageForward _shouldPageForward.value = pageForward
} }
fun setGamesDir(activity: FragmentActivity, dir: String) {
ViewModelProvider(activity)[GamesViewModel::class.java].reloadGames(true)
_gamesDir.value = dir
}
} }

View file

@ -290,6 +290,7 @@ class MainActivity : AppCompatActivity(), ThemeProvider {
).show() ).show()
gamesViewModel.reloadGames(true) gamesViewModel.reloadGames(true)
homeViewModel.setGamesDir(this, result.path!!)
} }
val getProdKey = val getProdKey =

View file

@ -53,6 +53,23 @@
android:layout_marginTop="5dp" android:layout_marginTop="5dp"
tools:text="@string/install_prod_keys_description" /> tools:text="@string/install_prod_keys_description" />
<com.google.android.material.textview.MaterialTextView
style="@style/TextAppearance.Material3.LabelMedium"
android:id="@+id/option_detail"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:textAlignment="viewStart"
android:textSize="14sp"
android:textStyle="bold"
android:singleLine="true"
android:marqueeRepeatLimit="marquee_forever"
android:ellipsize="none"
android:requiresFadingEdge="horizontal"
android:layout_marginTop="5dp"
android:visibility="gone"
tools:visibility="visible"
tools:text="/tree/primary:Games" />
</LinearLayout> </LinearLayout>
</LinearLayout> </LinearLayout>

View file

@ -35,7 +35,6 @@ namespace Core::Crypto {
namespace { namespace {
constexpr u64 CURRENT_CRYPTO_REVISION = 0x5; constexpr u64 CURRENT_CRYPTO_REVISION = 0x5;
constexpr u64 FULL_TICKET_SIZE = 0x400;
using Common::AsArray; using Common::AsArray;
@ -156,6 +155,10 @@ u64 GetSignatureTypePaddingSize(SignatureType type) {
UNREACHABLE(); UNREACHABLE();
} }
bool Ticket::IsValid() const {
return !std::holds_alternative<std::monostate>(data);
}
SignatureType Ticket::GetSignatureType() const { SignatureType Ticket::GetSignatureType() const {
if (const auto* ticket = std::get_if<RSA4096Ticket>(&data)) { if (const auto* ticket = std::get_if<RSA4096Ticket>(&data)) {
return ticket->sig_type; return ticket->sig_type;
@ -210,6 +213,54 @@ Ticket Ticket::SynthesizeCommon(Key128 title_key, const std::array<u8, 16>& righ
return Ticket{out}; return Ticket{out};
} }
Ticket Ticket::Read(const FileSys::VirtualFile& file) {
// Attempt to read up to the largest ticket size, and make sure we read at least a signature
// type.
std::array<u8, sizeof(RSA4096Ticket)> raw_data{};
auto read_size = file->Read(raw_data.data(), raw_data.size(), 0);
if (read_size < sizeof(SignatureType)) {
LOG_WARNING(Crypto, "Attempted to read ticket file with invalid size {}.", read_size);
return Ticket{std::monostate()};
}
return Read(std::span{raw_data});
}
Ticket Ticket::Read(std::span<const u8> raw_data) {
// Some tools read only 0x180 bytes of ticket data instead of 0x2C0, so
// just make sure we have at least the bare minimum of data to work with.
SignatureType sig_type;
if (raw_data.size() < sizeof(SignatureType)) {
LOG_WARNING(Crypto, "Attempted to parse ticket buffer with invalid size {}.",
raw_data.size());
return Ticket{std::monostate()};
}
std::memcpy(&sig_type, raw_data.data(), sizeof(sig_type));
switch (sig_type) {
case SignatureType::RSA_4096_SHA1:
case SignatureType::RSA_4096_SHA256: {
RSA4096Ticket ticket{};
std::memcpy(&ticket, raw_data.data(), sizeof(ticket));
return Ticket{ticket};
}
case SignatureType::RSA_2048_SHA1:
case SignatureType::RSA_2048_SHA256: {
RSA2048Ticket ticket{};
std::memcpy(&ticket, raw_data.data(), sizeof(ticket));
return Ticket{ticket};
}
case SignatureType::ECDSA_SHA1:
case SignatureType::ECDSA_SHA256: {
ECDSATicket ticket{};
std::memcpy(&ticket, raw_data.data(), sizeof(ticket));
return Ticket{ticket};
}
default:
LOG_WARNING(Crypto, "Attempted to parse ticket buffer with invalid type {}.", sig_type);
return Ticket{std::monostate()};
}
}
Key128 GenerateKeyEncryptionKey(Key128 source, Key128 master, Key128 kek_seed, Key128 key_seed) { Key128 GenerateKeyEncryptionKey(Key128 source, Key128 master, Key128 kek_seed, Key128 key_seed) {
Key128 out{}; Key128 out{};
@ -290,9 +341,9 @@ void KeyManager::DeriveGeneralPurposeKeys(std::size_t crypto_revision) {
} }
} }
RSAKeyPair<2048> KeyManager::GetETicketRSAKey() const { void KeyManager::DeriveETicketRSAKey() {
if (IsAllZeroArray(eticket_extended_kek) || !HasKey(S128KeyType::ETicketRSAKek)) { if (IsAllZeroArray(eticket_extended_kek) || !HasKey(S128KeyType::ETicketRSAKek)) {
return {}; return;
} }
const auto eticket_final = GetKey(S128KeyType::ETicketRSAKek); const auto eticket_final = GetKey(S128KeyType::ETicketRSAKek);
@ -304,12 +355,12 @@ RSAKeyPair<2048> KeyManager::GetETicketRSAKey() const {
rsa_1.Transcode(eticket_extended_kek.data() + 0x10, eticket_extended_kek.size() - 0x10, rsa_1.Transcode(eticket_extended_kek.data() + 0x10, eticket_extended_kek.size() - 0x10,
extended_dec.data(), Op::Decrypt); extended_dec.data(), Op::Decrypt);
RSAKeyPair<2048> rsa_key{}; std::memcpy(eticket_rsa_keypair.decryption_key.data(), extended_dec.data(),
std::memcpy(rsa_key.decryption_key.data(), extended_dec.data(), rsa_key.decryption_key.size()); eticket_rsa_keypair.decryption_key.size());
std::memcpy(rsa_key.modulus.data(), extended_dec.data() + 0x100, rsa_key.modulus.size()); std::memcpy(eticket_rsa_keypair.modulus.data(), extended_dec.data() + 0x100,
std::memcpy(rsa_key.exponent.data(), extended_dec.data() + 0x200, rsa_key.exponent.size()); eticket_rsa_keypair.modulus.size());
std::memcpy(eticket_rsa_keypair.exponent.data(), extended_dec.data() + 0x200,
return rsa_key; eticket_rsa_keypair.exponent.size());
} }
Key128 DeriveKeyblobMACKey(const Key128& keyblob_key, const Key128& mac_source) { Key128 DeriveKeyblobMACKey(const Key128& keyblob_key, const Key128& mac_source) {
@ -447,10 +498,12 @@ std::vector<Ticket> GetTicketblob(const Common::FS::IOFile& ticket_save) {
for (std::size_t offset = 0; offset + 0x4 < buffer.size(); ++offset) { for (std::size_t offset = 0; offset + 0x4 < buffer.size(); ++offset) {
if (buffer[offset] == 0x4 && buffer[offset + 1] == 0x0 && buffer[offset + 2] == 0x1 && if (buffer[offset] == 0x4 && buffer[offset + 1] == 0x0 && buffer[offset + 2] == 0x1 &&
buffer[offset + 3] == 0x0) { buffer[offset + 3] == 0x0) {
out.emplace_back(); // NOTE: Assumes ticket blob will only contain RSA-2048 tickets.
auto& next = out.back(); auto ticket = Ticket::Read(std::span{buffer.data() + offset, sizeof(RSA2048Ticket)});
std::memcpy(&next, buffer.data() + offset, sizeof(Ticket)); offset += sizeof(RSA2048Ticket);
offset += FULL_TICKET_SIZE; if (ticket.IsValid()) {
out.push_back(ticket);
}
} }
} }
@ -503,25 +556,35 @@ static std::optional<u64> FindTicketOffset(const std::array<u8, size>& data) {
return offset; return offset;
} }
std::optional<std::pair<Key128, Key128>> ParseTicket(const Ticket& ticket, std::optional<Key128> KeyManager::ParseTicketTitleKey(const Ticket& ticket) {
const RSAKeyPair<2048>& key) { if (eticket_rsa_keypair == RSAKeyPair<2048>{}) {
LOG_WARNING(Crypto,
"Skipping ticket title key parsing due to missing ETicket RSA key-pair.");
return std::nullopt;
}
if (!ticket.IsValid()) {
LOG_WARNING(Crypto, "Attempted to parse title key of invalid ticket.");
return std::nullopt;
}
if (ticket.GetData().rights_id == Key128{}) {
LOG_WARNING(Crypto, "Attempted to parse title key of ticket with no rights ID.");
return std::nullopt;
}
const auto issuer = ticket.GetData().issuer; const auto issuer = ticket.GetData().issuer;
if (IsAllZeroArray(issuer)) { if (IsAllZeroArray(issuer)) {
LOG_WARNING(Crypto, "Attempted to parse title key of ticket with invalid issuer.");
return std::nullopt; return std::nullopt;
} }
if (issuer[0] != 'R' || issuer[1] != 'o' || issuer[2] != 'o' || issuer[3] != 't') { if (issuer[0] != 'R' || issuer[1] != 'o' || issuer[2] != 'o' || issuer[3] != 't') {
LOG_INFO(Crypto, "Attempting to parse ticket with non-standard certificate authority."); LOG_WARNING(Crypto, "Parsing ticket with non-standard certificate authority.");
} }
Key128 rights_id = ticket.GetData().rights_id; if (ticket.GetData().type == TitleKeyType::Common) {
return ticket.GetData().title_key_common;
if (rights_id == Key128{}) {
return std::nullopt;
}
if (!std::any_of(ticket.GetData().title_key_common_pad.begin(),
ticket.GetData().title_key_common_pad.end(), [](u8 b) { return b != 0; })) {
return std::make_pair(rights_id, ticket.GetData().title_key_common);
} }
mbedtls_mpi D; // RSA Private Exponent mbedtls_mpi D; // RSA Private Exponent
@ -534,9 +597,12 @@ std::optional<std::pair<Key128, Key128>> ParseTicket(const Ticket& ticket,
mbedtls_mpi_init(&S); mbedtls_mpi_init(&S);
mbedtls_mpi_init(&M); mbedtls_mpi_init(&M);
mbedtls_mpi_read_binary(&D, key.decryption_key.data(), key.decryption_key.size()); const auto& title_key_block = ticket.GetData().title_key_block;
mbedtls_mpi_read_binary(&N, key.modulus.data(), key.modulus.size()); mbedtls_mpi_read_binary(&D, eticket_rsa_keypair.decryption_key.data(),
mbedtls_mpi_read_binary(&S, ticket.GetData().title_key_block.data(), 0x100); eticket_rsa_keypair.decryption_key.size());
mbedtls_mpi_read_binary(&N, eticket_rsa_keypair.modulus.data(),
eticket_rsa_keypair.modulus.size());
mbedtls_mpi_read_binary(&S, title_key_block.data(), title_key_block.size());
mbedtls_mpi_exp_mod(&M, &S, &D, &N, nullptr); mbedtls_mpi_exp_mod(&M, &S, &D, &N, nullptr);
@ -564,8 +630,7 @@ std::optional<std::pair<Key128, Key128>> ParseTicket(const Ticket& ticket,
Key128 key_temp{}; Key128 key_temp{};
std::memcpy(key_temp.data(), m_2.data() + *offset, key_temp.size()); std::memcpy(key_temp.data(), m_2.data() + *offset, key_temp.size());
return key_temp;
return std::make_pair(rights_id, key_temp);
} }
KeyManager::KeyManager() { KeyManager::KeyManager() {
@ -669,6 +734,14 @@ void KeyManager::LoadFromFile(const std::filesystem::path& file_path, bool is_ti
encrypted_keyblobs[index] = Common::HexStringToArray<0xB0>(out[1]); encrypted_keyblobs[index] = Common::HexStringToArray<0xB0>(out[1]);
} else if (out[0].compare(0, 20, "eticket_extended_kek") == 0) { } else if (out[0].compare(0, 20, "eticket_extended_kek") == 0) {
eticket_extended_kek = Common::HexStringToArray<576>(out[1]); eticket_extended_kek = Common::HexStringToArray<576>(out[1]);
} else if (out[0].compare(0, 19, "eticket_rsa_keypair") == 0) {
const auto key_data = Common::HexStringToArray<528>(out[1]);
std::memcpy(eticket_rsa_keypair.decryption_key.data(), key_data.data(),
eticket_rsa_keypair.decryption_key.size());
std::memcpy(eticket_rsa_keypair.modulus.data(), key_data.data() + 0x100,
eticket_rsa_keypair.modulus.size());
std::memcpy(eticket_rsa_keypair.exponent.data(), key_data.data() + 0x200,
eticket_rsa_keypair.exponent.size());
} else { } else {
for (const auto& kv : KEYS_VARIABLE_LENGTH) { for (const auto& kv : KEYS_VARIABLE_LENGTH) {
if (!ValidCryptoRevisionString(out[0], kv.second.size(), 2)) { if (!ValidCryptoRevisionString(out[0], kv.second.size(), 2)) {
@ -1110,13 +1183,12 @@ void KeyManager::DeriveETicket(PartitionDataManager& data,
eticket_extended_kek = data.GetETicketExtendedKek(); eticket_extended_kek = data.GetETicketExtendedKek();
WriteKeyToFile(KeyCategory::Console, "eticket_extended_kek", eticket_extended_kek); WriteKeyToFile(KeyCategory::Console, "eticket_extended_kek", eticket_extended_kek);
DeriveETicketRSAKey();
PopulateTickets(); PopulateTickets();
} }
void KeyManager::PopulateTickets() { void KeyManager::PopulateTickets() {
const auto rsa_key = GetETicketRSAKey(); if (eticket_rsa_keypair == RSAKeyPair<2048>{}) {
if (rsa_key == RSAKeyPair<2048>{}) {
return; return;
} }
@ -1136,30 +1208,12 @@ void KeyManager::PopulateTickets() {
const Common::FS::IOFile save_e2{system_save_e2_path, Common::FS::FileAccessMode::Read, const Common::FS::IOFile save_e2{system_save_e2_path, Common::FS::FileAccessMode::Read,
Common::FS::FileType::BinaryFile}; Common::FS::FileType::BinaryFile};
auto tickets = GetTicketblob(save_e1);
const auto blob2 = GetTicketblob(save_e2); const auto blob2 = GetTicketblob(save_e2);
auto res = GetTicketblob(save_e1); tickets.insert(tickets.end(), blob2.begin(), blob2.end());
const auto idx = res.size(); for (const auto& ticket : tickets) {
res.insert(res.end(), blob2.begin(), blob2.end()); AddTicket(ticket);
for (std::size_t i = 0; i < res.size(); ++i) {
const auto common = i < idx;
const auto pair = ParseTicket(res[i], rsa_key);
if (!pair) {
continue;
}
const auto& [rid, key] = *pair;
u128 rights_id;
std::memcpy(rights_id.data(), rid.data(), rid.size());
if (common) {
common_tickets[rights_id] = res[i];
} else {
personal_tickets[rights_id] = res[i];
}
SetKey(S128KeyType::Titlekey, key, rights_id[1], rights_id[0]);
} }
} }
@ -1291,41 +1345,33 @@ const std::map<u128, Ticket>& KeyManager::GetPersonalizedTickets() const {
return personal_tickets; return personal_tickets;
} }
bool KeyManager::AddTicketCommon(Ticket raw) { bool KeyManager::AddTicket(const Ticket& ticket) {
const auto rsa_key = GetETicketRSAKey(); if (!ticket.IsValid()) {
if (rsa_key == RSAKeyPair<2048>{}) { LOG_WARNING(Crypto, "Attempted to add invalid ticket.");
return false; return false;
} }
const auto pair = ParseTicket(raw, rsa_key); const auto& rid = ticket.GetData().rights_id;
if (!pair) {
return false;
}
const auto& [rid, key] = *pair;
u128 rights_id; u128 rights_id;
std::memcpy(rights_id.data(), rid.data(), rid.size()); std::memcpy(rights_id.data(), rid.data(), rid.size());
common_tickets[rights_id] = raw; if (ticket.GetData().type == Core::Crypto::TitleKeyType::Common) {
SetKey(S128KeyType::Titlekey, key, rights_id[1], rights_id[0]); common_tickets[rights_id] = ticket;
return true; } else {
} personal_tickets[rights_id] = ticket;
bool KeyManager::AddTicketPersonalized(Ticket raw) {
const auto rsa_key = GetETicketRSAKey();
if (rsa_key == RSAKeyPair<2048>{}) {
return false;
} }
const auto pair = ParseTicket(raw, rsa_key); if (HasKey(S128KeyType::Titlekey, rights_id[1], rights_id[0])) {
if (!pair) { LOG_DEBUG(Crypto,
return false; "Skipping parsing title key from ticket for known rights ID {:016X}{:016X}.",
rights_id[1], rights_id[0]);
return true;
} }
const auto& [rid, key] = *pair; const auto key = ParseTicketTitleKey(ticket);
u128 rights_id; if (!key) {
std::memcpy(rights_id.data(), rid.data(), rid.size()); return false;
common_tickets[rights_id] = raw; }
SetKey(S128KeyType::Titlekey, key, rights_id[1], rights_id[0]); SetKey(S128KeyType::Titlekey, key.value(), rights_id[1], rights_id[0]);
return true; return true;
} }
} // namespace Core::Crypto } // namespace Core::Crypto

View file

@ -7,6 +7,7 @@
#include <filesystem> #include <filesystem>
#include <map> #include <map>
#include <optional> #include <optional>
#include <span>
#include <string> #include <string>
#include <variant> #include <variant>
@ -29,8 +30,6 @@ enum class ResultStatus : u16;
namespace Core::Crypto { namespace Core::Crypto {
constexpr u64 TICKET_FILE_TITLEKEY_OFFSET = 0x180;
using Key128 = std::array<u8, 0x10>; using Key128 = std::array<u8, 0x10>;
using Key256 = std::array<u8, 0x20>; using Key256 = std::array<u8, 0x20>;
using SHA256Hash = std::array<u8, 0x20>; using SHA256Hash = std::array<u8, 0x20>;
@ -82,6 +81,7 @@ struct RSA4096Ticket {
INSERT_PADDING_BYTES(0x3C); INSERT_PADDING_BYTES(0x3C);
TicketData data; TicketData data;
}; };
static_assert(sizeof(RSA4096Ticket) == 0x500, "RSA4096Ticket has incorrect size.");
struct RSA2048Ticket { struct RSA2048Ticket {
SignatureType sig_type; SignatureType sig_type;
@ -89,6 +89,7 @@ struct RSA2048Ticket {
INSERT_PADDING_BYTES(0x3C); INSERT_PADDING_BYTES(0x3C);
TicketData data; TicketData data;
}; };
static_assert(sizeof(RSA2048Ticket) == 0x400, "RSA2048Ticket has incorrect size.");
struct ECDSATicket { struct ECDSATicket {
SignatureType sig_type; SignatureType sig_type;
@ -96,16 +97,41 @@ struct ECDSATicket {
INSERT_PADDING_BYTES(0x40); INSERT_PADDING_BYTES(0x40);
TicketData data; TicketData data;
}; };
static_assert(sizeof(ECDSATicket) == 0x340, "ECDSATicket has incorrect size.");
struct Ticket { struct Ticket {
std::variant<RSA4096Ticket, RSA2048Ticket, ECDSATicket> data; std::variant<std::monostate, RSA4096Ticket, RSA2048Ticket, ECDSATicket> data;
SignatureType GetSignatureType() const; [[nodiscard]] bool IsValid() const;
TicketData& GetData(); [[nodiscard]] SignatureType GetSignatureType() const;
const TicketData& GetData() const; [[nodiscard]] TicketData& GetData();
u64 GetSize() const; [[nodiscard]] const TicketData& GetData() const;
[[nodiscard]] u64 GetSize() const;
/**
* Synthesizes a common ticket given a title key and rights ID.
*
* @param title_key Title key to store in the ticket.
* @param rights_id Rights ID the ticket is for.
* @return The synthesized common ticket.
*/
static Ticket SynthesizeCommon(Key128 title_key, const std::array<u8, 0x10>& rights_id); static Ticket SynthesizeCommon(Key128 title_key, const std::array<u8, 0x10>& rights_id);
/**
* Reads a ticket from a file.
*
* @param file File to read the ticket from.
* @return The read ticket. If the ticket data is invalid, Ticket::IsValid() will be false.
*/
static Ticket Read(const FileSys::VirtualFile& file);
/**
* Reads a ticket from a memory buffer.
*
* @param raw_data Buffer to read the ticket from.
* @return The read ticket. If the ticket data is invalid, Ticket::IsValid() will be false.
*/
static Ticket Read(std::span<const u8> raw_data);
}; };
static_assert(sizeof(Key128) == 16, "Key128 must be 128 bytes big."); static_assert(sizeof(Key128) == 16, "Key128 must be 128 bytes big.");
@ -264,8 +290,7 @@ public:
const std::map<u128, Ticket>& GetCommonTickets() const; const std::map<u128, Ticket>& GetCommonTickets() const;
const std::map<u128, Ticket>& GetPersonalizedTickets() const; const std::map<u128, Ticket>& GetPersonalizedTickets() const;
bool AddTicketCommon(Ticket raw); bool AddTicket(const Ticket& ticket);
bool AddTicketPersonalized(Ticket raw);
void ReloadKeys(); void ReloadKeys();
bool AreKeysLoaded() const; bool AreKeysLoaded() const;
@ -283,6 +308,7 @@ private:
std::array<std::array<u8, 0xB0>, 0x20> encrypted_keyblobs{}; std::array<std::array<u8, 0xB0>, 0x20> encrypted_keyblobs{};
std::array<std::array<u8, 0x90>, 0x20> keyblobs{}; std::array<std::array<u8, 0x90>, 0x20> keyblobs{};
std::array<u8, 576> eticket_extended_kek{}; std::array<u8, 576> eticket_extended_kek{};
RSAKeyPair<2048> eticket_rsa_keypair{};
bool dev_mode; bool dev_mode;
void LoadFromFile(const std::filesystem::path& file_path, bool is_title_keys); void LoadFromFile(const std::filesystem::path& file_path, bool is_title_keys);
@ -293,10 +319,13 @@ private:
void DeriveGeneralPurposeKeys(std::size_t crypto_revision); void DeriveGeneralPurposeKeys(std::size_t crypto_revision);
RSAKeyPair<2048> GetETicketRSAKey() const; void DeriveETicketRSAKey();
void SetKeyWrapped(S128KeyType id, Key128 key, u64 field1 = 0, u64 field2 = 0); void SetKeyWrapped(S128KeyType id, Key128 key, u64 field1 = 0, u64 field2 = 0);
void SetKeyWrapped(S256KeyType id, Key256 key, u64 field1 = 0, u64 field2 = 0); void SetKeyWrapped(S256KeyType id, Key256 key, u64 field1 = 0, u64 field2 = 0);
/// Parses the title key section of a ticket.
std::optional<Key128> ParseTicketTitleKey(const Ticket& ticket);
}; };
Key128 GenerateKeyEncryptionKey(Key128 source, Key128 master, Key128 kek_seed, Key128 key_seed); Key128 GenerateKeyEncryptionKey(Key128 source, Key128 master, Key128 kek_seed, Key128 key_seed);
@ -311,9 +340,4 @@ Loader::ResultStatus DeriveSDKeys(std::array<Key256, 2>& sd_keys, KeyManager& ke
std::vector<Ticket> GetTicketblob(const Common::FS::IOFile& ticket_save); std::vector<Ticket> GetTicketblob(const Common::FS::IOFile& ticket_save);
// Returns a pair of {rights_id, titlekey}. Fails if the ticket has no certificate authority
// (offset 0x140-0x144 is zero)
std::optional<std::pair<Key128, Key128>> ParseTicket(const Ticket& ticket,
const RSAKeyPair<2048>& eticket_extended_key);
} // namespace Core::Crypto } // namespace Core::Crypto

View file

@ -164,24 +164,6 @@ VirtualFile NSP::GetNCAFile(u64 title_id, ContentRecordType type, TitleType titl
return nullptr; return nullptr;
} }
std::vector<Core::Crypto::Key128> NSP::GetTitlekey() const {
if (extracted)
LOG_WARNING(Service_FS, "called on an NSP that is of type extracted.");
std::vector<Core::Crypto::Key128> out;
for (const auto& ticket_file : ticket_files) {
if (ticket_file == nullptr ||
ticket_file->GetSize() <
Core::Crypto::TICKET_FILE_TITLEKEY_OFFSET + sizeof(Core::Crypto::Key128)) {
continue;
}
out.emplace_back();
ticket_file->Read(out.back().data(), out.back().size(),
Core::Crypto::TICKET_FILE_TITLEKEY_OFFSET);
}
return out;
}
std::vector<VirtualFile> NSP::GetFiles() const { std::vector<VirtualFile> NSP::GetFiles() const {
return pfs->GetFiles(); return pfs->GetFiles();
} }
@ -208,22 +190,11 @@ void NSP::SetTicketKeys(const std::vector<VirtualFile>& files) {
continue; continue;
} }
if (ticket_file->GetSize() < auto ticket = Core::Crypto::Ticket::Read(ticket_file);
Core::Crypto::TICKET_FILE_TITLEKEY_OFFSET + sizeof(Core::Crypto::Key128)) { if (!keys.AddTicket(ticket)) {
LOG_WARNING(Common_Filesystem, "Could not load NSP ticket {}", ticket_file->GetName());
continue; continue;
} }
Core::Crypto::Key128 key{};
ticket_file->Read(key.data(), key.size(), Core::Crypto::TICKET_FILE_TITLEKEY_OFFSET);
// We get the name without the extension in order to create the rights ID.
std::string name_only(ticket_file->GetName());
name_only.erase(name_only.size() - 4);
const auto rights_id_raw = Common::HexStringToArray<16>(name_only);
u128 rights_id;
std::memcpy(rights_id.data(), rights_id_raw.data(), sizeof(u128));
keys.SetKey(Core::Crypto::S128KeyType::Titlekey, key, rights_id[1], rights_id[0]);
} }
} }

View file

@ -53,7 +53,6 @@ public:
TitleType title_type = TitleType::Application) const; TitleType title_type = TitleType::Application) const;
VirtualFile GetNCAFile(u64 title_id, ContentRecordType type, VirtualFile GetNCAFile(u64 title_id, ContentRecordType type,
TitleType title_type = TitleType::Application) const; TitleType title_type = TitleType::Application) const;
std::vector<Core::Crypto::Key128> GetTitlekey() const;
std::vector<VirtualFile> GetFiles() const override; std::vector<VirtualFile> GetFiles() const override;

View file

@ -122,20 +122,18 @@ private:
} }
void ImportTicket(HLERequestContext& ctx) { void ImportTicket(HLERequestContext& ctx) {
const auto ticket = ctx.ReadBuffer(); const auto raw_ticket = ctx.ReadBuffer();
[[maybe_unused]] const auto cert = ctx.ReadBuffer(1); [[maybe_unused]] const auto cert = ctx.ReadBuffer(1);
if (ticket.size() < sizeof(Core::Crypto::Ticket)) { if (raw_ticket.size() < sizeof(Core::Crypto::Ticket)) {
LOG_ERROR(Service_ETicket, "The input buffer is not large enough!"); LOG_ERROR(Service_ETicket, "The input buffer is not large enough!");
IPC::ResponseBuilder rb{ctx, 2}; IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ERROR_INVALID_ARGUMENT); rb.Push(ERROR_INVALID_ARGUMENT);
return; return;
} }
Core::Crypto::Ticket raw{}; Core::Crypto::Ticket ticket = Core::Crypto::Ticket::Read(raw_ticket);
std::memcpy(&raw, ticket.data(), sizeof(Core::Crypto::Ticket)); if (!keys.AddTicket(ticket)) {
if (!keys.AddTicketPersonalized(raw)) {
LOG_ERROR(Service_ETicket, "The ticket could not be imported!"); LOG_ERROR(Service_ETicket, "The ticket could not be imported!");
IPC::ResponseBuilder rb{ctx, 2}; IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ERROR_INVALID_ARGUMENT); rb.Push(ERROR_INVALID_ARGUMENT);

View file

@ -42,6 +42,7 @@ set(SHADER_FILES
present_bicubic.frag present_bicubic.frag
present_gaussian.frag present_gaussian.frag
queries_prefix_scan_sum.comp queries_prefix_scan_sum.comp
queries_prefix_scan_sum_nosubgroups.comp
resolve_conditional_render.comp resolve_conditional_render.comp
smaa_edge_detection.vert smaa_edge_detection.vert
smaa_edge_detection.frag smaa_edge_detection.frag
@ -72,6 +73,7 @@ if ("${GLSLANGVALIDATOR}" STREQUAL "GLSLANGVALIDATOR-NOTFOUND")
endif() endif()
set(GLSL_FLAGS "") set(GLSL_FLAGS "")
set(SPIR_V_VERSION "spirv1.3")
set(QUIET_FLAG "--quiet") set(QUIET_FLAG "--quiet")
set(SHADER_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/include) set(SHADER_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/include)
@ -125,7 +127,7 @@ foreach(FILENAME IN ITEMS ${SHADER_FILES})
OUTPUT OUTPUT
${SPIRV_HEADER_FILE} ${SPIRV_HEADER_FILE}
COMMAND COMMAND
${GLSLANGVALIDATOR} -V ${QUIET_FLAG} -I"${FIDELITYFX_INCLUDE_DIR}" ${GLSL_FLAGS} --variable-name ${SPIRV_VARIABLE_NAME} -o ${SPIRV_HEADER_FILE} ${SOURCE_FILE} ${GLSLANGVALIDATOR} -V ${QUIET_FLAG} -I"${FIDELITYFX_INCLUDE_DIR}" ${GLSL_FLAGS} --variable-name ${SPIRV_VARIABLE_NAME} -o ${SPIRV_HEADER_FILE} ${SOURCE_FILE} --target-env ${SPIR_V_VERSION}
MAIN_DEPENDENCY MAIN_DEPENDENCY
${SOURCE_FILE} ${SOURCE_FILE}
) )

View file

@ -1,26 +1,24 @@
// SPDX-FileCopyrightText: Copyright 2015 Graham Sellers, Richard Wright Jr. and Nicholas Haemel // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: GPL-3.0-or-later
// Code obtained from OpenGL SuperBible, Seventh Edition by Graham Sellers, Richard Wright Jr. and
// Nicholas Haemel. Modified to suit needs and optimize for subgroup
#version 460 core #version 460 core
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#ifdef VULKAN #ifdef VULKAN
#extension GL_KHR_shader_subgroup_arithmetic : enable
#define HAS_EXTENDED_TYPES 1 #define HAS_EXTENDED_TYPES 1
#define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants { #define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants {
#define END_PUSH_CONSTANTS \ #define END_PUSH_CONSTANTS };
} \
;
#define UNIFORM(n) #define UNIFORM(n)
#define BINDING_INPUT_BUFFER 0 #define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 1 #define BINDING_OUTPUT_IMAGE 1
#else // ^^^ Vulkan ^^^ // vvv OpenGL vvv #else // ^^^ Vulkan ^^^ // vvv OpenGL vvv
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_NV_gpu_shader5 : enable #extension GL_NV_gpu_shader5 : enable
#ifdef GL_NV_gpu_shader5 #ifdef GL_NV_gpu_shader5
#define HAS_EXTENDED_TYPES 1 #define HAS_EXTENDED_TYPES 1
@ -43,19 +41,20 @@ END_PUSH_CONSTANTS
layout(local_size_x = 32) in; layout(local_size_x = 32) in;
layout(std430, binding = 0) readonly buffer block1 { layout(std430, binding = 0) readonly buffer block1 {
uvec2 input_data[gl_WorkGroupSize.x]; uvec2 input_data[];
}; };
layout(std430, binding = 1) writeonly coherent buffer block2 { layout(std430, binding = 1) coherent buffer block2 {
uvec2 output_data[gl_WorkGroupSize.x]; uvec2 output_data[];
}; };
layout(std430, binding = 2) coherent buffer block3 { layout(std430, binding = 2) coherent buffer block3 {
uvec2 accumulated_data; uvec2 accumulated_data;
}; };
shared uvec2 shared_data[gl_WorkGroupSize.x * 2]; shared uvec2 shared_data[2];
// Simple Uint64 add that uses 2 uint variables for GPUs that don't support uint64
uvec2 AddUint64(uvec2 value_1, uvec2 value_2) { uvec2 AddUint64(uvec2 value_1, uvec2 value_2) {
uint carry = 0; uint carry = 0;
uvec2 result; uvec2 result;
@ -64,61 +63,102 @@ uvec2 AddUint64(uvec2 value_1, uvec2 value_2) {
return result; return result;
} }
void main(void) { // do subgroup Prefix Sum using Hillis and Steele's algorithm
uint id = gl_LocalInvocationID.x; uvec2 subgroupInclusiveAddUint64(uvec2 value) {
uvec2 base_value_1 = (id * 2) < max_accumulation_base ? accumulated_data : uvec2(0); uvec2 result = value;
uvec2 base_value_2 = (id * 2 + 1) < max_accumulation_base ? accumulated_data : uvec2(0); for (uint i = 1; i < gl_SubgroupSize; i *= 2) {
uint work_size = gl_WorkGroupSize.x; if (i <= gl_SubgroupInvocationID) {
uint rd_id; uvec2 other = subgroupShuffleUp(result, i); // get value from subgroup_inv_id - i;
uint wr_id; result = AddUint64(result, other);
uint mask; }
uvec2 input_1 = input_data[id * 2]; }
uvec2 input_2 = input_data[id * 2 + 1]; return result;
// The number of steps is the log base 2 of the }
// work group size, which should be a power of 2
const uint steps = uint(log2(work_size)) + 1;
uint step = 0;
// Each invocation is responsible for the content of // Writes down the results to the output buffer and to the accumulation buffer
// two elements of the output array void WriteResults(uvec2 result) {
shared_data[id * 2] = input_1; uint current_global_id = gl_GlobalInvocationID.x;
shared_data[id * 2 + 1] = input_2; uvec2 base_data = current_global_id < max_accumulation_base ? accumulated_data : uvec2(0);
// Synchronize to make sure that everyone has initialized output_data[current_global_id] = result + base_data;
// their elements of shared_data[] with data loaded from if (max_accumulation_base >= accumulation_limit + 1) {
// the input arrays if (current_global_id == accumulation_limit) {
accumulated_data = result;
}
return;
}
// We have that ugly case in which the accumulation data is reset in the middle somewhere.
barrier();
groupMemoryBarrier();
if (current_global_id == accumulation_limit) {
uvec2 value_1 = output_data[max_accumulation_base];
accumulated_data = AddUint64(result, -value_1);
}
}
void main() {
uint subgroup_inv_id = gl_SubgroupInvocationID;
uint subgroup_id = gl_SubgroupID;
uint last_subgroup_id = subgroupMax(subgroup_inv_id);
uint current_global_id = gl_GlobalInvocationID.x;
uint total_work = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
uvec2 data = input_data[current_global_id];
// make sure all input data has been loaded
subgroupBarrier();
subgroupMemoryBarrier();
uvec2 result = subgroupInclusiveAddUint64(data);
// if we had less queries than our subgroup, just write down the results.
if (total_work <= gl_SubgroupSize) { // This condition is constant per dispatch.
WriteResults(result);
return;
}
// We now have more, so lets write the last result into shared memory.
// Only pick the last subgroup.
if (subgroup_inv_id == last_subgroup_id) {
shared_data[subgroup_id] = result;
}
// wait until everyone loaded their stuffs
barrier(); barrier();
memoryBarrierShared(); memoryBarrierShared();
// For each step...
for (step = 0; step < steps; step++) {
// Calculate the read and write index in the
// shared array
mask = (1 << step) - 1;
rd_id = ((id >> step) << (step + 1)) + mask;
wr_id = rd_id + 1 + (id & mask);
// Accumulate the read data into our element
shared_data[wr_id] = AddUint64(shared_data[rd_id], shared_data[wr_id]); // Case 1: the total work for the grouped results can be calculated in a single subgroup
// Synchronize again to make sure that everyone // operation (about 1024 queries).
// has caught up with us uint total_extra_work = gl_NumSubgroups * gl_NumWorkGroups.x;
if (total_extra_work <= gl_SubgroupSize) { // This condition is constant per dispatch.
if (subgroup_id != 0) {
uvec2 tmp = shared_data[subgroup_inv_id];
subgroupBarrier();
subgroupMemoryBarrierShared();
tmp = subgroupInclusiveAddUint64(tmp);
result = AddUint64(result, subgroupShuffle(tmp, subgroup_id - 1));
}
WriteResults(result);
return;
}
// Case 2: our work amount is huge, so lets do it in O(log n) steps.
const uint extra = (total_extra_work ^ (total_extra_work - 1)) != 0 ? 1 : 0;
const uint steps = 1 << (findMSB(total_extra_work) + extra);
uint step;
// Hillis and Steele's algorithm
for (step = 1; step < steps; step *= 2) {
if (current_global_id < steps && current_global_id >= step) {
uvec2 current = shared_data[current_global_id];
uvec2 other = shared_data[current_global_id - step];
shared_data[current_global_id] = AddUint64(current, other);
}
// steps is constant, so this will always execute in ever workgroup's thread.
barrier(); barrier();
memoryBarrierShared(); memoryBarrierShared();
} }
// Add the accumulation // Only add results for groups higher than 0
shared_data[id * 2] = AddUint64(shared_data[id * 2], base_value_1); if (subgroup_id != 0) {
shared_data[id * 2 + 1] = AddUint64(shared_data[id * 2 + 1], base_value_2); result = AddUint64(result, shared_data[subgroup_id - 1]);
barrier();
memoryBarrierShared();
// Finally write our data back to the output buffer
output_data[id * 2] = shared_data[id * 2];
output_data[id * 2 + 1] = shared_data[id * 2 + 1];
if (id == 0) {
if (max_accumulation_base >= accumulation_limit + 1) {
accumulated_data = shared_data[accumulation_limit];
return;
}
uvec2 value_1 = shared_data[max_accumulation_base];
uvec2 value_2 = shared_data[accumulation_limit];
accumulated_data = AddUint64(value_1, -value_2);
} }
// Just write the final results. We are done
WriteResults(result);
} }

View file

@ -0,0 +1,120 @@
// SPDX-FileCopyrightText: Copyright 2015 Graham Sellers, Richard Wright Jr. and Nicholas Haemel
// SPDX-License-Identifier: MIT
// Code obtained from OpenGL SuperBible, Seventh Edition by Graham Sellers, Richard Wright Jr. and
// Nicholas Haemel. Modified to suit needs.
#version 460 core
#ifdef VULKAN
#define HAS_EXTENDED_TYPES 1
#define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants {
#define END_PUSH_CONSTANTS };
#define UNIFORM(n)
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 1
#else // ^^^ Vulkan ^^^ // vvv OpenGL vvv
#extension GL_NV_gpu_shader5 : enable
#ifdef GL_NV_gpu_shader5
#define HAS_EXTENDED_TYPES 1
#else
#define HAS_EXTENDED_TYPES 0
#endif
#define BEGIN_PUSH_CONSTANTS
#define END_PUSH_CONSTANTS
#define UNIFORM(n) layout(location = n) uniform
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 0
#endif
BEGIN_PUSH_CONSTANTS
UNIFORM(0) uint max_accumulation_base;
UNIFORM(1) uint accumulation_limit;
END_PUSH_CONSTANTS
layout(local_size_x = 32) in;
layout(std430, binding = 0) readonly buffer block1 {
uvec2 input_data[gl_WorkGroupSize.x];
};
layout(std430, binding = 1) writeonly coherent buffer block2 {
uvec2 output_data[gl_WorkGroupSize.x];
};
layout(std430, binding = 2) coherent buffer block3 {
uvec2 accumulated_data;
};
shared uvec2 shared_data[gl_WorkGroupSize.x * 2];
uvec2 AddUint64(uvec2 value_1, uvec2 value_2) {
uint carry = 0;
uvec2 result;
result.x = uaddCarry(value_1.x, value_2.x, carry);
result.y = value_1.y + value_2.y + carry;
return result;
}
void main(void) {
uint id = gl_LocalInvocationID.x;
uvec2 base_value_1 = (id * 2) < max_accumulation_base ? accumulated_data : uvec2(0);
uvec2 base_value_2 = (id * 2 + 1) < max_accumulation_base ? accumulated_data : uvec2(0);
uint work_size = gl_WorkGroupSize.x;
uint rd_id;
uint wr_id;
uint mask;
uvec2 input_1 = input_data[id * 2];
uvec2 input_2 = input_data[id * 2 + 1];
// The number of steps is the log base 2 of the
// work group size, which should be a power of 2
const uint steps = uint(log2(work_size)) + 1;
uint step = 0;
// Each invocation is responsible for the content of
// two elements of the output array
shared_data[id * 2] = input_1;
shared_data[id * 2 + 1] = input_2;
// Synchronize to make sure that everyone has initialized
// their elements of shared_data[] with data loaded from
// the input arrays
barrier();
memoryBarrierShared();
// For each step...
for (step = 0; step < steps; step++) {
// Calculate the read and write index in the
// shared array
mask = (1 << step) - 1;
rd_id = ((id >> step) << (step + 1)) + mask;
wr_id = rd_id + 1 + (id & mask);
// Accumulate the read data into our element
shared_data[wr_id] = AddUint64(shared_data[rd_id], shared_data[wr_id]);
// Synchronize again to make sure that everyone
// has caught up with us
barrier();
memoryBarrierShared();
}
// Add the accumulation
shared_data[id * 2] = AddUint64(shared_data[id * 2], base_value_1);
shared_data[id * 2 + 1] = AddUint64(shared_data[id * 2 + 1], base_value_2);
barrier();
memoryBarrierShared();
// Finally write our data back to the output buffer
output_data[id * 2] = shared_data[id * 2];
output_data[id * 2 + 1] = shared_data[id * 2 + 1];
if (id == 0) {
if (max_accumulation_base >= accumulation_limit + 1) {
accumulated_data = shared_data[accumulation_limit];
return;
}
uvec2 value_1 = shared_data[max_accumulation_base];
uvec2 value_2 = shared_data[accumulation_limit];
accumulated_data = AddUint64(value_1, -value_2);
}
}

View file

@ -13,6 +13,7 @@
#include "common/div_ceil.h" #include "common/div_ceil.h"
#include "video_core/host_shaders/astc_decoder_comp_spv.h" #include "video_core/host_shaders/astc_decoder_comp_spv.h"
#include "video_core/host_shaders/queries_prefix_scan_sum_comp_spv.h" #include "video_core/host_shaders/queries_prefix_scan_sum_comp_spv.h"
#include "video_core/host_shaders/queries_prefix_scan_sum_nosubgroups_comp_spv.h"
#include "video_core/host_shaders/resolve_conditional_render_comp_spv.h" #include "video_core/host_shaders/resolve_conditional_render_comp_spv.h"
#include "video_core/host_shaders/vulkan_quad_indexed_comp_spv.h" #include "video_core/host_shaders/vulkan_quad_indexed_comp_spv.h"
#include "video_core/host_shaders/vulkan_uint8_comp_spv.h" #include "video_core/host_shaders/vulkan_uint8_comp_spv.h"
@ -187,7 +188,8 @@ ComputePass::ComputePass(const Device& device_, DescriptorPool& descriptor_pool,
vk::Span<VkDescriptorSetLayoutBinding> bindings, vk::Span<VkDescriptorSetLayoutBinding> bindings,
vk::Span<VkDescriptorUpdateTemplateEntry> templates, vk::Span<VkDescriptorUpdateTemplateEntry> templates,
const DescriptorBankInfo& bank_info, const DescriptorBankInfo& bank_info,
vk::Span<VkPushConstantRange> push_constants, std::span<const u32> code) vk::Span<VkPushConstantRange> push_constants, std::span<const u32> code,
std::optional<u32> optional_subgroup_size)
: device{device_} { : device{device_} {
descriptor_set_layout = device.GetLogical().CreateDescriptorSetLayout({ descriptor_set_layout = device.GetLogical().CreateDescriptorSetLayout({
.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, .sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
@ -228,13 +230,19 @@ ComputePass::ComputePass(const Device& device_, DescriptorPool& descriptor_pool,
.pCode = code.data(), .pCode = code.data(),
}); });
device.SaveShader(code); device.SaveShader(code);
const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT subgroup_size_ci{
.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT,
.pNext = nullptr,
.requiredSubgroupSize = optional_subgroup_size ? *optional_subgroup_size : 32U,
};
bool use_setup_size = device.IsExtSubgroupSizeControlSupported() && optional_subgroup_size;
pipeline = device.GetLogical().CreateComputePipeline({ pipeline = device.GetLogical().CreateComputePipeline({
.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
.pNext = nullptr, .pNext = nullptr,
.flags = 0, .flags = 0,
.stage{ .stage{
.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
.pNext = nullptr, .pNext = use_setup_size ? &subgroup_size_ci : nullptr,
.flags = 0, .flags = 0,
.stage = VK_SHADER_STAGE_COMPUTE_BIT, .stage = VK_SHADER_STAGE_COMPUTE_BIT,
.module = *module, .module = *module,
@ -374,7 +382,7 @@ void ConditionalRenderingResolvePass::Resolve(VkBuffer dst_buffer, VkBuffer src_
static constexpr VkMemoryBarrier read_barrier{ static constexpr VkMemoryBarrier read_barrier{
.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER,
.pNext = nullptr, .pNext = nullptr,
.srcAccessMask = VK_ACCESS_NONE, .srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_SHADER_WRITE_BIT,
.dstAccessMask = VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, .dstAccessMask = VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT,
}; };
static constexpr VkMemoryBarrier write_barrier{ static constexpr VkMemoryBarrier write_barrier{
@ -399,10 +407,17 @@ void ConditionalRenderingResolvePass::Resolve(VkBuffer dst_buffer, VkBuffer src_
QueriesPrefixScanPass::QueriesPrefixScanPass( QueriesPrefixScanPass::QueriesPrefixScanPass(
const Device& device_, Scheduler& scheduler_, DescriptorPool& descriptor_pool_, const Device& device_, Scheduler& scheduler_, DescriptorPool& descriptor_pool_,
ComputePassDescriptorQueue& compute_pass_descriptor_queue_) ComputePassDescriptorQueue& compute_pass_descriptor_queue_)
: ComputePass(device_, descriptor_pool_, QUERIES_SCAN_DESCRIPTOR_SET_BINDINGS, : ComputePass(
QUERIES_SCAN_DESCRIPTOR_UPDATE_TEMPLATE, QUERIES_SCAN_BANK_INFO, device_, descriptor_pool_, QUERIES_SCAN_DESCRIPTOR_SET_BINDINGS,
COMPUTE_PUSH_CONSTANT_RANGE<sizeof(QueriesPrefixScanPushConstants)>, QUERIES_SCAN_DESCRIPTOR_UPDATE_TEMPLATE, QUERIES_SCAN_BANK_INFO,
QUERIES_PREFIX_SCAN_SUM_COMP_SPV), COMPUTE_PUSH_CONSTANT_RANGE<sizeof(QueriesPrefixScanPushConstants)>,
device_.IsSubgroupFeatureSupported(VK_SUBGROUP_FEATURE_BASIC_BIT) &&
device_.IsSubgroupFeatureSupported(VK_SUBGROUP_FEATURE_ARITHMETIC_BIT) &&
device_.IsSubgroupFeatureSupported(VK_SUBGROUP_FEATURE_SHUFFLE_BIT) &&
device_.IsSubgroupFeatureSupported(VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT)
? std::span<const u32>(QUERIES_PREFIX_SCAN_SUM_COMP_SPV)
: std::span<const u32>(QUERIES_PREFIX_SCAN_SUM_NOSUBGROUPS_COMP_SPV),
{32}),
scheduler{scheduler_}, compute_pass_descriptor_queue{compute_pass_descriptor_queue_} {} scheduler{scheduler_}, compute_pass_descriptor_queue{compute_pass_descriptor_queue_} {}
void QueriesPrefixScanPass::Run(VkBuffer accumulation_buffer, VkBuffer dst_buffer, void QueriesPrefixScanPass::Run(VkBuffer accumulation_buffer, VkBuffer dst_buffer,
@ -422,7 +437,7 @@ void QueriesPrefixScanPass::Run(VkBuffer accumulation_buffer, VkBuffer dst_buffe
static constexpr VkMemoryBarrier read_barrier{ static constexpr VkMemoryBarrier read_barrier{
.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER,
.pNext = nullptr, .pNext = nullptr,
.srcAccessMask = VK_ACCESS_NONE, .srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT,
.dstAccessMask = VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, .dstAccessMask = VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT,
}; };
static constexpr VkMemoryBarrier write_barrier{ static constexpr VkMemoryBarrier write_barrier{

View file

@ -3,6 +3,7 @@
#pragma once #pragma once
#include <optional>
#include <span> #include <span>
#include <utility> #include <utility>
@ -31,7 +32,8 @@ public:
vk::Span<VkDescriptorSetLayoutBinding> bindings, vk::Span<VkDescriptorSetLayoutBinding> bindings,
vk::Span<VkDescriptorUpdateTemplateEntry> templates, vk::Span<VkDescriptorUpdateTemplateEntry> templates,
const DescriptorBankInfo& bank_info, const DescriptorBankInfo& bank_info,
vk::Span<VkPushConstantRange> push_constants, std::span<const u32> code); vk::Span<VkPushConstantRange> push_constants, std::span<const u32> code,
std::optional<u32> optional_subgroup_size = std::nullopt);
~ComputePass(); ~ComputePass();
protected: protected: